maxoul commited on
Commit
eeb94eb
·
verified ·
1 Parent(s): 1231748

Create splade.py

Browse files
Files changed (1) hide show
  1. splade.py +105 -0
splade.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import (
3
+ PretrainedConfig,
4
+ PreTrainedModel,
5
+ AutoConfig,
6
+ )
7
+ from huggingface_hub import snapshot_download
8
+ from typing import Optional
9
+ from transformers.utils import is_flash_attn_2_available
10
+ from .utils import (
11
+ get_decoder_model,
12
+ prepare_tokenizer,
13
+ splade_max,
14
+ similarity,
15
+ encode,
16
+ )
17
+ from peft import PeftModel
18
+
19
+
20
+ class SpladeConfig(PretrainedConfig):
21
+ model_type = "splade"
22
+
23
+ def __init__(
24
+ self,
25
+ model_name_or_path: str = "meta-llama/Llama-3.1-8B",
26
+ attn_implementation: str = "flash_attention_2",
27
+ bidirectional: bool = True, # only for decoder models
28
+ padding_side: str = "right",
29
+ **kwargs,
30
+ ):
31
+ super().__init__(**kwargs)
32
+ self.model_name_or_path = model_name_or_path
33
+ self.attn_implementation = attn_implementation
34
+ self.bidirectional = bidirectional
35
+ self.padding_side = padding_side
36
+
37
+
38
+ class Splade(PreTrainedModel):
39
+ config_class = SpladeConfig
40
+
41
+ # methods for MTEB's interface
42
+ similarity = similarity
43
+ encode = encode
44
+
45
+ def __init__(self, config):
46
+ super().__init__(config)
47
+ self.name = "splade"
48
+ base_cfg = AutoConfig.from_pretrained(
49
+ config.model_name_or_path,
50
+ attn_implementation=config.attn_implementation,
51
+ torch_dtype="auto",
52
+ )
53
+ self.tokenizer = prepare_tokenizer(
54
+ config.model_name_or_path, padding_side=config.padding_side
55
+ )
56
+ if is_flash_attn_2_available():
57
+ config.attn_implementation = "flash_attention_2"
58
+ else:
59
+ config.attn_implementation = "sdpa"
60
+ self.model = get_decoder_model(
61
+ model_name_or_path=config.model_name_or_path,
62
+ attn_implementation=config.attn_implementation,
63
+ bidirectional=getattr(config, "bidirectional", False),
64
+ base_cfg=base_cfg,
65
+ )
66
+
67
+ def save_pretrained(self, save_directory, *args, **kwargs):
68
+ self.model.save_pretrained(os.path.join(save_directory, "lora"))
69
+ self.config.save_pretrained(save_directory)
70
+
71
+ @classmethod
72
+ def from_pretrained(cls, model_name_or_path, *args, **kwargs):
73
+ config = SpladeConfig.from_pretrained(model_name_or_path)
74
+ model = cls(config)
75
+ # local_dir = snapshot_download(model_name_or_path)
76
+ # adapter_path = os.path.join(local_dir, "lora")
77
+ # model.model.load_adapter(adapter_path)
78
+ model.model = PeftModel.from_pretrained(
79
+ model.model,
80
+ model_name_or_path,
81
+ subfolder="lora",
82
+ token=kwargs.get("token", None),
83
+ )
84
+ # model.model = PeftModel.from_pretrained(model.model, adapter_path)
85
+ model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
86
+ return model
87
+
88
+ def forward(self, **tokens):
89
+ output = self.model(**tokens)
90
+ splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
91
+ return (splade_reps,)
92
+
93
+ def get_width(self):
94
+ return self.model.config.vocab_size
95
+
96
+ def create_batch_dict(self, input_texts, max_length):
97
+ return self.tokenizer(
98
+ input_texts,
99
+ add_special_tokens=True,
100
+ padding="longest",
101
+ truncation=True,
102
+ max_length=max_length,
103
+ return_attention_mask=True,
104
+ return_tensors="pt",
105
+ )