| | import torch |
| |
|
| | from pathlib import Path |
| | from typing import Dict, Any, Optional |
| |
|
| | from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier |
| |
|
| |
|
| | class ModelFactory: |
| | """ |
| | Factory class for creating and loading models |
| | """ |
| | |
| | @staticmethod |
| | def create_model( |
| | model_type: str, |
| | model_params: Dict[str, Any], |
| | state_dict_path: Optional[Path] = None |
| | ) -> torch.nn.Module: |
| | """ |
| | Create and load a model from configuration |
| | |
| | Args: |
| | model_type: Type of model ('Transformer', 'Mamba', 'LSTM') |
| | model_params: Dictionary of model parameters |
| | state_dict_path: Path to saved state dictionary |
| | |
| | Returns: |
| | Initialized PyTorch model |
| | |
| | Raises: |
| | ValueError: If model_type is unknown |
| | """ |
| |
|
| | model_classes = { |
| | "Transformer": TransformerClassifier, |
| | "Mamba": MambaClassifier, |
| | "LSTM": LSTMClassifier |
| | } |
| | |
| | if model_type not in model_classes: |
| | raise ValueError(f"Unknown model type: {model_type}") |
| | |
| | model = model_classes[model_type](**model_params) |
| | |
| | if state_dict_path: |
| | state_dict = torch.load(state_dict_path, map_location="cpu") |
| | model.load_state_dict(state_dict) |
| | |
| | model.eval() |
| | return model |
| |
|