# HG changeset patch # User Laman # Date 2017-10-01 18:45:34 # Node ID b52e197db5a8ae1e1368fd1b4509e9c26ed73952 # Parent 735b9c2a61e94b93c67d17490920a7ee00667740 better exception handling diff --git a/src/cli.py b/src/cli.py --- a/src/cli.py +++ b/src/cli.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser -from shamira import generate, reconstruct +from shamira import generate, reconstruct, SException def run(): @@ -49,16 +49,16 @@ def _generate(args): shares=generate(args.secret,args.k,args.n,encoding) for s in shares: print(s) - except ValueError as e: - print("operation failed: ",e) + except SException as e: + print(e) def _reconstruct(args): encoding=getEncoding(args) try: print(reconstruct(*args.share,encoding=encoding,raw=args.raw)) - except ValueError as e: - print("operation failed: ",e) + except SException as e: + print(e) def getEncoding(args): diff --git a/src/condensed.py b/src/condensed.py --- a/src/condensed.py +++ b/src/condensed.py @@ -52,6 +52,9 @@ import base64 import binascii +class SException(Exception): pass + + def reconstructRaw(*shares): """Tries to recover the secret from its shares. @@ -72,8 +75,10 @@ def reconstruct(*shares): :return: (str) reconstructed secret. Too few shares returns garbage.""" bs=reconstructRaw(*(decode(s) for s in shares)) - return bs.decode(encoding="utf-8") - + try: + return bs.decode(encoding="utf-8") + except UnicodeDecodeError: + raise SException('Failed to decode bytes to utf-8. Either you supplied invalid shares, or you missed the "raw" flag. Offending value: "{0}"'.format(bs)) def decode(share): @@ -81,12 +86,12 @@ def decode(share): (i,_,shareStr)=share.partition(".") i=int(i) if not 1<=i<=255: - raise ValueError() + raise SException("Malformed share: Failed 1<=k<=255, k={0}".format(i)) shareBytes=base64.b32decode(shareStr) return (i,shareBytes) except (ValueError,binascii.Error): - raise ValueError('bad share format: share="{0}"'.format(share)) + raise SException('Malformed share: share="{0}"'.format(share)) ### @@ -109,7 +114,7 @@ def run(): def _reconstruct(args): try: print(reconstruct(*args.share)) - except ValueError as e: print("operation failed: ",e) + except SException as e: print(e) if __name__=="__main__": diff --git a/src/shamira.py b/src/shamira.py --- a/src/shamira.py +++ b/src/shamira.py @@ -8,9 +8,16 @@ import binascii import gf256 +class SException(Exception): pass +class InvalidParams(SException): pass +class DetectionException(SException): pass +class DecodingException(SException): pass +class MalformedShare(SException): pass + + def _shareByte(secretB,k,n): if not k<=n<255: - raise ValueError("failing k<=n<255, k={0}, n={1}".format(k,n)) + raise InvalidParams("Failed k<=n<255, k={0}, n={1}".format(k,n)) coefs=[int(secretB)]+[int(b) for b in os.urandom(k-1)] points=[gf256.evaluate(coefs,i) for i in range(1,n+1)] return points @@ -65,7 +72,10 @@ def reconstruct(*shares,encoding="",raw= encoding=detectEncoding(shares) bs=reconstructRaw(*(decode(s,encoding) for s in shares)) - return bs if raw else bs.decode(encoding="utf-8") + try: + return bs if raw else bs.decode(encoding="utf-8") + except UnicodeDecodeError: + raise DecodingException('Failed to decode bytes to utf-8. Either you supplied invalid shares, or you missed the "raw" flag. Offending value: {0}'.format(bs)) def encode(share,encoding="b32"): @@ -81,14 +91,14 @@ def decode(share,encoding="b32"): (i,_,shareStr)=share.partition(".") i=int(i) if not 1<=i<=255: - raise ValueError() + raise MalformedShare("Malformed share: Failed 1<=k<=255, k={0}".format(i)) if encoding=="hex": f=base64.b16decode elif encoding=="b32": f=base64.b32decode else: f=base64.b64decode shareBytes=f(shareStr) return (i,shareBytes) except (ValueError,binascii.Error): - raise ValueError('bad share format: share="{0}", encoding="{1}"'.format(share,encoding)) + raise MalformedShare('Malformed share: share="{0}", encoding="{1}"'.format(share,encoding)) def detectEncoding(shares): @@ -100,7 +110,7 @@ def detectEncoding(shares): for (regexp, res) in classes: if all(regexp.fullmatch(share) for share in shares): return res - raise ValueError("no expected encoding detected") + raise DetectionException("No expected encoding detected") if __name__=="__main__": diff --git a/src/tests/test_condensed.py b/src/tests/test_condensed.py --- a/src/tests/test_condensed.py +++ b/src/tests/test_condensed.py @@ -70,14 +70,14 @@ class TestCondensed(TestCase): self.assertEqual(reconstruct(*shares[:k]), secret) try: self.assertNotEqual(reconstruct(*shares[:k-1]), secret) - except UnicodeDecodeError: + except SException: pass shares=generate(b"\xfeaa",2,3) - with self.assertRaises(UnicodeDecodeError): + with self.assertRaises(SException): reconstruct(*shares) def testDecode(self): - with self.assertRaises(ValueError): + with self.assertRaises(SException): decode("AAA") decode("1.") decode(".AAA") diff --git a/src/tests/test_shamira.py b/src/tests/test_shamira.py --- a/src/tests/test_shamira.py +++ b/src/tests/test_shamira.py @@ -20,9 +20,9 @@ class TestShamira(TestCase): os.urandom=cls._urandom def test_shareByte(self): - with self.assertRaises(ValueError): # too few shares + with self.assertRaises(InvalidParams): # too few shares _shareByte(b"a",5,4) - with self.assertRaises(ValueError): # too many shares + with self.assertRaises(InvalidParams): # too many shares _shareByte(b"a",5,255) with self.assertRaises(ValueError): # not castable to int _shareByte("x",2,3) @@ -53,7 +53,7 @@ class TestShamira(TestCase): s=secret if raw else secret.encode("utf-8") self.assertNotEqual(reconstruct(*shares[:k-1],encoding=encoding,raw=True), s) shares=generate(b"\xfeaa",2,3) - with self.assertRaises(UnicodeDecodeError): + with self.assertRaises(DecodingException): reconstruct(*shares) def testEncode(self): @@ -63,7 +63,7 @@ class TestShamira(TestCase): self.assertEqual(encode(share,encoding), "2."+encodedStr) def testDecode(self): - with self.assertRaises(ValueError): + with self.assertRaises(MalformedShare): decode("AAA") decode("1.") decode(".AAA") @@ -95,7 +95,7 @@ class TestShamira(TestCase): ["00010203"] # no index ]: with self.subTest(shares=shares): - with self.assertRaises(ValueError): + with self.assertRaises(DetectionException): detectEncoding(shares) self.assertEqual(detectEncoding(["10.00010203"]), "hex") self.assertEqual(detectEncoding(["2.AAAQEAY="]), "b32")