| | |
| | |
| | |
| |
|
| | use crate::{Error, Result}; |
| | use ndarray::{Array1, Array2, Axis}; |
| | use num_complex::Complex; |
| | use realfft::RealFftPlanner; |
| | use std::f32::consts::PI; |
| |
|
| | use super::AudioConfig; |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct MelFilterbank { |
| | |
| | pub filters: Array2<f32>, |
| | |
| | pub sample_rate: u32, |
| | |
| | pub n_mels: usize, |
| | |
| | pub n_fft: usize, |
| | } |
| |
|
| | impl MelFilterbank { |
| | |
| | pub fn new(sample_rate: u32, n_fft: usize, n_mels: usize, fmin: f32, fmax: f32) -> Self { |
| | let filters = create_mel_filterbank(sample_rate, n_fft, n_mels, fmin, fmax); |
| | Self { |
| | filters, |
| | sample_rate, |
| | n_mels, |
| | n_fft, |
| | } |
| | } |
| |
|
| | |
| | pub fn apply(&self, spectrogram: &Array2<f32>) -> Array2<f32> { |
| | |
| | |
| | |
| | self.filters.dot(spectrogram) |
| | } |
| | } |
| |
|
| | |
| | pub fn hz_to_mel(hz: f32) -> f32 { |
| | 2595.0 * (1.0 + hz / 700.0).log10() |
| | } |
| |
|
| | |
| | pub fn mel_to_hz(mel: f32) -> f32 { |
| | 700.0 * (10f32.powf(mel / 2595.0) - 1.0) |
| | } |
| |
|
| | |
| | fn create_mel_filterbank( |
| | sample_rate: u32, |
| | n_fft: usize, |
| | n_mels: usize, |
| | fmin: f32, |
| | fmax: f32, |
| | ) -> Array2<f32> { |
| | let n_freqs = n_fft / 2 + 1; |
| |
|
| | |
| | let mel_min = hz_to_mel(fmin); |
| | let mel_max = hz_to_mel(fmax); |
| |
|
| | |
| | let mel_points: Vec<f32> = (0..=n_mels + 1) |
| | .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32) |
| | .collect(); |
| |
|
| | |
| | let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect(); |
| |
|
| | |
| | let bin_points: Vec<usize> = hz_points |
| | .iter() |
| | .map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate as f32).floor() as usize) |
| | .collect(); |
| |
|
| | |
| | let mut filters = Array2::zeros((n_mels, n_freqs)); |
| |
|
| | for m in 0..n_mels { |
| | let f_left = bin_points[m]; |
| | let f_center = bin_points[m + 1]; |
| | let f_right = bin_points[m + 2]; |
| |
|
| | |
| | for k in f_left..f_center { |
| | if k < n_freqs { |
| | filters[[m, k]] = (k - f_left) as f32 / (f_center - f_left).max(1) as f32; |
| | } |
| | } |
| |
|
| | |
| | for k in f_center..f_right { |
| | if k < n_freqs { |
| | filters[[m, k]] = (f_right - k) as f32 / (f_right - f_center).max(1) as f32; |
| | } |
| | } |
| | } |
| |
|
| | filters |
| | } |
| |
|
| | |
| | fn hann_window(size: usize) -> Vec<f32> { |
| | (0..size) |
| | .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / size as f32).cos())) |
| | .collect() |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pub fn stft( |
| | signal: &[f32], |
| | n_fft: usize, |
| | hop_length: usize, |
| | win_length: usize, |
| | ) -> Result<Array2<Complex<f32>>> { |
| | if signal.is_empty() { |
| | return Err(Error::Audio("Empty signal".into())); |
| | } |
| |
|
| | |
| | let window = hann_window(win_length); |
| |
|
| | |
| | let pad_length = n_fft / 2; |
| | let mut padded = vec![0.0f32; pad_length]; |
| | padded.extend_from_slice(signal); |
| | padded.extend(vec![0.0f32; pad_length]); |
| |
|
| | |
| | let num_frames = (padded.len() - n_fft) / hop_length + 1; |
| | let n_freqs = n_fft / 2 + 1; |
| |
|
| | |
| | let mut planner = RealFftPlanner::<f32>::new(); |
| | let fft = planner.plan_fft_forward(n_fft); |
| |
|
| | |
| | let mut stft_matrix = Array2::zeros((n_freqs, num_frames)); |
| |
|
| | |
| | let mut input_buffer = vec![0.0f32; n_fft]; |
| | let mut output_buffer = vec![Complex::new(0.0f32, 0.0f32); n_freqs]; |
| |
|
| | for (frame_idx, start) in (0..padded.len() - n_fft + 1) |
| | .step_by(hop_length) |
| | .enumerate() |
| | { |
| | if frame_idx >= num_frames { |
| | break; |
| | } |
| |
|
| | |
| | for i in 0..win_length { |
| | input_buffer[i] = padded[start + i] * window[i]; |
| | } |
| | |
| | for i in win_length..n_fft { |
| | input_buffer[i] = 0.0; |
| | } |
| |
|
| | |
| | fft.process(&mut input_buffer, &mut output_buffer) |
| | .map_err(|e| Error::Audio(format!("FFT failed: {}", e)))?; |
| |
|
| | |
| | for (freq_idx, &val) in output_buffer.iter().enumerate() { |
| | stft_matrix[[freq_idx, frame_idx]] = val; |
| | } |
| | } |
| |
|
| | Ok(stft_matrix) |
| | } |
| |
|
| | |
| | pub fn magnitude_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> { |
| | stft_matrix.mapv(|c| c.norm()) |
| | } |
| |
|
| | |
| | pub fn power_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> { |
| | stft_matrix.mapv(|c| c.norm_sqr()) |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | pub fn mel_spectrogram(signal: &[f32], config: &AudioConfig) -> Result<Array2<f32>> { |
| | |
| | let stft_matrix = stft(signal, config.n_fft, config.hop_length, config.win_length)?; |
| |
|
| | |
| | let power_spec = power_spectrogram(&stft_matrix); |
| |
|
| | |
| | let mel_fb = MelFilterbank::new( |
| | config.sample_rate, |
| | config.n_fft, |
| | config.n_mels, |
| | config.fmin, |
| | config.fmax, |
| | ); |
| |
|
| | |
| | let mel_spec = mel_fb.apply(&power_spec); |
| |
|
| | |
| | let log_mel_spec = mel_spec.mapv(|x| (x.max(1e-10)).ln()); |
| |
|
| | Ok(log_mel_spec) |
| | } |
| |
|
| | |
| | pub fn mel_spectrogram_normalized( |
| | signal: &[f32], |
| | config: &AudioConfig, |
| | mean: Option<f32>, |
| | std: Option<f32>, |
| | ) -> Result<Array2<f32>> { |
| | let mut mel_spec = mel_spectrogram(signal, config)?; |
| |
|
| | |
| | if let (Some(m), Some(s)) = (mean, std) { |
| | mel_spec.mapv_inplace(|x| (x - m) / s); |
| | } else { |
| | |
| | let m = mel_spec.mean().unwrap_or(0.0); |
| | let s = mel_spec.std(0.0); |
| | if s > 1e-8 { |
| | mel_spec.mapv_inplace(|x| (x - m) / s); |
| | } |
| | } |
| |
|
| | Ok(mel_spec) |
| | } |
| |
|
| | |
| | pub fn mel_to_linear(mel_spec: &Array2<f32>, mel_fb: &MelFilterbank) -> Array2<f32> { |
| | |
| | let filters_t = mel_fb.filters.t(); |
| | let gram = mel_fb.filters.dot(&filters_t); |
| |
|
| | |
| | filters_t.dot(mel_spec) |
| | } |
| |
|
| | |
| | pub fn frame_energy(mel_spec: &Array2<f32>) -> Array1<f32> { |
| | mel_spec.sum_axis(Axis(0)) |
| | } |
| |
|
| | |
| | pub fn voice_activity_detection(mel_spec: &Array2<f32>, threshold_db: f32) -> Vec<bool> { |
| | let energy = frame_energy(mel_spec); |
| | let max_energy = energy.iter().cloned().fold(f32::NEG_INFINITY, f32::max); |
| | let threshold = max_energy + threshold_db; |
| |
|
| | energy.iter().map(|&e| e > threshold).collect() |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_hz_to_mel() { |
| | |
| | assert!((hz_to_mel(0.0) - 0.0).abs() < 1e-6); |
| | assert!((hz_to_mel(1000.0) - 1000.0).abs() < 50.0); |
| | } |
| |
|
| | #[test] |
| | fn test_mel_to_hz() { |
| | |
| | let hz = 440.0; |
| | let mel = hz_to_mel(hz); |
| | let hz_back = mel_to_hz(mel); |
| | assert!((hz - hz_back).abs() < 1e-4); |
| | } |
| |
|
| | #[test] |
| | fn test_mel_filterbank_creation() { |
| | let fb = MelFilterbank::new(22050, 1024, 80, 0.0, 8000.0); |
| | assert_eq!(fb.filters.shape(), &[80, 513]); |
| |
|
| | |
| | let total_sum: f32 = fb.filters.iter().sum(); |
| | assert!(total_sum > 0.0, "Filterbank should have some non-zero values"); |
| | } |
| |
|
| | #[test] |
| | fn test_hann_window() { |
| | let window = hann_window(1024); |
| | assert_eq!(window.len(), 1024); |
| | |
| | assert!(window[0].abs() < 1e-6); |
| | |
| | assert!((window[512] - 1.0).abs() < 1e-4); |
| | } |
| |
|
| | #[test] |
| | fn test_stft_basic() { |
| | |
| | let sr = 22050; |
| | let freq = 440.0; |
| | let duration = 0.1; |
| | let num_samples = (sr as f32 * duration) as usize; |
| |
|
| | let signal: Vec<f32> = (0..num_samples) |
| | .map(|i| (2.0 * PI * freq * i as f32 / sr as f32).sin()) |
| | .collect(); |
| |
|
| | let result = stft(&signal, 1024, 256, 1024); |
| | assert!(result.is_ok()); |
| |
|
| | let stft_matrix = result.unwrap(); |
| | assert_eq!(stft_matrix.shape()[0], 513); |
| | assert!(stft_matrix.shape()[1] > 0); |
| | } |
| |
|
| | #[test] |
| | fn test_mel_spectrogram() { |
| | let config = AudioConfig::default(); |
| | let num_samples = (config.sample_rate as f32 * 0.1) as usize; |
| | let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect(); |
| |
|
| | let result = mel_spectrogram(&signal, &config); |
| | assert!(result.is_ok()); |
| |
|
| | let mel_spec = result.unwrap(); |
| | assert_eq!(mel_spec.shape()[0], config.n_mels); |
| | assert!(mel_spec.shape()[1] > 0); |
| | } |
| | } |
| |
|