compare-codec / compare_codec /encodec_codec.py
twangodev's picture
feat: add Reconstruct button and optimize codec processing logic
458a0e7 verified
"""EnCodec (Meta) — wraps the HuggingFace transformers implementation."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import torch
import torchaudio
from compare_codec import CodecConfig, register
_BANDWIDTHS = [1.5, 3.0, 6.0, 12.0, 24.0]
class EnCodecCodec:
"""EnCodec 24kHz codec with lazy model loading."""
def __init__(self) -> None:
self._model = None
self._processor = None
@property
def name(self) -> str:
return "EnCodec"
@property
def sample_rate(self) -> int:
return 24_000
def configs(self) -> list[CodecConfig]:
return [
CodecConfig(
name=f"{bw:g} kbps",
params={"bandwidth": bw, "sample_rate": 24_000},
)
for bw in _BANDWIDTHS
]
def _load(self):
if self._model is None:
from transformers import AutoProcessor, EncodecModel
self._model = EncodecModel.from_pretrained(
"facebook/encodec_24khz", device_map="auto"
)
self._model.eval()
self._processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
@torch.no_grad()
def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
self._load()
bandwidth: float = config.params["bandwidth"]
target_sr: int = config.params["sample_rate"]
wav, sr = torchaudio.load(str(audio_path))
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
inputs = self._processor(
raw_audio=wav.squeeze(0).numpy(),
sampling_rate=target_sr,
return_tensors="pt",
)
device = self._model.device
inputs = {k: v.to(device) for k, v in inputs.items()}
enc = self._model.encode(
inputs["input_values"],
inputs["padding_mask"],
bandwidth=bandwidth,
)
audio_out = self._model.decode(
enc.audio_codes,
enc.audio_scales,
padding_mask=inputs["padding_mask"],
)[0]
return audio_out.squeeze(0).squeeze(0).cpu().numpy()
register(EnCodecCodec())