Tom Aarsen commited on
Commit
ebcd7f4
·
1 Parent(s): 222c67d

Patch loading SparseEncoder from Hub

Browse files
Files changed (2) hide show
  1. modules.json +1 -1
  2. splade.py +17 -1
modules.json CHANGED
@@ -3,7 +3,7 @@
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
- "type": "sentence_transformers.sparse_encoder.models.MLMTransformer"
7
  },
8
  {
9
  "idx": 1,
 
3
  "idx": 0,
4
  "name": "0",
5
  "path": "",
6
+ "type": "splade.SpladeCodeMLMTransformer"
7
  },
8
  {
9
  "idx": 1,
splade.py CHANGED
@@ -3,7 +3,7 @@ Compared to standard Qwen3, we're using bidirectional attention and not causal a
3
  with `is_causal=False` in the config.
4
 
5
  This file supports two loading paths:
6
- 1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
  The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
@@ -166,3 +166,19 @@ class Splade(PreTrainedModel):
166
 
167
 
168
  __all__ = ["Qwen3ForCausalLM", "Splade"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  with `is_causal=False` in the config.
4
 
5
  This file supports two loading paths:
6
+ 1. Sentence Transformers: `SparseEncoder("naver/splade-code-8B", trust_remote_code=True)` via SpladeCodeMLMTransformer -> AutoModelForMaskedLM -> Qwen3ForCausalLM
7
  2. Transformers: `AutoModelForCausalLM.from_pretrained("naver/splade-code-8B", trust_remote_code=True)` -> Splade
8
 
9
  The checkpoint is distributed as a LoRA adapter on top of Qwen/Qwen3-8B; `Qwen3ForCausalLM.from_pretrained`
 
166
 
167
 
168
  __all__ = ["Qwen3ForCausalLM", "Splade"]
169
+
170
+
171
+ # Override ST's `_load_config` to return our `Qwen3Config` (with `auto_map`)
172
+ # instead of a `PeftConfig`, so hub-path loads route to `splade.Qwen3ForCausalLM`
173
+ # instead of failing in `AutoModelForMaskedLM`. The LoRA is still applied by
174
+ # transformers' built-in PEFT path.
175
+ try:
176
+ from sentence_transformers.sparse_encoder.models import MLMTransformer
177
+
178
+ class SpladeCodeMLMTransformer(MLMTransformer):
179
+ def _load_config(self, model_name_or_path, backend, config_kwargs):
180
+ return AutoConfig.from_pretrained(model_name_or_path, **config_kwargs), False
181
+
182
+ __all__.append("SpladeCodeMLMTransformer")
183
+ except ImportError:
184
+ pass