| | import torch |
| | import torch.nn as nn |
| | from typing import Optional |
| | from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM |
| |
|
| |
|
| | """ |
| | Custom models are currently supposed to load completely from AutoModel.from_pretrained(path, trust_remote_code=True) |
| | """ |
| |
|
| |
|
| | class CustomModelForEmbedding(nn.Module): |
| | def __init__(self, model_path: str, dtype: torch.dtype = None): |
| | super().__init__() |
| | self.model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True) |
| | if hasattr(self.model, 'tokenizer'): |
| | self.tokenizer = self.model.tokenizer |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = False, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | if output_attentions: |
| | out = self.model(input_ids, attention_mask=attention_mask, output_attentions=output_attentions) |
| | return out.last_hidden_state, out.attentions |
| | else: |
| | return self.model(input_ids, attention_mask=attention_mask).last_hidden_state |
| |
|
| |
|
| | def build_custom_model(model_path: str, masked_lm: bool = False, dtype: torch.dtype = None, **kwargs): |
| | if masked_lm: |
| | model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval() |
| | else: |
| | model = CustomModelForEmbedding(model_path, dtype=dtype).eval() |
| | try: |
| | tokenizer = model.tokenizer |
| | except: |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | return model, tokenizer |
| |
|
| |
|
| | def build_custom_tokenizer(model_path: str, **kwargs): |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | return tokenizer |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | model, tokenizer = build_custom_model('answerdotai/ModernBERT-base') |
| | print(model) |
| | print(tokenizer) |
| | seq = 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL' |
| | encoded = tokenizer.encode(seq) |
| | decoded = tokenizer.decode(encoded) |
| | print(encoded) |
| | print(decoded) |
| |
|