diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -1,12 +1,13 @@ import os import re import base64 +import binascii import gf256 def _shareByte(secretB,k,n): - assert n<255 + assert k<=n<255 coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] return points @@ -31,8 +32,8 @@ def reconstructRaw(*shares): secretLen=len(shares[0][1]) res=[None]*secretLen for i in range(secretLen): - bs=[(x,s[i]) for (x,s) in shares] - res[i]=(gf256.getConstantCoef(*bs)) + points=[(x,s[i]) for (x,s) in shares] + res[i]=(gf256.getConstantCoef(*points)) return bytes(res) @@ -73,22 +74,25 @@ def encode(share,encoding="b32"): def decode(share,encoding="b32"): - (i,_,shareStr)=share.partition(".") - if not shareStr: - raise ValueError("bad share format") - i=int(i) - if encoding=="hex": f=base64.b16decode - elif encoding=="b32": f=base64.b32decode - else: f=base64.b64decode - shareBytes=f(shareStr) - return (i,shareBytes) + try: + (i,_,shareStr)=share.partition(".") + i=int(i) + if not 1<=i<=255: + raise ValueError() + if encoding=="hex": f=base64.b16decode + elif encoding=="b32": f=base64.b32decode + else: f=base64.b64decode + shareBytes=f(shareStr) + return (i,shareBytes) + except (ValueError,binascii.Error): + raise ValueError('bad share format: share="{0}", encoding="{1}"'.format(share,encoding)) def detectEncoding(shares): classes=[ - (re.compile(r"\d+\.[0-9A-F]+=*"), "hex"), - (re.compile(r"\d+\.[A-Z2-7]+=*"), "b32"), - (re.compile(r"\d+\.[A-Za-z0-9+/]+=*"), "b64") + (re.compile(r"\d+\.([0-9A-F]{2})+"), "hex"), + (re.compile(r"\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), + (re.compile(r"\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64") ] for (regexp, res) in classes: if all(regexp.fullmatch(share) for share in shares):