"""SNAC (Multi-Scale Neural Audio Codec) — wraps the snac package.""" from __future__ import annotations from pathlib import Path import numpy as np import torch import torchaudio from compare_codec import CodecConfig, register _device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) _MODELS = [ ("hubertsiuzdak/snac_24khz", 24_000), ("hubertsiuzdak/snac_32khz", 32_000), ("hubertsiuzdak/snac_44khz", 44_100), ] class SNACCodec: """SNAC codec with lazy model loading.""" def __init__(self) -> None: self._models: dict[str, object] = {} @property def name(self) -> str: return "SNAC" @property def sample_rate(self) -> int: return 24_000 def configs(self) -> list[CodecConfig]: configs = [] for model_id, sr in _MODELS: label = f"{sr // 1000}kHz" configs.append( CodecConfig( name=label, params={"model_id": model_id, "sample_rate": sr}, ) ) return configs def _get_model(self, model_id: str) -> object: if model_id not in self._models: from snac import SNAC model = SNAC.from_pretrained(model_id).to(_device) self._models[model_id] = model return self._models[model_id] @torch.no_grad() def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray: model_id: str = config.params["model_id"] target_sr: int = config.params["sample_rate"] model = self._get_model(model_id) wav, sr = torchaudio.load(str(audio_path)) # Mix to mono if needed. if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) # Resample if needed. if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) # SNAC expects (B, 1, T). wav = wav.unsqueeze(0).to(_device) audio_hat, _ = model(wav) return audio_hat.squeeze(0).squeeze(0).cpu().numpy() register(SNACCodec())