diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -15,23 +15,23 @@ class DecodingException(SException): pas class MalformedShare(SException): pass -def _shareByte(secretB,k,n): +def _shareByte(secretB, k, n): if not k<=n<255: - raise InvalidParams("Failed k<=n<255, k={0}, n={1}".format(k,n)) + raise InvalidParams("Failed k<=n<255, k={0}, n={1}".format(k, n)) # we might be concerned with zero coefficients degenerating our polynomial, but there's no reason - we still need k shares to determine it is the case - coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] - points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] + coefs = [int(secretB)]+[int(b) for b in os.urandom(k-1)] + points = [gf256.evaluate(coefs, i) for i in range(1, n+1)] return points -def generateRaw(secret,k,n): +def generateRaw(secret, k, n): """Splits secret into shares. :param secret: (bytes) :param k: number of shares necessary for secret recovery. 1 <= k <= n :param n: (int) number of shares generated. 1 <= n < 255 :return: [(i, (bytes) share), ...]""" - shares=[_shareByte(b,k,n) for b in secret] + shares = [_shareByte(b, k, n) for b in secret] return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] @@ -40,15 +40,15 @@ def reconstructRaw(*shares): :param shares: ((i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" - secretLen=len(shares[0][1]) - res=[None]*secretLen + secretLen = len(shares[0][1]) + res = [None]*secretLen for i in range(secretLen): - points=[(x,s[i]) for (x,s) in shares] - res[i]=(gf256.getConstantCoef(*points)) + points = [(x, s[i]) for (x, s) in shares] + res[i] = (gf256.getConstantCoef(*points)) return bytes(res) -def generate(secret,k,n,encoding="b32"): +def generate(secret, k, n, encoding="b32"): """Wraps generateRaw(). :param secret: (str or bytes) @@ -57,12 +57,12 @@ def generate(secret,k,n,encoding="b32"): :param encoding: {hex, b32, b64} desired output encoding. Hexadecimal, Base32 or Base64. :return: [(str) share, ...]""" if isinstance(secret,str): - secret=secret.encode("utf-8") - shares=generateRaw(secret,k,n) - return [encode(s,encoding) for s in shares] + secret = secret.encode("utf-8") + shares = generateRaw(secret, k, n) + return [encode(s, encoding) for s in shares] -def reconstruct(*shares,encoding="",raw=False): +def reconstruct(*shares, encoding="", raw=False): """Wraps reconstructRaw. :param shares: ((str) share, ...) @@ -70,40 +70,40 @@ def reconstruct(*shares,encoding="",raw= :param raw: (bool) whether to return bytes (True) or str (False) :return: (str or bytes) reconstructed secret. Too few shares returns garbage.""" if not encoding: - encoding=detectEncoding(shares) + encoding = detectEncoding(shares) - bs=reconstructRaw(*(decode(s,encoding) for s in shares)) + bs = reconstructRaw(*(decode(s, encoding) for s in shares)) try: return bs if raw else bs.decode(encoding="utf-8") except UnicodeDecodeError: raise DecodingException('Failed to decode bytes to utf-8. Either you supplied invalid shares, or you missed the "raw" flag. Offending value: {0}'.format(bs)) -def encode(share,encoding="b32"): - if encoding=="hex": f=base64.b16encode - elif encoding=="b32": f=base64.b32encode - else: f=base64.b64encode - (i,bs)=share - return "{0}.{1}".format(i,f(bs).decode("utf-8")) +def encode(share, encoding="b32"): + if encoding=="hex": f = base64.b16encode + elif encoding=="b32": f = base64.b32encode + else: f = base64.b64encode + (i, bs) = share + return "{0}.{1}".format(i, f(bs).decode("utf-8")) -def decode(share,encoding="b32"): +def decode(share, encoding="b32"): try: - (i,_,shareStr)=share.partition(".") - i=int(i) + (i, _, shareStr) = share.partition(".") + i = int(i) if not 1<=i<=255: raise MalformedShare("Malformed share: Failed 1<=k<=255, k={0}".format(i)) - if encoding=="hex": f=base64.b16decode - elif encoding=="b32": f=base64.b32decode - else: f=base64.b64decode - shareBytes=f(shareStr) - return (i,shareBytes) - except (ValueError,binascii.Error): - raise MalformedShare('Malformed share: share="{0}", encoding="{1}"'.format(share,encoding)) + if encoding=="hex": f = base64.b16decode + elif encoding=="b32": f = base64.b32decode + else: f = base64.b64decode + shareBytes = f(shareStr) + return (i, shareBytes) + except (ValueError, binascii.Error): + raise MalformedShare('Malformed share: share="{0}", encoding="{1}"'.format(share, encoding)) def detectEncoding(shares): - classes=[ + classes = [ (re.compile(r"\d+\.([0-9A-F]{2})+"), "hex"), (re.compile(r"\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), (re.compile(r"\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64")