diff --git a/src/shamira/fft.py b/src/shamira/fft.py --- a/src/shamira/fft.py +++ b/src/shamira/fft.py @@ -1,17 +1,41 @@ import math import cmath +import itertools + +from .gf256 import gfmul, gfpow + +# values of n-th square roots +SQUARE_ROOTS = {3: 189, 5: 12, 15: 225, 17: 53, 51: 51, 85: 15, 255: 3} + + +def precompute_x(n): + """Return a geometric sequence [1, w, w**2, ..., w**(n-1)], where w**n==1. + This can be done only for certain values of n.""" + assert n in SQUARE_ROOTS + w = SQUARE_ROOTS[n] # primitive N-th square root of 1 + return list(itertools.accumulate([1]+[w]*(n-1), gfmul)) + + +def complex_dft(p): + """Quadratic formula from the definition. The basic case in complex numbers.""" + N = len(p) + w = cmath.exp(-2*math.pi*1j/N) # primitive N-th square root of 1 + y = [0]*N + for k in range(N): + xk = w**k + for n in range(N): + y[k] += p[n] * xk**n + return y def dft(p): - """Quadratic formula from the definition.""" + """Quadratic formula from the definition. In GF256.""" N = len(p) - w = cmath.exp(-2*math.pi*1j/N) # primitive N-th square root of 1 - x = [0]*N + x = precompute_x(N) y = [0]*N for k in range(N): - x[k] = w**k for n in range(N): - y[k] += p[n] * x[k]**n + y[k] ^= gfmul(p[n], gfpow(x[k], n)) return y @@ -25,7 +49,7 @@ def compute_inverse(N1, N2): def prime_fft(p, divisors): """https://en.wikipedia.org/wiki/Prime-factor_FFT_algorithm""" if len(divisors) == 1: - return dft(p) + return complex_dft(p) N = len(p) N1 = divisors[0] N2 = N//N1 @@ -39,7 +63,7 @@ def prime_fft(p, divisors): for k2 in range(N2): # compute cols p_ = [row[k2] for row in ys] - y_ = dft(p_) + y_ = complex_dft(p_) for (yi, row) in zip(y_, ys): # update col row[k2] = yi @@ -49,58 +73,3 @@ def prime_fft(p, divisors): for k2 in range(N2): res[(k1*N2*N2_inv+k2*N1*N1_inv) % N] = ys[k1][k2] return res - - -# arr = [9,8,5,3,8,4,9,7,0,9,5,6,6,2,4] -# print(dft(arr)) -# print() -# print(prime_fft(arr,[3,5])) - -""" -def fft(x): - N = len(x) - if N <= 1: return x - even = fft(x[0::2]) - odd = fft(x[1::2]) - T = [cmath.exp(-2j*math.pi*k/N)*odd[k] for k in range(N//2)] - return [even[k] + T[k] for k in range(N//2)] + \ - [even[k] - T[k] for k in range(N//2)] - - -def fft_mixed(x, r1, r2): - A = dft(x) - for k0 in range(0, r2): - pass - - -def bit_reverse(x, width): - y = 0 - for i in range(width): - y <<= 1 - y |= x&1 - x >>= 1 - return y - - -def bit_reverse_seq(x): - n = len(x) - width = round(math.log2(n)) - return [x[bit_reverse(i, width)] for i in range(n)] - - -def fft2(x): - y = bit_reverse_seq(x) - n = len(x) - omegas = [(-2*math.pi*1j/n)**i for i in range(0, n)] - b = 1 - while b<n: - for j in range(0, n, 2*b): - for k in range(0, b): - alpha = omegas[n*k//(2*b)] - u = y[j+k] - t = alpha*y[j+k+b] - y[j+k] = u+t - y[j+k+b] = u-t - b *= 2 - return y -""" diff --git a/src/shamira/gf256.py b/src/shamira/gf256.py --- a/src/shamira/gf256.py +++ b/src/shamira/gf256.py @@ -38,6 +38,19 @@ def gfmul(a, b): return E[t] +def gfpow(x, k): + """Compute x**k.""" + i = 1 + res = 1 + while i <= k: + if k&i: + res = gfmul(res, x) + x = gfmul(x, x) + i <<= 1 + + return res + + def evaluate(coefs, x): """Evaluate polynomial's value at x. diff --git a/src/shamira/tests/test_fft.py b/src/shamira/tests/test_fft.py --- a/src/shamira/tests/test_fft.py +++ b/src/shamira/tests/test_fft.py @@ -1,19 +1,47 @@ # GNU GPLv3, see LICENSE +import random +import functools +import operator from unittest import TestCase +from ..gf256 import evaluate from ..fft import * +def batch_evaluate(coefs, xs): + return [evaluate(coefs, x) for x in xs] + + class TestFFT(TestCase): - def test_dft(self): - self.assertEqual(dft([0]), [0+0j]) - self.assertEqual(dft([1]), [1+0j]) - self.assertEqual(dft([2]), [2+0j]) - all(self.assertAlmostEqual(a, b) for (a, b) in zip(dft([3, 1]), [4+0j, 2+0j])) - all(self.assertAlmostEqual(a, b) for (a, b) in zip(dft([3, 1, 4]), [8+0j, 0.5+2.59807621j, 0.5-2.59807621j])) - all(self.assertAlmostEqual(a, b) for (a, b) in zip(dft([3, 1, 4, 1]), [9+0j, -1+0j, 5+0j, -1+0j])) + def test_complex_dft(self): + self.assertEqual(complex_dft([0]), [0+0j]) + self.assertEqual(complex_dft([1]), [1+0j]) + self.assertEqual(complex_dft([2]), [2+0j]) + all(self.assertAlmostEqual(a, b) for (a, b) in zip(complex_dft([3, 1]), [4+0j, 2+0j])) + all(self.assertAlmostEqual(a, b) for (a, b) in zip(complex_dft([3, 1, 4]), [8+0j, 0.5+2.59807621j, 0.5-2.59807621j])) + all(self.assertAlmostEqual(a, b) for (a, b) in zip(complex_dft([3, 1, 4, 1]), [9+0j, -1+0j, 5+0j, -1+0j])) all(self.assertAlmostEqual(a, b) for (a, b) in zip( - dft([3, 1, 4, 1, 5]), + complex_dft([3, 1, 4, 1, 5]), [14+0j, 0.80901699+2.04087031j, -0.30901699+5.20431056j, -0.30901699-5.20431056j, 0.80901699-2.04087031j] )) + + def test_complex_prime_fft(self): + random.seed(1918) + for divisors in [[3], [2, 3], [3, 5], [3, 5, 17], [2, 3, 5, 7, 11]]: + n = functools.reduce(operator.mul, divisors) + coefficients = [random.randint(-128, 127) for i in range(n)] + a = prime_fft(coefficients, divisors) + b = complex_dft(coefficients) + all(self.assertAlmostEqual(ai, bi) for (ai, bi) in zip(a, b)) + + def test_dft(self): + random.seed(1918) + x = {i: precompute_x(i) for i in [3, 5, 15, 17]} # all sets of xs + + for n in [3, 5, 15, 17]: + coefficients = [random.randint(0, 255) for i in range(n)] + self.assertEqual( + dft(coefficients), + batch_evaluate(coefficients[::-1], x[n]) + ) diff --git a/src/shamira/tests/test_gf256.py b/src/shamira/tests/test_gf256.py --- a/src/shamira/tests/test_gf256.py +++ b/src/shamira/tests/test_gf256.py @@ -22,6 +22,26 @@ class TestGF256(TestCase): for b in range(256): self.assertEqual(_gfmul(a, b), gfmul(a, b)) + def test_gfpow(self): + self.assertEqual(gfpow(0, 0), 1) + + for i in range(1, 256): + self.assertEqual(gfpow(i, 0), 1) + self.assertEqual(gfpow(i, 1), i) + self.assertEqual(gfpow(0, i), 0) + self.assertEqual(gfpow(1, i), 1) + self.assertEqual(gfpow(i, 256), i) + self.assertEqual(gfpow(i, 2), gfmul(i, i)) + + random.seed(1918) + for i in range(256): + j = random.randint(2, 255) + k = random.randint(3, 255) + y = 1 + for m in range(k): + y = gfmul(y, j) + self.assertEqual(gfpow(j, k), y) + def test_evaluate(self): for x in range(256): (a0, a1, a2, a3) = (x, x>>1, x>>2, x>>3)