File size: 3,529 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List
from transformers import (
    AutoTokenizer,
    AutoModel,
    AutoModelForSequenceClassification,
    AutoModelForTokenClassification
)
from .base_tokenizer import BaseSequenceTokenizer


presets = {
    "ProtCLM-1b": "biomap-research/proteinglm-1b-clm",
    #"ProtCLM-3b": "biomap-research/proteinglm-3b-clm",
    #"ProtCLM-7b": "biomap-research/proteinglm-7b-clm"
}


class ProtCLMTokenizerWrapper(BaseSequenceTokenizer):
    def __init__(self, tokenizer: AutoTokenizer):
        super().__init__(tokenizer)
    def __call__(self, sequences: Union[str, List[str]], **kwargs):
        if isinstance(sequences, str):
            sequences = [sequences]
        kwargs.setdefault("return_tensors", "pt")
        kwargs.setdefault("padding", "longest")
        kwargs.setdefault("add_special_tokens", True)
        return self.tokenizer(sequences, **kwargs)

class ProtCLMForEmbedding(nn.Module):
    def __init__(self, model_path: str, dtype: torch.dtype = None):
        super().__init__()
        self.plm = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> torch.Tensor:  
        assert not output_attentions or not output_hidden_states, (
            "output_attentions=True and output_hidden_states=True are not supported by ProtCLMForEmbedding."
        )

        out = self.plm(
            input_ids=input_ids, 
            attention_mask=attention_mask
            )
        return out.last_hidden_state


def get_protCLM_tokenizer(preset: str, model_path: str = None) -> BaseSequenceTokenizer:
    return ProtCLMTokenizerWrapper(
        AutoTokenizer.from_pretrained(model_path or presets[preset], trust_remote_code=True)
    )


def build_protCLM(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs) -> Tuple[AutoModel, BaseSequenceTokenizer]:
    if masked_lm:
        raise ValueError(f"Model {preset} does not support masked language modeling")
    model_path = model_path or presets[preset]
    model = ProtCLMForEmbedding(model_path, dtype=dtype).eval()
    tokenizer = get_protCLM_tokenizer(preset)
    return model, tokenizer


def get_protCLM_for_training(
    preset: str,
    tokenwise: bool = False,
    num_labels: int = None,
    hybrid: bool = False,
    dtype: torch.dtype = None,
    model_path: str = None,
    ):
    model_path = model_path or presets[preset]
    if hybrid:
        model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
    else:
        if tokenwise:
            model = AutoModelForTokenClassification.from_pretrained(
                model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
            ).eval()
        else:
            model = AutoModelForSequenceClassification.from_pretrained(
                model_path, num_labels=num_labels, dtype=dtype, trust_remote_code=True
            ).eval()
    tokenizer = get_protCLM_tokenizer(preset)
    return model, tokenizer


if __name__ == "__main__":
    # py -m src.protify.base_models.protCLM
    model, tokenizer = build_protCLM("ProtCLM-1b")
    print(model)
    print(tokenizer)
    print(tokenizer("MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL"))