from typing import List, Tuple import torch from transformers import PreTrainedModel from .configuration_usad2 import Usad2Config from .usad_model import UsadModel class Usad2Model(PreTrainedModel): config_class = Usad2Config base_model_prefix = "model" main_input_name = "wavs" def __init__(self, config: Usad2Config): super().__init__(config) self.model = UsadModel(config) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) @property def sample_rate(self) -> int: return 16000 # Hz @property def encoder_frame_rate(self) -> int: return round(100 / self.config.conv_subsample_rate) # Hz @property def mel_dim(self) -> int: return self.config.input_dim @property def encoder_dim(self) -> int: return self.config.encoder_dim @property def num_layers(self) -> int: return self.config.num_layers @property def device(self) -> torch.device: return next(self.parameters()).device @property def dtype(self) -> torch.dtype: return next(self.parameters()).dtype def set_audio_chunk_size(self, seconds: float = 30.0) -> None: self.model.set_audio_chunk_size(seconds) def load_audio(self, audio_path: str) -> torch.Tensor: return self.model.load_audio(audio_path) def load_audio_batch( self, audio_paths: List[str] ) -> Tuple[torch.Tensor, torch.Tensor]: return self.model.load_audio_batch(audio_paths)