File size: 593 Bytes
d3c9465 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | import torch
from sentence_transformers.models import Transformer
# pyright: basic
class ZembedTransformer(Transformer):
def tokenize(
self,
texts: list[str] | list[dict] | list[tuple[str, str]],
padding: str | bool = True,
) -> dict[str, torch.Tensor]:
texts = [text + "<|im_end|>\n" for text in texts] # pyright: ignore[reportOperatorIssue]
return self.tokenizer(
texts,
padding=padding,
truncation="longest_first",
return_tensors="pt",
max_length=self.max_seq_length,
)
|