# Author: Sathvik Udupa (2026) # Email: udupa@fit.vutbr.cz # Paper: Streaming Endpointer for Spoken Dialogue using Neural Audio Codecs and Label-Delayed Training, https://arxiv.org/abs/2506.07081, ASRU 2025 """Mimi Endpointer — DiscriminativeModel for the TURN benchmark. Two-stream LSTM over Mimi embeddings, streamed 20ms chunk at a time. Mimi operates in 1920-sample (80ms) chunks → 2 LSTM frames per chunk. Four harness steps are buffered before each Mimi run; floor bit is held between updates. Inherent latency: ~80ms. floor = 1 if P(user) > threshold else 0 subject is always fed as channel 0 (user); other as channel 1 (system). Debug mode (MIMI_DEBUG=1): saves debug_pass{N}.npz per conversation pass; run plot_debug.py afterwards to render PNGs. Sweep mode (MIMI_SWEEP=1): runs the harness sweep over thresholds 0.05–0.95 in a single inference pass; threshold is reported per-step as a list[int]. """ from __future__ import annotations import atexit import os import sys from pathlib import Path import numpy as np import torch _HERE = Path(__file__).resolve().parent # model.py lives alongside predict.py in the HF flat layout, or one level up in the # local nested layout (turn-bench-submission/ inside baselines/mimi_endpointer/) sys.path.insert(0, str(_HERE)) sys.path.insert(0, str(_HERE.parent)) from model import ( # noqa: E402 AudioFeatureExtractor, AUDIO_DEFAULTS, load_model as load_mimi_model, ) IDX_USER = 4 # from training config: {bos:0, system_end:1, user_end:2, system:3, user:4} # checkpoint.pt is alongside predict.py (HF) or one level up (local) CHECKPOINT = next( p for p in (_HERE / "checkpoint.pt", _HERE.parent / "checkpoint.pt") if p.exists() ) _CHUNK_STEPS = 4 # 4 × 20ms = 80ms = one Mimi frame_size (1920 samples at 24kHz) _SR = 24_000 _FRAME_RATE = 50 # harness step rate (Hz) class MimiEndpointerModel: input_sample_rate = _SR # Mimi native rate; 24000 % 50 == 0 → 480 samples/step def __init__( self, threshold: float = 0.5, thresholds: list[float] | None = None, debug: bool = False, ) -> None: # sweep mode: thresholds is a list; single mode: scalar threshold if thresholds is not None: self.thresholds = thresholds # harness detects sweep mode via hasattr self._thresholds_arr = thresholds else: self.threshold = threshold # single operating point self._thresholds_arr = [threshold] self._sweep = thresholds is not None self.debug = debug device = "cuda" if torch.cuda.is_available() else "cpu" self._device = device self._model = load_mimi_model(str(CHECKPOINT), device=device) self._extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device) self._ctx = None self._debug_idx = 0 self._log_subj: list[np.ndarray] = [] self._log_other: list[np.ndarray] = [] self._log_floor: list[int] = [] self._log_probs: list[np.ndarray] = [] # all 5 class probs per step (T, 5) if debug: atexit.register(self._save_npz) self.reset() def reset(self) -> None: if self.debug: self._save_npz() if self._ctx is not None: self._ctx.__exit__(None, None, None) self._ctx = self._extractor.mimi.streaming(batch_size=2) self._ctx.__enter__() self._h1, self._c1 = self._model.init_hidden(1, self._device) self._h2, self._c2 = self._model.init_hidden(1, self._device) self._buf_subj: list[np.ndarray] = [] self._buf_other: list[np.ndarray] = [] self._floor_bits: list[int] = [0] * len(self._thresholds_arr) self._log_subj = [] self._log_other = [] self._log_floor = [] self._log_probs = [] def __del__(self) -> None: if self._ctx is not None: self._ctx.__exit__(None, None, None) def _save_npz(self) -> None: if not self._log_floor: return out = _HERE / f"debug_pass{self._debug_idx}.npz" np.savez( out, subj=np.concatenate(self._log_subj), other=np.concatenate(self._log_other), floor=np.array(self._log_floor, dtype=np.int8), probs=np.array(self._log_probs, dtype=np.float32), # (T, 5) threshold=np.float32(self._thresholds_arr[0]), sr=np.int32(_SR), frame_rate=np.int32(_FRAME_RATE), ) sys.stderr.write(f"[debug] saved → {out}\n") self._debug_idx += 1 def step(self, subject_audio: np.ndarray, other_audio: np.ndarray): self._buf_subj.append(subject_audio) self._buf_other.append(other_audio) if self.debug: self._log_subj.append(subject_audio) self._log_other.append(other_audio) new_probs: np.ndarray | None = None if len(self._buf_subj) == _CHUNK_STEPS: chunk_s = torch.from_numpy(np.concatenate(self._buf_subj)).to(self._device) chunk_o = torch.from_numpy(np.concatenate(self._buf_other)).to(self._device) self._buf_subj.clear() self._buf_other.clear() # (2, 1, 1920) — subject=channel 0 (user), other=channel 1 (system) chunk = torch.stack([chunk_s, chunk_o]).unsqueeze(1) with torch.no_grad(): emb = self._extractor.mimi.encode_to_latent(chunk, quantize=True) # (2, feat, 1) emb = self._extractor.mimi.upsample(emb) # (2, feat, 2) logits = None for t in range(emb.shape[-1]): logits, self._h1, self._c1, self._h2, self._c2 = \ self._model.infer_ar_step( emb[0:1, :, t], emb[1:2, :, t], self._h1, self._c1, self._h2, self._c2, ) new_probs = torch.softmax(logits[0], dim=-1).cpu().numpy() # (5,) p_user = new_probs[IDX_USER] self._floor_bits = [1 if p_user > t else 0 for t in self._thresholds_arr] if self.debug: p = new_probs if new_probs is not None else ( self._log_probs[-1] if self._log_probs else np.zeros(5, dtype=np.float32) ) self._log_probs.append(p) self._log_floor.append(self._floor_bits[0]) # first threshold for debug plot return self._floor_bits if self._sweep else self._floor_bits[0] def load_model() -> MimiEndpointerModel: debug = os.environ.get("MIMI_DEBUG", "0") == "1" sweep = os.environ.get("MIMI_SWEEP", "0") == "1" if sweep: thresholds = list(np.round(np.arange(0.05, 1.0, 0.05), 2).tolist()) return MimiEndpointerModel(thresholds=thresholds, debug=debug) thr = float(os.environ.get("MIMI_THRESHOLD", "0.1")) return MimiEndpointerModel(threshold=thr, debug=debug)