"""Mimi (Kyutai) — wraps the HuggingFace transformers implementation.""" from __future__ import annotations from pathlib import Path import numpy as np import torch import torchaudio from compare_codec import CodecConfig, register class MimiCodec: """Mimi codec with lazy model loading.""" def __init__(self) -> None: self._model = None self._fe = None @property def name(self) -> str: return "Mimi" @property def sample_rate(self) -> int: return 24_000 def configs(self) -> list[CodecConfig]: return [ CodecConfig( name="1.1 kbps", params={"sample_rate": 24_000}, ) ] def _load(self): if self._model is None: from transformers import AutoFeatureExtractor, MimiModel self._model = MimiModel.from_pretrained("kyutai/mimi", device_map="auto") self._model.eval() self._fe = AutoFeatureExtractor.from_pretrained("kyutai/mimi") @torch.no_grad() def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray: self._load() target_sr: int = config.params["sample_rate"] wav, sr = torchaudio.load(str(audio_path)) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) original_len = wav.shape[-1] inputs = self._fe( raw_audio=wav.squeeze(0).numpy(), sampling_rate=target_sr, return_tensors="pt", ) device = self._model.device inputs = {k: v.to(device) for k, v in inputs.items()} enc = self._model.encode(inputs["input_values"], inputs["padding_mask"]) audio_out = self._model.decode(enc.audio_codes, inputs["padding_mask"])[0] # Trim to original length (Mimi may pad). audio_out = audio_out.squeeze(0).squeeze(0).cpu().numpy()[:original_len] return audio_out register(MimiCodec())