File size: 2,161 Bytes
287431b
 
 
 
 
 
 
 
 
 
 
 
458a0e7
 
 
 
 
 
 
 
287431b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458a0e7
287431b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458a0e7
287431b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""SNAC (Multi-Scale Neural Audio Codec) — wraps the snac package."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import torch
import torchaudio

from compare_codec import CodecConfig, register

_device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

_MODELS = [
    ("hubertsiuzdak/snac_24khz", 24_000),
    ("hubertsiuzdak/snac_32khz", 32_000),
    ("hubertsiuzdak/snac_44khz", 44_100),
]


class SNACCodec:
    """SNAC codec with lazy model loading."""

    def __init__(self) -> None:
        self._models: dict[str, object] = {}

    @property
    def name(self) -> str:
        return "SNAC"

    @property
    def sample_rate(self) -> int:
        return 24_000

    def configs(self) -> list[CodecConfig]:
        configs = []
        for model_id, sr in _MODELS:
            label = f"{sr // 1000}kHz"
            configs.append(
                CodecConfig(
                    name=label,
                    params={"model_id": model_id, "sample_rate": sr},
                )
            )
        return configs

    def _get_model(self, model_id: str) -> object:
        if model_id not in self._models:
            from snac import SNAC

            model = SNAC.from_pretrained(model_id).to(_device)
            self._models[model_id] = model
        return self._models[model_id]

    @torch.no_grad()
    def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
        model_id: str = config.params["model_id"]
        target_sr: int = config.params["sample_rate"]

        model = self._get_model(model_id)

        wav, sr = torchaudio.load(str(audio_path))
        # Mix to mono if needed.
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        # Resample if needed.
        if sr != target_sr:
            wav = torchaudio.functional.resample(wav, sr, target_sr)
        # SNAC expects (B, 1, T).
        wav = wav.unsqueeze(0).to(_device)

        audio_hat, _ = model(wav)

        return audio_hat.squeeze(0).squeeze(0).cpu().numpy()


register(SNACCodec())