""" We use the FastPLM implementation of DPLM2. """ import sys import os import torch import torch.nn as nn from typing import List, Optional, Union, Dict _FASTPLMS = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'FastPLMs') if _FASTPLMS not in sys.path: sys.path.insert(0, _FASTPLMS) from dplm2_fastplms.modeling_dplm2 import ( DPLM2ForMaskedLM, DPLM2ForSequenceClassification, DPLM2ForTokenClassification, ) from transformers import EsmTokenizer from .base_tokenizer import BaseSequenceTokenizer presets = { "DPLM2-150": "airkingbd/dplm2_150m", "DPLM2-650": "airkingbd/dplm2_650m", "DPLM2-3B": "airkingbd/dplm2_3b", } class DPLM2TokenizerWrapper(BaseSequenceTokenizer): def __init__(self, tokenizer: EsmTokenizer): super().__init__(tokenizer) def __call__( self, sequences: Union[str, List[str]], **kwargs ) -> Dict[str, torch.Tensor]: if isinstance(sequences, str): sequences = [sequences] kwargs.setdefault("return_tensors", "pt") kwargs.setdefault("padding", "longest") kwargs.setdefault("add_special_tokens", True) tokenized = self.tokenizer(sequences, **kwargs) return tokenized class DPLM2ForEmbedding(nn.Module): def __init__(self, model_path: str, dtype: torch.dtype = None): super().__init__() self.dplm2 = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = False, **kwargs, ) -> torch.Tensor: out = self.dplm2( input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) if output_attentions: return out.last_hidden_state, out.attentions return out.last_hidden_state def get_dplm2_tokenizer(preset: str, model_path: str = None): return DPLM2TokenizerWrapper(EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")) def build_dplm2_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs): model_path = model_path or presets[preset] if masked_lm: model = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval() else: model = DPLM2ForEmbedding(model_path, dtype=dtype).eval() tokenizer = get_dplm2_tokenizer(preset) return model, tokenizer def get_dplm2_for_training( preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None, ): model_path = model_path or presets[preset] if hybrid: model = DPLM2ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval() else: if tokenwise: model = DPLM2ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval() else: model = DPLM2ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval() tokenizer = get_dplm2_tokenizer(preset) return model, tokenizer if __name__ == "__main__": # py -m src.protify.base_models.dplm2 model, tokenizer = build_dplm2_model("DPLM2-150") print(model) print(tokenizer) print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL"))