DPDFNetDemo / app.py
danielr-ceva's picture
Upload app.py
0e6e1be verified
import io
import math
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple
import gradio as gr
import librosa
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import soundfile as sf
from PIL import Image
# -----------------------------
# Configuration
# -----------------------------
MAX_SECONDS = 10.0
ONNX_DIR = Path("./onnx")
@dataclass(frozen=True)
class ModelSpec:
name: str
sr: int
onnx_path: str
# -----------------------------
# Model discovery and metadata
# -----------------------------
def _infer_model_meta(model_name: str) -> int:
normalized = model_name.lower().replace("-", "_")
if "48khz" in normalized or "48k" in normalized or "48hr" in normalized:
return 48000
# Fallback for unknown 16 kHz DPDFNet variants
return 16000
def _display_label(spec: ModelSpec) -> str:
khz = int(spec.sr // 1000)
return f"{spec.name} ({khz} kHz)"
def discover_model_presets() -> Dict[str, ModelSpec]:
ordered_names = [
"baseline",
"dpdfnet2",
"dpdfnet4",
"dpdfnet8",
"dpdfnet2_48khz_hr",
"dpdfnet8_48khz_hr",
]
found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
presets: Dict[str, ModelSpec] = {}
for name in ordered_names:
p = found_paths.get(name)
if p is None:
continue
sr = _infer_model_meta(name)
spec = ModelSpec(
name=name,
sr=sr,
onnx_path=str(p),
)
presets[_display_label(spec)] = spec
# Include any additional ONNX files not in the canonical order list.
for name, p in sorted(found_paths.items()):
if name in ordered_names:
continue
sr = _infer_model_meta(name)
spec = ModelSpec(
name=name,
sr=sr,
onnx_path=str(p),
)
presets[_display_label(spec)] = spec
return presets
MODEL_PRESETS = discover_model_presets()
DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None)
# -----------------------------
# ONNX Runtime + frontend cache
# -----------------------------
_SESSIONS: Dict[str, ort.InferenceSession] = {}
_INIT_STATES: Dict[str, np.ndarray] = {}
def resolve_model_path(local_path: str) -> str:
p = Path(local_path)
if p.exists():
return str(p)
raise gr.Error(
f"ONNX model not found at: {local_path}. "
"Expected local models under ./onnx/."
)
def get_ort_session(model_key: str) -> ort.InferenceSession:
if model_key in _SESSIONS:
return _SESSIONS[model_key]
spec = MODEL_PRESETS[model_key]
onnx_path = resolve_model_path(spec.onnx_path)
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
sess = ort.InferenceSession(
onnx_path,
sess_options=options,
providers=["CPUExecutionProvider"],
)
_SESSIONS[model_key] = sess
return sess
def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
if model_key in _INIT_STATES:
return _INIT_STATES[model_key]
if len(session.get_inputs()) < 2:
raise gr.Error("Expected streaming ONNX model with two inputs: (spec, state).")
meta = session.get_modelmeta().custom_metadata_map
try:
state_size = int(meta["state_size"])
erb_norm_state_size = int(meta["erb_norm_state_size"])
spec_norm_state_size = int(meta["spec_norm_state_size"])
erb_norm_init = np.array(
[float(x) for x in meta["erb_norm_init"].split(",")], dtype=np.float32
)
spec_norm_init = np.array(
[float(x) for x in meta["spec_norm_init"].split(",")], dtype=np.float32
)
except KeyError as exc:
raise gr.Error(
f"ONNX model is missing required metadata key: {exc}. "
"Re-export the model to embed state initialisation metadata."
)
init_state = np.zeros(state_size, dtype=np.float32)
init_state[0:erb_norm_state_size] = erb_norm_init
init_state[erb_norm_state_size:erb_norm_state_size + spec_norm_state_size] = spec_norm_init
init_state = np.ascontiguousarray(init_state)
_INIT_STATES[model_key] = init_state
return init_state
# -----------------------------
# STFT/iSTFT (module-free)
# -----------------------------
def vorbis_window(window_len: int) -> np.ndarray:
window_size_h = window_len / 2
indices = np.arange(window_len)
sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
window = np.sin(0.5 * np.pi * sin * sin)
return window.astype(np.float32)
def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, np.ndarray]:
# ONNX spec input is [B, T, F, 2] (or dynamic variants).
spec_shape = session.get_inputs()[0].shape
freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
if isinstance(freq_bins, int) and freq_bins > 1:
win_len = int((freq_bins - 1) * 2)
else:
# 20 ms windows for DPDFNet family.
sr = MODEL_PRESETS[model_key].sr
win_len = int(round(sr * 0.02))
hop = win_len // 2
win = vorbis_window(win_len)
return win_len, hop, win
def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
audio_pad = np.pad(audio, (0, win_len), mode="constant")
spec = librosa.stft(
y=audio_pad,
n_fft=win_len,
hop_length=hop,
win_length=win_len,
window=win,
center=True,
pad_mode="reflect",
)
spec = spec.T.astype(np.complex64, copy=False) # [T, F]
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
return np.ascontiguousarray(spec_ri[None, ...], dtype=np.float32) # [1, T, F, 2]
def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, win: np.ndarray) -> np.ndarray:
spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
waveform_e = librosa.istft(
spec,
hop_length=hop,
win_length=win_len,
window=win,
center=True,
length=None,
).astype(np.float32, copy=False)
return np.concatenate(
[waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
axis=0,
)
# -----------------------------
# ONNX inference (non-streaming pre/post, streaming ONNX state loop)
# -----------------------------
def enhance_audio_onnx(
audio_mono: np.ndarray,
model_key: str,
) -> np.ndarray:
sess = get_ort_session(model_key)
inputs = sess.get_inputs()
outputs = sess.get_outputs()
if len(inputs) < 2 or len(outputs) < 2:
raise gr.Error(
"Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)."
)
in_spec_name = inputs[0].name
in_state_name = inputs[1].name
out_spec_name = outputs[0].name
out_state_name = outputs[1].name
waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
win_len, hop, win = _infer_stft_params(model_key, sess)
spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, win=win)
state = _load_initial_state(model_key, sess).copy()
spec_e_frames = []
num_frames = int(spec_r_np.shape[1])
for t in range(num_frames):
spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32)
spec_e_t, state = sess.run(
[out_spec_name, out_state_name],
{in_spec_name: spec_t, in_state_name: state},
)
spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
if not spec_e_frames:
return waveform
spec_e_np = np.concatenate(spec_e_frames, axis=1)
waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, win=win)
return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
# -----------------------------
# Audio utilities
# -----------------------------
def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]:
data, sr = sf.read(path, always_2d=True)
data = data.astype(np.float32, copy=False)
return data, int(sr)
def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]:
if x.ndim == 1:
return x.astype(np.float32, copy=False), 1
if x.shape[1] == 1:
return x[:, 0], 1
return x.mean(axis=1), int(x.shape[1])
def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
if sr_in == sr_out:
return y
return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False)
def _match_length(y: np.ndarray, target_len: int) -> np.ndarray:
if len(y) == target_len:
return y
if len(y) > target_len:
return y[:target_len]
out = np.zeros((target_len,), dtype=y.dtype)
out[: len(y)] = y
return out
def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str:
tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False)
tmp.close()
sf.write(tmp.name, y, sr)
return tmp.name
def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image:
win_length = max(256, int(0.032 * sr))
hop_length = max(64, int(0.008 * sr))
n_fft = 1 << (int(math.ceil(math.log2(win_length))))
S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False)
S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max)
fig, ax = plt.subplots(figsize=(8.4, 3.2))
ax.imshow(S_db, origin="lower", aspect="auto")
ax.set_axis_off()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=160)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# -----------------------------
# Main pipeline
# -----------------------------
def run_enhancement(
source: str,
mic_path: Optional[str],
file_path: Optional[str],
model_key: str,
):
if not MODEL_PRESETS:
raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.")
chosen_path = mic_path if source == "Microphone" else file_path
if not chosen_path:
raise gr.Error("Please provide audio either from the microphone or by uploading a file.")
x, sr_orig = _load_wav_from_gradio_path(chosen_path)
y_mono, n_ch = _to_mono(x)
max_samples = int(MAX_SECONDS * sr_orig)
was_trimmed = len(y_mono) > max_samples
if was_trimmed:
y_mono = y_mono[:max_samples]
dur = len(y_mono) / float(sr_orig)
spec = MODEL_PRESETS[model_key]
sr_model = spec.sr
y_model = _resample(y_mono, sr_orig, sr_model)
y_enh_model = enhance_audio_onnx(y_model, model_key)
y_enh = _resample(y_enh_model, sr_model, sr_orig)
y_enh = _match_length(y_enh, len(y_mono))
noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_")
enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_")
noisy_img = _spectrogram_image(y_mono, sr_orig)
enh_img = _spectrogram_image(y_enh, sr_orig)
status = (
f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n"
f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n"
+ (
f"**Resampling:** {sr_orig}{sr_model}{sr_orig}\n\n"
if sr_orig != sr_model
else "**Resampling:** none\n\n"
)
+ (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "")
+ "\n✅ Done."
)
return noisy_out, enh_out, noisy_img, enh_img, status
def set_source_visibility(source: str):
return (
gr.update(visible=(source == "Microphone")),
gr.update(visible=(source == "Upload")),
)
# -----------------------------
# UI (light polish)
# -----------------------------
THEME = gr.themes.Soft(
primary_hue="orange",
neutral_hue="slate",
font=[
"Arial",
"ui-sans-serif",
"system-ui",
"Segoe UI",
"Roboto",
"Helvetica Neue",
"Noto Sans",
"Liberation Sans",
"sans-serif",
],
)
CSS = """
.gradio-container{
max-width: 1040px !important;
margin: 0 auto !important;
font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important;
}
#header {
padding: 14px 16px;
border-radius: 16px;
border: 1px solid rgba(0,0,0,0.08);
background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
text-align: center;
}
#header h1{
margin: 0 0 6px 0;
font-size: 24px;
font-weight: 800;
letter-spacing: -0.2px;
}
#header p{
margin: 6px auto 0 auto;
max-width: 720px;
color: var(--body-text-color-subdued);
font-size: 14px;
line-height: 1.6;
}
#header hr{
margin-top: 18px;
border: none;
height: 1px;
background: linear-gradient(to right, transparent, #ddd, transparent);
}
.spec img { border-radius: 14px; }
.audio { border-radius: 14px !important; overflow: hidden; }
#run_btn{
border-radius: 12px !important;
font-weight: 800 !important;
}
#status_md p{ margin: 0.35rem 0; }
"""
with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
gr.Markdown(
"# DPDFNet Speech Enhancement\n\n"
"Causal · Real-Time · Edge-Ready\n\n"
"DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve "
"long-range temporal and cross-band modeling while preserving low latency. "
"Designed for single-channel streaming speech enhancement under challenging noise conditions.\n\n"
"---",
elem_id="header",
)
with gr.Row():
model_key = gr.Dropdown(
choices=list(MODEL_PRESETS.keys()),
value=DEFAULT_MODEL_KEY,
label="Model",
# info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.",
interactive=True,
)
source = gr.Radio(
choices=["Microphone", "Upload"],
value="Upload",
label="Input source",
)
with gr.Row():
mic_audio = gr.Audio(
sources=["microphone"],
type="filepath",
format="wav",
label="Microphone (max 10s)",
visible=False,
buttons=["download"],
elem_classes=["audio"],
)
file_audio = gr.Audio(
sources=["upload"],
type="filepath",
format="wav",
label="Upload file (WAV/MP3/FLAC etc., max 10s)",
visible=True,
buttons=["download"],
elem_classes=["audio"],
)
run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn")
status = gr.Markdown(elem_id="status_md")
gr.Markdown("## Results")
with gr.Row():
out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
with gr.Row():
img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"])
img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"])
source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio])
run_btn.click(
fn=run_enhancement,
inputs=[source, mic_audio, file_audio, model_key],
outputs=[out_noisy, out_enh, img_noisy, img_enh, status],
api_name="enhance",
)
if __name__ == "__main__":
demo.queue(max_size=32).launch()