TMT v3: 30.2% PPL reduction at 48% compute — full paper, code & live demo released

#1
by vigneshwar234 - opened

TemporalMesh Transformer v3 is now fully public 🎉

tl;dr — A new transformer architecture that achieves 29.4 PPL on WikiText-2 vs 42.1 baseline (−30.2%) at 48% relative compute, with zero architectural compromises.


What makes TMT different from every other efficient transformer?

Every prior approach fixes one problem:

  • Longformer/BigBird → sparse attention, but static topology
  • Mamba/RWKV → linear time, but no pairwise attention
  • MoE → high capacity, but uniform depth per token

TMT fixes all three simultaneously:

Innovation What it does Cost
Mesh Attention Dynamic $k$NN graph rebuilt per-layer from cosine similarity $O(S \cdot k)$ vs $O(S^2)$
Temporal Decay Encoding Learned multiplicative scalar attenuates semantically distant tokens post-softmax ~0% overhead
Adaptive Depth Routing Per-token exit gate: punctuation exits at layer 2, rare tokens at layer 12 −52% avg compute
Dual-Stream FFN Parallel syntax + semantic streams with sigmoid fusion gate Same FLOPs as standard FFN
EMA Memory Anchors 16 persistent fast-weight vectors, cross-sequence recall without recurrence 32KB extra params

Numbers

Benchmark Vanilla Mamba TMT
WikiText-2 PPL ↓ 42.1 31.8 29.4
WikiText-103 PPL ↓ 51.3 38.4 36.1
LongBench ↑ 41.2 51.3 53.4
C4 PPL ↓ 38.4 30.1 27.4
Throughput (TPS, A100 FP16) 94K 148K 138K
VRAM at S=4096 OOM 12GB 18GB

Resources

Quick start

pip install temporalmesh-transformer
from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
import torch

model = TMTModel(TMTConfig(vocab_size=50257, d_model=512, n_heads=8, n_layers=12))
out = model(torch.randint(0, 50257, (1, 256)))
# out.logits, out.exit_masks, out.graph_edges, out.confidences

Happy to answer questions about the architecture, training setup, or ablations. All results are reproducible with the provided training scripts and 3 fixed seeds.

— Vigneshwar LK

Sign up or log in to comment