Spaces:
Running
Running
| import io | |
| import json | |
| import math | |
| import os | |
| import tempfile | |
| from contextlib import contextmanager | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple | |
| import gradio as gr | |
| import librosa | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import onnxruntime as ort | |
| import soundfile as sf | |
| from filelock import FileLock | |
| from PIL import Image | |
| # ----------------------------- | |
| # Configuration | |
| # ----------------------------- | |
| MAX_SECONDS = 10.0 | |
| ONNX_DIR = Path("./onnx") | |
| COUNTER_DIR = Path("./data") if Path("./data").exists() else Path("./") | |
| COUNTER_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json" | |
| TMP_PATH = COUNTER_DIR / "dpdfnet_usage_counter.json.tmp" | |
| LOCK_PATH = str(COUNTER_PATH) + ".lock" | |
| print(COUNTER_PATH) | |
| class ModelSpec: | |
| name: str | |
| sr: int | |
| onnx_path: str | |
| def _maybe_lock(): | |
| with FileLock(LOCK_PATH): | |
| yield | |
| def _read_count() -> int: | |
| try: | |
| if not COUNTER_PATH.exists(): | |
| return 0 | |
| with COUNTER_PATH.open("r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| return int(data.get("enhance_runs", 0)) | |
| except Exception: | |
| return 0 | |
| def _atomic_write_json(obj: dict) -> None: | |
| COUNTER_DIR.mkdir(parents=True, exist_ok=True) | |
| with TMP_PATH.open("w", encoding="utf-8") as f: | |
| json.dump(obj, f, separators=(",", ":")) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| os.replace(TMP_PATH, COUNTER_PATH) | |
| def increment_and_get_count() -> int: | |
| with _maybe_lock(): | |
| new_value = _read_count() + 1 | |
| _atomic_write_json({"enhance_runs": new_value}) | |
| return new_value | |
| def usage_markdown_text() -> str: | |
| return f"**Total enhancements run:** {_read_count():,}" | |
| # ----------------------------- | |
| # Model discovery and metadata | |
| # ----------------------------- | |
| def _infer_model_meta(model_name: str) -> int: | |
| normalized = model_name.lower().replace("-", "_") | |
| if "48khz" in normalized or "48k" in normalized or "48hr" in normalized: | |
| return 48000 | |
| # Fallback for unknown 16 kHz DPDFNet variants | |
| return 16000 | |
| def _display_label(spec: ModelSpec) -> str: | |
| khz = int(spec.sr // 1000) | |
| return f"{spec.name} ({khz} kHz)" | |
| def discover_model_presets() -> Dict[str, ModelSpec]: | |
| ordered_names = [ | |
| "baseline", | |
| "dpdfnet2", | |
| "dpdfnet4", | |
| "dpdfnet8", | |
| "dpdfnet2_48khz_hr", | |
| ] | |
| found_paths = {p.stem: p for p in ONNX_DIR.glob("*.onnx") if p.is_file()} | |
| presets: Dict[str, ModelSpec] = {} | |
| for name in ordered_names: | |
| p = found_paths.get(name) | |
| if p is None: | |
| continue | |
| sr = _infer_model_meta(name) | |
| spec = ModelSpec( | |
| name=name, | |
| sr=sr, | |
| onnx_path=str(p), | |
| ) | |
| presets[_display_label(spec)] = spec | |
| # Include any additional ONNX files not in the canonical order list. | |
| for name, p in sorted(found_paths.items()): | |
| if name in ordered_names: | |
| continue | |
| sr = _infer_model_meta(name) | |
| spec = ModelSpec( | |
| name=name, | |
| sr=sr, | |
| onnx_path=str(p), | |
| ) | |
| presets[_display_label(spec)] = spec | |
| return presets | |
| MODEL_PRESETS = discover_model_presets() | |
| DEFAULT_MODEL_KEY = next(iter(MODEL_PRESETS), None) | |
| # ----------------------------- | |
| # ONNX Runtime + frontend cache | |
| # ----------------------------- | |
| _SESSIONS: Dict[str, ort.InferenceSession] = {} | |
| _INIT_STATES: Dict[str, np.ndarray] = {} | |
| def resolve_model_path(local_path: str) -> str: | |
| p = Path(local_path) | |
| if p.exists(): | |
| return str(p) | |
| raise gr.Error( | |
| f"ONNX model not found at: {local_path}. " | |
| "Expected local models under ./onnx/." | |
| ) | |
| def get_ort_session(model_key: str) -> ort.InferenceSession: | |
| if model_key in _SESSIONS: | |
| return _SESSIONS[model_key] | |
| spec = MODEL_PRESETS[model_key] | |
| onnx_path = resolve_model_path(spec.onnx_path) | |
| options = ort.SessionOptions() | |
| options.intra_op_num_threads = 1 | |
| options.inter_op_num_threads = 1 | |
| sess = ort.InferenceSession( | |
| onnx_path, | |
| sess_options=options, | |
| providers=["CPUExecutionProvider"], | |
| ) | |
| _SESSIONS[model_key] = sess | |
| return sess | |
| def _resolve_state_path(model_key: str) -> Path: | |
| spec = MODEL_PRESETS[model_key] | |
| model_path = Path(spec.onnx_path) | |
| state_path = model_path.with_name(f"{model_path.stem}_state.npz") | |
| if not state_path.is_file(): | |
| raise gr.Error(f"State file not found: {state_path}") | |
| return state_path | |
| def _load_initial_state(model_key: str, session: ort.InferenceSession) -> np.ndarray: | |
| if model_key in _INIT_STATES: | |
| return _INIT_STATES[model_key] | |
| state_path = _resolve_state_path(model_key) | |
| with np.load(state_path) as data: | |
| if "init_state" not in data: | |
| raise gr.Error(f"Missing 'init_state' key in state file: {state_path}") | |
| init_state = np.ascontiguousarray(data["init_state"].astype(np.float32, copy=False)) | |
| expected_shape = session.get_inputs()[1].shape | |
| if len(expected_shape) != init_state.ndim: | |
| raise gr.Error( | |
| f"Initial state rank mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}" | |
| ) | |
| for exp_dim, act_dim in zip(expected_shape, init_state.shape): | |
| if isinstance(exp_dim, int) and exp_dim != act_dim: | |
| raise gr.Error( | |
| f"Initial state shape mismatch for {state_path.name}: expected={expected_shape}, got={tuple(init_state.shape)}" | |
| ) | |
| _INIT_STATES[model_key] = init_state | |
| return init_state | |
| # ----------------------------- | |
| # STFT/iSTFT (module-free) | |
| # ----------------------------- | |
| def vorbis_window(window_len: int) -> np.ndarray: | |
| window_size_h = window_len / 2 | |
| indices = np.arange(window_len) | |
| sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h) | |
| window = np.sin(0.5 * np.pi * sin * sin) | |
| return window.astype(np.float32) | |
| def get_wnorm(window_len: int, frame_size: int) -> float: | |
| return 1.0 / (window_len ** 2 / (2 * frame_size)) | |
| def _infer_stft_params(model_key: str, session: ort.InferenceSession) -> Tuple[int, int, float, np.ndarray]: | |
| # ONNX spec input is [B, T, F, 2] (or dynamic variants). | |
| spec_shape = session.get_inputs()[0].shape | |
| freq_bins = spec_shape[-2] if len(spec_shape) >= 2 else None | |
| if isinstance(freq_bins, int) and freq_bins > 1: | |
| win_len = int((freq_bins - 1) * 2) | |
| else: | |
| # 20 ms windows for DPDFNet family. | |
| sr = MODEL_PRESETS[model_key].sr | |
| win_len = int(round(sr * 0.02)) | |
| hop = win_len // 2 | |
| win = vorbis_window(win_len) | |
| wnorm = get_wnorm(win_len, hop) | |
| return win_len, hop, wnorm, win | |
| def _preprocess_waveform(waveform: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray: | |
| audio = np.asarray(waveform, dtype=np.float32).reshape(-1) | |
| audio_pad = np.pad(audio, (0, win_len), mode="constant") | |
| spec = librosa.stft( | |
| y=audio_pad, | |
| n_fft=win_len, | |
| hop_length=hop, | |
| win_length=win_len, | |
| window=win, | |
| center=True, | |
| pad_mode="reflect", | |
| ) | |
| spec = (spec.T * wnorm).astype(np.complex64, copy=False) # [T, F] | |
| spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32, copy=False) # [T, F, 2] | |
| return spec_ri[None, ...] # [1, T, F, 2] | |
| def _postprocess_spec(spec_e: np.ndarray, win_len: int, hop: int, wnorm: float, win: np.ndarray) -> np.ndarray: | |
| spec_c = np.asarray(spec_e[0], dtype=np.float32) # [T, F, 2] | |
| spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64, copy=False) # [F, T] | |
| waveform_e = librosa.istft( | |
| spec, | |
| hop_length=hop, | |
| win_length=win_len, | |
| window=win, | |
| center=True, | |
| length=None, | |
| ).astype(np.float32, copy=False) | |
| waveform_e = waveform_e / wnorm | |
| waveform_e = np.concatenate( | |
| [waveform_e[win_len * 2 :], np.zeros(win_len * 2, dtype=np.float32)], | |
| axis=0, | |
| ) | |
| return waveform_e | |
| # ----------------------------- | |
| # ONNX inference (non-streaming pre/post, streaming ONNX state loop) | |
| # ----------------------------- | |
| def enhance_audio_onnx( | |
| audio_mono: np.ndarray, | |
| model_key: str, | |
| ) -> np.ndarray: | |
| sess = get_ort_session(model_key) | |
| inputs = sess.get_inputs() | |
| outputs = sess.get_outputs() | |
| if len(inputs) < 2 or len(outputs) < 2: | |
| raise gr.Error( | |
| "Expected streaming ONNX signature with 2 inputs (spec, state) and 2 outputs (spec_e, state_out)." | |
| ) | |
| in_spec_name = inputs[0].name | |
| in_state_name = inputs[1].name | |
| out_spec_name = outputs[0].name | |
| out_state_name = outputs[1].name | |
| waveform = np.asarray(audio_mono, dtype=np.float32).reshape(-1) | |
| win_len, hop, wnorm, win = _infer_stft_params(model_key, sess) | |
| spec_r_np = _preprocess_waveform(waveform, win_len=win_len, hop=hop, wnorm=wnorm, win=win) | |
| state = _load_initial_state(model_key, sess).copy() | |
| spec_e_frames = [] | |
| num_frames = int(spec_r_np.shape[1]) | |
| for t in range(num_frames): | |
| spec_t = np.ascontiguousarray(spec_r_np[:, t : t + 1, :, :], dtype=np.float32) | |
| spec_e_t, state = sess.run( | |
| [out_spec_name, out_state_name], | |
| {in_spec_name: spec_t, in_state_name: state}, | |
| ) | |
| spec_e_frames.append(np.ascontiguousarray(spec_e_t, dtype=np.float32)) | |
| if not spec_e_frames: | |
| return waveform | |
| spec_e_np = np.concatenate(spec_e_frames, axis=1) | |
| waveform_e = _postprocess_spec(spec_e_np, win_len=win_len, hop=hop, wnorm=wnorm, win=win) | |
| return np.asarray(waveform_e, dtype=np.float32).reshape(-1) | |
| # ----------------------------- | |
| # Audio utilities | |
| # ----------------------------- | |
| def _load_wav_from_gradio_path(path: str) -> Tuple[np.ndarray, int]: | |
| data, sr = sf.read(path, always_2d=True) | |
| data = data.astype(np.float32, copy=False) | |
| return data, int(sr) | |
| def _to_mono(x: np.ndarray) -> Tuple[np.ndarray, int]: | |
| if x.ndim == 1: | |
| return x.astype(np.float32, copy=False), 1 | |
| if x.shape[1] == 1: | |
| return x[:, 0], 1 | |
| return x.mean(axis=1), int(x.shape[1]) | |
| def _resample(y: np.ndarray, sr_in: int, sr_out: int) -> np.ndarray: | |
| if sr_in == sr_out: | |
| return y | |
| return librosa.resample(y, orig_sr=sr_in, target_sr=sr_out).astype(np.float32, copy=False) | |
| def _match_length(y: np.ndarray, target_len: int) -> np.ndarray: | |
| if len(y) == target_len: | |
| return y | |
| if len(y) > target_len: | |
| return y[:target_len] | |
| out = np.zeros((target_len,), dtype=y.dtype) | |
| out[: len(y)] = y | |
| return out | |
| def _save_wav(y: np.ndarray, sr: int, prefix: str) -> str: | |
| tmp = tempfile.NamedTemporaryFile(prefix=prefix, suffix=".wav", delete=False) | |
| tmp.close() | |
| sf.write(tmp.name, y, sr) | |
| return tmp.name | |
| def _spectrogram_image(y: np.ndarray, sr: int) -> Image.Image: | |
| win_length = max(256, int(0.032 * sr)) | |
| hop_length = max(64, int(0.008 * sr)) | |
| n_fft = 1 << (int(math.ceil(math.log2(win_length)))) | |
| S = librosa.stft(y, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False) | |
| S_db = librosa.amplitude_to_db(np.abs(S) + 1e-10, ref=np.max) | |
| fig, ax = plt.subplots(figsize=(8.4, 3.2)) | |
| ax.imshow(S_db, origin="lower", aspect="auto") | |
| ax.set_axis_off() | |
| fig.subplots_adjust(left=0, right=1, top=1, bottom=0) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=160) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf) | |
| # ----------------------------- | |
| # Main pipeline | |
| # ----------------------------- | |
| def run_enhancement( | |
| source: str, | |
| mic_path: Optional[str], | |
| file_path: Optional[str], | |
| model_key: str, | |
| ): | |
| if not MODEL_PRESETS: | |
| raise gr.Error("No ONNX models found under ./onnx/. Add models and retry.") | |
| chosen_path = mic_path if source == "Microphone" else file_path | |
| if not chosen_path: | |
| raise gr.Error("Please provide audio either from the microphone or by uploading a file.") | |
| x, sr_orig = _load_wav_from_gradio_path(chosen_path) | |
| y_mono, n_ch = _to_mono(x) | |
| max_samples = int(MAX_SECONDS * sr_orig) | |
| was_trimmed = len(y_mono) > max_samples | |
| if was_trimmed: | |
| y_mono = y_mono[:max_samples] | |
| dur = len(y_mono) / float(sr_orig) | |
| spec = MODEL_PRESETS[model_key] | |
| sr_model = spec.sr | |
| y_model = _resample(y_mono, sr_orig, sr_model) | |
| y_enh_model = enhance_audio_onnx(y_model, model_key) | |
| y_enh = _resample(y_enh_model, sr_model, sr_orig) | |
| y_enh = _match_length(y_enh, len(y_mono)) | |
| noisy_out = _save_wav(y_mono, sr_orig, prefix="noisy_mono_") | |
| enh_out = _save_wav(y_enh, sr_orig, prefix="enhanced_") | |
| noisy_img = _spectrogram_image(y_mono, sr_orig) | |
| enh_img = _spectrogram_image(y_enh, sr_orig) | |
| status = ( | |
| f"**Input:** {sr_orig} Hz, {dur:.2f}s, channels={n_ch} ⭢ mono\n\n" | |
| f"**Model:** {spec.name} (runs at {sr_model} Hz)\n\n" | |
| + ( | |
| f"**Resampling:** {sr_orig} ⭢ {sr_model} ⭢ {sr_orig}\n\n" | |
| if sr_orig != sr_model | |
| else "**Resampling:** none\n\n" | |
| ) | |
| + (f"**Trimmed:** first {MAX_SECONDS:.0f}s used\n" if was_trimmed else "") | |
| + "\n✅ Done." | |
| ) | |
| return noisy_out, enh_out, noisy_img, enh_img, status | |
| def run_enhancement_with_count( | |
| source: str, | |
| mic_path: Optional[str], | |
| file_path: Optional[str], | |
| model_key: str, | |
| ): | |
| noisy_out, enh_out, noisy_img, enh_img, status = run_enhancement( | |
| source=source, | |
| mic_path=mic_path, | |
| file_path=file_path, | |
| model_key=model_key, | |
| ) | |
| total = increment_and_get_count() | |
| return noisy_out, enh_out, noisy_img, enh_img, status, f"**Total enhancements run:** {total:,}" | |
| def set_source_visibility(source: str): | |
| return ( | |
| gr.update(visible=(source == "Microphone")), | |
| gr.update(visible=(source == "Upload")), | |
| ) | |
| # ----------------------------- | |
| # UI (light polish) | |
| # ----------------------------- | |
| THEME = gr.themes.Soft( | |
| primary_hue="orange", | |
| neutral_hue="slate", | |
| font=[ | |
| "Arial", | |
| "ui-sans-serif", | |
| "system-ui", | |
| "Segoe UI", | |
| "Roboto", | |
| "Helvetica Neue", | |
| "Noto Sans", | |
| "Liberation Sans", | |
| "sans-serif", | |
| ], | |
| ) | |
| CSS = """ | |
| .gradio-container{ | |
| max-width: min(96vw, 1500px) !important; | |
| margin: 0 auto !important; | |
| font-family: Arial, ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica Neue, Noto Sans, Liberation Sans, sans-serif !important; | |
| } | |
| #header { | |
| padding: 14px 16px; | |
| border-radius: 16px; | |
| border: 1px solid rgba(0,0,0,0.08); | |
| background: linear-gradient(135deg, rgba(255,152,0,0.14), rgba(255,152,0,0.04)); | |
| } | |
| #header h1{ | |
| margin: 0; | |
| font-size: 24px; | |
| font-weight: 800; | |
| letter-spacing: -0.2px; | |
| } | |
| #header p{ | |
| margin: 6px 0 0 0; | |
| color: var(--body-text-color-subdued); | |
| font-size: 13.5px; | |
| line-height: 1.35; | |
| } | |
| .spec img { border-radius: 14px; } | |
| .audio { border-radius: 14px !important; overflow: hidden; } | |
| #run_btn{ | |
| border-radius: 12px !important; | |
| font-weight: 800 !important; | |
| } | |
| #status_md p{ margin: 0.35rem 0; } | |
| """ | |
| with gr.Blocks(theme=THEME, css=CSS, title="DPDFNet Speech Enhancement") as demo: | |
| gr.HTML( | |
| # """ | |
| # <div id="header"> | |
| # <h1>DPDFNet Speech Enhancement</h1> | |
| # <p> | |
| # Upload or record up to 10 seconds. Multi-channel inputs are averaged to mono. | |
| # Choose any local ONNX model from <code>./onnx</code>. | |
| # Pre/postprocessing uses the same non-streaming STFT/iSTFT flow as <code>streaming/infer_dpdfnet_onnx.py</code>. | |
| # </p> | |
| # </div> | |
| # """ | |
| """ | |
| <div id="header" style="text-align: center; margin-bottom: 25px;"> | |
| <h1 style="margin-bottom: 6px;">DPDFNet Speech Enhancement</h1> | |
| <p style="font-size: 14px; letter-spacing: 1px; margin-bottom: 14px; color: #555;"> | |
| Causal • Real-Time • Edge-Ready | |
| </p> | |
| <p style="max-width: 720px; margin: 0 auto; font-size: 15px; line-height: 1.6;"> | |
| DPDFNet extends DeepFilterNet2 with Dual-Path RNN blocks to improve | |
| long-range temporal and cross-band modeling while preserving low latency. | |
| Designed for single-channel streaming speech enhancement under challenging noise conditions. | |
| </p> | |
| <hr style="margin-top: 22px; border: none; height: 1px; background: linear-gradient(to right, transparent, #ddd, transparent);"> | |
| </div> | |
| """ | |
| ) | |
| usage_md = gr.Markdown(usage_markdown_text()) | |
| with gr.Row(): | |
| model_key = gr.Dropdown( | |
| choices=list(MODEL_PRESETS.keys()), | |
| value=DEFAULT_MODEL_KEY, | |
| label="Model", | |
| # info="Audio is resampled to model SR, enhanced with ONNX, then resampled back.", | |
| interactive=True, | |
| ) | |
| source = gr.Radio( | |
| choices=["Microphone", "Upload"], | |
| value="Upload", | |
| label="Input source", | |
| ) | |
| with gr.Row(): | |
| mic_audio = gr.Audio( | |
| sources=["microphone"], | |
| type="filepath", | |
| format="wav", | |
| label="Microphone (max 10s)", | |
| visible=False, | |
| buttons=["download"], | |
| elem_classes=["audio"], | |
| ) | |
| file_audio = gr.Audio( | |
| sources=["upload"], | |
| type="filepath", | |
| format="wav", | |
| label="Upload file (WAV/MP3/FLAC etc., max 10s)", | |
| visible=True, | |
| buttons=["download"], | |
| elem_classes=["audio"], | |
| ) | |
| run_btn = gr.Button("Enhance", variant="primary", elem_id="run_btn") | |
| status = gr.Markdown(elem_id="status_md") | |
| gr.Markdown("## Results") | |
| with gr.Row(): | |
| out_noisy = gr.Audio(label="Before (mono)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) | |
| out_enh = gr.Audio(label="After (enhanced)", interactive=False, format="wav", buttons=["download"], elem_classes=["audio"]) | |
| with gr.Row(): | |
| img_noisy = gr.Image(label="Noisy spectrogram", elem_classes=["spec"]) | |
| img_enh = gr.Image(label="Enhanced spectrogram", elem_classes=["spec"]) | |
| source.change(fn=set_source_visibility, inputs=source, outputs=[mic_audio, file_audio]) | |
| run_btn.click( | |
| fn=run_enhancement_with_count, | |
| inputs=[source, mic_audio, file_audio, model_key], | |
| outputs=[out_noisy, out_enh, img_noisy, img_enh, status, usage_md], | |
| api_name="enhance", | |
| ) | |
| demo.load(fn=usage_markdown_text, outputs=usage_md) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=32).launch() | |