DPDFNetDemo / app.py
danielr-ceva's picture
Update app.py
44cdb79 verified
import io
import json
import math
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple
import gradio as gr
import librosa
import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import soundfile as sf
from filelock import FileLock
from PIL import Image
# -----------------------------
# Configuration
# -----------------------------
MAX_SECONDS = 10.0
ONNX_DIR = Path("./onnx")
COUNTER_DIR = Path("./data") if Path("./data").exists() else Path("./")
COUNTER_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json"
TMP_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json.tmp"
LOCK_PATH = str(COUNTER_PATH) + ".lock"
print(COUNTER_PATH)
@dataclass(frozen=True)
class ModelSpec:
name: str
sr: int
onnx_path: str
@contextmanager
def _maybe_lock():
with FileLock(LOCK_PATH):
yield
def _read_count() -> int:
try:
if not COUNTER_PATH.exists():
return 0
with COUNTER_PATH.open("r", encoding="utf-8") as f:
data = json.load(f)
return int(data.get("enhance_runs", 0))
except Exception:
return 0
def _atomic_write_json(obj: dict) -> None:
COUNTER_DIR.mkdir(parents=True, exist_ok=True)
with TMP_PATH.open("w", encoding="utf-8") as f:
json.dump(obj, f, separators=(",", ":"))
f.flush()
os.fsync(f.fileno())
os.replace(TMP_PATH, COUNTER_PATH)
def increment_and_get_count() -> int:
with _maybe_lock():
new_value = _read_count() + 1
_atomic_write_json({"enhance_runs": new_value})
return new_value
def usage_markdown_text() -> str:
return f"**Total enhancements run:** {_read_count():,}"
# -----------------------------
# Model discovery and metadata
# -----------------------------
def _infer_model_meta(model_name: str) -> int:
normalized = model_name.lower().replace("-", "_")
if "48khz" in normalized or "48k" in normalized or "48hr" in normalized:
return 48000
# Fallback for unknown 16 kHz DPDFNet variants
return 16000
def _display_label(spec: ModelSpec) -> str:
khz = int(spec.sr // 1000)
return f"{spec.name} ({khz} kHz)"
def discover_model_presets() -> Dict[str, ModelSpec]:
ordered_names = [
"baseline",
"dpdfnet2",
"dpdfnet4",
"dpdfnet8",
"dpdfnet2_48khz_hr",
]
found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()}
presets: Dict[str, ModelSpec] = {}
for name in ordered_names:
p = found_paths.get(name)
if p is None:
continue
sr = _infer_model_meta(name)
spec = ModelSpec(
name=name,
sr=sr,
onnx_path=str(p),
)
presets[_display_label(spec)] = spec
# Include any additional ONNX files not in the canonical order list.
for name, p in sorted(found_paths.items()):
if name in ordered_names:
continue
sr = _infer_model_meta(name)
spec = ModelSpec(
name=name,
sr=sr,
onnx_path=str(p),
)
presets[_display_label(spec)] = spec
return presets
MODEL_PRESETS = discover_model_presets()
DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None)
# -----------------------------
# ONNX Runtime + frontend cache
# -----------------------------
_SESSIONS: Dict[str, ort.InferenceSession] = {}
_INIT_STATES: Dict[str, np.ndarray] = {}
def resolve_model_path(local_path: str) -> str:
p = Path(local_path)
if p.exists():
return str(p)
raise gr.Error(
f"ONNX model not found at: {local_path}. "
"Expected local models under ./onnx/."
)
def get_ort_session(model_key: str) -> ort.InferenceSession:
if model_key in _SESSIONS:
return _SESSIONS[model_key]
spec = MODEL_PRESETS[model_key]
onnx_path = resolve_model_path(spec.onnx_path)
options = ort.SessionOptions()
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
sess = ort.InferenceSession(
onnx_path,
sess_options=options,
providers=["CPUExecutionProvider"],
)
_SESSIONS[model_key] = sess
return sess
def _resolve_state_path(model_key: str) -> Path:
spec = MODEL_PRESETS[model_key]
model_path = Path(spec.onnx_path)
state_path = model_path.with_name(f"{model_path.stem}_state.npz")
if not state_path.is_file():
raise gr.Error(f"State file not found: {state_path}")
return state_path
def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray:
if model_key in _INIT_STATES:
return _INIT_STATES[model_key]
state_path = _resolve_state_path(model_key)
with np.load(state_path) as data:
if "init_state" not in data:
raise gr.Error(f"Missing 'init_state' key in state file: {state_path}")
init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False))
expected_shape = session.get_inputs()[1].shape
if len(expected_shape) != init_state.ndim:
raise gr.Error(
f"Initial state rank mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
)
for exp_dim, act_dim in zip(expected_shape, init_state.shape):
if isinstance(exp_dim, int) and exp_dim != act_dim:
raise gr.Error(
f"Initial state shape mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}"
)
_INIT_STATES[model_key] = init_state
return init_state
# -----------------------------
# STFT/iSTFT (module-free)
# -----------------------------
def vorbis_window(window_len: int) -> np.ndarray:
window_size_h = window_len / 2
indices = np.arange(window_len)
sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
window = np.sin(0.5 * np.pi * sin * sin)
return window.astype(np.float32)
def get_wnorm(window_len: int, frame_size: int) -> float:
return 1.0 / (window_len ** 2 / (2 * frame_size))
def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]:
# ONNX spec input is [B, T, F, 2] (or dynamic variants).
spec_shape = session.get_inputs()[0].shape
freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None
if isinstance(freq_bins, int) and freq_bins > 1:
win_len = int((freq_bins - 1) * 2)
else:
# 20 ms windows for DPDFNet family.
sr = MODEL_PRESETS[model_key].sr
win_len = int(round(sr * 0.02))
hop = win_len // 2
win = vorbis_window(win_len)
wnorm = get_wnorm(win_len, hop)
return win_len, hop, wnorm, win
def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
audio = np.asarray(waveform, dtype=np.float32).reshape(-1)
audio_pad = np.pad(audio, (0, win_len), mode="constant")
spec = librosa.stft(
y=audio_pad,
n_fft=win_len,
hop_length=hop,
win_length=win_len,
window=win,
center=True,
pad_mode="reflect",
)
spec = (spec.T * wnorm).astype(np.complex64, copy=False) # [T, F]
spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2]
return spec_ri[None, ...] # [1, T, F, 2]
def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray:
spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2]
spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T]
waveform_e = librosa.istft(
spec,
hop_length=hop,
win_length=win_len,
window=win,
center=True,
length=None,
).astype(np.float32, copy=False)
waveform_e = waveform_e / wnorm
waveform_e = np.concatenate(
[waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)],
axis=0,
)
return waveform_e
# -----------------------------
# ONNX inference (non-streaming pre/post, streaming ONNX state loop)
# -----------------------------
def enhance_audio_onnx(
audio_mono: np.ndarray,
model_key: str,
) -> np.ndarray:
sess = get_ort_session(model_key)
inputs = sess.get_inputs()
outputs = sess.get_outputs()
if len(inputs) < 2 or len(outputs) < 2:
raise gr.Error(
"Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)."
)
in_spec_name = inputs[0].name
in_state_name = inputs[1].name
out_spec_name = outputs[0].name
out_state_name = outputs[1].name
waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1)
win_len, hop, wnorm, win = _infer_stft_params(model_key, sess)
spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
state = _load_initial_state(model_key, sess).copy()
spec_e_frames = []
num_frames = int(spec_r_np.shape[1])
for t in range(num_frames):
spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32)
spec_e_t, state = sess.run(
[out_spec_name, out_state_name],
{in_spec_name: spec_t, in_state_name: state},
)
spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32))
if not spec_e_frames:
return waveform
spec_e_np = np.concatenate(spec_e_frames, axis=1)
waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, wnorm=wnorm, win=win)
return np.asarray(waveform_e, dtype=np.float32).reshape(-1)
# -----------------------------
# Audio utilities
# -----------------------------
def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]:
data, sr = sf.read(path, always_2d=True)
data = data.astype(np.float32, copy=False)
return data, int(sr)
def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]:
if x.ndim == 1:
return x.astype(np.float32, copy=False), 1
if x.shape[1] == 1:
return x[:, 0], 1
return x.mean(axis=1), int(x.shape[1])
def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray:
if sr_in == sr_out:
return y
return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False)
def _match_length(y: np.ndarray, target_len: int) -> np.ndarray:
if len(y) == target_len:
return y
if len(y) > target_len:
return y[:target_len]
out = np.zeros((target_len,), dtype=y.dtype)
out[: len(y)] = y
return out
def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str:
tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False)
tmp.close()
sf.write(tmp.name, y, sr)
return tmp.name
def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image:
win_length = max(256, int(0.032 * sr))
hop_length = max(64, int(0.008 * sr))
n_fft = 1 << (int(math.ceil(math.log2(win_length))))
S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False)
S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max)
fig, ax = plt.subplots(figsize=(8.4, 3.2))
ax.imshow(S_db, origin="lower", aspect="auto")
ax.set_axis_off()
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=160)
plt.close(fig)
buf.seek(0)
return Image.open(buf)
# -----------------------------
# Main pipeline
# -----------------------------
def run_enhancement(
source: str,
mic_path: Optional[str],
file_path: Optional[str],
model_key: str,
):
if not MODEL_PRESETS:
raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.")
chosen_path = mic_path if source == "Microphone" else file_path
if not chosen_path:
raise gr.Error("Please provide audio either from the microphone or by uploading a file.")
x, sr_orig = _load_wav_from_gradio_path(chosen_path)
y_mono, n_ch = _to_mono(x)
max_samples = int(MAX_SECONDS * sr_orig)
was_trimmed = len(y_mono) > max_samples
if was_trimmed:
y_mono = y_mono[:max_samples]
dur = len(y_mono) / float(sr_orig)
spec = MODEL_PRESETS[model_key]
sr_model = spec.sr
y_model = _resample(y_mono, sr_orig, sr_model)
y_enh_model = enhance_audio_onnx(y_model, model_key)
y_enh = _resample(y_enh_model, sr_model, sr_orig)
y_enh = _match_length(y_enh, len(y_mono))
noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_")
enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_")
noisy_img = _spectrogram_image(y_mono, sr_orig)
enh_img = _spectrogram_image(y_enh, sr_orig)
status = (
f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n"
f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n"
+ (
f"**Resampling:** {sr_orig}{sr_model}{sr_orig}\n\n"
if sr_orig != sr_model
else "**Resampling:** none\n\n"
)
+ (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "")
+ "\n✅ Done."
)
return noisy_out, enh_out, noisy_img, enh_img, status
def run_enhancement_with_count(
source: str,
mic_path: Optional[str],
file_path: Optional[str],
model_key: str,
):
noisy_out, enh_out, noisy_img, enh_img, status = run_enhancement(
source=source,
mic_path=mic_path,
file_path=file_path,
model_key=model_key,
)
total = increment_and_get_count()
return noisy_out, enh_out, noisy_img, enh_img, status, f"**Total enhancements run:** {total:,}"
def set_source_visibility(source: str):
return (
gr.update(visible=(source == "Microphone")),
gr.update(visible=(source == "Upload")),
)
# -----------------------------
# UI (light polish)
# -----------------------------
THEME = gr.themes.Soft(
primary_hue="orange",
neutral_hue="slate",
font=[
"Arial",
"ui-sans-serif",
"system-ui",
"Segoe UI",
"Roboto",
"Helvetica Neue",
"Noto Sans",
"Liberation Sans",
"sans-serif",
],
)
CSS = """
.gradio-container{
max-width: min(96vw, 1500px) !important;
margin: 0 auto !important;
font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important;
}
#header {
padding: 14px 16px;
border-radius: 16px;
border: 1px solid rgba(0,0,0,0.08);
background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04));
}
#header h1{
margin: 0;
font-size: 24px;
font-weight: 800;
letter-spacing: -0.2px;
}
#header p{
margin: 6px 0 0 0;
color: var(--body-text-color-subdued);
font-size: 13.5px;
line-height: 1.35;
}
.spec img { border-radius: 14px; }
.audio { border-radius: 14px !important; overflow: hidden; }
#run_btn{
border-radius: 12px !important;
font-weight: 800 !important;
}
#status_md p{ margin: 0.35rem 0; }
"""
with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo:
gr.HTML(
# """
# <div id="header">
# <h1>DPDFNet Speech Enhancement</h1>
# <p>
# Upload or record up to 10 seconds. Multi-channel inputs are averaged to mono.
# Choose any local ONNX model from <code>./onnx</code>.
# Pre/postprocessing uses the same non-streaming STFT/iSTFT flow as <code>streaming/infer_dpdfnet_onnx.py</code>.
# </p>
# </div>
# """
"""
<div id="header" style="text-align: center; margin-bottom: 25px;">
<h1 style="margin-bottom: 6px;">DPDFNet Speech Enhancement</h1>
<p style="font-size: 14px; letter-spacing: 1px; margin-bottom: 14px; color: #555;">
Causal • Real-Time • Edge-Ready
</p>
<p style="max-width: 720px; margin: 0 auto; font-size: 15px; line-height: 1.6;">
DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve
long-range temporal and cross-band modeling while preserving low latency.
Designed for single-channel streaming speech enhancement under challenging noise conditions.
</p>
<hr style="margin-top: 22px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent);">
</div>
"""
)
usage_md = gr.Markdown(usage_markdown_text())
with gr.Row():
model_key = gr.Dropdown(
choices=list(MODEL_PRESETS.keys()),
value=DEFAULT_MODEL_KEY,
label="Model",
# info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.",
interactive=True,
)
source = gr.Radio(
choices=["Microphone", "Upload"],
value="Upload",
label="Input source",
)
with gr.Row():
mic_audio = gr.Audio(
sources=["microphone"],
type="filepath",
format="wav",
label="Microphone (max 10s)",
visible=False,
buttons=["download"],
elem_classes=["audio"],
)
file_audio = gr.Audio(
sources=["upload"],
type="filepath",
format="wav",
label="Upload file (WAV/MP3/FLAC etc., max 10s)",
visible=True,
buttons=["download"],
elem_classes=["audio"],
)
run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn")
status = gr.Markdown(elem_id="status_md")
gr.Markdown("## Results")
with gr.Row():
out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"])
with gr.Row():
img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"])
img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"])
source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio])
run_btn.click(
fn=run_enhancement_with_count,
inputs=[source, mic_audio, file_audio, model_key],
outputs=[out_noisy, out_enh, img_noisy, img_enh, status, usage_md],
api_name="enhance",
)
demo.load(fn=usage_markdown_text, outputs=usage_md)
if __name__ == "__main__":
demo.queue(max_size=32).launch()