diff --git a/src/shamira/fft.py b/src/shamira/fft.py --- a/src/shamira/fft.py +++ b/src/shamira/fft.py @@ -11,7 +11,7 @@ SQUARE_ROOTS = {3: 189, 5: 12, 15: 225, def precompute_x(n): """Return a geometric sequence [1, w, w**2, ..., w**(n-1)], where w**n==1. This can be done only for certain values of n.""" - assert n in SQUARE_ROOTS + assert n in SQUARE_ROOTS, n w = SQUARE_ROOTS[n] # primitive N-th square root of 1 return list(itertools.accumulate([1]+[w]*(n-1), gfmul)) @@ -46,10 +46,10 @@ def compute_inverse(N1, N2): raise ValueError("Failed to find an inverse to {0} mod {1}.".format(N1, N2)) -def prime_fft(p, divisors): +def prime_fft(p, divisors, basic_dft=dft): """https://en.wikipedia.org/wiki/Prime-factor_FFT_algorithm""" if len(divisors) == 1: - return complex_dft(p) + return basic_dft(p) N = len(p) N1 = divisors[0] N2 = N//N1 @@ -59,11 +59,11 @@ def prime_fft(p, divisors): ys = [] for n1 in range(N1): # compute rows p_ = [p[(n2*N1+n1*N2) % N] for n2 in range(N2)] - ys.append(prime_fft(p_, divisors[1:])) + ys.append(prime_fft(p_, divisors[1:], basic_dft)) for k2 in range(N2): # compute cols p_ = [row[k2] for row in ys] - y_ = complex_dft(p_) + y_ = basic_dft(p_) for (yi, row) in zip(y_, ys): # update col row[k2] = yi