nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
from dataclasses import dataclass
from .supported_models import currently_supported_models, standard_models, experimental_models
@dataclass
class BaseModelArguments:
def __init__(self, model_names: list[str] = None, model_paths: list[str] = None, model_types: list[str] = None, model_dtype=None, **kwargs):
if model_paths is not None:
assert model_types is not None, "model_types is required when model_paths is provided."
assert len(model_paths) == len(model_types), f"model_paths ({len(model_paths)}) and model_types ({len(model_types)}) must have the same length."
self.model_names = [p.split('/')[-1] for p in model_paths]
self._model_types = list(model_types)
self._model_paths = list(model_paths)
else:
assert model_names is not None, "Either model_names or model_paths/model_types must be provided."
if model_names[0] == 'standard':
self.model_names = standard_models
elif 'exp' in model_names[0].lower():
self.model_names = experimental_models
else:
self.model_names = model_names
self._model_types = None
self._model_paths = None
self.model_dtype = model_dtype
def model_entries(self):
"""Yields (display_name, dispatch_type, model_path) tuples for each model.
In preset mode: dispatch_type is the preset name, model_path is None.
In path mode: dispatch_type is the model type keyword, model_path is the explicit path.
"""
if self._model_paths is not None:
for name, mtype, mpath in zip(self.model_names, self._model_types, self._model_paths):
yield name, mtype, mpath
else:
for name in self.model_names:
yield name, name, None
def get_base_model(model_name: str, masked_lm: bool = False, dtype=None, model_path: str = None):
if 'random' in model_name.lower():
from .random import build_random_model
return build_random_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'esm2' in model_name.lower() and model_name.lower().count('esm2') == 1:
from .esm2 import build_esm2_model
return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'dsm' in model_name.lower():
from .esm2 import build_esm2_model
return build_esm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'esmc' in model_name.lower():
from .esmc import build_esmc_model
return build_esmc_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'protbert' in model_name.lower():
from .protbert import build_protbert_model
return build_protbert_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'prott5' in model_name.lower():
from .prott5 import build_prott5_model
return build_prott5_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'ankh' in model_name.lower():
from .ankh import build_ankh_model
return build_ankh_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'glm' in model_name.lower():
from .glm import build_glm2_model
return build_glm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'dplm2' in model_name.lower():
from .dplm2 import build_dplm2_model
return build_dplm2_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'dplm' in model_name.lower():
from .dplm import build_dplm_model
return build_dplm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'protclm' in model_name.lower():
from .protCLM import build_protCLM
return build_protCLM(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'onehot' in model_name.lower():
from .one_hot import build_one_hot_model
return build_one_hot_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'amplify' in model_name.lower():
from .amplify import build_amplify_model
return build_amplify_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'e1' in model_name.lower():
from .e1 import build_e1_model
return build_e1_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'vec2vec' in model_name.lower():
from .vec2vec import build_vec2vec_model
return build_vec2vec_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'calm' in model_name.lower():
from .calm import build_calm_model
return build_calm_model(model_name, masked_lm=masked_lm, dtype=dtype, model_path=model_path)
elif 'custom' in model_name.lower():
from .custom_model import build_custom_model
assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom."
return build_custom_model(model_path, masked_lm=masked_lm, dtype=dtype)
else:
raise ValueError(f"Model {model_name} not supported")
def get_base_model_for_training(model_name: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype=None, model_path: str = None):
if 'esm2' in model_name.lower() or 'dsm' in model_name.lower():
from .esm2 import get_esm2_for_training
return get_esm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'esmc' in model_name.lower():
from .esmc import get_esmc_for_training
return get_esmc_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'protbert' in model_name.lower():
from .protbert import get_protbert_for_training
return get_protbert_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'prott5' in model_name.lower():
from .prott5 import get_prott5_for_training
return get_prott5_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'ankh' in model_name.lower():
from .ankh import get_ankh_for_training
return get_ankh_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'glm' in model_name.lower():
from .glm import get_glm2_for_training
return get_glm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'dplm2' in model_name.lower():
from .dplm2 import get_dplm2_for_training
return get_dplm2_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'dplm' in model_name.lower():
from .dplm import get_dplm_for_training
return get_dplm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'e1' in model_name.lower():
from .e1 import get_e1_for_training
return get_e1_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'protclm' in model_name.lower():
from .protCLM import get_protCLM_for_training
return get_protCLM_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'amplify' in model_name.lower():
from .amplify import get_amplify_for_training
return get_amplify_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
elif 'calm' in model_name.lower():
from .calm import get_calm_for_training
return get_calm_for_training(model_name, tokenwise, num_labels, hybrid, dtype=dtype, model_path=model_path)
else:
raise ValueError(f"Model {model_name} not supported")
def get_tokenizer(model_name: str, model_path: str = None):
if 'custom' in model_name.lower():
from .custom_model import build_custom_tokenizer
assert model_path is not None, "model_path is required for custom models. Use --model_paths and --model_types custom."
return build_custom_tokenizer(model_path)
if 'esm2' in model_name.lower() or 'random' in model_name.lower() or 'dsm' in model_name.lower():
from .esm2 import get_esm2_tokenizer
return get_esm2_tokenizer(model_name, model_path=model_path)
elif 'esmc' in model_name.lower():
from .esmc import get_esmc_tokenizer
return get_esmc_tokenizer(model_name, model_path=model_path)
elif 'protbert' in model_name.lower():
from .protbert import get_protbert_tokenizer
return get_protbert_tokenizer(model_name, model_path=model_path)
elif 'prott5' in model_name.lower():
from .prott5 import get_prott5_tokenizer
return get_prott5_tokenizer(model_name, model_path=model_path)
elif 'ankh' in model_name.lower():
from .ankh import get_ankh_tokenizer
return get_ankh_tokenizer(model_name, model_path=model_path)
elif 'glm' in model_name.lower():
from .glm import get_glm2_tokenizer
return get_glm2_tokenizer(model_name, model_path=model_path)
elif 'dplm2' in model_name.lower():
from .dplm2 import get_dplm2_tokenizer
return get_dplm2_tokenizer(model_name, model_path=model_path)
elif 'dplm' in model_name.lower():
from .dplm import get_dplm_tokenizer
return get_dplm_tokenizer(model_name, model_path=model_path)
elif 'e1' in model_name.lower():
from .e1 import get_e1_tokenizer
return get_e1_tokenizer(model_name, model_path=model_path)
elif 'protclm' in model_name.lower():
from .protCLM import get_protCLM_tokenizer
return get_protCLM_tokenizer(model_name, model_path=model_path)
elif 'onehot' in model_name.lower():
from .one_hot import get_one_hot_tokenizer
return get_one_hot_tokenizer(model_name, model_path=model_path)
elif 'amplify' in model_name.lower():
from .amplify import get_amplify_tokenizer
return get_amplify_tokenizer(model_name, model_path=model_path)
elif 'calm' in model_name.lower():
from .calm import get_calm_tokenizer
return get_calm_tokenizer(model_name, model_path=model_path)
else:
raise ValueError(f"Model {model_name} not supported")
if __name__ == '__main__':
# py -m src.protify.base_models.get_base_models
import sys
import argparse
parser = argparse.ArgumentParser(description='Download and list supported models')
parser.add_argument('--download', action='store_true', help='Download all standard models')
parser.add_argument('--list', action='store_true', help='List all supported models with descriptions')
args = parser.parse_args()
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
if args.list:
try:
from resource_info import model_descriptions
print("\n=== Currently Supported Models ===\n")
max_name_len = max(len(name) for name in currently_supported_models)
max_type_len = max(len(model_descriptions.get(name, {}).get('type', 'Unknown')) for name in currently_supported_models if name in model_descriptions)
max_size_len = max(len(model_descriptions.get(name, {}).get('size', 'Unknown')) for name in currently_supported_models if name in model_descriptions)
# Print header
print(f"{'Model':<{max_name_len+2}}{'Type':<{max_type_len+2}}{'Size':<{max_size_len+2}}Description")
print("-" * (max_name_len + max_type_len + max_size_len + 50))
for model_name in currently_supported_models:
if model_name in model_descriptions:
model_info = model_descriptions[model_name]
print(f"{model_name:<{max_name_len+2}}{model_info.get('type', 'Unknown'):<{max_type_len+2}}{model_info.get('size', 'Unknown'):<{max_size_len+2}}{model_info.get('description', 'No description available')}")
else:
print(f"{model_name:<{max_name_len+2}}{'Unknown':<{max_type_len+2}}{'Unknown':<{max_size_len+2}}No description available")
print("\n=== Standard Models ===\n")
for model_name in standard_models:
print(f"- {model_name}")
except ImportError:
print("Model descriptions file not found. Only listing model names.")
print("\n=== Currently Supported Models ===\n")
for model_name in currently_supported_models:
print(f"- {model_name}")
print("\n=== Standard Models ===\n")
for model_name in standard_models:
print(f"- {model_name}")
if args.download:
### This will download all standard models
from torchinfo import summary
from ..utils import clear_screen
download_args = BaseModelArguments(model_names=['standard'])
for model_name in download_args.model_names:
model, tokenizer = get_base_model(model_name)
print(f'Downloaded {model_name}')
tokenized = tokenizer('MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICLLLICIIVMLL', return_tensors='pt').input_ids
summary(model, input_data=tokenized)
clear_screen()