twangodev's picture
feat: add Reconstruct button and optimize codec processing logic
458a0e7 verified
"""DAC (Descript Audio Codec) — wraps the descript-audio-codec package."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
from compare_codec import CodecConfig, register
_device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
class DACCodec:
"""DAC codec with lazy model loading."""
def __init__(self) -> None:
self._models: dict[str, object] = {}
@property
def name(self) -> str:
return "DAC"
@property
def sample_rate(self) -> int:
return 44_100
def configs(self) -> list[CodecConfig]:
configs = []
for model_type, sr, max_nq in [
("44khz", 44_100, 9),
("24khz", 24_000, 9),
("16khz", 16_000, 9),
]:
for nq in (max_nq, 6, 4, 2):
configs.append(
CodecConfig(
name=f"{model_type} / {nq} quantizers",
params={
"model_type": model_type,
"n_quantizers": nq,
"sample_rate": sr,
},
)
)
return configs
def _get_model(self, model_type: str) -> object:
if model_type not in self._models:
import dac as _dac
model_path = _dac.utils.download(model_type=model_type)
model = _dac.DAC.load(model_path)
model.eval().to(_device)
self._models[model_type] = model
return self._models[model_type]
@torch.no_grad()
def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
from audiotools import AudioSignal
model_type: str = config.params["model_type"]
n_quantizers: int = config.params["n_quantizers"]
target_sr: int = config.params["sample_rate"]
model = self._get_model(model_type)
signal = AudioSignal(str(audio_path))
if signal.audio_data.shape[1] > 1:
signal.audio_data = signal.audio_data.mean(dim=1, keepdim=True)
if signal.sample_rate != target_sr:
signal = signal.resample(target_sr)
signal = signal.to(model.device)
x = model.preprocess(signal.audio_data, signal.sample_rate)
z, codes, latents, _, _ = model.encode(x, n_quantizers=n_quantizers)
y = model.decode(z)
return y.squeeze(0).squeeze(0).cpu().numpy()
register(DACCodec())