# HG changeset patch # User Laman # Date 2017-09-23 13:00:15 # Node ID 9ccd379021d5601d3e28c5c0f71654821eb7b039 # Parent db65075fe7e050c401b75890228e4c8ee793754d input/output wrappers diff --git a/src/gf256.py b/src/gf256.py --- a/src/gf256.py +++ b/src/gf256.py @@ -1,4 +1,8 @@ +"""Arithmetic operations on Galois Field 2**8.""" + + def _ffmul(a, b): + """Basic multiplication.""" r=0 while a!=0: if (a&1)!=0: r^=b @@ -9,19 +13,20 @@ def _ffmul(a, b): return r -g=3 -E=[None]*256 -L=[None]*256 +g=3 # generator +E=[None]*256 # exponentials +L=[None]*256 # logarithms acc=1 for i in range(256): E[i]=acc L[acc]=i acc=_ffmul(acc, g) L[1]=0 -inv=[E[255-L[i]] if i!=0 else None for i in range(256)] +inv=[E[255-L[i]] if i!=0 else None for i in range(256)] # multiplicative inverse def ffmul(a, b): + """Fast multiplication. Basic multiplication is expensive. a*b==g**(log(a)+log(b))""" if a==0 or b==0: return 0 t=L[a]+L[b] if t>255: t-=255 @@ -29,6 +34,9 @@ def ffmul(a, b): def evaluate(coefs,x): + """Evaluate polynomial's value at x. + + :param coefs: [a0, a1, ...].""" res=0 xK=1 for a in coefs: @@ -37,7 +45,9 @@ def evaluate(coefs,x): return res -def getConstantCoef(k,*points): +def getConstantCoef(*points): + """Compute constant polynomial coefficient given the points.""" + k=len(points) res=0 for i in range(k): (x,y)=points[i] diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -1,4 +1,7 @@ import os +import sys +import re +import base64 import gf256 @@ -10,16 +13,70 @@ def shareByte(secretB,k,n): return points -def generate(secret,k,n): +def generateRaw(secret,k,n): shares=[shareByte(b,k,n) for b in secret] - return [(i+1, [s[i] for s in shares]) for i in range(n)] + return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] -def reconstruct(*shares): - k=len(shares) +def reconstructRaw(*shares): secretLen=len(shares[0][1]) res=[None]*secretLen for i in range(secretLen): bs=[(x,s[i]) for (x,s) in shares] - res[i]=(gf256.getConstantCoef(k,*bs)) + res[i]=(gf256.getConstantCoef(*bs)) return bytes(res) + + +def generate(secret,k,n,encoding="b32"): + if isinstance(secret,str): + secret=secret.encode("utf-8") + shares=generateRaw(secret,k,n) + return [encode(s,encoding) for s in shares] + + +def reconstruct(*shares,encoding="",raw=False): + if not encoding: + encoding=detectEncoding(shares) + + bs=reconstructRaw(decode(s,encoding) for s in shares) + return bs if raw else bs.decode(encoding="utf-8") + + +def encode(share,encoding="b32"): + if encoding=="hex": f=base64.b16encode + elif encoding=="b32": f=base64.b32encode + else: f=base64.b64encode + return ["{0}.{1}".format(i,f(bs).decode("utf8")) for (i,bs) in share] + + +def decode(share,encoding="b32"): + (i,_,shareStr)=share.partition(".") + if not shareStr: + raise ValueError("bad share format") + i=int(i) + if encoding=="hex": f=base64.b16decode + elif encoding=="b32": f=base64.b32decode + else: f=base64.b32decode + shareBytes=f(shareStr) + return (i,shareBytes) + + +def detectEncoding(shares): + classes=[ + (re.compile(r"\d+\.[0-9a-f]+=*"), "hex"), + (re.compile(r"\d+\.[A-Z2-7]+=*"), "b32"), + (re.compile(r"\d+\.[A-Za-z0-9+/]+=*"), "b64") + ] + for (regexp, res) in classes: + if all(regexp.fullmatch(share) for share in shares): + return res + raise ValueError("no expected encoding detected") + + +if __name__=="__main__": + secret=sys.argv[1].encode("utf8") + k=int(sys.argv[2]) + n=int(sys.argv[3]) + output=sys.argv[4] if len(sys.argv)>4 else "raw" + for share in generate(secret,k,n,output): + print(share)