diff --git a/src/cli.py b/src/cli.py --- a/src/cli.py +++ b/src/cli.py @@ -28,6 +28,9 @@ def build_split_parser(parser): encoding.add_argument("--b32", action="store_true", help="encode shares' bytes as a base32 string") encoding.add_argument("--b64", action="store_true", help="encode shares' bytes as a base64 string") + parser.add_argument("--label", help="any label to prefix the shares with") + parser.add_argument("--omit_k_n", action="store_true", help="suppress the default shares prefix") + parser.add_argument("secret", nargs="?", help="a secret to be split. Can be provided on the command line," " redirected through stdin, or will be asked for interactively.") parser.set_defaults(func=_generate) @@ -56,7 +59,7 @@ def _generate(args): secret = sys.stdin.read() try: - shares = generate(secret, args.k, args.n, encoding) + shares = generate(secret, args.k, args.n, encoding, label=args.label, omit_k_n=args.omit_k_n) for s in shares: print(s) except SException as e: diff --git a/src/condensed.py b/src/condensed.py --- a/src/condensed.py +++ b/src/condensed.py @@ -58,7 +58,7 @@ class SException(Exception): pass def reconstruct_raw(*shares): """Tries to recover the secret from its shares. - :param shares: ((i, (bytes) share), ...) + :param shares: (((int) i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" secret_len = len(shares[0][1]) res = [None]*secret_len @@ -83,7 +83,7 @@ def reconstruct(*shares): def decode(share): try: - (i, _, share_str) = share.partition(".") + (*_, i, share_str) = share.split(".") i = int(i) if not 1<=i<=255: raise SException("Malformed share: Failed 1<=k<=255, k={0}".format(i)) diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -38,8 +38,8 @@ def generate_raw(secret, k, n): def reconstruct_raw(*shares): """Tries to recover the secret from its shares. - :param shares: ((i, (bytes) share), ...) - :return: (bytes) reconstructed secret. Too few shares returns garbage.""" + :param shares: (((int) i, (bytes) share), ...) + :return: (bytes) reconstructed secret. Too few shares return garbage.""" secret_len = len(shares[0][1]) res = [None]*secret_len for i in range(secret_len): @@ -48,18 +48,27 @@ def reconstruct_raw(*shares): return bytes(res) -def generate(secret, k, n, encoding="b32"): +def generate(secret, k, n, encoding="b32", label="", omit_k_n=False): """Wraps generate_raw(). :param secret: (str or bytes) :param k: number of shares necessary for secret recovery :param n: number of shares generated :param encoding: {hex, b32, b64} desired output encoding. Hexadecimal, Base32 or Base64. + :param label: (str) any label to prefix the shares with + :param omit_k_n: (boolean) suppress the default shares prefix :return: [(str) share, ...]""" if isinstance(secret,str): secret = secret.encode("utf-8") shares = generate_raw(secret, k, n) - return [encode(s, encoding) for s in shares] + + prefix = "" + if label: + prefix = label + "." + if not omit_k_n: + prefix += "{0}.{1}.".format(k, n) + + return [prefix + encode(s, encoding) for s in shares] def reconstruct(*shares, encoding="", raw=False): @@ -89,7 +98,7 @@ def encode(share, encoding="b32"): def decode(share, encoding="b32"): try: - (i, _, share_str) = share.partition(".") + (*_, i, share_str) = share.split(".") i = int(i) if not 1<=i<=255: raise MalformedShare("Malformed share: Failed 1<=k<=255, k={0}".format(i)) @@ -104,9 +113,9 @@ def decode(share, encoding="b32"): def detect_encoding(shares): classes = [ - (re.compile(r"\d+\.([0-9A-F]{2})+"), "hex"), - (re.compile(r"\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), - (re.compile(r"\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64") + (re.compile(r"(.*\.)?\d+\.([0-9A-F]{2})+"), "hex"), + (re.compile(r"(.*\.)?\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), + (re.compile(r"(.*\.)?\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64") ] for (regexp, res) in classes: if all(regexp.fullmatch(share) for share in shares):