| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
| from moshi.models import loaders as moshi_loaders |
|
|
| |
| MODEL_DEFAULTS = { |
| "input_size": 512, |
| "hidden_size": 512, |
| "num_layers": 2, |
| "output_size": 5, |
| "dropout": 0.1, |
| "bidirectional": False, |
| "project": 128, |
| "num_project": 1, |
| } |
|
|
|
|
| AUDIO_DEFAULTS = { |
| "sr": 24000, |
| "num_quantizers": 8, |
| "frame_rate_hz": 25.0, |
| } |
|
|
|
|
| class MimiLSTM(nn.Module): |
| """Two-stream LSTM over Mimi embeddings. |
| |
| Input x: (batch, 2, feat_size, T) — speaker_1 at index 0, speaker_2 at index 1. |
| Output: (batch, T, output_size) |
| """ |
| def __init__(self, input_size, hidden_size, num_layers, output_size, |
| dropout, bidirectional, project, **_): |
| super().__init__() |
| self.mel_embed = nn.Sequential( |
| nn.Linear(input_size, project), |
| nn.ReLU(), |
| ) |
| lstm_kwargs = dict( |
| input_size=project, |
| hidden_size=hidden_size, |
| num_layers=num_layers, |
| batch_first=True, |
| dropout=dropout, |
| bidirectional=bidirectional, |
| ) |
| self.model1 = nn.LSTM(**lstm_kwargs) |
| self.model2 = nn.LSTM(**lstm_kwargs) |
| self.linear = nn.Linear(hidden_size * 2, output_size) |
|
|
| def init_hidden(self, batch_size, device): |
| h = torch.zeros(self.model1.num_layers, batch_size, self.model1.hidden_size).to(device) |
| c = torch.zeros(self.model1.num_layers, batch_size, self.model1.hidden_size).to(device) |
| return h, c |
|
|
| def infer(self, x): |
| """Full-sequence inference. x: (1, 2, feat, T).""" |
| x = x.permute(0, 1, 3, 2) |
| x = self.mel_embed(x) |
| h, c = self.init_hidden(x.size(0), x.device) |
| x1, _ = self.model1(x[:, 0], (h, c)) |
| x2, _ = self.model2(x[:, 1], (h, c)) |
| return self.linear(torch.cat([x1, x2], dim=-1)) |
|
|
| def infer_ar_step(self, feat1, feat2, h1, c1, h2, c2): |
| """Single-frame AR step. |
| |
| feat1, feat2: (1, feat_size) — one frame per speaker. |
| Returns: logits (1, output_size), updated (h1, c1, h2, c2). |
| """ |
| f1 = self.mel_embed(feat1).unsqueeze(1) |
| f2 = self.mel_embed(feat2).unsqueeze(1) |
| out1, (h1, c1) = self.model1(f1, (h1, c1)) |
| out2, (h2, c2) = self.model2(f2, (h2, c2)) |
| logits = self.linear(torch.cat([out1, out2], dim=-1)).squeeze(1) |
| return logits, h1, c1, h2, c2 |
|
|
| def infer_ar(self, x): |
| """Frame-by-frame AR inference. x: (1, 2, feat, T). |
| Equivalent to infer() for a unidirectional LSTM but simulates real-time decoding. |
| """ |
| x = x.permute(0, 1, 3, 2) |
| x = self.mel_embed(x) |
| T = x.size(2) |
| h1, c1 = self.init_hidden(x.size(0), x.device) |
| h2, c2 = self.init_hidden(x.size(0), x.device) |
| outputs = [] |
| for t in range(T): |
| f1 = x[:, 0, t:t+1, :] |
| f2 = x[:, 1, t:t+1, :] |
| out1, (h1, c1) = self.model1(f1, (h1, c1)) |
| out2, (h2, c2) = self.model2(f2, (h2, c2)) |
| outputs.append(self.linear(torch.cat([out1, out2], dim=-1))) |
| return torch.cat(outputs, dim=1) |
|
|
|
|
| class AudioFeatureExtractor: |
| """Extracts Mimi embeddings from raw waveform using the moshi library.""" |
|
|
| def __init__(self, sr, num_quantizers, device="cuda", **_): |
| mimi_weight = hf_hub_download(moshi_loaders.DEFAULT_REPO, moshi_loaders.MIMI_NAME) |
| self.mimi = moshi_loaders.get_mimi(mimi_weight, device=device, num_codebooks=num_quantizers) |
| self.mimi.eval() |
| self.sr = sr |
| self.device = device |
|
|
| def _to_tensor(self, wav): |
| """1D numpy array → (1, 1, T) tensor on device.""" |
| return torch.from_numpy(wav).float().to(self.device).unsqueeze(0).unsqueeze(0) |
|
|
| @torch.no_grad() |
| def __call__(self, wav): |
| """Full-sequence. wav: 1D numpy array. Returns (1, feat_size, T) at 25Hz.""" |
| x = self._to_tensor(wav) |
| emb = self.mimi.encode_to_latent(x, quantize=True) |
| return self.mimi.upsample(emb) |
|
|
| @torch.no_grad() |
| def stream(self, wav1, wav2): |
| """Streaming both speakers batched together, one Mimi frame (1920 samples) at a time. |
| Yields (1, feat_size, frames) for speaker1 and speaker2 per chunk. |
| Single streaming context so KV cache is shared and there's no double-enter conflict. |
| """ |
| chunk_samples = self.mimi.frame_size |
|
|
| def pad(wav): |
| t = torch.from_numpy(wav).float().to(self.device) |
| r = len(t) % chunk_samples |
| return torch.nn.functional.pad(t, (0, chunk_samples - r if r else 0)) |
|
|
| w1, w2 = pad(wav1), pad(wav2) |
| n_chunks = len(w1) // chunk_samples |
|
|
| with self.mimi.streaming(batch_size=2): |
| for i in range(n_chunks): |
| s = i * chunk_samples |
| chunk = torch.stack([w1[s:s+chunk_samples], w2[s:s+chunk_samples]]).unsqueeze(1) |
| emb = self.mimi.encode_to_latent(chunk, quantize=True) |
| emb = self.mimi.upsample(emb) |
| yield emb[0:1], emb[1:2] |
|
|
|
|
| def load_model(checkpoint_path, device="cpu"): |
| model = MimiLSTM(**MODEL_DEFAULTS) |
| ckpt = torch.load(checkpoint_path, map_location=device) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| model.eval() |
| return model.to(device) |
|
|
|
|
| CLASS_NAMES = ["bos", "system_end", "user_end", "system", "user"] |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| import numpy as np |
| import soundfile as sf |
| import matplotlib.pyplot as plt |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", default=None) |
| parser.add_argument("--sample_dir", required=True) |
| parser.add_argument("--ar", action="store_true", help="Streaming Mimi (chunk-by-chunk) + AR LSTM step-by-step.") |
| parser.add_argument("--out", default="eval_output.png") |
| args = parser.parse_args() |
|
|
| import torchaudio |
| import numpy as np |
| import soundfile as sf |
| import matplotlib.pyplot as plt |
| from pathlib import Path |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if args.checkpoint: |
| model = load_model(args.checkpoint, device=device) |
| print(f"Loaded: {args.checkpoint}") |
| else: |
| model = MimiLSTM(**MODEL_DEFAULTS).to(device).eval() |
| print("No checkpoint — using random weights.") |
|
|
| extractor = AudioFeatureExtractor(**AUDIO_DEFAULTS, device=device) |
| target_sr = AUDIO_DEFAULTS["sr"] |
| d = Path(args.sample_dir) |
|
|
| wav1, sr1 = sf.read(d / "speaker_1_audio.wav", dtype="float32") |
| wav2, sr2 = sf.read(d / "speaker_2_audio.wav", dtype="float32") |
| if sr1 != target_sr: |
| wav1 = torchaudio.functional.resample(torch.from_numpy(wav1), sr1, target_sr).numpy() |
| if sr2 != target_sr: |
| wav2 = torchaudio.functional.resample(torch.from_numpy(wav2), sr2, target_sr).numpy() |
|
|
| if args.ar: |
| |
| print("Running streaming Mimi + AR LSTM...") |
| from tqdm import tqdm |
| h1, c1 = model.init_hidden(1, device) |
| h2, c2 = model.init_hidden(1, device) |
| all_logits = [] |
| n_chunks = (len(wav1) + 1919) // 1920 |
| with torch.no_grad(): |
| for feat1_chunk, feat2_chunk in tqdm( |
| extractor.stream(wav1, wav2), |
| total=n_chunks, unit="chunk", desc="Streaming" |
| ): |
| feat1_chunk, feat2_chunk = feat1_chunk.to(device), feat2_chunk.to(device) |
| for t in range(feat1_chunk.shape[-1]): |
| logits, h1, c1, h2, c2 = model.infer_ar_step( |
| feat1_chunk[:, :, t], feat2_chunk[:, :, t], h1, c1, h2, c2 |
| ) |
| all_logits.append(logits) |
| out = torch.stack(all_logits, dim=1) |
| else: |
| feat1 = extractor(wav1) |
| feat2 = extractor(wav2) |
| T = min(feat1.shape[-1], feat2.shape[-1]) |
| x = torch.cat([feat1[:, :, :T], feat2[:, :, :T]], dim=0).unsqueeze(0) |
| with torch.no_grad(): |
| out = model.infer(x.to(device)) |
|
|
| T = out.shape[1] |
| print(f"Audio: {len(wav1)/target_sr:.1f}s, {T} frames @ {AUDIO_DEFAULTS['frame_rate_hz']} Hz") |
| probs = torch.softmax(out[0], dim=-1).cpu().numpy() |
|
|
| frame_times = np.arange(T) / AUDIO_DEFAULTS["frame_rate_hz"] |
| wav_times = np.arange(len(wav1)) / target_sr |
|
|
| duration = len(wav1) / target_sr |
| fig_width = max(28, int(duration * 0.2)) |
| fig, (ax_wav, ax_pred) = plt.subplots( |
| 2, 1, figsize=(fig_width, 6), |
| gridspec_kw={"hspace": 0.08, "height_ratios": [1, 1]}, |
| ) |
|
|
| ax_wav.plot(wav_times, wav1, linewidth=0.3, color="steelblue", alpha=0.7, label="Speaker 1") |
| ax_wav.plot(wav_times, wav2, linewidth=0.3, color="darkorange", alpha=0.7, label="Speaker 2") |
| ax_wav.set_ylabel("Amplitude") |
| ax_wav.set_xlim(wav_times[0], wav_times[-1]) |
| ax_wav.legend(loc="upper right", fontsize=8) |
| ax_wav.set_xticklabels([]) |
|
|
| for i, name in enumerate(CLASS_NAMES): |
| ax_pred.plot(frame_times, probs[:, i], label=name, linewidth=1.0) |
| ax_pred.set_ylabel("Softmax probability") |
| ax_pred.set_xlabel("Time (s)") |
| ax_pred.set_xlim(frame_times[0], frame_times[-1]) |
| ax_pred.set_ylim(0, 1) |
| ax_pred.legend(loc="upper right", fontsize=8) |
|
|
| plt.savefig(args.out, dpi=150, bbox_inches="tight") |
| print(f"Saved: {args.out}") |
|
|