import os import sys import re import base64 import gf256 def _shareByte(secretB,k,n): assert n<255 coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] return points def generateRaw(secret,k,n): """Splits secret into shares. :param secret: (bytes) :param k: number of shares necessary for secret recovery. 1 <= k <= n :param n: (int) number of shares generated. 1 <= n < 255 :return: [(i, (bytes) share), ...]""" shares=[_shareByte(b,k,n) for b in secret] return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] def reconstructRaw(*shares): """Tries to recover the secret from its shares. :param shares: ((i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" secretLen=len(shares[0][1]) res=[None]*secretLen for i in range(secretLen): bs=[(x,s[i]) for (x,s) in shares] res[i]=(gf256.getConstantCoef(*bs)) return bytes(res) def generate(secret,k,n,encoding="b32"): """Wraps generateRaw(). :param secret: (str or bytes) :param k: number of shares necessary for secret recovery :param n: number of shares generated :param encoding: {hex, b32, b64} desired output encoding. Hexadecimal, Base32 or Base64. :return: [(str) share, ...]""" if isinstance(secret,str): secret=secret.encode("utf-8") shares=generateRaw(secret,k,n) return [encode(s,encoding) for s in shares] def reconstruct(*shares,encoding="",raw=False): """Wraps reconstructRaw. :param shares: ((str) share, ...) :param encoding: {hex, b32, b64, ""} encoding of share strings. If not provided or empty, the function tries to guess it. :param raw: (bool) whether to return bytes (True) or str (False) :return: (str or bytes) reconstructed secret. Too few shares returns garbage.""" if not encoding: encoding=detectEncoding(shares) bs=reconstructRaw(decode(s,encoding) for s in shares) return bs if raw else bs.decode(encoding="utf-8") def encode(share,encoding="b32"): if encoding=="hex": f=base64.b16encode elif encoding=="b32": f=base64.b32encode else: f=base64.b64encode return ["{0}.{1}".format(i,f(bs).decode("utf8")) for (i,bs) in share] def decode(share,encoding="b32"): (i,_,shareStr)=share.partition(".") if not shareStr: raise ValueError("bad share format") i=int(i) if encoding=="hex": f=base64.b16decode elif encoding=="b32": f=base64.b32decode else: f=base64.b32decode shareBytes=f(shareStr) return (i,shareBytes) def detectEncoding(shares): classes=[ (re.compile(r"\d+\.[0-9a-f]+=*"), "hex"), (re.compile(r"\d+\.[A-Z2-7]+=*"), "b32"), (re.compile(r"\d+\.[A-Za-z0-9+/]+=*"), "b64") ] for (regexp, res) in classes: if all(regexp.fullmatch(share) for share in shares): return res raise ValueError("no expected encoding detected") if __name__=="__main__": secret=sys.argv[1].encode("utf8") k=int(sys.argv[2]) n=int(sys.argv[3]) output=sys.argv[4] if len(sys.argv)>4 else "raw" for share in generate(secret,k,n,output): print(share)