File size: 2,195 Bytes
714cf46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
from typing import Optional
from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM


"""
Custom models are currently supposed to load completely from AutoModel.from_pretrained(path, trust_remote_code=True)
"""


class CustomModelForEmbedding(nn.Module):
    def __init__(self, model_path: str, dtype: torch.dtype = None):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)
        if hasattr(self.model, 'tokenizer'):
            self.tokenizer = self.model.tokenizer

    def forward(
            self,
            input_ids: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = False,
            **kwargs,
    ) -> torch.Tensor:
        if output_attentions:
            out = self.model(input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
            return out.last_hidden_state, out.attentions
        else:
            return self.model(input_ids, attention_mask=attention_mask).last_hidden_state


def build_custom_model(model_path: str, masked_lm: bool = False, dtype: torch.dtype = None, **kwargs):
    if masked_lm:
        model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
    else:
        model = CustomModelForEmbedding(model_path, dtype=dtype).eval()
    try:
        tokenizer = model.tokenizer
    except:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return model, tokenizer


def build_custom_tokenizer(model_path: str, **kwargs):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    return tokenizer


if __name__ == "__main__":
    # py -m src.protify.base_models.custom_model
    model, tokenizer = build_custom_model('answerdotai/ModernBERT-base')
    print(model)
    print(tokenizer)
    seq = 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'
    encoded = tokenizer.encode(seq)
    decoded = tokenizer.decode(encoded)
    print(encoded)
    print(decoded)