diff --git a/src/shamira/fft.py b/src/shamira/fft.py --- a/src/shamira/fft.py +++ b/src/shamira/fft.py @@ -1,17 +1,41 @@ import math import cmath +import itertools + +from .gf256 import gfmul, gfpow + +# values of n-th square roots +SQUARE_ROOTS = {3: 189, 5: 12, 15: 225, 17: 53, 51: 51, 85: 15, 255: 3} + + +def precompute_x(n): + """Return a geometric sequence [1, w, w**2, ..., w**(n-1)], where w**n==1. + This can be done only for certain values of n.""" + assert n in SQUARE_ROOTS + w = SQUARE_ROOTS[n] # primitive N-th square root of 1 + return list(itertools.accumulate([1]+[w]*(n-1), gfmul)) + + +def complex_dft(p): + """Quadratic formula from the definition. The basic case in complex numbers.""" + N = len(p) + w = cmath.exp(-2*math.pi*1j/N) # primitive N-th square root of 1 + y = [0]*N + for k in range(N): + xk = w**k + for n in range(N): + y[k] += p[n] * xk**n + return y def dft(p): - """Quadratic formula from the definition.""" + """Quadratic formula from the definition. In GF256.""" N = len(p) - w = cmath.exp(-2*math.pi*1j/N) # primitive N-th square root of 1 - x = [0]*N + x = precompute_x(N) y = [0]*N for k in range(N): - x[k] = w**k for n in range(N): - y[k] += p[n] * x[k]**n + y[k] ^= gfmul(p[n], gfpow(x[k], n)) return y @@ -25,7 +49,7 @@ def compute_inverse(N1, N2): def prime_fft(p, divisors): """https://en.wikipedia.org/wiki/Prime-factor_FFT_algorithm""" if len(divisors) == 1: - return dft(p) + return complex_dft(p) N = len(p) N1 = divisors[0] N2 = N//N1 @@ -39,7 +63,7 @@ def prime_fft(p, divisors): for k2 in range(N2): # compute cols p_ = [row[k2] for row in ys] - y_ = dft(p_) + y_ = complex_dft(p_) for (yi, row) in zip(y_, ys): # update col row[k2] = yi @@ -49,58 +73,3 @@ def prime_fft(p, divisors): for k2 in range(N2): res[(k1*N2*N2_inv+k2*N1*N1_inv) % N] = ys[k1][k2] return res - - -# arr = [9,8,5,3,8,4,9,7,0,9,5,6,6,2,4] -# print(dft(arr)) -# print() -# print(prime_fft(arr,[3,5])) - -""" -def fft(x): - N = len(x) - if N <= 1: return x - even = fft(x[0::2]) - odd = fft(x[1::2]) - T = [cmath.exp(-2j*math.pi*k/N)*odd[k] for k in range(N//2)] - return [even[k] + T[k] for k in range(N//2)] + \ - [even[k] - T[k] for k in range(N//2)] - - -def fft_mixed(x, r1, r2): - A = dft(x) - for k0 in range(0, r2): - pass - - -def bit_reverse(x, width): - y = 0 - for i in range(width): - y <<= 1 - y |= x&1 - x >>= 1 - return y - - -def bit_reverse_seq(x): - n = len(x) - width = round(math.log2(n)) - return [x[bit_reverse(i, width)] for i in range(n)] - - -def fft2(x): - y = bit_reverse_seq(x) - n = len(x) - omegas = [(-2*math.pi*1j/n)**i for i in range(0, n)] - b = 1 - while b>1, x>>2, x>>3)