| | from dataclasses import dataclass |
| |
|
| | from .supported_models import currently_supported_models, standard_models, experimental_models |
| |
|
| |
|
| | @dataclass |
| | class BaseModelArguments: |
| | def __init__(self, model_names: list[str] = None, model_paths: list[str] = None, model_types: list[str] = None, model_dtype=None, **kwargs): |
| | if model_paths is not None: |
| | assert model_types is not None, "model_types is required when model_paths is provided." |
| | assert len(model_paths) == len(model_types), f"model_paths ({len(model_paths)}) and model_types ({len(model_types)}) must have the same length." |
| | self.model_names = [p.split('/')[-1] for p in model_paths] |
| | self._model_types = list(model_types) |
| | self._model_paths = list(model_paths) |
| | else: |
| | assert model_names is not None, "Either model_names or model_paths/model_types must be provided." |
| | if model_names[0] == 'standard': |
| | self.model_names = standard_models |
| | elif 'exp' in model_names[0].lower(): |
| | self.model_names = experimental_models |
| | else: |
| | self.model_names = model_names |
| | self._model_types = None |
| | self._model_paths = None |
| | self.model_dtype = model_dtype |
| |
|
| | def model_entries(self): |
| | """Yields (display_name, dispatch_type, model_path) tuples for each model. |
| | |
| | In preset mode: dispatch_type is the preset name, model_path is None. |
| | In path mode: dispatch_type is the model type keyword, model_path is the explicit path. |
| | """ |
| | if self._model_paths is not None: |
| | for name, mtype, mpath in zip(self.model_names, self._model_types, self._model_paths): |
| | yield name, mtype, mpath |
| | else: |
| | for name in self.model_names: |
| | yield name, name, None |
| |
|
| |
|
| | def get_base_model(model_name: str, masked_lm: bool = False, dtype=None, model_path: str = None): |
| | if 'random' in model_name.lower(): |
| | from .random import build_random_model |
| | return build_random_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'esm2' in model_name.lower() and model_name.lower().count('esm2') == 1: |
| | from .esm2 import build_esm2_model |
| | return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'dsm' in model_name.lower(): |
| | from .esm2 import build_esm2_model |
| | return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'esmc' in model_name.lower(): |
| | from .esmc import build_esmc_model |
| | return build_esmc_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'protbert' in model_name.lower(): |
| | from .protbert import build_protbert_model |
| | return build_protbert_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'prott5' in model_name.lower(): |
| | from .prott5 import build_prott5_model |
| | return build_prott5_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'ankh' in model_name.lower(): |
| | from .ankh import build_ankh_model |
| | return build_ankh_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'glm' in model_name.lower(): |
| | from .glm import build_glm2_model |
| | return build_glm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'dplm2' in model_name.lower(): |
| | from .dplm2 import build_dplm2_model |
| | return build_dplm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'dplm' in model_name.lower(): |
| | from .dplm import build_dplm_model |
| | return build_dplm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'protclm' in model_name.lower(): |
| | from .protCLM import build_protCLM |
| | return build_protCLM(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'onehot' in model_name.lower(): |
| | from .one_hot import build_one_hot_model |
| | return build_one_hot_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'amplify' in model_name.lower(): |
| | from .amplify import build_amplify_model |
| | return build_amplify_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'e1' in model_name.lower(): |
| | from .e1 import build_e1_model |
| | return build_e1_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'vec2vec' in model_name.lower(): |
| | from .vec2vec import build_vec2vec_model |
| | return build_vec2vec_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'calm' in model_name.lower(): |
| | from .calm import build_calm_model |
| | return build_calm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path) |
| | elif 'custom' in model_name.lower(): |
| | from .custom_model import build_custom_model |
| | assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom." |
| | return build_custom_model(model_path, masked_lm=masked_lm, dtype=dtype) |
| | else: |
| | raise ValueError(f"Model {model_name} not supported") |
| |
|
| |
|
| | def get_base_model_for_training(model_name: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype=None, model_path: str = None): |
| | if 'esm2' in model_name.lower() or 'dsm' in model_name.lower(): |
| | from .esm2 import get_esm2_for_training |
| | return get_esm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'esmc' in model_name.lower(): |
| | from .esmc import get_esmc_for_training |
| | return get_esmc_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'protbert' in model_name.lower(): |
| | from .protbert import get_protbert_for_training |
| | return get_protbert_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'prott5' in model_name.lower(): |
| | from .prott5 import get_prott5_for_training |
| | return get_prott5_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'ankh' in model_name.lower(): |
| | from .ankh import get_ankh_for_training |
| | return get_ankh_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'glm' in model_name.lower(): |
| | from .glm import get_glm2_for_training |
| | return get_glm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'dplm2' in model_name.lower(): |
| | from .dplm2 import get_dplm2_for_training |
| | return get_dplm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'dplm' in model_name.lower(): |
| | from .dplm import get_dplm_for_training |
| | return get_dplm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'e1' in model_name.lower(): |
| | from .e1 import get_e1_for_training |
| | return get_e1_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'protclm' in model_name.lower(): |
| | from .protCLM import get_protCLM_for_training |
| | return get_protCLM_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'amplify' in model_name.lower(): |
| | from .amplify import get_amplify_for_training |
| | return get_amplify_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | elif 'calm' in model_name.lower(): |
| | from .calm import get_calm_for_training |
| | return get_calm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path) |
| | else: |
| | raise ValueError(f"Model {model_name} not supported") |
| |
|
| |
|
| | def get_tokenizer(model_name: str, model_path: str = None): |
| | if 'custom' in model_name.lower(): |
| | from .custom_model import build_custom_tokenizer |
| | assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom." |
| | return build_custom_tokenizer(model_path) |
| | if 'esm2' in model_name.lower() or 'random' in model_name.lower() or 'dsm' in model_name.lower(): |
| | from .esm2 import get_esm2_tokenizer |
| | return get_esm2_tokenizer(model_name, model_path=model_path) |
| | elif 'esmc' in model_name.lower(): |
| | from .esmc import get_esmc_tokenizer |
| | return get_esmc_tokenizer(model_name, model_path=model_path) |
| | elif 'protbert' in model_name.lower(): |
| | from .protbert import get_protbert_tokenizer |
| | return get_protbert_tokenizer(model_name, model_path=model_path) |
| | elif 'prott5' in model_name.lower(): |
| | from .prott5 import get_prott5_tokenizer |
| | return get_prott5_tokenizer(model_name, model_path=model_path) |
| | elif 'ankh' in model_name.lower(): |
| | from .ankh import get_ankh_tokenizer |
| | return get_ankh_tokenizer(model_name, model_path=model_path) |
| | elif 'glm' in model_name.lower(): |
| | from .glm import get_glm2_tokenizer |
| | return get_glm2_tokenizer(model_name, model_path=model_path) |
| | elif 'dplm2' in model_name.lower(): |
| | from .dplm2 import get_dplm2_tokenizer |
| | return get_dplm2_tokenizer(model_name, model_path=model_path) |
| | elif 'dplm' in model_name.lower(): |
| | from .dplm import get_dplm_tokenizer |
| | return get_dplm_tokenizer(model_name, model_path=model_path) |
| | elif 'e1' in model_name.lower(): |
| | from .e1 import get_e1_tokenizer |
| | return get_e1_tokenizer(model_name, model_path=model_path) |
| | elif 'protclm' in model_name.lower(): |
| | from .protCLM import get_protCLM_tokenizer |
| | return get_protCLM_tokenizer(model_name, model_path=model_path) |
| | elif 'onehot' in model_name.lower(): |
| | from .one_hot import get_one_hot_tokenizer |
| | return get_one_hot_tokenizer(model_name, model_path=model_path) |
| | elif 'amplify' in model_name.lower(): |
| | from .amplify import get_amplify_tokenizer |
| | return get_amplify_tokenizer(model_name, model_path=model_path) |
| | elif 'calm' in model_name.lower(): |
| | from .calm import get_calm_tokenizer |
| | return get_calm_tokenizer(model_name, model_path=model_path) |
| | else: |
| | raise ValueError(f"Model {model_name} not supported") |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | import sys |
| | import argparse |
| | |
| | parser = argparse.ArgumentParser(description='Download and list supported models') |
| | parser.add_argument('--download', action='store_true', help='Download all standard models') |
| | parser.add_argument('--list', action='store_true', help='List all supported models with descriptions') |
| | args = parser.parse_args() |
| | |
| | if len(sys.argv) == 1: |
| | parser.print_help() |
| | sys.exit(1) |
| | |
| | if args.list: |
| | try: |
| | from resource_info import model_descriptions |
| | print("\n=== Currently Supported Models ===\n") |
| | |
| | max_name_len = max(len(name) for name in currently_supported_models) |
| | max_type_len = max(len(model_descriptions.get(name, {}).get('type', 'Unknown')) for name in currently_supported_models if name in model_descriptions) |
| | max_size_len = max(len(model_descriptions.get(name, {}).get('size', 'Unknown')) for name in currently_supported_models if name in model_descriptions) |
| | |
| | |
| | print(f"{'Model':<{max_name_len+2}}{'Type':<{max_type_len+2}}{'Size':<{max_size_len+2}}Description") |
| | print("-" * (max_name_len + max_type_len + max_size_len + 50)) |
| | |
| | for model_name in currently_supported_models: |
| | if model_name in model_descriptions: |
| | model_info = model_descriptions[model_name] |
| | print(f"{model_name:<{max_name_len+2}}{model_info.get('type', 'Unknown'):<{max_type_len+2}}{model_info.get('size', 'Unknown'):<{max_size_len+2}}{model_info.get('description', 'No description available')}") |
| | else: |
| | print(f"{model_name:<{max_name_len+2}}{'Unknown':<{max_type_len+2}}{'Unknown':<{max_size_len+2}}No description available") |
| | |
| | print("\n=== Standard Models ===\n") |
| | for model_name in standard_models: |
| | print(f"- {model_name}") |
| | |
| | except ImportError: |
| | print("Model descriptions file not found. Only listing model names.") |
| | print("\n=== Currently Supported Models ===\n") |
| | for model_name in currently_supported_models: |
| | print(f"- {model_name}") |
| | |
| | print("\n=== Standard Models ===\n") |
| | for model_name in standard_models: |
| | print(f"- {model_name}") |
| | |
| | if args.download: |
| | |
| | from torchinfo import summary |
| | from ..utils import clear_screen |
| | download_args = BaseModelArguments(model_names=['standard']) |
| | for model_name in download_args.model_names: |
| | model, tokenizer = get_base_model(model_name) |
| | print(f'Downloaded {model_name}') |
| | tokenized = tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICLLLICIIVMLL', return_tensors='pt').input_ids |
| | summary(model, input_data=tokenized) |
| | clear_screen() |
| |
|