nikraf's picture
Upload folder using huggingface_hub
714cf46 verified
"""
We use the FastPLM implementation of E1.
"""
import sys
import os
import torch
import torch.nn as nn
from typing import Optional, Union, List, Dict, Tuple
_FASTPLMS = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'FastPLMs')
if _FASTPLMS not in sys.path:
sys.path.insert(0, _FASTPLMS)
from e1_fastplms.modeling_e1 import (
E1Model,
E1ForMaskedLM,
E1ForSequenceClassification,
E1ForTokenClassification,
)
from .base_tokenizer import BaseSequenceTokenizer
from .e1_utils import E1BatchPreparer
presets = {
'E1-150': 'Synthyra/Profluent-E1-150M',
'E1-300': 'Synthyra/Profluent-E1-300M',
'E1-600': 'Synthyra/Profluent-E1-600M',
}
class E1TokenizerWrapper(BaseSequenceTokenizer):
def __init__(self, tokenizer: E1BatchPreparer):
super().__init__(tokenizer)
def __call__(self, sequences: Union[str, List[str]], **kwargs) -> Dict[str, torch.Tensor]:
if isinstance(sequences, str):
sequences = [sequences]
tokenized = self.tokenizer.get_batch_kwargs(sequences)
return tokenized
class E1ForEmbedding(nn.Module):
def __init__(self, model_path: str, dtype: torch.dtype = None):
super().__init__()
self.e1 = E1Model.from_pretrained(model_path, dtype=dtype)
def forward(
self,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, ...]]]:
if output_attentions:
out = self.e1(**kwargs, output_attentions=output_attentions)
return out.last_hidden_state, out.attentions
else:
return self.e1(**kwargs, output_hidden_states=False, output_attentions=False).last_hidden_state
def get_e1_tokenizer(preset: str, model_path: str = None):
tokenizer = E1BatchPreparer()
return E1TokenizerWrapper(tokenizer)
def build_e1_model(preset: str, masked_lm: bool = False, dtype: torch.dtype = None, model_path: str = None, **kwargs):
model_path = model_path or presets[preset]
if masked_lm:
model = E1ForMaskedLM.from_pretrained(model_path, dtype=dtype).eval()
else:
model = E1ForEmbedding(model_path, dtype=dtype).eval()
tokenizer = get_e1_tokenizer(preset)
return model, tokenizer
def get_e1_for_training(preset: str, tokenwise: bool = False, num_labels: int = None, hybrid: bool = False, dtype: torch.dtype = None, model_path: str = None):
model_path = model_path or presets[preset]
if hybrid:
model = E1Model.from_pretrained(model_path, dtype=dtype).eval()
else:
if tokenwise:
model = E1ForTokenClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
else:
model = E1ForSequenceClassification.from_pretrained(model_path, num_labels=num_labels, dtype=dtype).eval()
tokenizer = get_e1_tokenizer(preset)
return model, tokenizer
if __name__ == '__main__':
# py -m base_models.e1
model, tokenizer = build_e1_model('E1-150')
print(model)
print(tokenizer)
print(tokenizer(['MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL', 'MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICBBOLLICIIVMLL']))