compare-codec / tests /test_snac_codec.py
twangodev's picture
feat: implement SNAC codec and integrate into the codec registry
287431b verified
"""Tests for the SNAC codec wrapper."""
from __future__ import annotations
from pathlib import Path
import numpy as np
import pytest
import torch
import torchaudio
@pytest.fixture()
def wav_file(tmp_path: Path) -> Path:
"""Create a short mono WAV at 24 kHz."""
sr = 24_000
samples = torch.randn(1, sr * 2) # 2 seconds
path = tmp_path / "test.wav"
torchaudio.save(str(path), samples, sr)
return path
def test_snac_name():
from compare_codec.snac_codec import SNACCodec
codec = SNACCodec()
assert codec.name == "SNAC"
def test_snac_sample_rate():
from compare_codec.snac_codec import SNACCodec
codec = SNACCodec()
assert codec.sample_rate == 24_000
def test_snac_configs_not_empty():
from compare_codec.snac_codec import SNACCodec
codec = SNACCodec()
configs = codec.configs()
assert len(configs) >= 3 # at least one per model variant
def test_snac_configs_have_sample_rate():
from compare_codec.snac_codec import SNACCodec
codec = SNACCodec()
for cfg in codec.configs():
assert "sample_rate" in cfg.params
assert "model_id" in cfg.params
def test_snac_encode_decode_returns_float32_array(wav_file: Path):
from compare_codec.snac_codec import SNACCodec
codec = SNACCodec()
cfg = [c for c in codec.configs() if c.params["sample_rate"] == 24_000][0]
result = codec.encode_decode(wav_file, cfg)
assert isinstance(result, np.ndarray)
assert result.dtype == np.float32
assert result.ndim == 1
assert len(result) > 0