import torch import torch.nn as nn from huggingface_hub import hf_hub_download from moshi.models import loaders as moshi_loaders # Config (from ms_lstm_mimi-25hz-nq8_delay1f training run) MODEL_DEFAULTS = { "input_size": 512, # Mimi feat_size "hidden_size": 512, "num_layers": 2, "output_size": 5, "dropout": 0.1, "bidirectional": False, "project": 128, # projects 512 → 128 before LSTMs "num_project": 1, } AUDIO_DEFAULTS = { "sr": 24000, # Mimi native sample rate "num_quantizers": 8, "frame_rate_hz": 25.0, # after ×2 upsample from Mimi's native 12.5Hz } class MimiLSTM(nn.Module): """Two-stream LSTM over Mimi embeddings. Input x: (batch, 2, feat_size, T) — speaker_1 at index 0, speaker_2 at index 1. Output: (batch, T, output_size) """ def __init__(self, input_size, hidden_size, num_layers, output_size, dropout, bidirectional, project, **_): super().__init__() self.mel_embed = nn.Sequential( nn.Linear(input_size, project), nn.ReLU(), ) lstm_kwargs = dict( input_size=project, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout, bidirectional=bidirectional, ) self.model1 = nn.LSTM(**lstm_kwargs) self.model2 = nn.LSTM(**lstm_kwargs) self.linear = nn.Linear(hidden_size * 2, output_size) def init_hidden(self, batch_size, device): h = torch.zeros(self.model1.num_layers, batch_size, self.model1.hidden_size).to(device) c = torch.zeros(self.model1.num_layers, batch_size, self.model1.hidden_size).to(device) return h, c def infer(self, x): """Full-sequence inference. x: (1, 2, feat, T).""" x = x.permute(0, 1, 3, 2) # (1, 2, T, feat) x = self.mel_embed(x) # (1, 2, T, project) h, c = self.init_hidden(x.size(0), x.device) x1, _ = self.model1(x[:, 0], (h, c)) x2, _ = self.model2(x[:, 1], (h, c)) return self.linear(torch.cat([x1, x2], dim=-1)) # (1, T, output_size) def infer_ar_step(self, feat1, feat2, h1, c1, h2, c2): """Single-frame AR step. feat1, feat2: (1, feat_size) — one frame per speaker. Returns: logits (1, output_size), updated (h1, c1, h2, c2). """ f1 = self.mel_embed(feat1).unsqueeze(1) # (1, 1, project) f2 = self.mel_embed(feat2).unsqueeze(1) out1, (h1, c1) = self.model1(f1, (h1, c1)) # (1, 1, hidden) out2, (h2, c2) = self.model2(f2, (h2, c2)) logits = self.linear(torch.cat([out1, out2], dim=-1)).squeeze(1) # (1, output_size) return logits, h1, c1, h2, c2 def infer_ar(self, x): """Frame-by-frame AR inference. x: (1, 2, feat, T). Equivalent to infer() for a unidirectional LSTM but simulates real-time decoding. """ x = x.permute(0, 1, 3, 2) # (1, 2, T, feat) x = self.mel_embed(x) # (1, 2, T, project) T = x.size(2) h1, c1 = self.init_hidden(x.size(0), x.device) h2, c2 = self.init_hidden(x.size(0), x.device) outputs = [] for t in range(T): f1 = x[:, 0, t:t+1, :] # (1, 1, project) f2 = x[:, 1, t:t+1, :] out1, (h1, c1) = self.model1(f1, (h1, c1)) out2, (h2, c2) = self.model2(f2, (h2, c2)) outputs.append(self.linear(torch.cat([out1, out2], dim=-1))) return torch.cat(outputs, dim=1) # (1, T, output_size) class AudioFeatureExtractor: """Extracts Mimi embeddings from raw waveform using the moshi library.""" def __init__(self, sr, num_quantizers, device="cuda", **_): mimi_weight = hf_hub_download(moshi_loaders.DEFAULT_REPO, moshi_loaders.MIMI_NAME) self.mimi = moshi_loaders.get_mimi(mimi_weight, device=device, num_codebooks=num_quantizers) self.mimi.eval() self.sr = sr self.device = device def _to_tensor(self, wav): """1D numpy array → (1, 1, T) tensor on device.""" return torch.from_numpy(wav).float().to(self.device).unsqueeze(0).unsqueeze(0) @torch.no_grad() def __call__(self, wav): """Full-sequence. wav: 1D numpy array. Returns (1, feat_size, T) at 25Hz.""" x = self._to_tensor(wav) emb = self.mimi.encode_to_latent(x, quantize=True) # (1, feat, T) at 12.5Hz return self.mimi.upsample(emb) # (1, feat, T) at 25Hz @torch.no_grad() def stream(self, wav1, wav2): """Streaming both speakers batched together, one Mimi frame (1920 samples) at a time. Yields (1, feat_size, frames) for speaker1 and speaker2 per chunk. Single streaming context so KV cache is shared and there's no double-enter conflict. """ chunk_samples = self.mimi.frame_size # 1920 def pad(wav): t = torch.from_numpy(wav).float().to(self.device) r = len(t) % chunk_samples return torch.nn.functional.pad(t, (0, chunk_samples - r if r else 0)) w1, w2 = pad(wav1), pad(wav2) n_chunks = len(w1) // chunk_samples with self.mimi.streaming(batch_size=2): for i in range(n_chunks): s = i * chunk_samples chunk = torch.stack([w1[s:s+chunk_samples], w2[s:s+chunk_samples]]).unsqueeze(1) # (2, 1, 1920) emb = self.mimi.encode_to_latent(chunk, quantize=True) # (2, feat, 1) emb = self.mimi.upsample(emb) # (2, feat, 2) yield emb[0:1], emb[1:2] # (1, feat, 2) each def load_model(checkpoint_path, device="cpu"): model = MimiLSTM(**MODEL_DEFAULTS) ckpt = torch.load(checkpoint_path, map_location=device) model.load_state_dict(ckpt["model_state_dict"]) model.eval() return model.to(device) CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"] if __name__ == "__main__": import argparse import numpy as np import soundfile as sf import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default=None) parser.add_argument("--sample_dir", required=True) parser.add_argument("--ar", action="store_true", help="Streaming Mimi (chunk-by-chunk) + AR LSTM step-by-step.") parser.add_argument("--out", default="eval_output.png") args = parser.parse_args() import torchaudio import numpy as np import soundfile as sf import matplotlib.pyplot as plt from pathlib import Path device = "cuda" if torch.cuda.is_available() else "cpu" if args.checkpoint: model = load_model(args.checkpoint, device=device) print(f"Loaded: {args.checkpoint}") else: model = MimiLSTM(**MODEL_DEFAULTS).to(device).eval() print("No checkpoint — using random weights.") extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device) target_sr = AUDIO_DEFAULTS["sr"] d = Path(args.sample_dir) wav1, sr1 = sf.read(d / "speaker_1_audio.wav", dtype="float32") wav2, sr2 = sf.read(d / "speaker_2_audio.wav", dtype="float32") if sr1 != target_sr: wav1 = torchaudio.functional.resample(torch.from_numpy(wav1), sr1, target_sr).numpy() if sr2 != target_sr: wav2 = torchaudio.functional.resample(torch.from_numpy(wav2), sr2, target_sr).numpy() if args.ar: # streaming Mimi chunk-by-chunk + AR LSTM step-by-step print("Running streaming Mimi + AR LSTM...") from tqdm import tqdm h1, c1 = model.init_hidden(1, device) h2, c2 = model.init_hidden(1, device) all_logits = [] n_chunks = (len(wav1) + 1919) // 1920 with torch.no_grad(): for feat1_chunk, feat2_chunk in tqdm( extractor.stream(wav1, wav2), total=n_chunks, unit="chunk", desc="Streaming" ): feat1_chunk, feat2_chunk = feat1_chunk.to(device), feat2_chunk.to(device) for t in range(feat1_chunk.shape[-1]): logits, h1, c1, h2, c2 = model.infer_ar_step( feat1_chunk[:, :, t], feat2_chunk[:, :, t], h1, c1, h2, c2 ) all_logits.append(logits) out = torch.stack(all_logits, dim=1) # (1, T, output_size) else: feat1 = extractor(wav1) # (1, feat, T) feat2 = extractor(wav2) T = min(feat1.shape[-1], feat2.shape[-1]) x = torch.cat([feat1[:, :, :T], feat2[:, :, :T]], dim=0).unsqueeze(0) # (1, 2, feat, T) with torch.no_grad(): out = model.infer(x.to(device)) T = out.shape[1] print(f"Audio: {len(wav1)/target_sr:.1f}s, {T} frames @ {AUDIO_DEFAULTS['frame_rate_hz']} Hz") probs = torch.softmax(out[0], dim=-1).cpu().numpy() # (T, 5) frame_times = np.arange(T) / AUDIO_DEFAULTS["frame_rate_hz"] wav_times = np.arange(len(wav1)) / target_sr duration = len(wav1) / target_sr fig_width = max(28, int(duration * 0.2)) # ~0.2 inches per second fig, (ax_wav, ax_pred) = plt.subplots( 2, 1, figsize=(fig_width, 6), gridspec_kw={"hspace": 0.08, "height_ratios": [1, 1]}, ) ax_wav.plot(wav_times, wav1, linewidth=0.3, color="steelblue", alpha=0.7, label="Speaker 1") ax_wav.plot(wav_times, wav2, linewidth=0.3, color="darkorange", alpha=0.7, label="Speaker 2") ax_wav.set_ylabel("Amplitude") ax_wav.set_xlim(wav_times[0], wav_times[-1]) ax_wav.legend(loc="upper right", fontsize=8) ax_wav.set_xticklabels([]) for i, name in enumerate(CLASS_NAMES): ax_pred.plot(frame_times, probs[:, i], label=name, linewidth=1.0) ax_pred.set_ylabel("Softmax probability") ax_pred.set_xlabel("Time (s)") ax_pred.set_xlim(frame_times[0], frame_times[-1]) ax_pred.set_ylim(0, 1) ax_pred.legend(loc="upper right", fontsize=8) plt.savefig(args.out, dpi=150, bbox_inches="tight") print(f"Saved: {args.out}")