Res2TCNGuard / evaluate.py
korallll's picture
Add model code (_net.py, evaluate.py, res2tcnguard.py); fix README usage; precise params
f2beec2 verified
"""Standalone evaluation for Res2TCNGuard.
The network definition lives in ``_net.py`` (in this repo). This script loads
the pretrained checkpoint ``best_1.495.pth`` and scores audio, returning a
bona-fide score where **higher = more bona fide**.
Dependencies: torch, numpy (plus soundfile + scipy for the file demo below).
python evaluate.py path/to/audio.wav
"""
from __future__ import annotations
import numpy as np
import torch
from _net import TestModel
CUT = 64600 # fixed input length the classifier head requires
SAMPLE_RATE = 16000 # model operates on 16 kHz mono audio
def pad_fixed(x: np.ndarray, max_len: int = CUT) -> np.ndarray:
"""Deterministic window: first ``max_len`` samples; tile-repeat if shorter.
This is exactly the windowing used to produce the Arena scores (no random
crop), so results are reproducible.
"""
x = np.asarray(x, dtype=np.float32).reshape(-1)
n = x.shape[0]
if n >= max_len:
return x[:max_len]
reps = max_len // n + 1
return np.tile(x, reps)[:max_len].astype(np.float32)
def load_model(ckpt: str = "best_1.495.pth", device: str = "cpu") -> TestModel:
model = TestModel()
sd = torch.load(ckpt, map_location="cpu")
sd = sd.get("state_dict", sd) # accept raw state_dict or wrapped
model.load_state_dict(sd, strict=True)
return model.eval().to(device)
@torch.no_grad()
def score(model: TestModel, audio: np.ndarray, device: str = "cpu") -> float:
"""Score one utterance (float32 mono 16 kHz waveform). Higher = bona fide."""
x = torch.from_numpy(pad_fixed(audio))[None].to(device) # (1, 64600)
_, logits = model(x) # (1, 2)
return float(logits[0, 1])
if __name__ == "__main__":
import sys
from math import gcd
import soundfile as sf
from scipy.signal import resample_poly
audio, sr = sf.read(sys.argv[1])
if audio.ndim == 2:
audio = audio.mean(axis=1)
audio = audio.astype(np.float32)
if sr != SAMPLE_RATE:
g = gcd(int(sr), SAMPLE_RATE)
audio = resample_poly(audio, SAMPLE_RATE // g, int(sr) // g).astype(np.float32)
model = load_model(device="cpu")
print(f"bona-fide score: {score(model, audio):.6f} (higher = more bona fide)")