| import torch |
| import torch.nn as nn |
| from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
| DEVICE_TYPE = "cuda" |
|
|
|
|
| class TransformersTextEncoderBase(nn.Module): |
| def __init__(self, model_name: str, embed_dim: int): |
| super().__init__() |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModel.from_pretrained(model_name) |
| self.proj = nn.Linear(self.model.config.hidden_size, embed_dim) |
|
|
| def forward( |
| self, |
| text: list[str], |
| ): |
| output, mask = self.encode(text) |
| output = self.projection(output) |
| return {"output": output, "mask": mask} |
|
|
| def encode(self, text: list[str]): |
| device = self.model.device |
| batch = self.tokenizer( |
| text, |
| max_length=self.tokenizer.model_max_length, |
| padding=True, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| input_ids = batch.input_ids.to(device) |
| attention_mask = batch.attention_mask.to(device) |
| output: BaseModelOutput = self.model( |
| input_ids=input_ids, attention_mask=attention_mask |
| ) |
| output = output.last_hidden_state |
| mask = (attention_mask == 1).to(device) |
| return output, mask |
|
|
| def projection(self, x): |
| return self.proj(x) |
|
|
|
|
| class T5TextEncoder(TransformersTextEncoderBase): |
| def __init__( |
| self, embed_dim: int, model_name: str = "google/flan-t5-large" |
| ): |
| nn.Module.__init__(self) |
| self.tokenizer = T5Tokenizer.from_pretrained(model_name) |
| self.model = T5EncoderModel.from_pretrained(model_name) |
| for param in self.model.parameters(): |
| param.requires_grad = False |
| self.model.eval() |
| self.proj = nn.Linear(self.model.config.hidden_size, embed_dim) |
|
|
| def encode( |
| self, |
| text: list[str], |
| ): |
| with torch.no_grad(), torch.amp.autocast( |
| device_type=DEVICE_TYPE, enabled=False |
| ): |
| return super().encode(text) |
|
|
|
|
| if __name__ == "__main__": |
| text_encoder = T5TextEncoder(embed_dim=512) |
| text = ["a man is speaking", "a woman is singing while a dog is barking"] |
|
|
| output = text_encoder(text) |
| print(output) |
| print(output['output'].shape) |
| print(output['mask'].shape) |
|
|