zembed-1 / modeling_zembed.py
Dilawar Mahmood
initial commit
d3c9465
import torch
from sentence_transformers.models import Transformer
# pyright: basic
class ZembedTransformer(Transformer):
def tokenize(
self,
texts: list[str] | list[dict] | list[tuple[str, str]],
padding: str | bool = True,
) -> dict[str, torch.Tensor]:
texts = [text + "<|im_end|>\n" for text in texts] # pyright: ignore[reportOperatorIssue]
return self.tokenizer(
texts,
padding=padding,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_seq_length,
)