--- library_name: transformers license: apache-2.0 language: - en tags: - monoid - causal-lm - linear-attention - state-space - O(1)-inference - vector-decay - reasoning pipeline_tag: text-generation model-index: - name: Spartacus-1B-Instruct results: [] --- # Spartacus-1B-Instruct — Causal Monoid Language Model A 1.3B parameter language model that replaces softmax attention with **causal monoid state compression**, achieving **O(1) time per token** and **O(1) memory** at inference — regardless of sequence length. ## SFT Training Curves | Loss | Accuracy | |:---:|:---:| | ![SFT Loss](LOSS_SPAR.png) | ![SFT Accuracy](ACC_SPAR.png) | ## Core Mechanism ![Core Mechanism: The Monoid Recurrence](ARCH.png) ## Key Properties | Property | Transformer (Llama) | Spartacus (Monoid) | |---|---|---| | Inference time per token | O(T) — scans full KV-cache | **O(1)** — single state update | | Inference memory per layer | O(T) — stores all past K,V | **O(1)** — fixed d×d state matrix | | Sequence length extrapolation | Degrades beyond training length | **Unlimited** — state size is constant | | Causality | Imposed via attention mask | **Built into the recurrence** | | Training complexity | O(T²) | **O(T)** via parallel prefix scan | ## The Monoid Recurrence Standard attention computes: ``` o_t = Σ_{i≤t} softmax(q_t · k_i) v_i — requires O(T) KV-cache ``` Monoid attention compresses the entire causal history into a **fixed-size state matrix** S_t per head: ``` S_t = diag(α_t) · S_{t-1} + k_t ⊗ v_t — vector decay monoid recurrence o_t = q_t · S_t — state readout ``` This is a monoid because the binary operator `(α, S) ⊕ (β, X) = (α·β, diag(β)·S + X)` is **associative**, enabling O(T) parallel prefix scan for training and O(1) sequential update for inference. ## Vector Decay — Per-Dimension Memory Lifetimes Unlike scalar decay (one α per head), Spartacus uses **vector decay**: each dimension of the d-vector has its own independent decay rate α_t[i] ∈ (0, 1): ``` S_t[i,j] = α_t[i] · S_{t-1}[i,j] + k_t[i] · v_t[j] ``` This allows different feature dimensions to specialize: - **Fast-decaying dimensions** (α ≈ 0) — local syntax, punctuation, function words - **Slow-decaying dimensions** (α ≈ 1) — entity memory, topic tracking, long-range facts The decay gate uses **Sigmoid** activation: ``` α_t = σ(W·x_t + b) ``` | Property | Value | |---|---| | Range | α ∈ (0, 1) — bounded, no explosion | | Perfect memory | W·x → +∞ ⟹ σ → 1 (lossless retention) | | Full forgetting | W·x → -∞ ⟹ σ → 0 (complete reset) | | Stability | α < 1 by construction — no divergence regardless of input magnitude | | Bias init | b = 3.0 ⟹ σ(3) ≈ 0.95, model starts in "mostly remember" mode | ## Attention Mask — Padding-Aware Recurrence The monoid recurrence correctly handles `attention_mask` for padded batches (e.g., left-padding during `generate()`). For PAD positions (mask=0): ``` α = α * mask + (1 - mask) → α = 1 (preserve state unchanged) k = k * mask, v = v * mask → kv = 0 (no information injected) ``` Net effect: `S_t = 1·S_{t-1} + 0 = S_{t-1}` — PAD acts as the **monoid identity element**, completely invisible to the recurrence. This ensures identical outputs whether inputs are padded or not. ## Design Choices - **SiLU-activated keys**: `k = SiLU(k_proj(x))` ensures non-negative keys, making the state matrix S positive semi-definite (PSD). This prevents "feature erasure" where one token's contribution cancels another's - **QK-Norm**: RMSNorm on both q and k before readout, stabilizing the scale of q·S when the state matrix accumulates many outer products - **Output Norm**: RMSNorm on the readout o after `q·S`, further stabilizing scale before gating - **Output Gate**: `gate = SiLU(gate_proj(x))`, modulates the multi-head readout before o_proj (similar to GLA/RetNet). Lets the model suppress or amplify specific head outputs conditioned on the current input - **Sigmoid decay gate**: Ensures α ∈ (0, 1) by construction — allows near-perfect memory (α→1) while preventing state explosion (α>1). Bias initialized to 3.0 so σ(3)≈0.95, starting in high-retention mode - **Learnable h0**: The initial state S₀ = h0 is a learnable parameter (zero-initialized), acting as a compressed "system prompt" - **Log-space decay in scan**: The parallel prefix scan works in log-space `log(α)` to avoid numerical underflow when computing cumulative products over long sequences ## Three Forward Paths | Path | Condition | Complexity | Description | |---|---|---|---| | Training | `use_cache=False` | O(T) parallel scan | Vectorized outer products → parallel prefix scan → vectorized readout | | Inference prefill | `use_cache=True, T>1` | O(T) parallel scan | Same as training + extracts final state S_T for cache | | Inference decode | `use_cache=True, T=1` | **O(1)** monoid_op | Single `monoid_op` to fold new token into state → one matmul readout | ## Model Details | Parameter | Value | |---|---| | Model | `NoesisLab/Spartacus-1B-Instruct` | | Architecture | MonoidForCausalLM | | Parameters | ~1.34B (tied embeddings) | | Hidden size | 2048 | | Intermediate size (MLP) | 8192 | | Layers | 16 | | Attention heads | 32 | | Head dimension | 64 | | Decay gate | Vector decay (Sigmoid), d=64 per head | | State matrix per head | 64 × 64 = 4,096 floats | | Vocabulary | 128,256 (Llama-3.2 tokenizer) | | Precision | bfloat16 | ## Benchmarks (0-shot) | Task | Metric | Value | Stderr | |---|---|---|---| | ARC-Challenge | acc_norm | 0.3063 | ±0.0135 | | ARC-Easy | acc | 0.5518 | ±0.0102 | | HellaSwag | acc_norm | 0.4610 | ±0.0050 | | PIQA | acc_norm | 0.6915 | ±0.0108 | | WinoGrande | acc | 0.5225 | ±0.0140 | ### Comparison with ~1B Baselines (acc_norm, 0-shot) | Task | Spartacus-1B | TinyLlama-1.1B | Llama 3.2-1B | Mamba-1.4B | RWKV-6-1.6B | |---|---|---|---|---|---| | ARC-C | **0.3063** | 0.3268 | ~0.359 | 0.284 | ~0.301 | | ARC-E | **0.5518** | 0.5547 | ~0.752 | 0.512 | ~0.530 | | HellaSwag | **0.4610** | 0.4670 | ~0.546 | 0.435 | ~0.450 | | PIQA | **0.6915** | 0.7210 | ~0.740 | 0.655 | ~0.670 | | WinoGrande | **0.5225** | 0.5040 | ~0.592 | 0.510 | ~0.515 | > Spartacus achieves competitive performance with sub-quadratic models (Mamba, RWKV) while maintaining **O(1) inference time and memory per token**. Scores marked with ~ are approximate community-reported values. ## Parallel Scan Implementation The `monoid_scan_cuda.py` module provides a Triton JIT-compiled parallel prefix scan for the vector-decay monoid: - **Grid**: `(B*H*D_k, ceil(D_v/BLOCK_DV))` — one program per state matrix row - **Forward**: Sequential scan along T per row, parallelized across all (batch, head, d_k) dimensions - **Backward**: Reverse-order adjoint scan with per-row D_v reduction (minimal atomic_add) - **Fallback**: Pure PyTorch sequential scan for CPU/MPS - **Auto-dispatch**: CUDA → Triton kernel, otherwise → PyTorch fallback ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained( "NoesisLab/Spartacus-1B-Instruct", trust_remote_code=True, torch_dtype="bfloat16", device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained("NoesisLab/Spartacus-1B-Instruct") messages = [{"role": "user", "content": "Hello!"}] text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) outputs = model.generate(**inputs, max_new_tokens=512) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## File Structure ``` MonoidForCausalLM.py # Model architecture (MonoidConfig, MonoidAttention, MonoidForCausalLM) monoid_scan_cuda.py # Triton JIT parallel prefix scan (vector decay) + PyTorch fallback model.safetensors # Model weights (bfloat16) config.json # Model configuration tokenizer.json # Llama-3.2 tokenizer ARCH.png # Core mechanism diagram (monoid recurrence + parallel scan) ACC_SPAR.png # SFT accuracy curve LOSS_SPAR.png # SFT loss curve ``` ## Citation ```bibtex @software{spartacus2025, title={Spartacus: Causal Monoid Language Model with O(1) Inference}, author={NoesisLab}, year={2025}, url={https://huggingface.co/NoesisLab/Spartacus-1B-Instruct}, description={Replaces softmax attention with vector-decay monoid state compression for constant-time, constant-memory autoregressive generation} } ``` ## License Apache 2.0