diff --git a/src/cli.py b/src/cli.py --- a/src/cli.py +++ b/src/cli.py @@ -6,47 +6,47 @@ from shamira import generate, reconstruc def run(): - parser=ArgumentParser() - subparsers=parser.add_subparsers() + parser = ArgumentParser() + subparsers = parser.add_subparsers() buildSplitParser(subparsers.add_parser("split")) buildJoinParser(subparsers.add_parser("join")) parser.set_defaults(func=lambda _: parser.error("missing command")) - args=parser.parse_args() + args = parser.parse_args() args.func(args) def buildSplitParser(parser): - parser.add_argument("-k",type=int,required=True,help="number of shares necessary for recovering the secret") - parser.add_argument("-n",type=int,required=True,help="number of generated shares") + parser.add_argument("-k", type=int, required=True, help="number of shares necessary for recovering the secret") + parser.add_argument("-n", type=int, required=True, help="number of generated shares") encoding=parser.add_mutually_exclusive_group() - encoding.add_argument("--hex",action="store_true",help="encode shares' bytes as a hexadecimal string") - encoding.add_argument("--b32",action="store_true",help="encode shares' bytes as a base32 string") - encoding.add_argument("--b64",action="store_true",help="encode shares' bytes as a base64 string") + encoding.add_argument("--hex", action="store_true", help="encode shares' bytes as a hexadecimal string") + encoding.add_argument("--b32", action="store_true", help="encode shares' bytes as a base32 string") + encoding.add_argument("--b64", action="store_true", help="encode shares' bytes as a base64 string") - parser.add_argument("secret",help="secret to be parsed") + parser.add_argument("secret", help="secret to be parsed") parser.set_defaults(func=_generate) def buildJoinParser(parser): encoding=parser.add_mutually_exclusive_group() - encoding.add_argument("--hex",action="store_true",help="decode shares' bytes from a hexadecimal string") - encoding.add_argument("--b32",action="store_true",help="decode shares' bytes from a base32 string") - encoding.add_argument("--b64",action="store_true",help="decode shares' bytes from a base64 string") + encoding.add_argument("--hex", action="store_true", help="decode shares' bytes from a hexadecimal string") + encoding.add_argument("--b32", action="store_true", help="decode shares' bytes from a base32 string") + encoding.add_argument("--b64", action="store_true", help="decode shares' bytes from a base64 string") - parser.add_argument("-r","--raw",action="store_true",help="return secret as raw bytes") - parser.add_argument("share",nargs="+",help="shares to be joined") + parser.add_argument("-r", "--raw", action="store_true", help="return secret as raw bytes") + parser.add_argument("share", nargs="+", help="shares to be joined") parser.set_defaults(func=_reconstruct) def _generate(args): - encoding=getEncoding(args) or "b32" + encoding = getEncoding(args) or "b32" try: - shares=generate(args.secret,args.k,args.n,encoding) + shares = generate(args.secret, args.k, args.n, encoding) for s in shares: print(s) except SException as e: @@ -54,9 +54,9 @@ def _generate(args): def _reconstruct(args): - encoding=getEncoding(args) + encoding = getEncoding(args) try: - print(reconstruct(*args.share,encoding=encoding,raw=args.raw)) + print(reconstruct(*args.share, encoding=encoding, raw=args.raw)) except SException as e: print(e) diff --git a/src/condensed.py b/src/condensed.py --- a/src/condensed.py +++ b/src/condensed.py @@ -7,43 +7,43 @@ """Arithmetic operations on Galois Field 2**8. See https://en.wikipedia.org/wiki/Finite_field_arithmetic""" -def gfmul(a,b): +def gfmul(a, b): """Basic multiplication. Russian peasant algorithm.""" - res=0 + res = 0 while a and b: - if b&1: res^=a - if a&0x80: a=0xff&(a<<1)^0x1b - else: a<<=1 - b>>=1 + if b&1: res ^= a + if a&0x80: a = 0xff&(a<<1)^0x1b + else: a <<= 1 + b >>= 1 return res -g=3 # generator -E=[None]*256 # exponentials -L=[None]*256 # logarithms -acc=1 +g = 3 # generator +E = [None]*256 # exponentials +L = [None]*256 # logarithms +acc = 1 for i in range(256): - E[i]=acc - L[acc]=i - acc=gfmul(acc, g) -L[1]=0 -inv=[E[255-L[i]] if i!=0 else None for i in range(256)] # multiplicative inverse + E[i] = acc + L[acc] = i + acc = gfmul(acc, g) +L[1] = 0 +inv = [E[255-L[i]] if i!=0 else None for i in range(256)] # multiplicative inverse def getConstantCoef(*points): """Compute constant polynomial coefficient given the points. See https://en.wikipedia.org/wiki/Shamir's_Secret_Sharing#Computationally_Efficient_Approach""" - k=len(points) - res=0 + k = len(points) + res = 0 for i in range(k): - (x,y)=points[i] - prod=1 + (x, y) = points[i] + prod = 1 for j in range(k): if i==j: continue - (xj,yj)=points[j] - prod=gfmul(prod, (gfmul(xj,inv[xj^x]))) - res^=gfmul(y,prod) + (xj, yj) = points[j] + prod = gfmul(prod, (gfmul(xj, inv[xj^x]))) + res ^= gfmul(y, prod) return res ### @@ -60,11 +60,11 @@ def reconstructRaw(*shares): :param shares: ((i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" - secretLen=len(shares[0][1]) - res=[None]*secretLen + secretLen = len(shares[0][1]) + res = [None]*secretLen for i in range(secretLen): - points=[(x,s[i]) for (x,s) in shares] - res[i]=(getConstantCoef(*points)) + points = [(x, s[i]) for (x, s) in shares] + res[i] = (getConstantCoef(*points)) return bytes(res) @@ -74,7 +74,7 @@ def reconstruct(*shares): :param shares: ((str) share, ...) :return: (str) reconstructed secret. Too few shares returns garbage.""" - bs=reconstructRaw(*(decode(s) for s in shares)) + bs = reconstructRaw(*(decode(s) for s in shares)) try: return bs.decode(encoding="utf-8") except UnicodeDecodeError: @@ -83,14 +83,14 @@ def reconstruct(*shares): def decode(share): try: - (i,_,shareStr)=share.partition(".") - i=int(i) + (i, _, shareStr) = share.partition(".") + i = int(i) if not 1<=i<=255: raise SException("Malformed share: Failed 1<=k<=255, k={0}".format(i)) - shareBytes=base64.b32decode(shareStr) - return (i,shareBytes) - except (ValueError,binascii.Error): + shareBytes = base64.b32decode(shareStr) + return (i, shareBytes) + except (ValueError, binascii.Error): raise SException('Malformed share: share="{0}"'.format(share)) ### @@ -99,16 +99,16 @@ from argparse import ArgumentParser def run(): - parser=ArgumentParser() - subparsers=parser.add_subparsers() + parser = ArgumentParser() + subparsers = parser.add_subparsers() - joiner=subparsers.add_parser("join") - joiner.add_argument("share",nargs="+",help="shares to be joined") + joiner = subparsers.add_parser("join") + joiner.add_argument("share", nargs="+", help="shares to be joined") joiner.set_defaults(func=_reconstruct) parser.set_defaults(func=lambda: parser.error("missing command")) - args=parser.parse_args() + args = parser.parse_args() args.func(args) diff --git a/src/gf256.py b/src/gf256.py --- a/src/gf256.py +++ b/src/gf256.py @@ -3,47 +3,47 @@ """Arithmetic operations on Galois Field 2**8. See https://en.wikipedia.org/wiki/Finite_field_arithmetic""" -def _gfmul(a,b): +def _gfmul(a, b): """Basic multiplication. Russian peasant algorithm.""" - res=0 + res = 0 while a and b: - if b&1: res^=a - if a&0x80: a=0xff&(a<<1)^0x1b - else: a<<=1 - b>>=1 + if b&1: res ^= a + if a&0x80: a = 0xff&(a<<1)^0x1b + else: a <<= 1 + b >>= 1 return res -g=3 # generator -E=[None]*256 # exponentials -L=[None]*256 # logarithms -acc=1 +g = 3 # generator +E = [None]*256 # exponentials +L = [None]*256 # logarithms +acc = 1 for i in range(256): - E[i]=acc - L[acc]=i - acc=_gfmul(acc, g) -L[1]=0 -inv=[E[255-L[i]] if i!=0 else None for i in range(256)] # multiplicative inverse + E[i] = acc + L[acc] = i + acc = _gfmul(acc, g) +L[1] = 0 +inv = [E[255-L[i]] if i != 0 else None for i in range(256)] # multiplicative inverse def gfmul(a, b): """Fast multiplication. Basic multiplication is expensive. a*b==g**(log(a)+log(b))""" assert 0<=a<=255, 0<=b<=255 if a==0 or b==0: return 0 - t=L[a]+L[b] - if t>255: t-=255 + t = L[a]+L[b] + if t>255: t -= 255 return E[t] -def evaluate(coefs,x): +def evaluate(coefs, x): """Evaluate polynomial's value at x. :param coefs: [a0, a1, ...].""" - res=0 - xK=1 + res = 0 + xK = 1 for a in coefs: - res^=gfmul(a,xK) - xK=gfmul(xK,x) + res ^= gfmul(a, xK) + xK = gfmul(xK, x) return res @@ -51,14 +51,14 @@ def getConstantCoef(*points): """Compute constant polynomial coefficient given the points. See https://en.wikipedia.org/wiki/Shamir's_Secret_Sharing#Computationally_Efficient_Approach""" - k=len(points) - res=0 + k = len(points) + res = 0 for i in range(k): - (x,y)=points[i] - prod=1 + (x, y) = points[i] + prod = 1 for j in range(k): if i==j: continue - (xj,yj)=points[j] - prod=gfmul(prod, (gfmul(xj,inv[xj^x]))) - res^=gfmul(y,prod) + (xj, yj) = points[j] + prod = gfmul(prod, (gfmul(xj, inv[xj^x]))) + res ^= gfmul(y, prod) return res diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -15,23 +15,23 @@ class DecodingException(SException): pas class MalformedShare(SException): pass -def _shareByte(secretB,k,n): +def _shareByte(secretB, k, n): if not k<=n<255: - raise InvalidParams("Failed k<=n<255, k={0}, n={1}".format(k,n)) + raise InvalidParams("Failed k<=n<255, k={0}, n={1}".format(k, n)) # we might be concerned with zero coefficients degenerating our polynomial, but there's no reason - we still need k shares to determine it is the case - coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] - points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] + coefs = [int(secretB)]+[int(b) for b in os.urandom(k-1)] + points = [gf256.evaluate(coefs, i) for i in range(1, n+1)] return points -def generateRaw(secret,k,n): +def generateRaw(secret, k, n): """Splits secret into shares. :param secret: (bytes) :param k: number of shares necessary for secret recovery. 1 <= k <= n :param n: (int) number of shares generated. 1 <= n < 255 :return: [(i, (bytes) share), ...]""" - shares=[_shareByte(b,k,n) for b in secret] + shares = [_shareByte(b, k, n) for b in secret] return [(i+1, bytes([s[i] for s in shares])) for i in range(n)] @@ -40,15 +40,15 @@ def reconstructRaw(*shares): :param shares: ((i, (bytes) share), ...) :return: (bytes) reconstructed secret. Too few shares returns garbage.""" - secretLen=len(shares[0][1]) - res=[None]*secretLen + secretLen = len(shares[0][1]) + res = [None]*secretLen for i in range(secretLen): - points=[(x,s[i]) for (x,s) in shares] - res[i]=(gf256.getConstantCoef(*points)) + points = [(x, s[i]) for (x, s) in shares] + res[i] = (gf256.getConstantCoef(*points)) return bytes(res) -def generate(secret,k,n,encoding="b32"): +def generate(secret, k, n, encoding="b32"): """Wraps generateRaw(). :param secret: (str or bytes) @@ -57,12 +57,12 @@ def generate(secret,k,n,encoding="b32"): :param encoding: {hex, b32, b64} desired output encoding. Hexadecimal, Base32 or Base64. :return: [(str) share, ...]""" if isinstance(secret,str): - secret=secret.encode("utf-8") - shares=generateRaw(secret,k,n) - return [encode(s,encoding) for s in shares] + secret = secret.encode("utf-8") + shares = generateRaw(secret, k, n) + return [encode(s, encoding) for s in shares] -def reconstruct(*shares,encoding="",raw=False): +def reconstruct(*shares, encoding="", raw=False): """Wraps reconstructRaw. :param shares: ((str) share, ...) @@ -70,40 +70,40 @@ def reconstruct(*shares,encoding="",raw= :param raw: (bool) whether to return bytes (True) or str (False) :return: (str or bytes) reconstructed secret. Too few shares returns garbage.""" if not encoding: - encoding=detectEncoding(shares) + encoding = detectEncoding(shares) - bs=reconstructRaw(*(decode(s,encoding) for s in shares)) + bs = reconstructRaw(*(decode(s, encoding) for s in shares)) try: return bs if raw else bs.decode(encoding="utf-8") except UnicodeDecodeError: raise DecodingException('Failed to decode bytes to utf-8. Either you supplied invalid shares, or you missed the "raw" flag. Offending value: {0}'.format(bs)) -def encode(share,encoding="b32"): - if encoding=="hex": f=base64.b16encode - elif encoding=="b32": f=base64.b32encode - else: f=base64.b64encode - (i,bs)=share - return "{0}.{1}".format(i,f(bs).decode("utf-8")) +def encode(share, encoding="b32"): + if encoding=="hex": f = base64.b16encode + elif encoding=="b32": f = base64.b32encode + else: f = base64.b64encode + (i, bs) = share + return "{0}.{1}".format(i, f(bs).decode("utf-8")) -def decode(share,encoding="b32"): +def decode(share, encoding="b32"): try: - (i,_,shareStr)=share.partition(".") - i=int(i) + (i, _, shareStr) = share.partition(".") + i = int(i) if not 1<=i<=255: raise MalformedShare("Malformed share: Failed 1<=k<=255, k={0}".format(i)) - if encoding=="hex": f=base64.b16decode - elif encoding=="b32": f=base64.b32decode - else: f=base64.b64decode - shareBytes=f(shareStr) - return (i,shareBytes) - except (ValueError,binascii.Error): - raise MalformedShare('Malformed share: share="{0}", encoding="{1}"'.format(share,encoding)) + if encoding=="hex": f = base64.b16decode + elif encoding=="b32": f = base64.b32decode + else: f = base64.b64decode + shareBytes = f(shareStr) + return (i, shareBytes) + except (ValueError, binascii.Error): + raise MalformedShare('Malformed share: share="{0}", encoding="{1}"'.format(share, encoding)) def detectEncoding(shares): - classes=[ + classes = [ (re.compile(r"\d+\.([0-9A-F]{2})+"), "hex"), (re.compile(r"\d+\.([A-Z2-7]{8})*([A-Z2-7]{8}|[A-Z2-7]{2}={6}|[A-Z2-7]{4}={4}|[A-Z2-7]{5}={3}|[A-Z2-7]{7}={1})"), "b32"), (re.compile(r"\d+\.([A-Za-z0-9+/]{4})*([A-Za-z0-9+/]{4}|[A-Za-z0-9+/]{2}={2}|[A-Za-z0-9+/]{3}={1})"), "b64") diff --git a/src/tests/test_condensed.py b/src/tests/test_condensed.py --- a/src/tests/test_condensed.py +++ b/src/tests/test_condensed.py @@ -4,75 +4,75 @@ import os import random from unittest import TestCase -from gf256 import _gfmul,evaluate -from shamira import generateRaw,generate +from gf256 import _gfmul, evaluate +from shamira import generateRaw, generate from condensed import * class TestCondensed(TestCase): - _urandom=os.urandom + _urandom = os.urandom @classmethod def setUpClass(cls): random.seed(17) - os.urandom=lambda n: bytes(random.randint(0,255) for i in range(n)) + os.urandom = lambda n: bytes(random.randint(0, 255) for i in range(n)) @classmethod def tearDownClass(cls): - os.urandom=cls._urandom + os.urandom = cls._urandom def testGfmul(self): for a in range(256): for b in range(256): - self.assertEqual(_gfmul(a,b), gfmul(a,b)) + self.assertEqual(_gfmul(a, b), gfmul(a, b)) def testGetConstantCoef(self): - self.assertEqual(getConstantCoef((1,1),(2,2),(3,3)), 0) + self.assertEqual(getConstantCoef((1, 1), (2, 2), (3, 3)), 0) random.seed(17) - randomMatches=0 + randomMatches = 0 for i in range(10): - k=random.randint(2,255) + k = random.randint(2, 255) # exact - res=self.checkCoefsMatch(k,k) + res = self.checkCoefsMatch(k, k) self.assertEqual(res[0], res[1]) # overdetermined - res=self.checkCoefsMatch(k,256) + res = self.checkCoefsMatch(k, 256) self.assertEqual(res[0], res[1]) # underdetermined => random - res=self.checkCoefsMatch(k,k-1) + res = self.checkCoefsMatch(k, k-1) if res[0]==res[1]: - randomMatches+=1 - self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match + randomMatches += 1 + self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match - def checkCoefsMatch(self,k,m): - coefs=[random.randint(0,255) for i in range(k)] - points=[(j, evaluate(coefs,j)) for j in range(1,256)] + def checkCoefsMatch(self, k, m): + coefs = [random.randint(0, 255) for i in range(k)] + points = [(j, evaluate(coefs, j)) for j in range(1, 256)] random.shuffle(points) return (getConstantCoef(*points[:m]), coefs[0]) def testGenerateReconstructRaw(self): - for (k,n) in [(2,3), (254,254)]: - shares=generateRaw(b"abcd",k,n) + for (k, n) in [(2, 3), (254, 254)]: + shares = generateRaw(b"abcd", k, n) random.shuffle(shares) self.assertEqual(reconstructRaw(*shares[:k]), b"abcd") self.assertNotEqual(reconstructRaw(*shares[:k-1]), b"abcd") def testGenerateReconstruct(self): - for secret in ["abcde","ěščřžý"]: - for (k,n) in [(2,3), (254,254)]: - with self.subTest(sec=secret,k=k,n=n): - shares=generate(secret,k,n) + for secret in ["abcde", "ěščřžý"]: + for (k, n) in [(2, 3), (254, 254)]: + with self.subTest(sec=secret, k=k, n=n): + shares = generate(secret, k, n) random.shuffle(shares) self.assertEqual(reconstruct(*shares[:k]), secret) try: self.assertNotEqual(reconstruct(*shares[:k-1]), secret) except SException: pass - shares=generate(b"\xfeaa",2,3) + shares = generate(b"\xfeaa", 2, 3) with self.assertRaises(SException): reconstruct(*shares) @@ -85,4 +85,4 @@ class TestCondensed(TestCase): decode("1.AAAQEAY") decode("1.AAAQEAy=") decode("256.AAAQEAY=") - self.assertEqual(decode("2.AAAQEAY="), (2,b"\x00\x01\x02\x03")) + self.assertEqual(decode("2.AAAQEAY="), (2, b"\x00\x01\x02\x03")) diff --git a/src/tests/test_gf256.py b/src/tests/test_gf256.py --- a/src/tests/test_gf256.py +++ b/src/tests/test_gf256.py @@ -10,53 +10,53 @@ from gf256 import * class TestGF256(TestCase): def test_gfmul(self): - self.assertEqual(_gfmul(0,0), 0) - self.assertEqual(_gfmul(1,1), 1) - self.assertEqual(_gfmul(2,2), 4) - self.assertEqual(_gfmul(0,21), 0) - self.assertEqual(_gfmul(0x53,0xca), 0x01) - self.assertEqual(_gfmul(0xff,0xff), 0x13) + self.assertEqual(_gfmul(0, 0), 0) + self.assertEqual(_gfmul(1, 1), 1) + self.assertEqual(_gfmul(2, 2), 4) + self.assertEqual(_gfmul(0, 21), 0) + self.assertEqual(_gfmul(0x53, 0xca), 0x01) + self.assertEqual(_gfmul(0xff, 0xff), 0x13) def testGfmul(self): for a in range(256): for b in range(256): - self.assertEqual(_gfmul(a,b), gfmul(a,b)) + self.assertEqual(_gfmul(a, b), gfmul(a, b)) def testEvaluate(self): for x in range(256): - (a0,a1,a2,a3)=(x,x>>1,x>>2,x>>3) - self.assertEqual(evaluate([17],x), 17) # constant polynomial - self.assertEqual(evaluate([a0,a1,a2,a3],0), x) # any polynomial at 0 - self.assertEqual(evaluate([a0,a1,a2,a3],1), a0^a1^a2^a3) # polynomial at 1 == sum of coefficients + (a0, a1, a2, a3) = (x, x>>1, x>>2, x>>3) + self.assertEqual(evaluate([17], x), 17) # constant polynomial + self.assertEqual(evaluate([a0, a1, a2, a3], 0), x) # any polynomial at 0 + self.assertEqual(evaluate([a0, a1, a2, a3], 1), a0^a1^a2^a3) # polynomial at 1 == sum of coefficients def testGetConstantCoef(self): - self.assertEqual(getConstantCoef((1,1),(2,2),(3,3)), 0) + self.assertEqual(getConstantCoef((1, 1), (2, 2), (3, 3)), 0) random.seed(17) - randomMatches=0 + randomMatches = 0 for i in range(10): - k=random.randint(2,255) + k = random.randint(2, 255) # exact - res=self.checkCoefsMatch(k,k) + res = self.checkCoefsMatch(k, k) self.assertEqual(res[0], res[1]) # overdetermined - res=self.checkCoefsMatch(k,256) + res = self.checkCoefsMatch(k, 256) self.assertEqual(res[0], res[1]) # underdetermined => random - res=self.checkCoefsMatch(k,k-1) + res = self.checkCoefsMatch(k, k-1) if res[0]==res[1]: - randomMatches+=1 - self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match + randomMatches += 1 + self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match - def checkCoefsMatch(self,k,m): - coefs=[random.randint(0,255) for i in range(k)] - points=[(j, evaluate(coefs,j)) for j in range(1,256)] + def checkCoefsMatch(self, k, m): + coefs = [random.randint(0, 255) for i in range(k)] + points = [(j, evaluate(coefs, j)) for j in range(1, 256)] random.shuffle(points) return (getConstantCoef(*points[:m]), coefs[0]) -if __name__ == '__main__': +if __name__=='__main__': unittest.main() diff --git a/src/tests/test_shamira.py b/src/tests/test_shamira.py --- a/src/tests/test_shamira.py +++ b/src/tests/test_shamira.py @@ -8,59 +8,59 @@ from shamira import * class TestShamira(TestCase): - _urandom=os.urandom + _urandom = os.urandom @classmethod def setUpClass(cls): random.seed(17) - os.urandom=lambda n: bytes(random.randint(0,255) for i in range(n)) + os.urandom = lambda n: bytes(random.randint(0, 255) for i in range(n)) @classmethod def tearDownClass(cls): - os.urandom=cls._urandom + os.urandom = cls._urandom def test_shareByte(self): - with self.assertRaises(InvalidParams): # too few shares - _shareByte(b"a",5,4) - with self.assertRaises(InvalidParams): # too many shares - _shareByte(b"a",5,255) - with self.assertRaises(ValueError): # not castable to int - _shareByte("x",2,3) + with self.assertRaises(InvalidParams): # too few shares + _shareByte(b"a", 5, 4) + with self.assertRaises(InvalidParams): # too many shares + _shareByte(b"a", 5, 255) + with self.assertRaises(ValueError): # not castable to int + _shareByte("x", 2, 3) - vals=_shareByte(ord(b"a"),2,3) - points=list(zip(range(1,256), vals)) + vals = _shareByte(ord(b"a"), 2, 3) + points = list(zip(range(1, 256), vals)) self.assertEqual(gf256.getConstantCoef(*points), ord(b"a")) self.assertEqual(gf256.getConstantCoef(*points[:2]), ord(b"a")) - self.assertNotEqual(gf256.getConstantCoef(*points[:1]), ord(b"a")) # underdetermined => random + self.assertNotEqual(gf256.getConstantCoef(*points[:1]), ord(b"a")) # underdetermined => random def testGenerateReconstructRaw(self): - for (k,n) in [(2,3), (254,254)]: - shares=generateRaw(b"abcd",k,n) + for (k, n) in [(2, 3), (254, 254)]: + shares = generateRaw(b"abcd", k, n) random.shuffle(shares) self.assertEqual(reconstructRaw(*shares[:k]), b"abcd") self.assertNotEqual(reconstructRaw(*shares[:k-1]), b"abcd") def testGenerateReconstruct(self): - for encoding in ["hex","b32","b64"]: - for secret in [b"abcd","abcde","ěščřžý"]: - for (k,n) in [(2,3), (254,254)]: - raw=isinstance(secret,bytes) - with self.subTest(enc=encoding,r=raw,sec=secret,k=k,n=n): - shares=generate(secret,k,n,encoding) + for encoding in ["hex", "b32", "b64"]: + for secret in [b"abcd", "abcde", "ěščřžý"]: + for (k, n) in [(2, 3), (254, 254)]: + raw = isinstance(secret, bytes) + with self.subTest(enc=encoding, r=raw, sec=secret, k=k, n=n): + shares = generate(secret, k, n, encoding) random.shuffle(shares) - self.assertEqual(reconstruct(*shares[:k],encoding=encoding,raw=raw), secret) - self.assertEqual(reconstruct(*shares[:k],raw=raw), secret) - s=secret if raw else secret.encode("utf-8") - self.assertNotEqual(reconstruct(*shares[:k-1],encoding=encoding,raw=True), s) - shares=generate(b"\xfeaa",2,3) + self.assertEqual(reconstruct(*shares[:k], encoding=encoding, raw=raw), secret) + self.assertEqual(reconstruct(*shares[:k], raw=raw), secret) + s = secret if raw else secret.encode("utf-8") + self.assertNotEqual(reconstruct(*shares[:k-1], encoding=encoding, raw=True), s) + shares = generate(b"\xfeaa", 2, 3) with self.assertRaises(DecodingException): reconstruct(*shares) def testEncode(self): - share=(2,b"\x00\x01\x02") - for (encoding,encodedStr) in [("hex",'000102'),("b32",'AAAQE==='),("b64",'AAEC')]: + share = (2, b"\x00\x01\x02") + for (encoding, encodedStr) in [("hex", '000102'), ("b32", 'AAAQE==='), ("b64", 'AAEC')]: with self.subTest(enc=encoding): - self.assertEqual(encode(share,encoding), "2."+encodedStr) + self.assertEqual(encode(share, encoding), "2."+encodedStr) def testDecode(self): with self.assertRaises(MalformedShare): @@ -68,31 +68,30 @@ class TestShamira(TestCase): decode("1.") decode(".AAA") decode("1AAA") - decode("1.0001020f","hex") - decode("1.000102030","hex") + decode("1.0001020f", "hex") + decode("1.000102030", "hex") decode("1.AAAQEAY") decode("1.AAAQEAy=") - decode("1.AAECAw=","b64") - decode("1.AAECA?==","b64") - decode("256.00010203","hex") - self.assertEqual(decode("1.00010203","hex"), (1,b"\x00\x01\x02\x03")) - self.assertEqual(decode("2.AAAQEAY=","b32"), (2,b"\x00\x01\x02\x03")) - self.assertEqual(decode("3.AAECAw==","b64"), (3,b"\x00\x01\x02\x03")) - + decode("1.AAECAw=", "b64") + decode("1.AAECA?==", "b64") + decode("256.00010203", "hex") + self.assertEqual(decode("1.00010203", "hex"), (1, b"\x00\x01\x02\x03")) + self.assertEqual(decode("2.AAAQEAY=", "b32"), (2, b"\x00\x01\x02\x03")) + self.assertEqual(decode("3.AAECAw==", "b64"), (3, b"\x00\x01\x02\x03")) def testDetectEncoding(self): for shares in [ - ["1.00010f"], # bad case - ["1.000102030"], # bad char count - ["1.AAAQEAY"], # no padding - ["1.AAAQe==="], # bad case - ["1.AAECA?=="], # bad char - ["1.AAECAw="], # bad padding - ["1.000102","2.AAAQEAY="], # mixed encoding - ["1.000102","2.AAECAw=="], - ["1.AAECAw==","2.AAAQE==="], - [".00010203"], # no index - ["00010203"] # no index + ["1.00010f"], # bad case + ["1.000102030"], # bad char count + ["1.AAAQEAY"], # no padding + ["1.AAAQe==="], # bad case + ["1.AAECA?=="], # bad char + ["1.AAECAw="], # bad padding + ["1.000102", "2.AAAQEAY="], # mixed encoding + ["1.000102", "2.AAECAw=="], + ["1.AAECAw==", "2.AAAQE==="], + [".00010203"], # no index + ["00010203"] # no index ]: with self.subTest(shares=shares): with self.assertRaises(DetectionException): @@ -100,4 +99,4 @@ class TestShamira(TestCase): self.assertEqual(detectEncoding(["10.00010203"]), "hex") self.assertEqual(detectEncoding(["2.AAAQEAY="]), "b32") self.assertEqual(detectEncoding(["3.AAECAw=="]), "b64") - self.assertEqual(detectEncoding(["3.AAECAwQF","1.00010203"]), "b64") + self.assertEqual(detectEncoding(["3.AAECAwQF", "1.00010203"]), "b64")