// Implementation of raders's fft for prime sized ffts use std::f32::consts::PI; use crate::{ complex::Complex32, fft::{ DFTAlgorithm, FFTDirection, create_fft, dft::NaiveDFT, is_prime, mixed_radix::MixedRadixFFT, }, }; pub struct RaderFFT { permutations: Box<[usize]>, convolution_operand: Box<[Complex32]>, convolution_ifft: Box, convolution_fft: Box, output: Box<[Complex32]>, size: usize, } impl DFTAlgorithm for RaderFFT { fn create(size: usize, direction: FFTDirection) -> Self where Self: Sized, { assert!(is_prime(size)); // Primitive root and its powers let g = compute_prime_primitive_root(size); let permutations: Box<[usize]> = (0..(size - 1)).map(|i| exp_mod(g, i + 1, size)).collect(); // Compute fourrier transform of twiddle factors let twiddle_factors = (0..(size - 1)) .map(|i| { Complex32::cexp( -2. * PI * direction.sign() * (permutations[i] as f32) / (size as f32), ) }) .collect::>(); let mut convolution_fft = create_fft(size - 1, FFTDirection::Forward); convolution_fft.execute(&twiddle_factors); RaderFFT { permutations, convolution_operand: convolution_fft.get_output().iter().copied().collect(), //convolution_fft, convolution_fft, convolution_ifft: create_fft(size - 1, FFTDirection::Inverse), output: vec![Complex32::zero(); size].into(), size, } } fn execute(&mut self, input: &[Complex32]) { // Compute fft of input signal for i in 0..(self.size - 1) { let k = self.permutations[self.size - 1 - i - 1]; // Using output as staging buffer self.output[i] = input[k]; } self.convolution_fft.execute(&self.output); // Compute convolution by multiplying in freq domain for i in 0..(self.size - 1) { // Using output as staging buffer self.output[i] = self.convolution_fft.get_output()[i] * self.convolution_operand[i]; } self.convolution_ifft.execute(&self.output); self.output[0] = Complex32::zero(); for x in input { self.output[0] = self.output[0] + *x; } for i in 0..(self.size - 1) { // Actually compute the output let k = self.permutations[i]; self.output[k] = (self.convolution_ifft.get_output()[i] / (self.size - 1) as f32) + input[0]; } } fn get_output(&self) -> &[Complex32] { &self.output } } pub fn compute_prime_primitive_root(n: usize) -> usize { assert!(is_prime(n)); let phi = n - 1; // Euler's totient for n prime // Test all candidates for i in 1..(n + 1) { // Find multiplicative order of i let mut val = i; let mut order = 1; for _ in 0..n { if val == 1 { break; } val = (val * i) % n; order += 1; } if order == phi { return i; } } unreachable!("Prime must have primitive root"); } pub fn exp_mod(mut n: usize, mut exp: usize, m: usize) -> usize { if m == 1 { return 0; } n %= m; let mut r = 1; while exp > 0 { if exp % 2 == 1 { r = (r * n) % m; } n = (n * n) % m; exp >>= 1; } r }