# GNU GPLv3, see LICENSE import os import random from unittest import TestCase from gf256 import _gfmul,evaluate from shamira import generateRaw,generate from condensed import * class TestCondensed(TestCase): _urandom=os.urandom @classmethod def setUpClass(cls): random.seed(17) os.urandom=lambda n: bytes(random.randint(0,255) for i in range(n)) @classmethod def tearDownClass(cls): os.urandom=cls._urandom def testGfmul(self): for a in range(256): for b in range(256): self.assertEqual(_gfmul(a,b), gfmul(a,b)) def testGetConstantCoef(self): self.assertEqual(getConstantCoef((1,1),(2,2),(3,3)), 0) random.seed(17) randomMatches=0 for i in range(10): k=random.randint(2,255) # exact res=self.checkCoefsMatch(k,k) self.assertEqual(res[0], res[1]) # overdetermined res=self.checkCoefsMatch(k,256) self.assertEqual(res[0], res[1]) # underdetermined => random res=self.checkCoefsMatch(k,k-1) if res[0]==res[1]: randomMatches+=1 self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match def checkCoefsMatch(self,k,m): coefs=[random.randint(0,255) for i in range(k)] points=[(j, evaluate(coefs,j)) for j in range(1,256)] random.shuffle(points) return (getConstantCoef(*points[:m]), coefs[0]) def testGenerateReconstructRaw(self): for (k,n) in [(2,3), (254,254)]: shares=generateRaw(b"abcd",k,n) random.shuffle(shares) self.assertEqual(reconstructRaw(*shares[:k]), b"abcd") self.assertNotEqual(reconstructRaw(*shares[:k-1]), b"abcd") def testGenerateReconstruct(self): for secret in ["abcde","ěščřžý"]: for (k,n) in [(2,3), (254,254)]: with self.subTest(sec=secret,k=k,n=n): shares=generate(secret,k,n) random.shuffle(shares) self.assertEqual(reconstruct(*shares[:k]), secret) try: self.assertNotEqual(reconstruct(*shares[:k-1]), secret) except UnicodeDecodeError: pass shares=generate(b"\xfeaa",2,3) with self.assertRaises(UnicodeDecodeError): reconstruct(*shares) def testDecode(self): with self.assertRaises(ValueError): decode("AAA") decode("1.") decode(".AAA") decode("1AAA") decode("1.AAAQEAY") decode("1.AAAQEAy=") decode("256.AAAQEAY=") self.assertEqual(decode("2.AAAQEAY="), (2,b"\x00\x01\x02\x03"))