Support Sentence Transformers via SparseEncoder
Hello!
Congratulations on the release! I believed it was possible to both simplify and integrate with Sentence Transformers, so here is my attempt.
Pull Request overview
- Simplify the bidirectional Qwen3 implementation heavily by relying on https://github.com/huggingface/transformers/pull/43705 (requires transformers v5.2.0+)
- Support SparseEncoder from Sentence Transformers, matches original implementation outputs
Details
The https://github.com/huggingface/transformers/pull/43705 pull request in transformers starts supporting the is_causal parameter on all architectures, meaning that turning a causal model bidirectional becomes as simple as setting "is_causal": falsein the config.json. This means that we don't even need the custom modeling code. However, sadly trust_remote_code=Trueis still required because we're loading a CausalLM model using AutoModelForMaskedLM. The implementation can be very simple now, e.g. with:
"""
This file exists solely to allow loading the Qwen3ForCausalLM via the AutoModelForMaskedLM class.
Compared to standard Qwen3, we're using bidirectional attention and not causal attention, but it's specified
with `is_causal=False` in the config.
"""
from transformers import Qwen3ForCausalLM
__all__ = ["Qwen3ForCausalLM"]
Although you can also extend this again to re-add the encode etc. methods.
Instead, I kept it simple for now and just used a SparseEncoder implementation instead. That just required some configuration files. I also copied the Qwen3-0.6B config.json as that's just loaded directly now instead of loading some custom code which itself loads a Qwen3 model. I also copied the Qwen3 tokenizer files.
The SparseEncoder usage is now:
from sentence_transformers import SparseEncoder
model = SparseEncoder("naver/splade-code-06B", trust_remote_code=True, revision="refs/pr/1")
queries = [
"SELECT *\nFROM Student\nWHERE Age = (\nSELECT MAX(Age)\nFROM Student\nWHERE Group = 'specific_group'\n)\nAND Group = 'specific_group';"
]
query_embeddings = model.encode(queries)
print(query_embeddings.shape)
# torch.Size([1, 151936])
sparsity = model.sparsity(query_embeddings)
print(sparsity)
# {'active_dims': 1231.0, 'sparsity_ratio': 0.991897904380792}
decoded = model.decode(query_embeddings, top_k=10)
print(decoded)
# [[
# ("Ġgroup", 2.34375),
# ("Ġage", 2.34375),
# ("ĠAge", 2.34375),
# ("ĠStudent", 2.296875),
# ("Ġspecific", 2.296875),
# ("_group", 2.296875),
# ("ĠMax", 2.21875),
# ("Ġmax", 2.21875),
# ("Ġstudent", 2.203125),
# ("ĠGroup", 2.1875),
# ]]
And it works with transformers>5.2.0 and sentence-transformers>5.0.0. If you install kernels, you can set model_kwargs={"attn_implementation": "flash_attention_2"} and it will use a kernel from the Hub without having to actually install flash-attn (which is always annoying).
It also works with e.g. sdpa or eager, unlike the current implementation, and is likely to keep working with future transformers versions as it imports very little from transformers.
Note that the above script has revision="refs/pr/1" so you can test it directly from this PR branch without having to check anything out locally or merge it.
I believe the SparseEncoder is also integrated in MTEB, so you can evaluate this easily on MTEB for submission. In Sentence Transformers v5.4.0 I suspect that you'll also be able to avoid the trust_remote_code=True, but I'm not 100% sure on that yet.
Normally my integration PRs aren't this big, and the original code generally still works (but now it doesn't), so I totally understand if you're not as interested. Let me know your thoughts! Perhaps there's a nice middle-ground where both Sentence Transformers and regular transformers still works. I can also copy these changes to the 8B model if you're interested.
- Tom Aarsen
Hi Tom,
Thanks a lot for this code and the initial integration.
I have tried to run it on a fresh env with: python=3.10.16 torch 2.6.0+cu124 transformers 5.3.0 sentence-transformers 5.3.0, but I still got a lot of errors.
e.g. for eager or sdpa attention:NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.
and with kernel for flash att 2 I got: ValueError: An error occurred while trying to load from 'kernels-community/flash-attn2': Cannot install kernel from repo kernels-community/flash-attn2 (revision: main).
Any idea about these? what torch and python version did you use?
Yes, I agree that the best option would be to have both the transformers and the Sentence Transformer code working.
- Simon
Hello Simon,
Those are totally acceptable versions, also versions that we'd like users to be able to use. I'll try to do some digging for the eager/sdpa route.
For the flash-attention-2 route: the ValueError: An error occurred while trying to load from 'kernels-community/flash-attn2': Cannot install kernel from repo kernels-community/flash-attn2 (revision: main)error is a "follow-up error" of the real error, which will be listed above. Most likely, there isn't a pre-built Flash Attention 2 wheel for your combination of operating system, Torch version, and CUDA version. The current builds are here: https://huggingface.co/kernels-community/flash-attn2/tree/main/build
However, eager, sdpa, and "regular" flash attention should still work and not encounter that Cannot copy out of meta tensor error. I'll try to debug.
- Tom Aarsen
I believe there may have been an issue with torch 2.6.0 and the meta device (e.g. https://github.com/pytorch/pytorch/issues/153330), can you perhaps try torch 2.7.0 or another non-2.6.0 version?
I updated to torch 2.7.0+cu128, the env is working, but I still got the same meta tensor error. I tried to specify device=cuda or device=cpu, but doesn't help.