from dataclasses import dataclass from .supported_models import currently_supported_models, standard_models, experimental_models @dataclass class BaseModelArguments: def __init__(self, model_names: list[str] = None, model_paths: list[str] = None, model_types: list[str] = None, model_dtype=None, **kwargs): if model_paths is not None: assert model_types is not None, "model_types is required when model_paths is provided." assert len(model_paths) == len(model_types), f"model_paths ({len(model_paths)}) and model_types ({len(model_types)}) must have the same length." self.model_names = [p.split('/')[-1] for p in model_paths] self._model_types = list(model_types) self._model_paths = list(model_paths) else: assert model_names is not None, "Either model_names or model_paths/model_types must be provided." if model_names[0] == 'standard': self.model_names = standard_models elif 'exp' in model_names[0].lower(): self.model_names = experimental_models else: self.model_names = model_names self._model_types = None self._model_paths = None self.model_dtype = model_dtype def model_entries(self): """Yields (display_name, dispatch_type, model_path) tuples for each model. In preset mode: dispatch_type is the preset name, model_path is None. In path mode: dispatch_type is the model type keyword, model_path is the explicit path. """ if self._model_paths is not None: for name, mtype, mpath in zip(self.model_names, self._model_types, self._model_paths): yield name, mtype, mpath else: for name in self.model_names: yield name, name, None def get_base_model(model_name: str, masked_lm: bool = False, dtype=None, model_path: str = None): if 'random' in model_name.lower(): from .random import build_random_model return build_random_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'esm2' in model_name.lower() and model_name.lower().count('esm2') == 1: from .esm2 import build_esm2_model return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'dsm' in model_name.lower(): from .esm2 import build_esm2_model return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'esmc' in model_name.lower(): from .esmc import build_esmc_model return build_esmc_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'protbert' in model_name.lower(): from .protbert import build_protbert_model return build_protbert_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'prott5' in model_name.lower(): from .prott5 import build_prott5_model return build_prott5_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'ankh' in model_name.lower(): from .ankh import build_ankh_model return build_ankh_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'glm' in model_name.lower(): from .glm import build_glm2_model return build_glm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'dplm2' in model_name.lower(): from .dplm2 import build_dplm2_model return build_dplm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'dplm' in model_name.lower(): from .dplm import build_dplm_model return build_dplm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'protclm' in model_name.lower(): from .protCLM import build_protCLM return build_protCLM(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'onehot' in model_name.lower(): from .one_hot import build_one_hot_model return build_one_hot_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'amplify' in model_name.lower(): from .amplify import build_amplify_model return build_amplify_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'e1' in model_name.lower(): from .e1 import build_e1_model return build_e1_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'vec2vec' in model_name.lower(): from .vec2vec import build_vec2vec_model return build_vec2vec_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'calm' in model_name.lower(): from .calm import build_calm_model return build_calm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) elif 'custom' in model_name.lower(): from .custom_model import build_custom_model assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom." return build_custom_model(model_path, masked_lm=masked_lm, dtype=dtype) else: raise ValueError(f"Model {model_name} not supported") def get_base_model_for_training(model_name: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype=None, model_path: str = None): if 'esm2' in model_name.lower() or 'dsm' in model_name.lower(): from .esm2 import get_esm2_for_training return get_esm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'esmc' in model_name.lower(): from .esmc import get_esmc_for_training return get_esmc_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'protbert' in model_name.lower(): from .protbert import get_protbert_for_training return get_protbert_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'prott5' in model_name.lower(): from .prott5 import get_prott5_for_training return get_prott5_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'ankh' in model_name.lower(): from .ankh import get_ankh_for_training return get_ankh_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'glm' in model_name.lower(): from .glm import get_glm2_for_training return get_glm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'dplm2' in model_name.lower(): from .dplm2 import get_dplm2_for_training return get_dplm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'dplm' in model_name.lower(): from .dplm import get_dplm_for_training return get_dplm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'e1' in model_name.lower(): from .e1 import get_e1_for_training return get_e1_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'protclm' in model_name.lower(): from .protCLM import get_protCLM_for_training return get_protCLM_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'amplify' in model_name.lower(): from .amplify import get_amplify_for_training return get_amplify_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) elif 'calm' in model_name.lower(): from .calm import get_calm_for_training return get_calm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) else: raise ValueError(f"Model {model_name} not supported") def get_tokenizer(model_name: str, model_path: str = None): if 'custom' in model_name.lower(): from .custom_model import build_custom_tokenizer assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom." return build_custom_tokenizer(model_path) if 'esm2' in model_name.lower() or 'random' in model_name.lower() or 'dsm' in model_name.lower(): from .esm2 import get_esm2_tokenizer return get_esm2_tokenizer(model_name, model_path=model_path) elif 'esmc' in model_name.lower(): from .esmc import get_esmc_tokenizer return get_esmc_tokenizer(model_name, model_path=model_path) elif 'protbert' in model_name.lower(): from .protbert import get_protbert_tokenizer return get_protbert_tokenizer(model_name, model_path=model_path) elif 'prott5' in model_name.lower(): from .prott5 import get_prott5_tokenizer return get_prott5_tokenizer(model_name, model_path=model_path) elif 'ankh' in model_name.lower(): from .ankh import get_ankh_tokenizer return get_ankh_tokenizer(model_name, model_path=model_path) elif 'glm' in model_name.lower(): from .glm import get_glm2_tokenizer return get_glm2_tokenizer(model_name, model_path=model_path) elif 'dplm2' in model_name.lower(): from .dplm2 import get_dplm2_tokenizer return get_dplm2_tokenizer(model_name, model_path=model_path) elif 'dplm' in model_name.lower(): from .dplm import get_dplm_tokenizer return get_dplm_tokenizer(model_name, model_path=model_path) elif 'e1' in model_name.lower(): from .e1 import get_e1_tokenizer return get_e1_tokenizer(model_name, model_path=model_path) elif 'protclm' in model_name.lower(): from .protCLM import get_protCLM_tokenizer return get_protCLM_tokenizer(model_name, model_path=model_path) elif 'onehot' in model_name.lower(): from .one_hot import get_one_hot_tokenizer return get_one_hot_tokenizer(model_name, model_path=model_path) elif 'amplify' in model_name.lower(): from .amplify import get_amplify_tokenizer return get_amplify_tokenizer(model_name, model_path=model_path) elif 'calm' in model_name.lower(): from .calm import get_calm_tokenizer return get_calm_tokenizer(model_name, model_path=model_path) else: raise ValueError(f"Model {model_name} not supported") if __name__ == '__main__': # py -m src.protify.base_models.get_base_models import sys import argparse parser = argparse.ArgumentParser(description='Download and list supported models') parser.add_argument('--download', action='store_true', help='Download all standard models') parser.add_argument('--list', action='store_true', help='List all supported models with descriptions') args = parser.parse_args() if len(sys.argv) == 1: parser.print_help() sys.exit(1) if args.list: try: from resource_info import model_descriptions print("\n=== Currently Supported Models ===\n") max_name_len = max(len(name) for name in currently_supported_models) max_type_len = max(len(model_descriptions.get(name, {}).get('type', 'Unknown')) for name in currently_supported_models if name in model_descriptions) max_size_len = max(len(model_descriptions.get(name, {}).get('size', 'Unknown')) for name in currently_supported_models if name in model_descriptions) # Print header print(f"{'Model':<{max_name_len+2}}{'Type':<{max_type_len+2}}{'Size':<{max_size_len+2}}Description") print("-" * (max_name_len + max_type_len + max_size_len + 50)) for model_name in currently_supported_models: if model_name in model_descriptions: model_info = model_descriptions[model_name] print(f"{model_name:<{max_name_len+2}}{model_info.get('type', 'Unknown'):<{max_type_len+2}}{model_info.get('size', 'Unknown'):<{max_size_len+2}}{model_info.get('description', 'No description available')}") else: print(f"{model_name:<{max_name_len+2}}{'Unknown':<{max_type_len+2}}{'Unknown':<{max_size_len+2}}No description available") print("\n=== Standard Models ===\n") for model_name in standard_models: print(f"- {model_name}") except ImportError: print("Model descriptions file not found. Only listing model names.") print("\n=== Currently Supported Models ===\n") for model_name in currently_supported_models: print(f"- {model_name}") print("\n=== Standard Models ===\n") for model_name in standard_models: print(f"- {model_name}") if args.download: ### This will download all standard models from torchinfo import summary from ..utils import clear_screen download_args = BaseModelArguments(model_names=['standard']) for model_name in download_args.model_names: model, tokenizer = get_base_model(model_name) print(f'Downloaded {model_name}') tokenized = tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICLLLICIIVMLL', return_tensors='pt').input_ids summary(model, input_data=tokenized) clear_screen()