File size: 2,195 Bytes
714cf46 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | import torch
import torch.nn as nn
from typing import Optional
from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM
"""
Custom models are currently supposed to load completely from AutoModel.from_pretrained(path, trust_remote_code=True)
"""
class CustomModelForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.model = AutoModel.from_pretrained(model_path, dtype=dtype, trust_remote_code=True)
if hasattr(self.model, 'tokenizer'):
self.tokenizer = self.model.tokenizer
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
if output_attentions:
out = self.model(input_ids, attention_mask=attention_mask, output_attentions=output_attentions)
return out.last_hidden_state, out.attentions
else:
return self.model(input_ids, attention_mask=attention_mask).last_hidden_state
def build_custom_model(model_path: str, masked_lm: bool = False, dtype: torch.dtype = None, **kwargs):
if masked_lm:
model = AutoModelForMaskedLM.from_pretrained(model_path, dtype=dtype, trust_remote_code=True).eval()
else:
model = CustomModelForEmbedding(model_path, dtype=dtype).eval()
try:
tokenizer = model.tokenizer
except:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
def build_custom_tokenizer(model_path: str, **kwargs):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return tokenizer
if __name__ == "__main__":
# py -m src.protify.base_models.custom_model
model, tokenizer = build_custom_model('answerdotai/ModernBERT-base')
print(model)
print(tokenizer)
seq = 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL'
encoded = tokenizer.encode(seq)
decoded = tokenizer.decode(encoded)
print(encoded)
print(decoded)
|