import io import math import tempfile from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple import gradio as gr import librosa import matplotlib.pyplot as plt import numpy as np import onnxruntime as ort import soundfile as sf from PIL import Image # ----------------------------- # Configuration # ----------------------------- MAX_SECONDS = 10.0 ONNX_DIR = Path("./onnx") @dataclass(frozen=True) class ModelSpec: name: str sr: int onnx_path: str # ----------------------------- # Model discovery and metadata # ----------------------------- def _infer_model_meta(model_name: str) -> int: normalized = model_name.lower().replace("-", "_") if "48khz" in normalized or "48k" in normalized or "48hr" in normalized: return 48000 # Fallback for unknown 16 kHz DPDFNet variants return 16000 def _display_label(spec: ModelSpec) -> str: khz = int(spec.sr // 1000) return f"{spec.name} ({khz} kHz)" def discover_model_presets() -> Dict[str, ModelSpec]: ordered_names = [ "baseline", "dpdfnet2", "dpdfnet4", "dpdfnet8", "dpdfnet2_48khz_hr", "dpdfnet8_48khz_hr", ] found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()} presets: Dict[str, ModelSpec] = {} for name in ordered_names: p = found_paths.get(name) if p is None: continue sr = _infer_model_meta(name) spec = ModelSpec( name=name, sr=sr, onnx_path=str(p), ) presets[_display_label(spec)] = spec # Include any additional ONNX files not in the canonical order list. for name, p in sorted(found_paths.items()): if name in ordered_names: continue sr = _infer_model_meta(name) spec = ModelSpec( name=name, sr=sr, onnx_path=str(p), ) presets[_display_label(spec)] = spec return presets MODEL_PRESETS = discover_model_presets() DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None) # ----------------------------- # ONNX Runtime + frontend cache # ----------------------------- _SESSIONS: Dict[str, ort.InferenceSession] = {} _INIT_STATES: Dict[str, np.ndarray] = {} def resolve_model_path(local_path: str) -> str: p = Path(local_path) if p.exists(): return str(p) raise gr.Error( f"ONNX model not found at: {local_path}. " "Expected local models under ./onnx/." ) def get_ort_session(model_key: str) -> ort.InferenceSession: if model_key in _SESSIONS: return _SESSIONS[model_key] spec = MODEL_PRESETS[model_key] onnx_path = resolve_model_path(spec.onnx_path) options = ort.SessionOptions() options.intra_op_num_threads = 1 options.inter_op_num_threads = 1 sess = ort.InferenceSession( onnx_path, sess_options=options, providers=["CPUExecutionProvider"], ) _SESSIONS[model_key] = sess return sess def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray: if model_key in _INIT_STATES: return _INIT_STATES[model_key] if len(session.get_inputs()) < 2: raise gr.Error("Expected streaming ONNX model with two inputs: (spec, state).") meta = session.get_modelmeta().custom_metadata_map try: state_size = int(meta["state_size"]) erb_norm_state_size = int(meta["erb_norm_state_size"]) spec_norm_state_size = int(meta["spec_norm_state_size"]) erb_norm_init = np.array( [float(x) for x in meta["erb_norm_init"].split(",")], dtype=np.float32 ) spec_norm_init = np.array( [float(x) for x in meta["spec_norm_init"].split(",")], dtype=np.float32 ) except KeyError as exc: raise gr.Error( f"ONNX model is missing required metadata key: {exc}. " "Re-export the model to embed state initialisation metadata." ) init_state = np.zeros(state_size, dtype=np.float32) init_state[0:erb_norm_state_size] = erb_norm_init init_state[erb_norm_state_size:erb_norm_state_size + spec_norm_state_size] = spec_norm_init init_state = np.ascontiguousarray(init_state) _INIT_STATES[model_key] = init_state return init_state # ----------------------------- # STFT/iSTFT (module-free) # ----------------------------- def vorbis_window(window_len: int) -> np.ndarray: window_size_h = window_len / 2 indices = np.arange(window_len) sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h) window = np.sin(0.5 * np.pi * sin * sin) return window.astype(np.float32) def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, np.ndarray]: # ONNX spec input is [B, T, F, 2] (or dynamic variants). spec_shape = session.get_inputs()[0].shape freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None if isinstance(freq_bins, int) and freq_bins > 1: win_len = int((freq_bins - 1) * 2) else: # 20 ms windows for DPDFNet family. sr = MODEL_PRESETS[model_key].sr win_len = int(round(sr * 0.02)) hop = win_len // 2 win = vorbis_window(win_len) return win_len, hop, win def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray: audio = np.asarray(waveform, dtype=np.float32).reshape(-1) audio_pad = np.pad(audio, (0, win_len), mode="constant") spec = librosa.stft( y=audio_pad, n_fft=win_len, hop_length=hop, win_length=win_len, window=win, center=True, pad_mode="reflect", ) spec = spec.T.astype(np.complex64, copy=False) # [T, F] spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2] return np.ascontiguousarray(spec_ri[None, ...], dtype=np.float32) # [1, T, F, 2] def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray: spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2] spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T] waveform_e = librosa.istft( spec, hop_length=hop, win_length=win_len, window=win, center=True, length=None, ).astype(np.float32, copy=False) return np.concatenate( [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)], axis=0, ) # ----------------------------- # ONNX inference (non-streaming pre/post, streaming ONNX state loop) # ----------------------------- def enhance_audio_onnx( audio_mono: np.ndarray, model_key: str, ) -> np.ndarray: sess = get_ort_session(model_key) inputs = sess.get_inputs() outputs = sess.get_outputs() if len(inputs) < 2 or len(outputs) < 2: raise gr.Error( "Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)." ) in_spec_name = inputs[0].name in_state_name = inputs[1].name out_spec_name = outputs[0].name out_state_name = outputs[1].name waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1) win_len, hop, win = _infer_stft_params(model_key, sess) spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, win=win) state = _load_initial_state(model_key, sess).copy() spec_e_frames = [] num_frames = int(spec_r_np.shape[1]) for t in range(num_frames): spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32) spec_e_t, state = sess.run( [out_spec_name, out_state_name], {in_spec_name: spec_t, in_state_name: state}, ) spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32)) if not spec_e_frames: return waveform spec_e_np = np.concatenate(spec_e_frames, axis=1) waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, win=win) return np.asarray(waveform_e, dtype=np.float32).reshape(-1) # ----------------------------- # Audio utilities # ----------------------------- def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]: data, sr = sf.read(path, always_2d=True) data = data.astype(np.float32, copy=False) return data, int(sr) def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]: if x.ndim == 1: return x.astype(np.float32, copy=False), 1 if x.shape[1] == 1: return x[:, 0], 1 return x.mean(axis=1), int(x.shape[1]) def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: if sr_in == sr_out: return y return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False) def _match_length(y: np.ndarray, target_len: int) -> np.ndarray: if len(y) == target_len: return y if len(y) > target_len: return y[:target_len] out = np.zeros((target_len,), dtype=y.dtype) out[: len(y)] = y return out def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str: tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False) tmp.close() sf.write(tmp.name, y, sr) return tmp.name def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image: win_length = max(256, int(0.032 * sr)) hop_length = max(64, int(0.008 * sr)) n_fft = 1 << (int(math.ceil(math.log2(win_length)))) S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False) S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max) fig, ax = plt.subplots(figsize=(8.4, 3.2)) ax.imshow(S_db, origin="lower", aspect="auto") ax.set_axis_off() fig.subplots_adjust(left=0, right=1, top=1, bottom=0) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=160) plt.close(fig) buf.seek(0) return Image.open(buf) # ----------------------------- # Main pipeline # ----------------------------- def run_enhancement( source: str, mic_path: Optional[str], file_path: Optional[str], model_key: str, ): if not MODEL_PRESETS: raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.") chosen_path = mic_path if source == "Microphone" else file_path if not chosen_path: raise gr.Error("Please provide audio either from the microphone or by uploading a file.") x, sr_orig = _load_wav_from_gradio_path(chosen_path) y_mono, n_ch = _to_mono(x) max_samples = int(MAX_SECONDS * sr_orig) was_trimmed = len(y_mono) > max_samples if was_trimmed: y_mono = y_mono[:max_samples] dur = len(y_mono) / float(sr_orig) spec = MODEL_PRESETS[model_key] sr_model = spec.sr y_model = _resample(y_mono, sr_orig, sr_model) y_enh_model = enhance_audio_onnx(y_model, model_key) y_enh = _resample(y_enh_model, sr_model, sr_orig) y_enh = _match_length(y_enh, len(y_mono)) noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_") enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_") noisy_img = _spectrogram_image(y_mono, sr_orig) enh_img = _spectrogram_image(y_enh, sr_orig) status = ( f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n" f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n" + ( f"**Resampling:** {sr_orig} ⭢ {sr_model} ⭢ {sr_orig}\n\n" if sr_orig != sr_model else "**Resampling:** none\n\n" ) + (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "") + "\n✅ Done." ) return noisy_out, enh_out, noisy_img, enh_img, status def set_source_visibility(source: str): return ( gr.update(visible=(source == "Microphone")), gr.update(visible=(source == "Upload")), ) # ----------------------------- # UI (light polish) # ----------------------------- THEME = gr.themes.Soft( primary_hue="orange", neutral_hue="slate", font=[ "Arial", "ui-sans-serif", "system-ui", "Segoe UI", "Roboto", "Helvetica Neue", "Noto Sans", "Liberation Sans", "sans-serif", ], ) CSS = """ .gradio-container{ max-width: 1040px !important; margin: 0 auto !important; font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important; } #header { padding: 14px 16px; border-radius: 16px; border: 1px solid rgba(0,0,0,0.08); background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04)); text-align: center; } #header h1{ margin: 0 0 6px 0; font-size: 24px; font-weight: 800; letter-spacing: -0.2px; } #header p{ margin: 6px auto 0 auto; max-width: 720px; color: var(--body-text-color-subdued); font-size: 14px; line-height: 1.6; } #header hr{ margin-top: 18px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent); } .spec img { border-radius: 14px; } .audio { border-radius: 14px !important; overflow: hidden; } #run_btn{ border-radius: 12px !important; font-weight: 800 !important; } #status_md p{ margin: 0.35rem 0; } """ with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo: gr.Markdown( "# DPDFNet Speech Enhancement\n\n" "Causal · Real-Time · Edge-Ready\n\n" "DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve " "long-range temporal and cross-band modeling while preserving low latency. " "Designed for single-channel streaming speech enhancement under challenging noise conditions.\n\n" "---", elem_id="header", ) with gr.Row(): model_key = gr.Dropdown( choices=list(MODEL_PRESETS.keys()), value=DEFAULT_MODEL_KEY, label="Model", # info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.", interactive=True, ) source = gr.Radio( choices=["Microphone", "Upload"], value="Upload", label="Input source", ) with gr.Row(): mic_audio = gr.Audio( sources=["microphone"], type="filepath", format="wav", label="Microphone (max 10s)", visible=False, buttons=["download"], elem_classes=["audio"], ) file_audio = gr.Audio( sources=["upload"], type="filepath", format="wav", label="Upload file (WAV/MP3/FLAC etc., max 10s)", visible=True, buttons=["download"], elem_classes=["audio"], ) run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn") status = gr.Markdown(elem_id="status_md") gr.Markdown("## Results") with gr.Row(): out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) with gr.Row(): img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"]) img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"]) source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio]) run_btn.click( fn=run_enhancement, inputs=[source, mic_audio, file_audio, model_key], outputs=[out_noisy, out_enh, img_noisy, img_enh, status], api_name="enhance", ) if __name__ == "__main__": demo.queue(max_size=32).launch()