Spaces:
Sleeping
Sleeping
File size: 2,068 Bytes
f5b74a2 458a0e7 f5b74a2 458a0e7 f5b74a2 458a0e7 f5b74a2 458a0e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | """Mimi (Kyutai) — wraps the HuggingFace transformers implementation."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import torchaudio
from compare_codec import CodecConfig, register
class MimiCodec:
"""Mimi codec with lazy model loading."""
def __init__(self) -> None:
self._model = None
self._fe = None
@property
def name(self) -> str:
return "Mimi"
@property
def sample_rate(self) -> int:
return 24_000
def configs(self) -> list[CodecConfig]:
return [
CodecConfig(
name="1.1 kbps",
params={"sample_rate": 24_000},
)
]
def _load(self):
if self._model is None:
from transformers import AutoFeatureExtractor, MimiModel
self._model = MimiModel.from_pretrained("kyutai/mimi", device_map="auto")
self._model.eval()
self._fe = AutoFeatureExtractor.from_pretrained("kyutai/mimi")
@torch.no_grad()
def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
self._load()
target_sr: int = config.params["sample_rate"]
wav, sr = torchaudio.load(str(audio_path))
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
original_len = wav.shape[-1]
inputs = self._fe(
raw_audio=wav.squeeze(0).numpy(),
sampling_rate=target_sr,
return_tensors="pt",
)
device = self._model.device
inputs = {k: v.to(device) for k, v in inputs.items()}
enc = self._model.encode(inputs["input_values"], inputs["padding_mask"])
audio_out = self._model.decode(enc.audio_codes, inputs["padding_mask"])[0]
# Trim to original length (Mimi may pad).
audio_out = audio_out.squeeze(0).squeeze(0).cpu().numpy()[:original_len]
return audio_out
register(MimiCodec())
|