import io import json import math import os import tempfile from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple import gradio as gr import librosa import matplotlib.pyplot as plt import numpy as np import onnxruntime as ort import soundfile as sf from filelock import FileLock from PIL import Image # ----------------------------- # Configuration # ----------------------------- MAX_SECONDS = 10.0 ONNX_DIR = Path("./onnx") COUNTER_DIR = Path("./data") if Path("./data").exists() else Path("./") COUNTER_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json" TMP_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json.tmp" LOCK_PATH = str(COUNTER_PATH) + ".lock" print(COUNTER_PATH) @dataclass(frozen=True) class ModelSpec: name: str sr: int onnx_path: str @contextmanager def _maybe_lock(): with FileLock(LOCK_PATH): yield def _read_count() -> int: try: if not COUNTER_PATH.exists(): return 0 with COUNTER_PATH.open("r", encoding="utf-8") as f: data = json.load(f) return int(data.get("enhance_runs", 0)) except Exception: return 0 def _atomic_write_json(obj: dict) -> None: COUNTER_DIR.mkdir(parents=True, exist_ok=True) with TMP_PATH.open("w", encoding="utf-8") as f: json.dump(obj, f, separators=(",", ":")) f.flush() os.fsync(f.fileno()) os.replace(TMP_PATH, COUNTER_PATH) def increment_and_get_count() -> int: with _maybe_lock(): new_value = _read_count() + 1 _atomic_write_json({"enhance_runs": new_value}) return new_value def usage_markdown_text() -> str: return f"**Total enhancements run:** {_read_count():,}" # ----------------------------- # Model discovery and metadata # ----------------------------- def _infer_model_meta(model_name: str) -> int: normalized = model_name.lower().replace("-", "_") if "48khz" in normalized or "48k" in normalized or "48hr" in normalized: return 48000 # Fallback for unknown 16 kHz DPDFNet variants return 16000 def _display_label(spec: ModelSpec) -> str: khz = int(spec.sr // 1000) return f"{spec.name} ({khz} kHz)" def discover_model_presets() -> Dict[str, ModelSpec]: ordered_names = [ "baseline", "dpdfnet2", "dpdfnet4", "dpdfnet8", "dpdfnet2_48khz_hr", ] found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()} presets: Dict[str, ModelSpec] = {} for name in ordered_names: p = found_paths.get(name) if p is None: continue sr = _infer_model_meta(name) spec = ModelSpec( name=name, sr=sr, onnx_path=str(p), ) presets[_display_label(spec)] = spec # Include any additional ONNX files not in the canonical order list. for name, p in sorted(found_paths.items()): if name in ordered_names: continue sr = _infer_model_meta(name) spec = ModelSpec( name=name, sr=sr, onnx_path=str(p), ) presets[_display_label(spec)] = spec return presets MODEL_PRESETS = discover_model_presets() DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None) # ----------------------------- # ONNX Runtime + frontend cache # ----------------------------- _SESSIONS: Dict[str, ort.InferenceSession] = {} _INIT_STATES: Dict[str, np.ndarray] = {} def resolve_model_path(local_path: str) -> str: p = Path(local_path) if p.exists(): return str(p) raise gr.Error( f"ONNX model not found at: {local_path}. " "Expected local models under ./onnx/." ) def get_ort_session(model_key: str) -> ort.InferenceSession: if model_key in _SESSIONS: return _SESSIONS[model_key] spec = MODEL_PRESETS[model_key] onnx_path = resolve_model_path(spec.onnx_path) options = ort.SessionOptions() options.intra_op_num_threads = 1 options.inter_op_num_threads = 1 sess = ort.InferenceSession( onnx_path, sess_options=options, providers=["CPUExecutionProvider"], ) _SESSIONS[model_key] = sess return sess def _resolve_state_path(model_key: str) -> Path: spec = MODEL_PRESETS[model_key] model_path = Path(spec.onnx_path) state_path = model_path.with_name(f"{model_path.stem}_state.npz") if not state_path.is_file(): raise gr.Error(f"State file not found: {state_path}") return state_path def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray: if model_key in _INIT_STATES: return _INIT_STATES[model_key] state_path = _resolve_state_path(model_key) with np.load(state_path) as data: if "init_state" not in data: raise gr.Error(f"Missing 'init_state' key in state file: {state_path}") init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False)) expected_shape = session.get_inputs()[1].shape if len(expected_shape) != init_state.ndim: raise gr.Error( f"Initial state rank mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}" ) for exp_dim, act_dim in zip(expected_shape, init_state.shape): if isinstance(exp_dim, int) and exp_dim != act_dim: raise gr.Error( f"Initial state shape mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}" ) _INIT_STATES[model_key] = init_state return init_state # ----------------------------- # STFT/iSTFT (module-free) # ----------------------------- def vorbis_window(window_len: int) -> np.ndarray: window_size_h = window_len / 2 indices = np.arange(window_len) sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h) window = np.sin(0.5 * np.pi * sin * sin) return window.astype(np.float32) def get_wnorm(window_len: int, frame_size: int) -> float: return 1.0 / (window_len ** 2 / (2 * frame_size)) def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]: # ONNX spec input is [B, T, F, 2] (or dynamic variants). spec_shape = session.get_inputs()[0].shape freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None if isinstance(freq_bins, int) and freq_bins > 1: win_len = int((freq_bins - 1) * 2) else: # 20 ms windows for DPDFNet family. sr = MODEL_PRESETS[model_key].sr win_len = int(round(sr * 0.02)) hop = win_len // 2 win = vorbis_window(win_len) wnorm = get_wnorm(win_len, hop) return win_len, hop, wnorm, win def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray: audio = np.asarray(waveform, dtype=np.float32).reshape(-1) audio_pad = np.pad(audio, (0, win_len), mode="constant") spec = librosa.stft( y=audio_pad, n_fft=win_len, hop_length=hop, win_length=win_len, window=win, center=True, pad_mode="reflect", ) spec = (spec.T * wnorm).astype(np.complex64, copy=False) # [T, F] spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2] return spec_ri[None, ...] # [1, T, F, 2] def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray: spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2] spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T] waveform_e = librosa.istft( spec, hop_length=hop, win_length=win_len, window=win, center=True, length=None, ).astype(np.float32, copy=False) waveform_e = waveform_e / wnorm waveform_e = np.concatenate( [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)], axis=0, ) return waveform_e # ----------------------------- # ONNX inference (non-streaming pre/post, streaming ONNX state loop) # ----------------------------- def enhance_audio_onnx( audio_mono: np.ndarray, model_key: str, ) -> np.ndarray: sess = get_ort_session(model_key) inputs = sess.get_inputs() outputs = sess.get_outputs() if len(inputs) < 2 or len(outputs) < 2: raise gr.Error( "Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)." ) in_spec_name = inputs[0].name in_state_name = inputs[1].name out_spec_name = outputs[0].name out_state_name = outputs[1].name waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1) win_len, hop, wnorm, win = _infer_stft_params(model_key, sess) spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, wnorm=wnorm, win=win) state = _load_initial_state(model_key, sess).copy() spec_e_frames = [] num_frames = int(spec_r_np.shape[1]) for t in range(num_frames): spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32) spec_e_t, state = sess.run( [out_spec_name, out_state_name], {in_spec_name: spec_t, in_state_name: state}, ) spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32)) if not spec_e_frames: return waveform spec_e_np = np.concatenate(spec_e_frames, axis=1) waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, wnorm=wnorm, win=win) return np.asarray(waveform_e, dtype=np.float32).reshape(-1) # ----------------------------- # Audio utilities # ----------------------------- def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]: data, sr = sf.read(path, always_2d=True) data = data.astype(np.float32, copy=False) return data, int(sr) def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]: if x.ndim == 1: return x.astype(np.float32, copy=False), 1 if x.shape[1] == 1: return x[:, 0], 1 return x.mean(axis=1), int(x.shape[1]) def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: if sr_in == sr_out: return y return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False) def _match_length(y: np.ndarray, target_len: int) -> np.ndarray: if len(y) == target_len: return y if len(y) > target_len: return y[:target_len] out = np.zeros((target_len,), dtype=y.dtype) out[: len(y)] = y return out def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str: tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False) tmp.close() sf.write(tmp.name, y, sr) return tmp.name def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image: win_length = max(256, int(0.032 * sr)) hop_length = max(64, int(0.008 * sr)) n_fft = 1 << (int(math.ceil(math.log2(win_length)))) S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False) S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max) fig, ax = plt.subplots(figsize=(8.4, 3.2)) ax.imshow(S_db, origin="lower", aspect="auto") ax.set_axis_off() fig.subplots_adjust(left=0, right=1, top=1, bottom=0) buf = io.BytesIO() fig.savefig(buf, format="png", dpi=160) plt.close(fig) buf.seek(0) return Image.open(buf) # ----------------------------- # Main pipeline # ----------------------------- def run_enhancement( source: str, mic_path: Optional[str], file_path: Optional[str], model_key: str, ): if not MODEL_PRESETS: raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.") chosen_path = mic_path if source == "Microphone" else file_path if not chosen_path: raise gr.Error("Please provide audio either from the microphone or by uploading a file.") x, sr_orig = _load_wav_from_gradio_path(chosen_path) y_mono, n_ch = _to_mono(x) max_samples = int(MAX_SECONDS * sr_orig) was_trimmed = len(y_mono) > max_samples if was_trimmed: y_mono = y_mono[:max_samples] dur = len(y_mono) / float(sr_orig) spec = MODEL_PRESETS[model_key] sr_model = spec.sr y_model = _resample(y_mono, sr_orig, sr_model) y_enh_model = enhance_audio_onnx(y_model, model_key) y_enh = _resample(y_enh_model, sr_model, sr_orig) y_enh = _match_length(y_enh, len(y_mono)) noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_") enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_") noisy_img = _spectrogram_image(y_mono, sr_orig) enh_img = _spectrogram_image(y_enh, sr_orig) status = ( f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n" f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n" + ( f"**Resampling:** {sr_orig} ⭢ {sr_model} ⭢ {sr_orig}\n\n" if sr_orig != sr_model else "**Resampling:** none\n\n" ) + (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "") + "\n✅ Done." ) return noisy_out, enh_out, noisy_img, enh_img, status def run_enhancement_with_count( source: str, mic_path: Optional[str], file_path: Optional[str], model_key: str, ): noisy_out, enh_out, noisy_img, enh_img, status = run_enhancement( source=source, mic_path=mic_path, file_path=file_path, model_key=model_key, ) total = increment_and_get_count() return noisy_out, enh_out, noisy_img, enh_img, status, f"**Total enhancements run:** {total:,}" def set_source_visibility(source: str): return ( gr.update(visible=(source == "Microphone")), gr.update(visible=(source == "Upload")), ) # ----------------------------- # UI (light polish) # ----------------------------- THEME = gr.themes.Soft( primary_hue="orange", neutral_hue="slate", font=[ "Arial", "ui-sans-serif", "system-ui", "Segoe UI", "Roboto", "Helvetica Neue", "Noto Sans", "Liberation Sans", "sans-serif", ], ) CSS = """ .gradio-container{ max-width: min(96vw, 1500px) !important; margin: 0 auto !important; font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important; } #header { padding: 14px 16px; border-radius: 16px; border: 1px solid rgba(0,0,0,0.08); background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04)); } #header h1{ margin: 0; font-size: 24px; font-weight: 800; letter-spacing: -0.2px; } #header p{ margin: 6px 0 0 0; color: var(--body-text-color-subdued); font-size: 13.5px; line-height: 1.35; } .spec img { border-radius: 14px; } .audio { border-radius: 14px !important; overflow: hidden; } #run_btn{ border-radius: 12px !important; font-weight: 800 !important; } #status_md p{ margin: 0.35rem 0; } """ with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo: gr.HTML( # """ # # """ """ """ ) usage_md = gr.Markdown(usage_markdown_text()) with gr.Row(): model_key = gr.Dropdown( choices=list(MODEL_PRESETS.keys()), value=DEFAULT_MODEL_KEY, label="Model", # info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.", interactive=True, ) source = gr.Radio( choices=["Microphone", "Upload"], value="Upload", label="Input source", ) with gr.Row(): mic_audio = gr.Audio( sources=["microphone"], type="filepath", format="wav", label="Microphone (max 10s)", visible=False, buttons=["download"], elem_classes=["audio"], ) file_audio = gr.Audio( sources=["upload"], type="filepath", format="wav", label="Upload file (WAV/MP3/FLAC etc., max 10s)", visible=True, buttons=["download"], elem_classes=["audio"], ) run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn") status = gr.Markdown(elem_id="status_md") gr.Markdown("## Results") with gr.Row(): out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) with gr.Row(): img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"]) img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"]) source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio]) run_btn.click( fn=run_enhancement_with_count, inputs=[source, mic_audio, file_audio, model_key], outputs=[out_noisy, out_enh, img_noisy, img_enh, status, usage_md], api_name="enhance", ) demo.load(fn=usage_markdown_text, outputs=usage_md) if __name__ == "__main__": demo.queue(max_size=32).launch()