Files
@ b52e197db5a8
Branch filter:
Location: Shamira/src/tests/test_shamira.py
b52e197db5a8
3.4 KiB
text/x-python
better exception handling
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # GNU GPLv3, see LICENSE
import random
from unittest import TestCase
from shamira import _shareByte
from shamira import *
class TestShamira(TestCase):
_urandom=os.urandom
@classmethod
def setUpClass(cls):
random.seed(17)
os.urandom=lambda n: bytes(random.randint(0,255) for i in range(n))
@classmethod
def tearDownClass(cls):
os.urandom=cls._urandom
def test_shareByte(self):
with self.assertRaises(InvalidParams): # too few shares
_shareByte(b"a",5,4)
with self.assertRaises(InvalidParams): # too many shares
_shareByte(b"a",5,255)
with self.assertRaises(ValueError): # not castable to int
_shareByte("x",2,3)
vals=_shareByte(ord(b"a"),2,3)
points=list(zip(range(1,256), vals))
self.assertEqual(gf256.getConstantCoef(*points), ord(b"a"))
self.assertEqual(gf256.getConstantCoef(*points[:2]), ord(b"a"))
self.assertNotEqual(gf256.getConstantCoef(*points[:1]), ord(b"a")) # underdetermined => random
def testGenerateReconstructRaw(self):
for (k,n) in [(2,3), (254,254)]:
shares=generateRaw(b"abcd",k,n)
random.shuffle(shares)
self.assertEqual(reconstructRaw(*shares[:k]), b"abcd")
self.assertNotEqual(reconstructRaw(*shares[:k-1]), b"abcd")
def testGenerateReconstruct(self):
for encoding in ["hex","b32","b64"]:
for secret in [b"abcd","abcde","ěščřžý"]:
for (k,n) in [(2,3), (254,254)]:
raw=isinstance(secret,bytes)
with self.subTest(enc=encoding,r=raw,sec=secret,k=k,n=n):
shares=generate(secret,k,n,encoding)
random.shuffle(shares)
self.assertEqual(reconstruct(*shares[:k],encoding=encoding,raw=raw), secret)
self.assertEqual(reconstruct(*shares[:k],raw=raw), secret)
s=secret if raw else secret.encode("utf-8")
self.assertNotEqual(reconstruct(*shares[:k-1],encoding=encoding,raw=True), s)
shares=generate(b"\xfeaa",2,3)
with self.assertRaises(DecodingException):
reconstruct(*shares)
def testEncode(self):
share=(2,b"\x00\x01\x02")
for (encoding,encodedStr) in [("hex",'000102'),("b32",'AAAQE==='),("b64",'AAEC')]:
with self.subTest(enc=encoding):
self.assertEqual(encode(share,encoding), "2."+encodedStr)
def testDecode(self):
with self.assertRaises(MalformedShare):
decode("AAA")
decode("1.")
decode(".AAA")
decode("1AAA")
decode("1.0001020f","hex")
decode("1.000102030","hex")
decode("1.AAAQEAY")
decode("1.AAAQEAy=")
decode("1.AAECAw=","b64")
decode("1.AAECA?==","b64")
decode("256.00010203","hex")
self.assertEqual(decode("1.00010203","hex"), (1,b"\x00\x01\x02\x03"))
self.assertEqual(decode("2.AAAQEAY=","b32"), (2,b"\x00\x01\x02\x03"))
self.assertEqual(decode("3.AAECAw==","b64"), (3,b"\x00\x01\x02\x03"))
def testDetectEncoding(self):
for shares in [
["1.00010f"], # bad case
["1.000102030"], # bad char count
["1.AAAQEAY"], # no padding
["1.AAAQe==="], # bad case
["1.AAECA?=="], # bad char
["1.AAECAw="], # bad padding
["1.000102","2.AAAQEAY="], # mixed encoding
["1.000102","2.AAECAw=="],
["1.AAECAw==","2.AAAQE==="],
[".00010203"], # no index
["00010203"] # no index
]:
with self.subTest(shares=shares):
with self.assertRaises(DetectionException):
detectEncoding(shares)
self.assertEqual(detectEncoding(["10.00010203"]), "hex")
self.assertEqual(detectEncoding(["2.AAAQEAY="]), "b32")
self.assertEqual(detectEncoding(["3.AAECAw=="]), "b64")
self.assertEqual(detectEncoding(["3.AAECAwQF","1.00010203"]), "b64")
|