| |
| |
| |
|
|
| """Mimi Endpointer — DiscriminativeModel for the TURN benchmark. |
| |
| Two-stream LSTM over Mimi embeddings, streamed 20ms chunk at a time. |
| Mimi operates in 1920-sample (80ms) chunks → 2 LSTM frames per chunk. |
| Four harness steps are buffered before each Mimi run; floor bit is held |
| between updates. Inherent latency: ~80ms. |
| |
| floor = 1 if P(user) > threshold else 0 |
| |
| subject is always fed as channel 0 (user); other as channel 1 (system). |
| |
| Debug mode (MIMI_DEBUG=1): saves debug_pass{N}.npz per conversation pass; |
| run plot_debug.py afterwards to render PNGs. |
| |
| Sweep mode (MIMI_SWEEP=1): runs the harness sweep over thresholds 0.05–0.95 |
| in a single inference pass; threshold is reported per-step as a list[int]. |
| """ |
| from __future__ import annotations |
|
|
| import atexit |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| _HERE = Path(__file__).resolve().parent |
| |
| |
| sys.path.insert(0, str(_HERE)) |
| sys.path.insert(0, str(_HERE.parent)) |
|
|
| from model import ( |
| AudioFeatureExtractor, |
| AUDIO_DEFAULTS, |
| load_model as load_mimi_model, |
| ) |
|
|
| IDX_USER = 4 |
|
|
| |
| CHECKPOINT = next( |
| p for p in (_HERE / "checkpoint.pt", _HERE.parent / "checkpoint.pt") if p.exists() |
| ) |
|
|
| _CHUNK_STEPS = 4 |
| _SR = 24_000 |
| _FRAME_RATE = 50 |
|
|
|
|
| class MimiEndpointerModel: |
| input_sample_rate = _SR |
|
|
| def __init__( |
| self, |
| threshold: float = 0.5, |
| thresholds: list[float] | None = None, |
| debug: bool = False, |
| ) -> None: |
| |
| if thresholds is not None: |
| self.thresholds = thresholds |
| self._thresholds_arr = thresholds |
| else: |
| self.threshold = threshold |
| self._thresholds_arr = [threshold] |
| self._sweep = thresholds is not None |
| self.debug = debug |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| self._device = device |
| self._model = load_mimi_model(str(CHECKPOINT), device=device) |
| self._extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device) |
| self._ctx = None |
| self._debug_idx = 0 |
| self._log_subj: list[np.ndarray] = [] |
| self._log_other: list[np.ndarray] = [] |
| self._log_floor: list[int] = [] |
| self._log_probs: list[np.ndarray] = [] |
| if debug: |
| atexit.register(self._save_npz) |
| self.reset() |
|
|
| def reset(self) -> None: |
| if self.debug: |
| self._save_npz() |
| if self._ctx is not None: |
| self._ctx.__exit__(None, None, None) |
| self._ctx = self._extractor.mimi.streaming(batch_size=2) |
| self._ctx.__enter__() |
| self._h1, self._c1 = self._model.init_hidden(1, self._device) |
| self._h2, self._c2 = self._model.init_hidden(1, self._device) |
| self._buf_subj: list[np.ndarray] = [] |
| self._buf_other: list[np.ndarray] = [] |
| self._floor_bits: list[int] = [0] * len(self._thresholds_arr) |
| self._log_subj = [] |
| self._log_other = [] |
| self._log_floor = [] |
| self._log_probs = [] |
|
|
| def __del__(self) -> None: |
| if self._ctx is not None: |
| self._ctx.__exit__(None, None, None) |
|
|
| def _save_npz(self) -> None: |
| if not self._log_floor: |
| return |
| out = _HERE / f"debug_pass{self._debug_idx}.npz" |
| np.savez( |
| out, |
| subj=np.concatenate(self._log_subj), |
| other=np.concatenate(self._log_other), |
| floor=np.array(self._log_floor, dtype=np.int8), |
| probs=np.array(self._log_probs, dtype=np.float32), |
| threshold=np.float32(self._thresholds_arr[0]), |
| sr=np.int32(_SR), |
| frame_rate=np.int32(_FRAME_RATE), |
| ) |
| sys.stderr.write(f"[debug] saved → {out}\n") |
| self._debug_idx += 1 |
|
|
| def step(self, subject_audio: np.ndarray, other_audio: np.ndarray): |
| self._buf_subj.append(subject_audio) |
| self._buf_other.append(other_audio) |
|
|
| if self.debug: |
| self._log_subj.append(subject_audio) |
| self._log_other.append(other_audio) |
|
|
| new_probs: np.ndarray | None = None |
|
|
| if len(self._buf_subj) == _CHUNK_STEPS: |
| chunk_s = torch.from_numpy(np.concatenate(self._buf_subj)).to(self._device) |
| chunk_o = torch.from_numpy(np.concatenate(self._buf_other)).to(self._device) |
| self._buf_subj.clear() |
| self._buf_other.clear() |
|
|
| |
| chunk = torch.stack([chunk_s, chunk_o]).unsqueeze(1) |
| with torch.no_grad(): |
| emb = self._extractor.mimi.encode_to_latent(chunk, quantize=True) |
| emb = self._extractor.mimi.upsample(emb) |
| logits = None |
| for t in range(emb.shape[-1]): |
| logits, self._h1, self._c1, self._h2, self._c2 = \ |
| self._model.infer_ar_step( |
| emb[0:1, :, t], emb[1:2, :, t], |
| self._h1, self._c1, self._h2, self._c2, |
| ) |
| new_probs = torch.softmax(logits[0], dim=-1).cpu().numpy() |
| p_user = new_probs[IDX_USER] |
| self._floor_bits = [1 if p_user > t else 0 for t in self._thresholds_arr] |
|
|
| if self.debug: |
| p = new_probs if new_probs is not None else ( |
| self._log_probs[-1] if self._log_probs else np.zeros(5, dtype=np.float32) |
| ) |
| self._log_probs.append(p) |
| self._log_floor.append(self._floor_bits[0]) |
|
|
| return self._floor_bits if self._sweep else self._floor_bits[0] |
|
|
|
|
| def load_model() -> MimiEndpointerModel: |
| debug = os.environ.get("MIMI_DEBUG", "0") == "1" |
| sweep = os.environ.get("MIMI_SWEEP", "0") == "1" |
| if sweep: |
| thresholds = list(np.round(np.arange(0.05, 1.0, 0.05), 2).tolist()) |
| return MimiEndpointerModel(thresholds=thresholds, debug=debug) |
| thr = float(os.environ.get("MIMI_THRESHOLD", "0.1")) |
| return MimiEndpointerModel(threshold=thr, debug=debug) |
|
|