File size: 2,573 Bytes
1df078a
 
 
 
 
 
 
 
 
 
 
458a0e7
 
 
 
 
 
 
 
1df078a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458a0e7
1df078a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""DAC (Descript Audio Codec) — wraps the descript-audio-codec package."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import torch

from compare_codec import CodecConfig, register

_device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)


class DACCodec:
    """DAC codec with lazy model loading."""

    def __init__(self) -> None:
        self._models: dict[str, object] = {}

    @property
    def name(self) -> str:
        return "DAC"

    @property
    def sample_rate(self) -> int:
        return 44_100

    def configs(self) -> list[CodecConfig]:
        configs = []
        for model_type, sr, max_nq in [
            ("44khz", 44_100, 9),
            ("24khz", 24_000, 9),
            ("16khz", 16_000, 9),
        ]:
            for nq in (max_nq, 6, 4, 2):
                configs.append(
                    CodecConfig(
                        name=f"{model_type} / {nq} quantizers",
                        params={
                            "model_type": model_type,
                            "n_quantizers": nq,
                            "sample_rate": sr,
                        },
                    )
                )
        return configs

    def _get_model(self, model_type: str) -> object:
        if model_type not in self._models:
            import dac as _dac

            model_path = _dac.utils.download(model_type=model_type)
            model = _dac.DAC.load(model_path)
            model.eval().to(_device)
            self._models[model_type] = model
        return self._models[model_type]

    @torch.no_grad()
    def encode_decode(self, audio_path: Path, config: CodecConfig) -> np.ndarray:
        from audiotools import AudioSignal

        model_type: str = config.params["model_type"]
        n_quantizers: int = config.params["n_quantizers"]
        target_sr: int = config.params["sample_rate"]

        model = self._get_model(model_type)

        signal = AudioSignal(str(audio_path))
        if signal.audio_data.shape[1] > 1:
            signal.audio_data = signal.audio_data.mean(dim=1, keepdim=True)
        if signal.sample_rate != target_sr:
            signal = signal.resample(target_sr)

        signal = signal.to(model.device)
        x = model.preprocess(signal.audio_data, signal.sample_rate)
        z, codes, latents, _, _ = model.encode(x, n_quantizers=n_quantizers)
        y = model.decode(z)

        return y.squeeze(0).squeeze(0).cpu().numpy()


register(DACCodec())