Spaces:
Sleeping
Sleeping
| """SNAC (Multi-Scale Neural Audio Codec) — wraps the snac package.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from compare_codec import CodecConfig, register | |
| _device = torch.device( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else "mps" | |
| if torch.backends.mps.is_available() | |
| else "cpu" | |
| ) | |
| _MODELS = [ | |
| ("hubertsiuzdak/snac_24khz", 24_000), | |
| ("hubertsiuzdak/snac_32khz", 32_000), | |
| ("hubertsiuzdak/snac_44khz", 44_100), | |
| ] | |
| class SNACCodec: | |
| """SNAC codec with lazy model loading.""" | |
| def __init__(self) -> None: | |
| self._models: dict[str, object] = {} | |
| def name(self) -> str: | |
| return "SNAC" | |
| def sample_rate(self) -> int: | |
| return 24_000 | |
| def configs(self) -> list[CodecConfig]: | |
| configs = [] | |
| for model_id, sr in _MODELS: | |
| label = f"{sr // 1000}kHz" | |
| configs.append( | |
| CodecConfig( | |
| name=label, | |
| params={"model_id": model_id, "sample_rate": sr}, | |
| ) | |
| ) | |
| return configs | |
| def _get_model(self, model_id: str) -> object: | |
| if model_id not in self._models: | |
| from snac import SNAC | |
| model = SNAC.from_pretrained(model_id).to(_device) | |
| self._models[model_id] = model | |
| return self._models[model_id] | |
| def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray: | |
| model_id: str = config.params["model_id"] | |
| target_sr: int = config.params["sample_rate"] | |
| model = self._get_model(model_id) | |
| wav, sr = torchaudio.load(str(audio_path)) | |
| # Mix to mono if needed. | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| # Resample if needed. | |
| if sr != target_sr: | |
| wav = torchaudio.functional.resample(wav, sr, target_sr) | |
| # SNAC expects (B, 1, T). | |
| wav = wav.unsqueeze(0).to(_device) | |
| audio_hat, _ = model(wav) | |
| return audio_hat.squeeze(0).squeeze(0).cpu().numpy() | |
| register(SNACCodec()) | |