# HG changeset patch # User Laman # Date 2020-11-29 21:10:38 # Node ID 25c5d4c877c607964497908cd9aab3a9b516fdf4 # Parent d19e877af29d6bb5a45bc769067fdd12fdf372c9 dft pluggable into prime_fft diff --git a/src/shamira/fft.py b/src/shamira/fft.py --- a/src/shamira/fft.py +++ b/src/shamira/fft.py @@ -11,7 +11,7 @@ SQUARE_ROOTS = {3: 189, 5: 12, 15: 225, def precompute_x(n): """Return a geometric sequence [1, w, w**2, ..., w**(n-1)], where w**n==1. This can be done only for certain values of n.""" - assert n in SQUARE_ROOTS + assert n in SQUARE_ROOTS, n w = SQUARE_ROOTS[n] # primitive N-th square root of 1 return list(itertools.accumulate([1]+[w]*(n-1), gfmul)) @@ -46,10 +46,10 @@ def compute_inverse(N1, N2): raise ValueError("Failed to find an inverse to {0} mod {1}.".format(N1, N2)) -def prime_fft(p, divisors): +def prime_fft(p, divisors, basic_dft=dft): """https://en.wikipedia.org/wiki/Prime-factor_FFT_algorithm""" if len(divisors) == 1: - return complex_dft(p) + return basic_dft(p) N = len(p) N1 = divisors[0] N2 = N//N1 @@ -59,11 +59,11 @@ def prime_fft(p, divisors): ys = [] for n1 in range(N1): # compute rows p_ = [p[(n2*N1+n1*N2) % N] for n2 in range(N2)] - ys.append(prime_fft(p_, divisors[1:])) + ys.append(prime_fft(p_, divisors[1:], basic_dft)) for k2 in range(N2): # compute cols p_ = [row[k2] for row in ys] - y_ = complex_dft(p_) + y_ = basic_dft(p_) for (yi, row) in zip(y_, ys): # update col row[k2] = yi diff --git a/src/shamira/tests/test_fft.py b/src/shamira/tests/test_fft.py --- a/src/shamira/tests/test_fft.py +++ b/src/shamira/tests/test_fft.py @@ -31,11 +31,11 @@ class TestFFT(TestCase): for divisors in [[3], [2, 3], [3, 5], [3, 5, 17], [2, 3, 5, 7, 11]]: n = functools.reduce(operator.mul, divisors) coefficients = [random.randint(-128, 127) for i in range(n)] - a = prime_fft(coefficients, divisors) + a = prime_fft(coefficients, divisors, complex_dft) b = complex_dft(coefficients) all(self.assertAlmostEqual(ai, bi) for (ai, bi) in zip(a, b)) - def test_dft(self): + def test_finite_dft(self): random.seed(1918) x = {i: precompute_x(i) for i in [3, 5, 15, 17]} # all sets of xs @@ -45,3 +45,12 @@ class TestFFT(TestCase): dft(coefficients), batch_evaluate(coefficients[::-1], x[n]) ) + + def test_finite_prime_fft(self): + random.seed(1918) + for divisors in [[3], [3, 5], [3, 17], [5, 17], [3, 5, 17]]: + n = functools.reduce(operator.mul, divisors) + coefficients = [random.randint(0, 255) for i in range(n)] + a = prime_fft(coefficients, divisors) + b = dft(coefficients) + all(self.assertAlmostEqual(ai, bi) for (ai, bi) in zip(a, b))