nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List
from transformers import (
AutoTokenizer,
AutoModel,
AutoModelForSequenceClassification,
AutoModelForTokenClassification
)
from .base_tokenizer import BaseSequenceTokenizer
presets = {
"ProtCLM-1b": "biomap-research/proteinglm-1b-clm",
#"ProtCLM-3b": "biomap-research/proteinglm-3b-clm",
#"ProtCLM-7b": "biomap-research/proteinglm-7b-clm"
}
class ProtCLMTokenizerWrapper(BaseSequenceTokenizer):
def __init__(self, tokenizer: AutoTokenizer):
super().__init__(tokenizer)
def __call__(self, sequences: Union[str, List[str]], **kwargs):
if isinstance(sequences, str):
sequences = [sequences]
kwargs.setdefault("return_tensors", "pt")
kwargs.setdefault("padding", "longest")
kwargs.setdefault("add_special_tokens", True)
return self.tokenizer(sequences, **kwargs)
class ProtCLMForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.plm = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> torch.Tensor:
assert not output_attentions or not output_hidden_states, (
"output_attentions=True and output_hidden_states=True are not supported by ProtCLMForEmbedding."
)
out = self.plm(
input_ids=input_ids,
attention_mask=attention_mask
)
return out.last_hidden_state
def get_protCLM_tokenizer(preset: str, model_path: str = None) -> BaseSequenceTokenizer:
return ProtCLMTokenizerWrapper(
AutoTokenizer.from_pretrained(model_path or presets[preset], trust_remote_code=True)
)
def build_protCLM(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs) -> Tuple[AutoModel, BaseSequenceTokenizer]:
if masked_lm:
raise ValueError(f"Model {preset} does not support masked language modeling")
model_path = model_path or presets[preset]
model = ProtCLMForEmbedding(model_path, dtype=dtype).eval()
tokenizer = get_protCLM_tokenizer(preset)
return model, tokenizer
def get_protCLM_for_training(
preset: str,
tokenwise: bool = False,
num_labels: int = None,
hybrid: bool = False,
dtype: torch.dtype = None,
model_path: str = None,
):
model_path = model_path or presets[preset]
if hybrid:
model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
else:
if tokenwise:
model = AutoModelForTokenClassification.from_pretrained(
model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
).eval()
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
).eval()
tokenizer = get_protCLM_tokenizer(preset)
return model, tokenizer
if __name__ == "__main__":
# py -m src.protify.base_models.protCLM
model, tokenizer = build_protCLM("ProtCLM-1b")
print(model)
print(tokenizer)
print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL"))