Spaces:
Sleeping
Sleeping
| """Mimi (Kyutai) — wraps the HuggingFace transformers implementation.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from compare_codec import CodecConfig, register | |
| class MimiCodec: | |
| """Mimi codec with lazy model loading.""" | |
| def __init__(self) -> None: | |
| self._model = None | |
| self._fe = None | |
| def name(self) -> str: | |
| return "Mimi" | |
| def sample_rate(self) -> int: | |
| return 24_000 | |
| def configs(self) -> list[CodecConfig]: | |
| return [ | |
| CodecConfig( | |
| name="1.1 kbps", | |
| params={"sample_rate": 24_000}, | |
| ) | |
| ] | |
| def _load(self): | |
| if self._model is None: | |
| from transformers import AutoFeatureExtractor, MimiModel | |
| self._model = MimiModel.from_pretrained("kyutai/mimi", device_map="auto") | |
| self._model.eval() | |
| self._fe = AutoFeatureExtractor.from_pretrained("kyutai/mimi") | |
| def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray: | |
| self._load() | |
| target_sr: int = config.params["sample_rate"] | |
| wav, sr = torchaudio.load(str(audio_path)) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| if sr != target_sr: | |
| wav = torchaudio.functional.resample(wav, sr, target_sr) | |
| original_len = wav.shape[-1] | |
| inputs = self._fe( | |
| raw_audio=wav.squeeze(0).numpy(), | |
| sampling_rate=target_sr, | |
| return_tensors="pt", | |
| ) | |
| device = self._model.device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| enc = self._model.encode(inputs["input_values"], inputs["padding_mask"]) | |
| audio_out = self._model.decode(enc.audio_codes, inputs["padding_mask"])[0] | |
| # Trim to original length (Mimi may pad). | |
| audio_out = audio_out.squeeze(0).squeeze(0).cpu().numpy()[:original_len] | |
| return audio_out | |
| register(MimiCodec()) | |