File size: 3,393 Bytes
5afcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6518298
5afcf9e
 
 
 
 
 
 
ad75b95
5afcf9e
 
ad75b95
5afcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
789894d
 
 
 
 
 
ad75b95
 
 
789894d
5afcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    AutoConfig,
)
from huggingface_hub import snapshot_download
from typing import Optional
from transformers.utils import is_flash_attn_2_available
from .utils import (
    get_decoder_model,
    prepare_tokenizer,
    splade_max,
    similarity,
    encode,
)
from peft import PeftModel


class SpladeConfig(PretrainedConfig):
    model_type = "splade"

    def __init__(
        self,
        model_name_or_path: str = "Qwen/Qwen3-8B",
        attn_implementation: str = "flash_attention_2",
        bidirectional: bool = True,  # only for decoder models
        padding_side: str = "left",
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.model_name_or_path = model_name_or_path
        self.attn_implementation = attn_implementation
        self.bidirectional = bidirectional
        self.padding_side = padding_side


class Splade(PreTrainedModel):
    config_class = SpladeConfig

    # methods for MTEB's interface
    similarity = similarity
    encode = encode

    def __init__(self, config):
        super().__init__(config)
        self.name = "splade"
        base_cfg = AutoConfig.from_pretrained(
            config.model_name_or_path,
            attn_implementation=config.attn_implementation,
            torch_dtype="auto",
        )
        self.tokenizer = prepare_tokenizer(
            config.model_name_or_path, padding_side=config.padding_side
        )
        if is_flash_attn_2_available():
            config.attn_implementation = "flash_attention_2"
        else:
            config.attn_implementation = "sdpa"
        self.model = get_decoder_model(
            model_name_or_path=config.model_name_or_path,
            attn_implementation=config.attn_implementation,
            bidirectional=getattr(config, "bidirectional", False),
            base_cfg=base_cfg,
        )

    def save_pretrained(self, save_directory, *args, **kwargs):
        self.model.save_pretrained(os.path.join(save_directory, "lora"))
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, model_name_or_path, *args, **kwargs):
        config = SpladeConfig.from_pretrained(model_name_or_path)
        model = cls(config)
        model.model = PeftModel.from_pretrained(
            model.model,
            model_name_or_path,
            subfolder="lora",
            token=kwargs.get("token", None),
        )
        # local_dir = snapshot_download(model_name_or_path)
        # adapter_path = os.path.join(local_dir, "lora")
        # model.model.load_adapter(adapter_path)
        # model.model = PeftModel.from_pretrained(model.model, adapter_path)
        model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
        return model

    def forward(self, **tokens):
        output = self.model(**tokens)
        splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
        return (splade_reps,)

    def get_width(self):
        return self.model.config.vocab_size

    def create_batch_dict(self, input_texts, max_length):
        return self.tokenizer(
            input_texts,
            add_special_tokens=True,
            padding="longest",
            truncation=True,
            max_length=max_length,
            return_attention_mask=True,
            return_tensors="pt",
        )