mimi-endpointer / predict.py
viks66's picture
add predict.py
a96f2f1 verified
Raw
History Blame Contribute Delete
6.99 kB
# Author: Sathvik Udupa (2026)
# Email: udupa@fit.vutbr.cz
# Paper: Streaming Endpointer for Spoken Dialogue using Neural Audio Codecs and Label-Delayed Training, https://arxiv.org/abs/2506.07081, ASRU 2025
"""Mimi Endpointer — DiscriminativeModel for the TURN benchmark.
Two-stream LSTM over Mimi embeddings, streamed 20ms chunk at a time.
Mimi operates in 1920-sample (80ms) chunks → 2 LSTM frames per chunk.
Four harness steps are buffered before each Mimi run; floor bit is held
between updates. Inherent latency: ~80ms.
floor = 1 if P(user) > threshold else 0
subject is always fed as channel 0 (user); other as channel 1 (system).
Debug mode (MIMI_DEBUG=1): saves debug_pass{N}.npz per conversation pass;
run plot_debug.py afterwards to render PNGs.
Sweep mode (MIMI_SWEEP=1): runs the harness sweep over thresholds 0.05–0.95
in a single inference pass; threshold is reported per-step as a list[int].
"""
from __future__ import annotations
import atexit
import os
import sys
from pathlib import Path
import numpy as np
import torch
_HERE = Path(__file__).resolve().parent
# model.py lives alongside predict.py in the HF flat layout, or one level up in the
# local nested layout (turn-bench-submission/ inside baselines/mimi_endpointer/)
sys.path.insert(0, str(_HERE))
sys.path.insert(0, str(_HERE.parent))
from model import ( # noqa: E402
AudioFeatureExtractor,
AUDIO_DEFAULTS,
load_model as load_mimi_model,
)
IDX_USER = 4 # from training config: {bos:0, system_end:1, user_end:2, system:3, user:4}
# checkpoint.pt is alongside predict.py (HF) or one level up (local)
CHECKPOINT = next(
p for p in (_HERE / "checkpoint.pt", _HERE.parent / "checkpoint.pt") if p.exists()
)
_CHUNK_STEPS = 4 # 4 × 20ms = 80ms = one Mimi frame_size (1920 samples at 24kHz)
_SR = 24_000
_FRAME_RATE = 50 # harness step rate (Hz)
class MimiEndpointerModel:
input_sample_rate = _SR # Mimi native rate; 24000 % 50 == 0 → 480 samples/step
def __init__(
self,
threshold: float = 0.5,
thresholds: list[float] | None = None,
debug: bool = False,
) -> None:
# sweep mode: thresholds is a list; single mode: scalar threshold
if thresholds is not None:
self.thresholds = thresholds # harness detects sweep mode via hasattr
self._thresholds_arr = thresholds
else:
self.threshold = threshold # single operating point
self._thresholds_arr = [threshold]
self._sweep = thresholds is not None
self.debug = debug
device = "cuda" if torch.cuda.is_available() else "cpu"
self._device = device
self._model = load_mimi_model(str(CHECKPOINT), device=device)
self._extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device)
self._ctx = None
self._debug_idx = 0
self._log_subj: list[np.ndarray] = []
self._log_other: list[np.ndarray] = []
self._log_floor: list[int] = []
self._log_probs: list[np.ndarray] = [] # all 5 class probs per step (T, 5)
if debug:
atexit.register(self._save_npz)
self.reset()
def reset(self) -> None:
if self.debug:
self._save_npz()
if self._ctx is not None:
self._ctx.__exit__(None, None, None)
self._ctx = self._extractor.mimi.streaming(batch_size=2)
self._ctx.__enter__()
self._h1, self._c1 = self._model.init_hidden(1, self._device)
self._h2, self._c2 = self._model.init_hidden(1, self._device)
self._buf_subj: list[np.ndarray] = []
self._buf_other: list[np.ndarray] = []
self._floor_bits: list[int] = [0] * len(self._thresholds_arr)
self._log_subj = []
self._log_other = []
self._log_floor = []
self._log_probs = []
def __del__(self) -> None:
if self._ctx is not None:
self._ctx.__exit__(None, None, None)
def _save_npz(self) -> None:
if not self._log_floor:
return
out = _HERE / f"debug_pass{self._debug_idx}.npz"
np.savez(
out,
subj=np.concatenate(self._log_subj),
other=np.concatenate(self._log_other),
floor=np.array(self._log_floor, dtype=np.int8),
probs=np.array(self._log_probs, dtype=np.float32), # (T, 5)
threshold=np.float32(self._thresholds_arr[0]),
sr=np.int32(_SR),
frame_rate=np.int32(_FRAME_RATE),
)
sys.stderr.write(f"[debug] saved → {out}\n")
self._debug_idx += 1
def step(self, subject_audio: np.ndarray, other_audio: np.ndarray):
self._buf_subj.append(subject_audio)
self._buf_other.append(other_audio)
if self.debug:
self._log_subj.append(subject_audio)
self._log_other.append(other_audio)
new_probs: np.ndarray | None = None
if len(self._buf_subj) == _CHUNK_STEPS:
chunk_s = torch.from_numpy(np.concatenate(self._buf_subj)).to(self._device)
chunk_o = torch.from_numpy(np.concatenate(self._buf_other)).to(self._device)
self._buf_subj.clear()
self._buf_other.clear()
# (2, 1, 1920) — subject=channel 0 (user), other=channel 1 (system)
chunk = torch.stack([chunk_s, chunk_o]).unsqueeze(1)
with torch.no_grad():
emb = self._extractor.mimi.encode_to_latent(chunk, quantize=True) # (2, feat, 1)
emb = self._extractor.mimi.upsample(emb) # (2, feat, 2)
logits = None
for t in range(emb.shape[-1]):
logits, self._h1, self._c1, self._h2, self._c2 = \
self._model.infer_ar_step(
emb[0:1, :, t], emb[1:2, :, t],
self._h1, self._c1, self._h2, self._c2,
)
new_probs = torch.softmax(logits[0], dim=-1).cpu().numpy() # (5,)
p_user = new_probs[IDX_USER]
self._floor_bits = [1 if p_user > t else 0 for t in self._thresholds_arr]
if self.debug:
p = new_probs if new_probs is not None else (
self._log_probs[-1] if self._log_probs else np.zeros(5, dtype=np.float32)
)
self._log_probs.append(p)
self._log_floor.append(self._floor_bits[0]) # first threshold for debug plot
return self._floor_bits if self._sweep else self._floor_bits[0]
def load_model() -> MimiEndpointerModel:
debug = os.environ.get("MIMI_DEBUG", "0") == "1"
sweep = os.environ.get("MIMI_SWEEP", "0") == "1"
if sweep:
thresholds = list(np.round(np.arange(0.05, 1.0, 0.05), 2).tolist())
return MimiEndpointerModel(thresholds=thresholds, debug=debug)
thr = float(os.environ.get("MIMI_THRESHOLD", "0.1"))
return MimiEndpointerModel(threshold=thr, debug=debug)