File size: 4,678 Bytes
f440f7a 450c551 f440f7a 450c551 f440f7a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | import numpy as np
import torch
from typing import Any
from transformers import AutoTokenizer
def splade_max(features, attention_mask):
"""
SPLADE pooling operation
"""
relu = torch.nn.ReLU(inplace=False)
values, ids_ = torch.max(
torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
)
return values, ids_
def encode(
self,
sentences: list[str],
max_length: int = 1024,
prompt_type: str = "document",
return_dict: bool = False,
print_dict: bool = False,
batch_size: int = 8,
top_k_q: int = -1,
top_k_d: int = -1,
**kwargs: Any,
) -> np.ndarray:
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_texts = sentences[i : i + batch_size]
batch_dict = self.create_batch_dict(batch_texts, max_length)
batch_dict = {
key: value.to(self.model.device) for key, value in batch_dict.items()
}
with torch.no_grad():
splare_reps = self(**batch_dict)[0]
if prompt_type == "query" and top_k_q > 0:
splare_reps = top_k(splare_reps, top_k_q)
if prompt_type == "document" and top_k_d > 0:
splare_reps = top_k(splare_reps, top_k_d)
all_embeddings.append(splare_reps.cpu().float().numpy())
if return_dict:
d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
if print_dict:
print_bow_bars(sentences, d)
return d
else:
return np.concatenate(all_embeddings, axis=0)
def bow_dict(self, embeddings):
out = []
for vector in embeddings:
idx = np.nonzero(vector)[0]
weights = vector[idx]
d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
sorted_d = {
self.reverse_voc[k]: float(v)
for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
}
out.append(sorted_d)
return out
def print_bow_bars(sentences, bow_list, width=20):
ascii_header("TOP ACTIVATED WORDS")
for sent, bow in zip(sentences, bow_list):
print(f"* INPUT: {sent}\n")
max_w = max(bow.values())
for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
bar = "█" * int(v / max_w * width)
print(f"{k[:25]:25} | {bar} {v:.2f}")
print("\n")
def ascii_header(title, width=70):
title = f" {title} "
print("+" + "-" * (width - 2) + "+")
print("|" + title.center(width - 2) + "|")
print("+" + "-" * (width - 2) + "+")
print("\n")
def similarity(self, a, b) -> torch.Tensor:
"""
MTEB eval requires this
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
def _dot_score_core(a_tensor, b_tensor):
if len(a_tensor.shape) == 1:
a_tensor = a_tensor.unsqueeze(0)
if len(b_tensor.shape) == 1:
b_tensor = b_tensor.unsqueeze(0)
return a_tensor @ b_tensor.transpose(0, 1)
return _dot_score_core(a, b)
def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
"""
loads and prepares tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = (
tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
)
tokenizer.padding_side = padding_side
return tokenizer
def get_decoder_model(
model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None
):
"""
base_cfg is the pretrained config of the underlying model
"""
print("WARNING: bidirectional only tested for transformer 4.51.2")
assert (
bidirectional is True
), "the model has been trained with bi-directional attention!"
assert (
attn_implementation == "flash_attention_2"
), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!"
from .modeling_qwen3_bidir import Qwen3BidirForCausalLM
return Qwen3BidirForCausalLM.from_pretrained(
model_name_or_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
token=token,
)
def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
"""
zeroes out all but the top-k values in the last dimension of x
"""
_, topk_indices = x.topk(k, dim=-1)
# create a zero tensor of the same shape as x
mask = torch.zeros_like(x, dtype=torch.bool)
# use scatter along the last dimension
mask.scatter_(-1, topk_indices, True)
# zero out all but the top-k
return x * mask |