File size: 4,678 Bytes
f440f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450c551
f440f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450c551
f440f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import numpy as np
import torch

from typing import Any
from transformers import AutoTokenizer


def splade_max(features, attention_mask):
    """
    SPLADE pooling operation
    """
    relu = torch.nn.ReLU(inplace=False)
    values, ids_ = torch.max(
        torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
    )
    return values, ids_


def encode(
    self,
    sentences: list[str],
    max_length: int = 1024,
    prompt_type: str = "document",
    return_dict: bool = False,
    print_dict: bool = False,
    batch_size: int = 8,
    top_k_q: int = -1,
    top_k_d: int = -1,
    **kwargs: Any,
) -> np.ndarray:
    all_embeddings = []
    for i in range(0, len(sentences), batch_size):
        batch_texts = sentences[i : i + batch_size]
        batch_dict = self.create_batch_dict(batch_texts, max_length)
        batch_dict = {
            key: value.to(self.model.device) for key, value in batch_dict.items()
        }
        with torch.no_grad():
            splare_reps = self(**batch_dict)[0]
            if prompt_type == "query" and top_k_q > 0:
                splare_reps = top_k(splare_reps, top_k_q)
            if prompt_type == "document" and top_k_d > 0:
                splare_reps = top_k(splare_reps, top_k_d)
            all_embeddings.append(splare_reps.cpu().float().numpy())
    if return_dict:
        d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
        if print_dict:
            print_bow_bars(sentences, d)
        return d
    else:
        return np.concatenate(all_embeddings, axis=0)


def bow_dict(self, embeddings):
    out = []
    for vector in embeddings:
        idx = np.nonzero(vector)[0]
        weights = vector[idx]
        d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
        sorted_d = {
            self.reverse_voc[k]: float(v)
            for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
        }
        out.append(sorted_d)
    return out


def print_bow_bars(sentences, bow_list, width=20):
    ascii_header("TOP ACTIVATED WORDS")
    for sent, bow in zip(sentences, bow_list):
        print(f"* INPUT: {sent}\n")
        max_w = max(bow.values())
        for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
            bar = "█" * int(v / max_w * width)
            print(f"{k[:25]:25} | {bar} {v:.2f}")
        print("\n")


def ascii_header(title, width=70):
    title = f" {title} "
    print("+" + "-" * (width - 2) + "+")
    print("|" + title.center(width - 2) + "|")
    print("+" + "-" * (width - 2) + "+")
    print("\n")


def similarity(self, a, b) -> torch.Tensor:
    """
    MTEB eval requires this
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)
    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    def _dot_score_core(a_tensor, b_tensor):
        if len(a_tensor.shape) == 1:
            a_tensor = a_tensor.unsqueeze(0)
        if len(b_tensor.shape) == 1:
            b_tensor = b_tensor.unsqueeze(0)
        return a_tensor @ b_tensor.transpose(0, 1)

    return _dot_score_core(a, b)


def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
    """
    loads and prepares tokenizer
    """
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    tokenizer.pad_token = (
        tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
    )
    tokenizer.padding_side = padding_side
    return tokenizer


def get_decoder_model(
    model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None
):
    """
    base_cfg is the pretrained config of the underlying model
    """
    print("WARNING: bidirectional only tested for transformer 4.51.2")
    assert (
        bidirectional is True
    ), "the model has been trained with bi-directional attention!"
    assert (
        attn_implementation == "flash_attention_2"
    ), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!"
    from .modeling_qwen3_bidir import Qwen3BidirForCausalLM

    return Qwen3BidirForCausalLM.from_pretrained(
        model_name_or_path,
        config=base_cfg,
        torch_dtype=torch.bfloat16,
        attn_implementation=attn_implementation,
        token=token,
    )


def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
    """
    zeroes out all but the top-k values in the last dimension of x
    """
    _, topk_indices = x.topk(k, dim=-1)
    # create a zero tensor of the same shape as x
    mask = torch.zeros_like(x, dtype=torch.bool)
    # use scatter along the last dimension
    mask.scatter_(-1, topk_indices, True)
    # zero out all but the top-k
    return x * mask