compare-codec / compare_codec /snac_codec.py
twangodev's picture
feat: add Reconstruct button and optimize codec processing logic
458a0e7 verified
"""SNAC (Multi-Scale Neural Audio Codec) — wraps the snac package."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import torchaudio
from compare_codec import CodecConfig, register
_device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
_MODELS = [
("hubertsiuzdak/snac_24khz", 24_000),
("hubertsiuzdak/snac_32khz", 32_000),
("hubertsiuzdak/snac_44khz", 44_100),
]
class SNACCodec:
"""SNAC codec with lazy model loading."""
def __init__(self) -> None:
self._models: dict[str, object] = {}
@property
def name(self) -> str:
return "SNAC"
@property
def sample_rate(self) -> int:
return 24_000
def configs(self) -> list[CodecConfig]:
configs = []
for model_id, sr in _MODELS:
label = f"{sr // 1000}kHz"
configs.append(
CodecConfig(
name=label,
params={"model_id": model_id, "sample_rate": sr},
)
)
return configs
def _get_model(self, model_id: str) -> object:
if model_id not in self._models:
from snac import SNAC
model = SNAC.from_pretrained(model_id).to(_device)
self._models[model_id] = model
return self._models[model_id]
@torch.no_grad()
def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
model_id: str = config.params["model_id"]
target_sr: int = config.params["sample_rate"]
model = self._get_model(model_id)
wav, sr = torchaudio.load(str(audio_path))
# Mix to mono if needed.
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
# Resample if needed.
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
# SNAC expects (B, 1, T).
wav = wav.unsqueeze(0).to(_device)
audio_hat, _ = model(wav)
return audio_hat.squeeze(0).squeeze(0).cpu().numpy()
register(SNACCodec())