# GNU GPLv3, see LICENSE
import os
import random
from unittest import TestCase
from gf256 import _gfmul, evaluate
from shamira import generateRaw, generate
from condensed import *
class TestCondensed(TestCase):
_urandom = os.urandom
@classmethod
def setUpClass(cls):
random.seed(17)
os.urandom = lambda n: bytes(random.randint(0, 255) for i in range(n))
@classmethod
def tearDownClass(cls):
os.urandom = cls._urandom
def testGfmul(self):
for a in range(256):
for b in range(256):
self.assertEqual(_gfmul(a, b), gfmul(a, b))
def testGetConstantCoef(self):
self.assertEqual(getConstantCoef((1, 1), (2, 2), (3, 3)), 0)
random.seed(17)
randomMatches = 0
for i in range(10):
k = random.randint(2, 255)
# exact
res = self.checkCoefsMatch(k, k)
self.assertEqual(res[0], res[1])
# overdetermined
res = self.checkCoefsMatch(k, 256)
self.assertEqual(res[0], res[1])
# underdetermined => random
res = self.checkCoefsMatch(k, k-1)
if res[0]==res[1]:
randomMatches += 1
self.assertLess(randomMatches, 2) # with a chance (255/256)**10=0.96 there should be no match
def checkCoefsMatch(self, k, m):
coefs = [random.randint(0, 255) for i in range(k)]
points = [(j, evaluate(coefs, j)) for j in range(1, 256)]
random.shuffle(points)
return (getConstantCoef(*points[:m]), coefs[0])
def testGenerateReconstructRaw(self):
for (k, n) in [(2, 3), (254, 254)]:
shares = generateRaw(b"abcd", k, n)
random.shuffle(shares)
self.assertEqual(reconstructRaw(*shares[:k]), b"abcd")
self.assertNotEqual(reconstructRaw(*shares[:k-1]), b"abcd")
def testGenerateReconstruct(self):
for secret in ["abcde", "ěščřžý"]:
for (k, n) in [(2, 3), (254, 254)]:
with self.subTest(sec=secret, k=k, n=n):
shares = generate(secret, k, n)
random.shuffle(shares)
self.assertEqual(reconstruct(*shares[:k]), secret)
try:
self.assertNotEqual(reconstruct(*shares[:k-1]), secret)
except SException:
pass
shares = generate(b"\xfeaa", 2, 3)
with self.assertRaises(SException):
reconstruct(*shares)
def testDecode(self):
with self.assertRaises(SException):
decode("AAA")
decode("1.")
decode(".AAA")
decode("1AAA")
decode("1.AAAQEAY")
decode("1.AAAQEAy=")
decode("256.AAAQEAY=")
self.assertEqual(decode("2.AAAQEAY="), (2, b"\x00\x01\x02\x03"))