Files @ 3c3a529119dd
Branch filter:

Location: Shamira/src/shamira/tests/test_condensed.py

Laman
updated performance.md
# GNU GPLv3, see LICENSE

import os
import random
from unittest import TestCase

from ..gf256 import _gfmul, evaluate
from .. import generate_raw, generate
from ..condensed import *


class TestCondensed(TestCase):
	_urandom = os.urandom

	@classmethod
	def setUpClass(cls):
		random.seed(17)
		os.urandom = lambda n: bytes(random.randint(0, 255) for i in range(n))

	@classmethod
	def tearDownClass(cls):
		os.urandom = cls._urandom

	def test_gfmul(self):
		for a in range(256):
			for b in range(256):
				self.assertEqual(_gfmul(a, b), gfmul(a, b))

	def test_get_constant_coef(self):
		self.assertEqual(get_constant_coef((1, 1), (2, 2), (3, 3)), 0)

		random.seed(17)
		random_matches = 0
		for i in range(10):
			k = random.randint(2, 255)

			# exact
			res = self.check_coefs_match(k, k)
			self.assertEqual(res[0], res[1])

			# overdetermined
			res = self.check_coefs_match(k, 256)
			self.assertEqual(res[0], res[1])

			# underdetermined => random
			res = self.check_coefs_match(k, k-1)
			if res[0]==res[1]:
				random_matches += 1
		self.assertLess(random_matches, 2)  # with a chance (255/256)**10=0.96 there should be no match

	def check_coefs_match(self, k, m):
		coefs = [random.randint(0, 255) for i in range(k)]
		points = [(j, evaluate(coefs, j)) for j in range(1, 256)]
		random.shuffle(points)
		return (get_constant_coef(*points[:m]), coefs[-1])

	def test_generate_reconstruct_raw(self):
		for (k, n) in [(2, 3), (254, 254)]:
			shares = generate_raw(b"abcd", k, n)
			random.shuffle(shares)
			self.assertEqual(reconstruct_raw(*shares[:k]), b"abcd")
			self.assertNotEqual(reconstruct_raw(*shares[:k-1]), b"abcd")

	def test_generate_reconstruct(self):
		for secret in ["abcde", "ěščřžý"]:
			for (k, n) in [(2, 3), (254, 254)]:
				with self.subTest(sec=secret, k=k, n=n):
					shares = generate(secret, k, n)
					random.shuffle(shares)
					self.assertEqual(reconstruct(*shares[:k]), secret)
					try:
						self.assertNotEqual(reconstruct(*shares[:k-1]), secret)
					except SException:
						pass
		shares = generate(b"\xfeaa", 2, 3)
		with self.assertRaises(SException):
			reconstruct(*shares)

	def test_decode(self):
		with self.assertRaises(SException):
			decode("AAA")
			decode("1.")
			decode(".AAA")
			decode("1AAA")
			decode("1.AAAQEAY")
			decode("1.AAAQEAy=")
			decode("256.AAAQEAY=")
		self.assertEqual(decode("2.AAAQEAY="), (2, b"\x00\x01\x02\x03"))