"""EnCodec (Meta) — wraps the HuggingFace transformers implementation.""" from __future__ import annotations from pathlib import Path import numpy as np import torch import torchaudio from compare_codec import CodecConfig, register _BANDWIDTHS = [1.5, 3.0, 6.0, 12.0, 24.0] class EnCodecCodec: """EnCodec 24kHz codec with lazy model loading.""" def __init__(self) -> None: self._model = None self._processor = None @property def name(self) -> str: return "EnCodec" @property def sample_rate(self) -> int: return 24_000 def configs(self) -> list[CodecConfig]: return [ CodecConfig( name=f"{bw:g} kbps", params={"bandwidth": bw, "sample_rate": 24_000}, ) for bw in _BANDWIDTHS ] def _load(self): if self._model is None: from transformers import AutoProcessor, EncodecModel self._model = EncodecModel.from_pretrained( "facebook/encodec_24khz", device_map="auto" ) self._model.eval() self._processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") @torch.no_grad() def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray: self._load() bandwidth: float = config.params["bandwidth"] target_sr: int = config.params["sample_rate"] wav, sr = torchaudio.load(str(audio_path)) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) inputs = self._processor( raw_audio=wav.squeeze(0).numpy(), sampling_rate=target_sr, return_tensors="pt", ) device = self._model.device inputs = {k: v.to(device) for k, v in inputs.items()} enc = self._model.encode( inputs["input_values"], inputs["padding_mask"], bandwidth=bandwidth, ) audio_out = self._model.decode( enc.audio_codes, enc.audio_scales, padding_mask=inputs["padding_mask"], )[0] return audio_out.squeeze(0).squeeze(0).cpu().numpy() register(EnCodecCodec())