splade-code-06B / splade.py
maxoul's picture
Update splade.py
1095893 verified
import os
from transformers import (
PretrainedConfig,
PreTrainedModel,
AutoConfig,
)
from huggingface_hub import snapshot_download
from typing import Optional
from transformers.utils import is_flash_attn_2_available
from .utils import (
get_decoder_model,
prepare_tokenizer,
splade_max,
similarity,
encode,
)
from peft import PeftModel
class SpladeConfig(PretrainedConfig):
model_type = "splade"
def __init__(
self,
model_name_or_path: str = "Qwen/Qwen3-0.6B",
attn_implementation: str = "flash_attention_2",
bidirectional: bool = True, # only for decoder models
padding_side: str = "left",
**kwargs,
):
super().__init__(**kwargs)
self.model_name_or_path = model_name_or_path
self.attn_implementation = attn_implementation
self.bidirectional = bidirectional
self.padding_side = padding_side
class Splade(PreTrainedModel):
config_class = SpladeConfig
# methods for MTEB's interface
similarity = similarity
encode = encode
def __init__(self, config, weights_path=None, token=None):
super().__init__(config)
self.name = "splade"
base_cfg = AutoConfig.from_pretrained(
config.model_name_or_path,
attn_implementation=config.attn_implementation,
torch_dtype="auto",
)
self.tokenizer = prepare_tokenizer(
config.model_name_or_path, padding_side=config.padding_side
)
if is_flash_attn_2_available():
config.attn_implementation = "flash_attention_2"
else:
config.attn_implementation = "sdpa"
source = weights_path or config.model_name_or_path
self.model = get_decoder_model(
model_name_or_path=source,
attn_implementation=config.attn_implementation,
bidirectional=getattr(config, "bidirectional", False),
base_cfg=base_cfg,
token=token
)
def save_pretrained(self, save_directory, *args, **kwargs):
self.model.save_pretrained(os.path.join(save_directory, "lora"))
self.config.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, model_name_or_path, *args, **kwargs):
token = kwargs.get("token", None)
config = SpladeConfig.from_pretrained(
model_name_or_path,
token=token,
)
model = cls(config, weights_path=model_name_or_path, token=token)
model.reverse_voc = {v: k for k, v in model.tokenizer.vocab.items()}
return model
def forward(self, **tokens):
output = self.model(**tokens)
splade_reps, _ = splade_max(output.logits, tokens["attention_mask"])
return (splade_reps,)
def get_width(self):
return self.model.config.vocab_size
def create_batch_dict(self, input_texts, max_length):
return self.tokenizer(
input_texts,
add_special_tokens=True,
padding="longest",
truncation=True,
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
)