toolcalling-sae

TopK Sparse Autoencoder checkpoints from To Call or Not to Call: Diagnosing Intrinsic Over-Calling Bias in LLM Agents.

Checkpoints

Model Layer Dict Size k Stage 1 Stage 2
gemma-3-1b-it L17 9 216 128 50M tokens 5M tokens
gemma-3-4b-it L29 20 480 128 50M tokens 5M tokens
gemma-4-E2B-it L30 12 288 128 50M tokens 5M tokens
gemma-4-E4B-it L30 20 480 128 50M tokens 5M tokens
Ministral-3-3B-Instruct-2512 L21 24 576 128 50M tokens 5M tokens
Ministral-3-8B-Instruct-2512 L31 32 768 128 50M tokens 5M tokens
Qwen3.5-4B L25 20 480 128 50M tokens 5M tokens
Qwen3.5-9B L25 32 768 128 50M tokens 5M tokens

Stage 1: Pre-trained on OpenWebText2.
Stage 2: Fine-tuned on tool-calling activations from the When2Call benchmark.
All checkpoints use bfloat16 precision.

Usage

from huggingface_hub import hf_hub_download
from sae_model import TopKSAE

ckpt_path = hf_hub_download(
    repo_id="SKwra/toolcalling-sae",
    filename="gemma-3-1b-it/stage2/gemma-3-1b-it-L17-d9216-5M-stage2.pt"
)
sae = TopKSAE.load(ckpt_path, device="cuda")

sae_model.py is included in this repo. Full code at GitHub.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for SKwra/toolcalling-sae