diff --git a/out.png b/out.png index c20cb65..0c9e62c 100644 Binary files a/out.png and b/out.png differ diff --git a/src/complex.rs b/src/complex.rs index b43f157..05b892c 100644 --- a/src/complex.rs +++ b/src/complex.rs @@ -119,7 +119,8 @@ impl Complex32 { } else if self.re == 0. { if self.im >= 0. { PI / 2.0 } else { -PI / 2.0 } } else { - (self.im / self.re).atan() + //(self.im / self.re).atan() + self.im.atan2(self.re) } } } diff --git a/src/fft.rs b/src/fft.rs index cf4a9d9..d25dbc5 100644 --- a/src/fft.rs +++ b/src/fft.rs @@ -1,13 +1,18 @@ pub mod dft; pub mod mixed_radix; pub mod rader; +pub mod rader2; pub mod radix2; +pub mod windows; use std::iter::Map; use crate::{ complex::Complex32, - fft::{dft::NaiveDFT, mixed_radix::MixedRadixFFT, rader::RaderFFT, radix2::Radix2FFT}, + fft::{ + dft::NaiveDFT, mixed_radix::MixedRadixFFT, rader::RaderFFT, rader2::Rader2FFT, + radix2::Radix2FFT, + }, }; pub trait DFT { @@ -18,22 +23,30 @@ pub trait DFT { fn get_input(&mut self) -> &mut [Complex32]; fn get_output(&self) -> &[Complex32]; - fn execute(&mut self); + fn execute(&mut self, window: fn(f32) -> f32); +} + +pub trait DFTWindow { + fn eval(t: f32) -> f32; } pub fn create_fft(size: usize) -> Box { - if size == 1 { + if size == 1 || size < 16 { + println!("Naive {size}"); return Box::new(NaiveDFT::create(size)); } if size.count_ones() == 1 { // TODO: Return hardcoded fft for small sized + println!("Radix 2 {size}"); return Box::new(Radix2FFT::create(size)); } if is_prime(size) { - return Box::new(NaiveDFT::create(size)); + println!("Prime rader {size}"); + return Box::new(RaderFFT::create(size)); } + println!("Mixed radix {size}"); Box::new(MixedRadixFFT::create(size)) } diff --git a/src/fft/dft.rs b/src/fft/dft.rs index f50ea89..df28ebd 100644 --- a/src/fft/dft.rs +++ b/src/fft/dft.rs @@ -10,23 +10,41 @@ pub struct NaiveDFT { impl DFT for NaiveDFT { fn create(size: usize) -> Self - where Self: Sized + where + Self: Sized, { - NaiveDFT - { + NaiveDFT { output_buffer: vec![Complex32::zero(); size].into_boxed_slice(), input_buffer: vec![Complex32::zero(); size].into_boxed_slice(), - size + size, } } - fn execute(&mut self) - { - self.output_buffer.iter_mut().enumerate().for_each(|(freq, out)| - *out = self.input_buffer.iter().enumerate().fold(Complex32::zero(), |acc, (i, s)| - acc + (*s * Complex32::cexp(- 2. * PI * (freq as f32 * i as f32 / self.size as f32))) - ) - ) + fn execute(&mut self, window: fn(f32) -> f32) { + for (freq, out) in self.output_buffer.iter_mut().enumerate() { + *out = Complex32::zero(); + for (i, inp) in self.input_buffer.iter().enumerate() { + *out = *out + + ((*inp * Complex32::cexp(-2. * PI * (i * freq) as f32 / self.size as f32)) + * window(i as f32 / self.size as f32)); + } + } + + /* + self.output_buffer + .iter_mut() + .enumerate() + .for_each(|(freq, out)| { + *out = self + .input_buffer + .iter() + .enumerate() + .map(|(i, s)| { + (*s) * Complex32::cexp(-2. * PI * (i * freq) as f32 / self.size as f32) + }) + .sum() + }) + */ } fn get_input(&mut self) -> &mut [Complex32] { diff --git a/src/fft/mixed_radix.rs b/src/fft/mixed_radix.rs index d4643ac..ba97098 100644 --- a/src/fft/mixed_radix.rs +++ b/src/fft/mixed_radix.rs @@ -4,7 +4,7 @@ use std::f32::consts::PI; use crate::{ complex::Complex32, - fft::{DFT, create_fft, dft::NaiveDFT, prime_factors}, + fft::{DFT, create_fft, dft::NaiveDFT, prime_factors, windows}, }; pub struct MixedRadixFFT { @@ -26,13 +26,14 @@ impl DFT for MixedRadixFFT { fn create(size: usize) -> Self { let q = decide_radix_factor(size); let p = size / q; + println!("{} {}", p, q); // TODO: Figure out why it does not work in the other direction ... let (p, q) = (q, p); - let qfft = create_fft(q); - let pfft = create_fft(p); + //let qfft = create_fft(q); + //let pfft = create_fft(p); - //let qfft = Box::new(NaiveDFT::create(q)); - //let pfft = Box::new(NaiveDFT::create(p)); + let qfft = Box::new(NaiveDFT::create(q)); + let pfft = Box::new(NaiveDFT::create(p)); MixedRadixFFT { input_buffer: vec![Complex32::zero(); size].into_boxed_slice(), @@ -48,15 +49,17 @@ impl DFT for MixedRadixFFT { } } - fn execute(&mut self) { + fn execute(&mut self, window: fn(f32) -> f32) { // Perform p ffts of size q for k0 in 0..self.p { // Copy samples into input buffer for k1 in 0..self.q { - self.qfft.get_input()[k1] = self.input_buffer[k1 * self.p + k0]; + let k = k1 * self.p + k0; + self.qfft.get_input()[k1] = + self.input_buffer[k] * window(k as f32 / self.size as f32); } - self.qfft.execute(); + self.qfft.execute(windows::rectanguar); for j0 in 0..self.q { // "Store j0'th of k0'th fft into staging buffer" @@ -72,7 +75,7 @@ impl DFT for MixedRadixFFT { self.pfft.get_input()[k0] = self.staging_buffer[j0 * self.p + k0]; } - self.pfft.execute(); + self.pfft.execute(windows::rectanguar); for j1 in 0..self.p { self.output_buffer[j1 * self.q + j0] = self.pfft.get_output()[j1]; @@ -93,7 +96,7 @@ fn compute_twiddle_factors(size: usize) -> Box<[Complex32]> { let mut factors = vec![Complex32::zero(); size].into_boxed_slice(); for i in 0..size { - factors[i] = Complex32::cexp(2. * PI * i as f32 / (size as f32)); + factors[i] = Complex32::cexp(-2. * PI * i as f32 / (size as f32)); } factors } diff --git a/src/fft/rader.rs b/src/fft/rader.rs index 3f0f6c8..111fdb0 100644 --- a/src/fft/rader.rs +++ b/src/fft/rader.rs @@ -5,19 +5,18 @@ use std::{f32::consts::PI, ops::Deref}; use super::mixed_radix; use crate::{ complex::Complex32, - fft::{DFT, create_fft, dft::NaiveDFT, is_prime}, + fft::{DFT, create_fft, dft::NaiveDFT, is_prime, windows}, }; pub struct RaderFFT { input_buffer: Box<[Complex32]>, output_buffer: Box<[Complex32]>, - size: usize, + permutations: Box<[usize]>, + convolution_op: Box<[Complex32]>, + conv_fft: Box, - // Fourrier transform of the exponential terms - convolution_operand: Box<[Complex32]>, - convolution_fft: Box, // TODO: Use fft - permutation: Box<[usize]>, + size: usize, } impl DFT for RaderFFT { @@ -25,46 +24,57 @@ impl DFT for RaderFFT { where Self: Sized, { + assert!(is_prime(size)); let g = compute_prime_primitive_root(size); - let permutation: Box<[usize]> = (0..(size - 1)).map(|i| exp_mod(g, i + 1, size)).collect(); + let permutations: Box<[usize]> = (0..(size - 1)).map(|i| exp_mod(g, i, size)).collect(); + + let mut conv_fft = Box::new(NaiveDFT::create(size - 1)); + conv_fft + .get_input() + .iter_mut() + .enumerate() + .for_each(|(i, x)| { + *x = Complex32::cexp(-2. * PI * (permutations[i] as f32) / (size as f32)) + }); + conv_fft.execute(windows::rectanguar); + RaderFFT { - input_buffer: vec![Complex32::zero(); size].into_boxed_slice(), - output_buffer: vec![Complex32::zero(); size].into_boxed_slice(), + input_buffer: vec![Complex32::zero(); size].into(), + output_buffer: vec![Complex32::zero(); size].into(), + + permutations, + convolution_op: conv_fft.get_output().iter().copied().collect(), + conv_fft, size, - - convolution_operand: compute_convolution_operand(size, &permutation), - convolution_fft: create_fft(size - 1), - permutation, } } - fn execute(&mut self) { + fn execute(&mut self, window: fn(f32) -> f32) { + // Compute fft of input signal for i in 0..(self.size - 1) { - self.convolution_fft.get_input()[i] = - self.input_buffer[self.permutation[self.size - 1 - i - 1]] + let k = self.permutations[i]; + self.conv_fft.get_input()[i] = self.input_buffer[k]; } - self.convolution_fft.execute(); + self.conv_fft.execute(windows::rectanguar); - // Use output buffer as staging buffer for i in 0..(self.size - 1) { self.output_buffer[i] = - self.convolution_fft.get_output()[i] * self.convolution_operand[i]; + self.conv_fft.get_output()[self.size - 1 - i - 1] * self.convolution_op[i]; } for i in 0..(self.size - 1) { - self.convolution_fft.get_input()[i] = self.output_buffer[self.size - 1 - i - 1]; + //self.conv_fft.get_input()[i] = self.output_buffer[self.size - 1 - i - 1]; + self.conv_fft.get_input()[i] = self.output_buffer[i]; } - self.convolution_fft.get_input()[0] = - self.convolution_fft.get_input()[0] + self.input_buffer[0]; - // Compute ifft to obtain convolution - self.convolution_fft.execute(); + self.conv_fft.execute(windows::rectanguar); for i in 0..(self.size - 1) { - self.output_buffer[self.permutation[i]] = - self.convolution_fft.get_output()[i] / (self.size - 1) as f32; + let k = self.permutations[i]; + self.output_buffer[k] = + (self.conv_fft.get_output()[i] / (self.size - 1) as f32) + self.input_buffer[0]; } self.output_buffer[0] = self.input_buffer.iter().copied().sum(); @@ -79,18 +89,6 @@ impl DFT for RaderFFT { } } -pub fn compute_convolution_operand(n: usize, permutation: &[usize]) -> Box<[Complex32]> { - //let mut fft = create_fft(n - 1); - let mut fft = NaiveDFT::create(n - 1); - - fft.get_input() - .iter_mut() - .enumerate() - .for_each(|(i, x)| *x = Complex32::cexp(-2. * PI * (permutation[i] as f32) / (n as f32))); - fft.execute(); - fft.get_output().iter().copied().collect() -} - pub fn compute_prime_primitive_root(n: usize) -> usize { assert!(is_prime(n)); diff --git a/src/fft/rader2.rs b/src/fft/rader2.rs new file mode 100644 index 0000000..77f36e4 --- /dev/null +++ b/src/fft/rader2.rs @@ -0,0 +1,167 @@ +// Implementation of raders's fft for prime sized ffts + +use std::{f32::consts::PI, ops::Deref}; + +use super::mixed_radix; +use crate::{ + complex::Complex32, + fft::{DFT, create_fft, dft::NaiveDFT, is_prime, windows}, +}; + +pub struct Rader2FFT { + input_buffer: Box<[Complex32]>, + output_buffer: Box<[Complex32]>, + + size: usize, + sub_size: usize, + + // Fourrier transform of the exponential terms + convolution_operand: Box<[Complex32]>, + convolution_fft: Box, // TODO: Use fft + permutation: Box<[usize]>, +} + +impl DFT for Rader2FFT { + fn create(size: usize) -> Self + where + Self: Sized, + { + let g = compute_prime_primitive_root(size); + let permutation: Box<[usize]> = (0..(size - 1)).map(|i| exp_mod(g, i + 1, size)).collect(); + let sub_size = next_pow2((2 * size - 4) - 1); + Rader2FFT { + input_buffer: vec![Complex32::zero(); size].into_boxed_slice(), + output_buffer: vec![Complex32::zero(); sub_size].into_boxed_slice(), + + size, + sub_size, + + convolution_operand: compute_convolution_operand(size, sub_size, &permutation), + //convolution_fft: create_fft(next_pow2((2 * size - 4) - 1)), + convolution_fft: Box::new(NaiveDFT::create(sub_size)), + permutation, + } + } + + fn execute(&mut self, window: fn(f32) -> f32) { + self.convolution_fft.get_input()[0] = self.input_buffer[self.permutation[self.size - 2]]; + + for i in 0..(self.sub_size - self.size + 1) { + self.convolution_fft.get_input()[i + 1] = Complex32::zero(); + } + for i in 1..(self.size - 1) { + let k = self.permutation[self.size - 1 - i - 1]; + self.convolution_fft.get_input()[i + self.sub_size - self.size + 1] = + self.input_buffer[k] * window(k as f32 / self.size as f32) + } + + self.convolution_fft.execute(windows::rectanguar); + + // Use output buffer as staging buffer + for i in 0..(self.sub_size) { + self.output_buffer[i] = + self.convolution_fft.get_output()[i] * self.convolution_operand[i]; + } + + for i in 0..(self.sub_size) { + self.convolution_fft.get_input()[i] = self.output_buffer[i]; + } + /* + self.convolution_fft.get_input()[0] = + self.convolution_fft.get_input()[0] + self.input_buffer[0] * window(0.); + */ + + // Compute ifft to obtain convolution + self.convolution_fft.execute(window); + + for i in 0..(self.size - 1) { + self.output_buffer[self.permutation[i]] = + self.convolution_fft.get_output()[i] / self.sub_size as f32; + } + + self.output_buffer[0] = self + .input_buffer + .iter() + .copied() + .enumerate() + .map(|(i, x)| x * window(i as f32 / self.size as f32)) + .sum(); + } + + fn get_input(&mut self) -> &mut [Complex32] { + &mut self.input_buffer + } + + fn get_output(&self) -> &[Complex32] { + &self.output_buffer + } +} + +pub fn compute_convolution_operand( + n: usize, + sub_size: usize, + permutation: &[usize], +) -> Box<[Complex32]> { + //let mut fft = create_fft(sub_size); + let mut fft = NaiveDFT::create(sub_size); + + fft.get_input().iter_mut().enumerate().for_each(|(i, x)| { + *x = Complex32::cexp(-2. * PI * (permutation[i % (n - 1)] as f32) / (n as f32)) + }); + fft.execute(windows::rectanguar); + fft.get_output().iter().copied().collect() +} + +pub fn compute_prime_primitive_root(n: usize) -> usize { + assert!(is_prime(n)); + + let phi = n - 1; // Euler's totient for n prime + + // Test all candidates + for i in 1..(n + 1) { + // Find multiplicative order of i + let mut val = i; + let mut order = 1; + for j in 0..n { + if val == 1 { + break; + } + val = (val * i) % n; + order += 1; + } + + if order == phi { + return i; + } + } + + unreachable!("Prime must have primitive root"); +} + +pub fn exp_mod(mut n: usize, mut exp: usize, m: usize) -> usize { + if m == 1 { + return 0; + } + + n %= m; + let mut r = 1; + while exp > 0 { + if exp % 2 == 1 { + r = (r * n) % m; + } + + n = (n * n) % m; + exp >>= 1; + } + + r +} + +pub fn next_pow2(mut n: usize) -> usize { + let mut pow = 0; + while n > 0 { + n >>= 1; + pow += 1; + } + 1 << pow +} diff --git a/src/fft/radix2.rs b/src/fft/radix2.rs index c169236..94c8334 100644 --- a/src/fft/radix2.rs +++ b/src/fft/radix2.rs @@ -26,10 +26,11 @@ impl DFT for Radix2FFT { } } - fn execute(&mut self) { + fn execute(&mut self, window: fn(f32) -> f32) { // Reorder samples for (i, x) in self.output_buffer.iter_mut().enumerate() { - *x = self.input_buffer[reverse_bits(i, self.size as u32)]; + let k = reverse_bits(i, self.size as u32); + *x = self.input_buffer[k] * window(k as f32 / self.size as f32); } for step in 1..(self.size + 1) { @@ -40,7 +41,7 @@ impl DFT for Radix2FFT { // Compute current polynomial at each unit root let a = self.output_buffer[s + i]; let b = self.output_buffer[s + i + mid_point]; - let angle = - 2. * PI * (i as f32) / (pol_length as f32); + let angle = -2. * PI * (i as f32) / (pol_length as f32); let phasor = Complex32::cexp(angle); self.output_buffer[i + s] = a + phasor * b; self.output_buffer[i + s + mid_point] = a - phasor * b; diff --git a/src/fft/windows.rs b/src/fft/windows.rs new file mode 100644 index 0000000..1893288 --- /dev/null +++ b/src/fft/windows.rs @@ -0,0 +1,7 @@ +pub fn rectanguar(t: f32) -> f32 { + 1. +} + +pub fn bartlett(t: f32) -> f32 { + if t < 0.5 { 2. * t } else { 2. - 2. * t } +} diff --git a/src/main.rs b/src/main.rs index 538c2f2..0213a8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,9 +17,15 @@ use fft::rader; use nco::Nco; use plotters::prelude::*; - use crate::fft::{ - create_fft, dft::NaiveDFT, mixed_radix::MixedRadixFFT, prime_factors, rader::{compute_prime_primitive_root, exp_mod, RaderFFT}, radix2::Radix2FFT, DFT + DFT, create_fft, + dft::NaiveDFT, + mixed_radix::MixedRadixFFT, + prime_factors, + rader::{RaderFFT, compute_prime_primitive_root, exp_mod}, + rader2::{Rader2FFT, next_pow2}, + radix2::Radix2FFT, + windows, }; // Utilities @@ -43,23 +49,25 @@ fn test() { let freq2 = 2. * PI / 8.0; //let sample_count = 71*71; - //let sample_count = 71*71; - let sample_count = 4800; + //let sample_count = 71 * 71; + //let sample_count = 4804; + let sample_count = 4799; let mut o1 = Nco::new(freq1); let mut o2 = Nco::new(freq2); - - let mut fft = create_fft(sample_count); - //let mut fft = create_fft(sample_count); - for x in fft.get_input().iter_mut() { - *x = o1.cexp() + o2.cexp(); + let mut fft = RaderFFT::create(sample_count); + let mut dft = RaderFFT::create(sample_count); + for (x, y) in fft.get_input().iter_mut().zip(dft.get_input().iter_mut()) { + *y = o1.cexp();// + o2.cexp(); + //*y = *x; o1.step(); o2.step(); } - fft.execute(); + //fft.execute(windows::rectanguar); + dft.execute(windows::rectanguar); let root = BitMapBackend::new("out.png", (640, 480)).into_drawing_area(); root.fill(&WHITE).unwrap(); @@ -68,22 +76,37 @@ fn test() { .margin(5) .x_label_area_size(30) .y_label_area_size(30) - .build_cartesian_2d(0f32..(sample_count as f32), 0f32..(sample_count as f32)).unwrap(); + .build_cartesian_2d(0f32..(sample_count as f32), -PI..PI) + .unwrap(); //chart.configure_mesh().draw()?; chart .draw_series(LineSeries::new( - (0..sample_count).zip(fft.get_output().iter()).map(|(x, y)| (x as f32, y.mag())), + (0..sample_count) + .zip(dft.get_output().iter()) + .map(|(x, y)| (x as f32, (*y).arg() * (*y).mag())), &RED, - )).unwrap() - .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], &RED)); + )) + .unwrap() + .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], RED)); + + chart + .draw_series(LineSeries::new( + (0..sample_count) + .zip(dft.get_output().iter()) + .map(|(x, y)| (x as f32, (*y).mag() / sample_count as f32)), + &BLUE, + )) + .unwrap() + .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], BLUE)); chart .configure_series_labels() .background_style(&WHITE.mix(0.8)) .border_style(&BLACK) - .draw().unwrap(); + .draw() + .unwrap(); root.present().unwrap(); }