diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -1,4 +1,7 @@ import os +import sys +import re +import base64 import gf256 @@ -10,16 +13,70 @@ def shareByte(secretB,k,n): return points -def generate(secret,k,n): +def generateRaw(secret,k,n): shares=[shareByte(b,k,n) for b in secret] - return [(i+1, [s[i] for s in shares]) for i in range(n)] + return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] -def reconstruct(*shares): - k=len(shares) +def reconstructRaw(*shares): secretLen=len(shares[0][1]) res=[None]*secretLen for i in range(secretLen): bs=[(x,s[i]) for (x,s) in shares] - res[i]=(gf256.getConstantCoef(k,*bs)) + res[i]=(gf256.getConstantCoef(*bs)) return bytes(res) + + +def generate(secret,k,n,encoding="b32"): + if isinstance(secret,str): + secret=secret.encode("utf-8") + shares=generateRaw(secret,k,n) + return [encode(s,encoding) for s in shares] + + +def reconstruct(*shares,encoding="",raw=False): + if not encoding: + encoding=detectEncoding(shares) + + bs=reconstructRaw(decode(s,encoding) for s in shares) + return bs if raw else bs.decode(encoding="utf-8") + + +def encode(share,encoding="b32"): + if encoding=="hex": f=base64.b16encode + elif encoding=="b32": f=base64.b32encode + else: f=base64.b64encode + return ["{0}.{1}".format(i,f(bs).decode("utf8")) for (i,bs) in share] + + +def decode(share,encoding="b32"): + (i,_,shareStr)=share.partition(".") + if not shareStr: + raise ValueError("bad share format") + i=int(i) + if encoding=="hex": f=base64.b16decode + elif encoding=="b32": f=base64.b32decode + else: f=base64.b32decode + shareBytes=f(shareStr) + return (i,shareBytes) + + +def detectEncoding(shares): + classes=[ + (re.compile(r"\d+\.[0-9a-f]+=*"), "hex"), + (re.compile(r"\d+\.[A-Z2-7]+=*"), "b32"), + (re.compile(r"\d+\.[A-Za-z0-9+/]+=*"), "b64") + ] + for (regexp, res) in classes: + if all(regexp.fullmatch(share) for share in shares): + return res + raise ValueError("no expected encoding detected") + + +if __name__=="__main__": + secret=sys.argv[1].encode("utf8") + k=int(sys.argv[2]) + n=int(sys.argv[3]) + output=sys.argv[4] if len(sys.argv)>4 else "raw" + for share in generate(secret,k,n,output): + print(share)