| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Tuple, Union, List |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModel, |
| | AutoModelForSequenceClassification, |
| | AutoModelForTokenClassification |
| | ) |
| | from .base_tokenizer import BaseSequenceTokenizer |
| |
|
| |
|
| | presets = { |
| | "ProtCLM-1b": "biomap-research/proteinglm-1b-clm", |
| | |
| | |
| | } |
| |
|
| |
|
| | class ProtCLMTokenizerWrapper(BaseSequenceTokenizer): |
| | def __init__(self, tokenizer: AutoTokenizer): |
| | super().__init__(tokenizer) |
| | def __call__(self, sequences: Union[str, List[str]], **kwargs): |
| | if isinstance(sequences, str): |
| | sequences = [sequences] |
| | kwargs.setdefault("return_tensors", "pt") |
| | kwargs.setdefault("padding", "longest") |
| | kwargs.setdefault("add_special_tokens", True) |
| | return self.tokenizer(sequences, **kwargs) |
| |
|
| | class ProtCLMForEmbedding(nn.Module): |
| | def __init__(self, model_path: str, dtype: torch.dtype = None): |
| | super().__init__() |
| | self.plm = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | assert not output_attentions or not output_hidden_states, ( |
| | "output_attentions=True and output_hidden_states=True are not supported by ProtCLMForEmbedding." |
| | ) |
| |
|
| | out = self.plm( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask |
| | ) |
| | return out.last_hidden_state |
| |
|
| |
|
| | def get_protCLM_tokenizer(preset: str, model_path: str = None) -> BaseSequenceTokenizer: |
| | return ProtCLMTokenizerWrapper( |
| | AutoTokenizer.from_pretrained(model_path or presets[preset], trust_remote_code=True) |
| | ) |
| |
|
| |
|
| | def build_protCLM(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs) -> Tuple[AutoModel, BaseSequenceTokenizer]: |
| | if masked_lm: |
| | raise ValueError(f"Model {preset} does not support masked language modeling") |
| | model_path = model_path or presets[preset] |
| | model = ProtCLMForEmbedding(model_path, dtype=dtype).eval() |
| | tokenizer = get_protCLM_tokenizer(preset) |
| | return model, tokenizer |
| |
|
| |
|
| | def get_protCLM_for_training( |
| | preset: str, |
| | tokenwise: bool = False, |
| | num_labels: int = None, |
| | hybrid: bool = False, |
| | dtype: torch.dtype = None, |
| | model_path: str = None, |
| | ): |
| | model_path = model_path or presets[preset] |
| | if hybrid: |
| | model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval() |
| | else: |
| | if tokenwise: |
| | model = AutoModelForTokenClassification.from_pretrained( |
| | model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True |
| | ).eval() |
| | else: |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True |
| | ).eval() |
| | tokenizer = get_protCLM_tokenizer(preset) |
| | return model, tokenizer |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | model, tokenizer = build_protCLM("ProtCLM-1b") |
| | print(model) |
| | print(tokenizer) |
| | print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL")) |
| |
|