USAD2-Small / modeling_usad2.py
vectominist's picture
Add USAD2 model
94d3a9f verified
from typing import List, Tuple
import torch
from transformers import PreTrainedModel
from .configuration_usad2 import Usad2Config
from .usad_model import UsadModel
class Usad2Model(PreTrainedModel):
config_class = Usad2Config
base_model_prefix = "model"
main_input_name = "wavs"
def __init__(self, config: Usad2Config):
super().__init__(config)
self.model = UsadModel(config)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
@property
def sample_rate(self) -> int:
return 16000 # Hz
@property
def encoder_frame_rate(self) -> int:
return round(100 / self.config.conv_subsample_rate) # Hz
@property
def mel_dim(self) -> int:
return self.config.input_dim
@property
def encoder_dim(self) -> int:
return self.config.encoder_dim
@property
def num_layers(self) -> int:
return self.config.num_layers
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
self.model.set_audio_chunk_size(seconds)
def load_audio(self, audio_path: str) -> torch.Tensor:
return self.model.load_audio(audio_path)
def load_audio_batch(
self, audio_paths: List[str]
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.model.load_audio_batch(audio_paths)