| import numpy as np |
| import torch |
|
|
| from typing import Any |
| from transformers import AutoTokenizer |
|
|
|
|
| def splade_max(features, attention_mask): |
| """ |
| SPLADE pooling operation |
| """ |
| relu = torch.nn.ReLU(inplace=False) |
| values, ids_ = torch.max( |
| torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1 |
| ) |
| return values, ids_ |
|
|
|
|
| def encode( |
| self, |
| sentences: list[str], |
| max_length: int = 1024, |
| prompt_type: str = "document", |
| return_dict: bool = False, |
| print_dict: bool = False, |
| batch_size: int = 8, |
| top_k_q: int = -1, |
| top_k_d: int = -1, |
| **kwargs: Any, |
| ) -> np.ndarray: |
| all_embeddings = [] |
| for i in range(0, len(sentences), batch_size): |
| batch_texts = sentences[i : i + batch_size] |
| batch_dict = self.create_batch_dict(batch_texts, max_length) |
| batch_dict = { |
| key: value.to(self.model.device) for key, value in batch_dict.items() |
| } |
| with torch.no_grad(): |
| splare_reps = self(**batch_dict)[0] |
| if prompt_type == "query" and top_k_q > 0: |
| splare_reps = top_k(splare_reps, top_k_q) |
| if prompt_type == "document" and top_k_d > 0: |
| splare_reps = top_k(splare_reps, top_k_d) |
| all_embeddings.append(splare_reps.cpu().float().numpy()) |
| if return_dict: |
| d = bow_dict(self, np.concatenate(all_embeddings, axis=0)) |
| if print_dict: |
| print_bow_bars(sentences, d) |
| return d |
| else: |
| return np.concatenate(all_embeddings, axis=0) |
|
|
|
|
| def bow_dict(self, embeddings): |
| out = [] |
| for vector in embeddings: |
| idx = np.nonzero(vector)[0] |
| weights = vector[idx] |
| d = {k: v for k, v in zip(idx.tolist(), weights.tolist())} |
| sorted_d = { |
| self.reverse_voc[k]: float(v) |
| for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True) |
| } |
| out.append(sorted_d) |
| return out |
|
|
|
|
| def print_bow_bars(sentences, bow_list, width=20): |
| ascii_header("TOP ACTIVATED WORDS") |
| for sent, bow in zip(sentences, bow_list): |
| print(f"* INPUT: {sent}\n") |
| max_w = max(bow.values()) |
| for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True): |
| bar = "█" * int(v / max_w * width) |
| print(f"{k[:25]:25} | {bar} {v:.2f}") |
| print("\n") |
|
|
|
|
| def ascii_header(title, width=70): |
| title = f" {title} " |
| print("+" + "-" * (width - 2) + "+") |
| print("|" + title.center(width - 2) + "|") |
| print("+" + "-" * (width - 2) + "+") |
| print("\n") |
|
|
|
|
| def similarity(self, a, b) -> torch.Tensor: |
| """ |
| MTEB eval requires this |
| """ |
| if not isinstance(a, torch.Tensor): |
| a = torch.tensor(a) |
| if not isinstance(b, torch.Tensor): |
| b = torch.tensor(b) |
|
|
| def _dot_score_core(a_tensor, b_tensor): |
| if len(a_tensor.shape) == 1: |
| a_tensor = a_tensor.unsqueeze(0) |
| if len(b_tensor.shape) == 1: |
| b_tensor = b_tensor.unsqueeze(0) |
| return a_tensor @ b_tensor.transpose(0, 1) |
|
|
| return _dot_score_core(a, b) |
|
|
|
|
| def prepare_tokenizer(tokenizer_name: str, padding_side="right"): |
| """ |
| loads and prepares tokenizer |
| """ |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| tokenizer.pad_token = ( |
| tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token |
| ) |
| tokenizer.padding_side = padding_side |
| return tokenizer |
|
|
|
|
| def get_decoder_model( |
| model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None |
| ): |
| """ |
| base_cfg is the pretrained config of the underlying model |
| """ |
| print("WARNING: bidirectional only tested for transformer 4.51.2") |
| assert ( |
| bidirectional is True |
| ), "the model has been trained with bi-directional attention!" |
| assert ( |
| attn_implementation == "flash_attention_2" |
| ), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!" |
| from .modeling_qwen3_bidir import Qwen3BidirForCausalLM |
|
|
| return Qwen3BidirForCausalLM.from_pretrained( |
| model_name_or_path, |
| config=base_cfg, |
| torch_dtype=torch.bfloat16, |
| attn_implementation=attn_implementation, |
| token=token, |
| ) |
|
|
|
|
| def top_k(x: torch.Tensor, k: int) -> torch.Tensor: |
| """ |
| zeroes out all but the top-k values in the last dimension of x |
| """ |
| _, topk_indices = x.topk(k, dim=-1) |
| |
| mask = torch.zeros_like(x, dtype=torch.bool) |
| |
| mask.scatter_(-1, topk_indices, True) |
| |
| return x * mask |