import os import re import base64 import binascii import gf256 def _shareByte(secretB,k,n): assert k<=n<255 coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] return points def generateRaw(secret,k,n): """Splits secret into shares. :param secret: (bytes) :param k: number of shares necessary for secret recovery. 1 <= k <= n :param n: (int) number of shares generated. 1 <= n < 255 :return: [(i, (bytes) share), ...]""" shares=[_shareByte(b,k,n) for b in secret] return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] def reconstructRaw(*shares): """Tries to recover the secret from its shares. :param shares: ((i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" secretLen=len(shares[0][1]) res=[None]*secretLen for i in range(secretLen): points=[(x,s[i]) for (x,s) in shares] res[i]=(gf256.getConstantCoef(*points)) return bytes(res) def generate(secret,k,n,encoding="b32"): """Wraps generateRaw(). :param secret: (str or bytes) :param k: number of shares necessary for secret recovery :param n: number of shares generated :param encoding: {hex, b32, b64} desired output encoding. Hexadecimal, Base32 or Base64. :return: [(str) share, ...]""" if isinstance(secret,str): secret=secret.encode("utf-8") shares=generateRaw(secret,k,n) return [encode(s,encoding) for s in shares] def reconstruct(*shares,encoding="",raw=False): """Wraps reconstructRaw. :param shares: ((str) share, ...) :param encoding: {hex, b32, b64, ""} encoding of share strings. If not provided or empty, the function tries to guess it. :param raw: (bool) whether to return bytes (True) or str (False) :return: (str or bytes) reconstructed secret. Too few shares returns garbage.""" if not encoding: encoding=detectEncoding(shares) bs=reconstructRaw(*(decode(s,encoding) for s in shares)) return bs if raw else bs.decode(encoding="utf-8") def encode(share,encoding="b32"): if encoding=="hex": f=base64.b16encode elif encoding=="b32": f=base64.b32encode else: f=base64.b64encode (i,bs)=share return "{0}.{1}".format(i,f(bs).decode("utf-8")) def decode(share,encoding="b32"): try: (i,_,shareStr)=share.partition(".") i=int(i) if not 1<=i<=255: raise ValueError() if encoding=="hex": f=base64.b16decode elif encoding=="b32": f=base64.b32decode else: f=base64.b64decode shareBytes=f(shareStr) return (i,shareBytes) except (ValueError,binascii.Error): raise ValueError('bad share format: share="{0}", encoding="{1}"'.format(share,encoding)) def detectEncoding(shares): classes=[ (re.compile(r"\d+\.([0-9A-F]{2})+"), "hex"), (re.compile(r"\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), (re.compile(r"\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64") ] for (regexp, res) in classes: if all(regexp.fullmatch(share) for share in shares): return res raise ValueError("no expected encoding detected") if __name__=="__main__": import cli cli.run()