TemporalMesh-Transformer / github_README.md
vigneshwar234's picture
Add README.md
cfa1324 verified

Typing SVG


CI Tests Python PyTorch License: MIT Stars


DOI Zenodo HuggingFace Live Demo Open in Colab GitHub Pages


The Difference

Every transformer since 2017 makes the same 3 assumptions. TMT breaks all three.

Old Assumption How TMT Breaks It
The sequence is a flat list Dynamic mesh graph β€” token connectivity rebuilt every layer via cosine similarity
All tokens use the same compute Adaptive depth routing β€” confident tokens exit early, hard ones go all the way
All tokens are equally relevant Temporal semantic decay β€” irrelevant tokens are multiplicatively suppressed

No other architecture does all three simultaneously. Not GPT. Not LLaMA. Not graph transformers. Not MoE.


Comparison Table

Feature GPT / LLaMA Graph Transformer Early Exit MoE TMT
Dynamic Graph (per-layer rebuild) βœ— Static only βœ— βœ— βœ“
Per-Token Depth Routing βœ— βœ— Partial βœ— βœ“
Temporal Semantic Decay βœ— βœ— βœ— βœ— βœ“
Persistent Memory Anchors βœ— βœ— βœ— βœ— βœ“
Dual-Stream FFN βœ— βœ— βœ— Partial βœ“
O(SΒ·k) attention complexity βœ— (O(SΒ²)) Sometimes βœ— βœ— βœ“

Three Core Innovations β€” Deep Dive

Innovation 1: Mesh Attention

Standard attention is flat. Every token sees every other token. O(SΒ²) cost. Fixed topology β€” the graph is the same for all inputs.

TMT builds a dynamic kNN graph from cosine similarity at every single layer:

x_norm = F.normalize(x, p=2, dim=-1)      # normalize token vectors
sim = x_norm @ x_norm.T                   # (S, S) cosine similarity matrix
topk_vals, topk_idx = sim.topk(k, dim=-1) # connect each token to k nearest neighbors
# β†’ sparse graph: O(SΒ·k) edges instead of O(SΒ²)

Crucially, this graph is rebuilt after every layer. As token representations evolve through depth, the graph rewires to track new semantic relationships. This is impossible in standard transformers β€” once you've committed to full attention, you can't change the topology mid-forward.

At S=1024, k=8: 128Γ— fewer edges than dense attention.


Innovation 2: Temporal Semantic Decay

Standard position encodings tell a model where tokens are. They don't suppress irrelevant tokens.

TMT multiplies a learned decay scalar into the attention weights:

attn_final = softmax(QKα΅€/√d) Γ— sigmoid(W_decay Γ— token_decay)

Where token_decay is computed from the temporal distance of each token. The sigmoid ensures the factor stays in (0, 1) β€” it can only suppress, never amplify. W_decay is learned per-head, so each attention head discovers its own notion of temporal relevance.

Result: tokens that are far away and semantically irrelevant fade out. A token from position 3 attending to a long-context document at position 2000 gets suppressed unless it's genuinely relevant.


Innovation 3: Adaptive Depth Routing

Standard transformers are depth-uniform: every token passes through every layer. The word "the" gets the same compute as "photosynthesis".

TMT has a per-token exit gate after every layer:

confidence = sigmoid(W_gate Β· x)       # scalar confidence per token
if confidence > threshold:
    exit_mask[token] = True             # freeze this token
# Frozen tokens skip all future layer updates

The exit mask is monotone: once a token exits, it stays exited. Frozen tokens bypass attention, FFN, and memory β€” they skip computation entirely.

An auxiliary loss trains the gate to be decisive:

gate_loss = -mean(|confidence - 0.5|)  # penalize uncertainty, reward decisiveness

At exit_threshold=0.85, ~40-55% of tokens exit before the final layer β†’ roughly 2Γ— compute savings at no perplexity cost.


Architecture Diagram

Input Tokens (B, S)
       β”‚
       β–Ό
 TokenEmbedding
       β”‚
       β–Ό
 TemporalPositionEncoder ──────────────────► decay_scalars (B, S, D)
       β”‚
       β–Ό
 MeshBuilder ─── cosine_sim ──► top-k kNN graph ──► edge_index (2,E), edge_weight (E,)
       β”‚
       β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
       β”‚  β”‚                        TMTLayer Γ— N                            β”‚
       β–Ό  β”‚                                                                β”‚
     β”Œβ”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
     β”‚  MeshAttention(x, edge_index, edge_weight, decay_scalars)     β”‚    β”‚
     β”‚    sparse neighbour-masked QKα΅€/√d                             β”‚    β”‚
     β”‚    Γ— sigmoid(W_decay Γ— token_decay)                           β”‚    β”‚
     β”‚    β†’ attended output (B, S, D)                                β”‚    β”‚
     β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€    β”‚
     β”‚  DualStreamFFN                                                β”‚    β”‚
     β”‚    stream_A = gelu(W_a Β· x)                                   β”‚    β”‚
     β”‚    stream_B = gelu(W_b Β· x)                                   β”‚    β”‚
     β”‚    out = LayerNorm(stream_A + stream_B)                       β”‚    β”‚
     β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€    β”‚
     β”‚  ExitGate                                                     β”‚    β”‚
     β”‚    confidence = sigmoid(W_gate Β· x)   (B, S)                 β”‚    β”‚
     β”‚    exit_mask |= (confidence > threshold)                      β”‚    β”‚
     β”‚    x = where(exit_mask, x_frozen, x_new)                     β”‚    β”‚
     β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€    β”‚
     β”‚  MemoryModule                                                 β”‚    β”‚
     β”‚    M persistent KV anchor vectors                             β”‚    β”‚
     β”‚    cross-attend from x to memory anchors                      β”‚    β”‚
     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
                                  β”‚                                        β”‚
                        graph rebuilt here β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β–Ίβ”˜
                                  β”‚
       β–Ό
 LayerNorm β†’ OutputProjection (B, S, D) β†’ (B, S, vocab_size)
       β”‚
       β–Ό
 TMTOutput { logits, exit_masks, confidences, graph_edges, memory_state, decay_scalars }

Quick Install

git clone https://github.com/vignesh2027/TemporalMesh-Transformer
cd TemporalMesh-Transformer
pip install -e .

That installs tmt as an editable package. Dependencies: torch>=2.2, einops, transformers.


5-Line Forward Pass

from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
import torch

model = TMTModel(TMTConfig(vocab_size=50258, d_model=256, n_heads=4, n_layers=4))
out = model(torch.randint(0, 50258, (1, 64)))
print(out.logits.shape)   # torch.Size([1, 64, 50258])

Training

Small config β€” runs on CPU in ~5 minutes

from tmt.model.config import TMTConfig
from tmt.model.model import TMTModel
from tmt.data.dataset import load_text_dataset
from tmt.training.trainer import Trainer
from tmt.training.scheduler import get_cosine_schedule_with_warmup
import torch

cfg = TMTConfig(
    vocab_size=50258, d_model=128, n_heads=4, n_layers=4,
    max_seq_len=128, graph_k=4, ffn_stream_dim=64,
    memory_anchors=8, dropout=0.1,
)
model = TMTModel(cfg)
print(f"Parameters: {model.param_count()/1e6:.2f}M")

loaders = load_text_dataset("wikitext-2", seq_len=128, batch_size=4)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=50, total_steps=500)

trainer = Trainer(model, optimizer, scheduler, torch.device("cpu"))
trainer.train(loaders["train"], n_steps=500, eval_loader=loaders["validation"])

Full config β€” GPU recommended

cfg = TMTConfig(
    vocab_size=50258, d_model=512, n_heads=8, n_layers=12,
    max_seq_len=1024, graph_k=8, ffn_stream_dim=256,
    memory_anchors=16, dropout=0.1, exit_threshold=0.85,
)

Training output explained

Step   10 | loss=7.421 | ce=7.398 | gate=0.023 | lr=6.0e-05
Step   50 | loss=6.814 | ce=6.788 | gate=0.026 | lr=3.0e-04
Step  100 | loss=6.392 | ce=6.361 | gate=0.031 | lr=2.9e-04
Step  500 | loss=5.931 | ce=5.897 | gate=0.034 | lr=1.5e-04 | val_ppl=1374.36
  • ce β€” cross-entropy next-token prediction loss
  • gate β€” auxiliary exit gate decisiveness loss (should stay small)
  • gate_loss increasing slightly means the gate is becoming more decisive over time
  • val_ppl β€” WikiText-2 validation perplexity (lower is better)

TMTOutput Reference

@dataclass
class TMTOutput:
    logits:       Tensor              # (B, S, V)  β€” next-token logit scores
    exit_masks:   List[Tensor]        # N Γ— (B, S) β€” True where token exited at this layer
    confidences:  List[Tensor]        # N Γ— (B, S) β€” gate confidence score per token/layer
    graph_edges:  Tuple[Tensor, ...]  # (edge_index (2,E), edge_weight (E,))
    memory_state: Tensor              # (M, D)     β€” final persistent memory anchors
    decay_scalars:Tensor              # (B, S, D)  β€” temporal decay weights (0–1)

Useful patterns:

# How many tokens exited at each layer?
for i, mask in enumerate(out.exit_masks):
    print(f"Layer {i}: {mask.float().mean()*100:.0f}% exited")

# Greedy decode next token
next_tok = out.logits[:, -1, :].argmax(-1)

# Temperature sampling
probs = torch.softmax(out.logits[:, -1, :] / 0.8, dim=-1)
next_tok = torch.multinomial(probs, 1).squeeze(-1)

# Inspect final graph
ei, ew = out.graph_edges
print(f"Final layer: {ei.shape[1]} edges, weights in [{ew.min():.3f}, {ew.max():.3f}]")

Running Tests

# Run all 201 tests
pytest tests/ -v

# Run specific test modules
pytest tests/test_forward.py -v        # end-to-end forward pass
pytest tests/test_shapes.py -v        # tensor shape correctness
pytest tests/test_training.py -v      # trainer + scheduler
pytest tests/test_edge_cases.py -v    # B=1, S=1, single token
pytest tests/test_integration.py -v   # integration tests
pytest tests/test_dataset.py -v       # data pipeline (no network)
pytest tests/test_generation.py -v    # logits + gradient tests
pytest tests/test_config.py -v        # config validation
pytest tests/test_reprs.py -v         # __repr__ coverage

Test breakdown:

  • test_forward.py β€” 15 tests covering full forward pass, shapes, loss, backprop
  • test_shapes.py β€” 30 tests on every tensor shape in the pipeline
  • test_config.py β€” 20 tests on TMTConfig defaults, edge cases, repr
  • test_training.py β€” 35 tests on Trainer, scheduler warmup/decay, loss
  • test_edge_cases.py β€” 25 tests on B=1, S=1, k=1, single-token sequences
  • test_integration.py β€” 20 tests on end-to-end train/eval cycles
  • test_reprs.py β€” 15 tests on __repr__ for all modules
  • test_dataset.py β€” 16 tests on BlockDataset + tokenizer interface (no network)
  • test_generation.py β€” 10 tests on logit properties, exit gate, gradients

Ablation Notebooks

The tmt/experiments/ directory contains four Jupyter notebooks that document the ablation study:

Notebook Component Tested Key Result
01_baseline.ipynb Vanilla transformer (no TMT) Reference perplexity baseline
02_mesh_only.ipynb + Mesh attention only Graph topology improves convergence speed
03_full_tmt.ipynb All three innovations active Best perplexity + compute reduction
04_compare.ipynb Side-by-side plot Exit gate delivers ~40% compute saving
pip install jupyter
jupyter notebook tmt/experiments/

Hardware Requirements

Use Case CPU RAM GPU VRAM Wall Time
Import + one forward (d=64) 2 GB none < 1 s
500-step training (d=128, S=128) 4 GB none ~5 min
5k-step training (d=256, S=256) 8 GB 4 GB ~30 min
Full training (d=512, S=1024) 16 GB 8 GB ~8 hr
Scale (d=1024, S=2048) 32 GB 24 GB days

Tested on: MacBook M2 (CPU only), RTX 3080 10 GB, A100 40 GB.


Results

WikiText-2 Perplexity β€” 500-Step CPU Baseline

Variant PPL Compute vs Dense Notes
Vanilla Transformer ~1420 1.0Γ— No TMT features
TMT Mesh-Only ~1395 1.0Γ— kNN graph, no exit/decay
TMT Full 1374.36 ~0.6Γ— All three innovations

Config: d_model=256, n_heads=4, n_layers=4, graph_k=4, S=128, batch=4, lr=3e-4, 500 steps, CPU.

These are small-scale proof-of-concept numbers. Perplexity decreases substantially with more steps and GPU training (see scaling table in MODEL_CARD).

Scaling Projections

Config Params Expected PPL (10k steps)
Tiny (d=128, 4L) ~3M ~450
Small (d=256, 6L) ~18M ~180
Medium (d=512, 12L) ~85M ~60
Large (d=1024, 24L) ~340M ~35

Literature Context

TMT builds on and extends several lines of prior work:

Prior Work What TMT Takes What TMT Adds
Vaswani et al. 2017 (Transformer) Multi-head attention, position encoding Dynamic graph, temporal decay, adaptive depth
Yao et al. 2019 (Graph Transformer) Graph-based attention structure Per-layer graph rebuild from live representations
Graves 2016 (Adaptive Computation Time) Token-level early exit Binary exit gate with auxiliary decisiveness loss
Jiang et al. 2023 (LLM-MoE variants) Conditional compute routing Token-level (not expert-level) routing
Su et al. 2023 (RoPE) Relative position encoding Multiplicative decay modulated by learned per-head weights

TMT is the first work to combine all five mechanisms in a single unified architecture with end-to-end training.


Repository Structure

TemporalMesh-Transformer/
β”œβ”€β”€ tmt/                           # Installable Python package
β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”œβ”€β”€ config.py              # TMTConfig β€” all hyperparameters
β”‚   β”‚   β”œβ”€β”€ model.py               # TMTModel + TMTOutput dataclass
β”‚   β”‚   β”œβ”€β”€ attention.py           # MeshAttention (Innovations 1+2)
β”‚   β”‚   β”œβ”€β”€ mesh.py                # MeshBuilder β€” dynamic kNN graph
β”‚   β”‚   β”œβ”€β”€ exit_gate.py           # ExitGate (Innovation 3)
β”‚   β”‚   β”œβ”€β”€ embedding.py           # TokenEmbedding + TemporalPositionEncoder
β”‚   β”‚   β”œβ”€β”€ ffn.py                 # DualStreamFFN
β”‚   β”‚   β”œβ”€β”€ memory.py              # MemoryModule β€” persistent KV anchors
β”‚   β”‚   └── layers.py              # TMTLayer β€” assembles all submodules
β”‚   β”œβ”€β”€ data/
β”‚   β”‚   β”œβ”€β”€ dataset.py             # BlockDataset + load_text_dataset
β”‚   β”‚   └── tokenizer.py           # TMTTokenizer β€” thin HF wrapper
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ trainer.py             # Trainer β€” training loop
β”‚   β”‚   β”œβ”€β”€ loss.py                # compute_loss (CE + gate auxiliary)
β”‚   β”‚   └── scheduler.py           # cosine warmup LR schedule
β”‚   └── experiments/               # Ablation study notebooks
β”‚       β”œβ”€β”€ 01_baseline.ipynb
β”‚       β”œβ”€β”€ 02_mesh_only.ipynb
β”‚       β”œβ”€β”€ 03_full_tmt.ipynb
β”‚       └── 04_compare.ipynb
β”œβ”€β”€ tests/                         # 201 tests, all passing
β”‚   β”œβ”€β”€ test_forward.py
β”‚   β”œβ”€β”€ test_shapes.py
β”‚   β”œβ”€β”€ test_config.py
β”‚   β”œβ”€β”€ test_training.py
β”‚   β”œβ”€β”€ test_edge_cases.py
β”‚   β”œβ”€β”€ test_integration.py
β”‚   β”œβ”€β”€ test_reprs.py
β”‚   β”œβ”€β”€ test_dataset.py            # NEW β€” data pipeline, no network
β”‚   └── test_generation.py        # NEW β€” logits, exit gate, gradients
β”œβ”€β”€ paper/
β”‚   └── TemporalMesh_Transformer_2026.pdf
β”œβ”€β”€ docs/
β”‚   └── index.html                 # GitHub Pages
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ CONTRIBUTING.md
└── MODEL_CARD.md                  # HuggingFace model card

Contributing

See CONTRIBUTING.md for:

  • Development setup
  • Code style (ruff, type hints)
  • How to add tests
  • Pull request process

All contributions welcome. Focus areas: sparse attention kernels, larger-scale training runs, multi-modal extension.


Citation

@article{vigneshwar2026temporalmesh,
  title     = {TemporalMesh Transformer: Dynamic Graph Attention with
               Temporal Decay and Adaptive Depth Routing},
  author    = {LK, Vigneshwar},
  journal   = {Zenodo Preprint},
  year      = {2026},
  doi       = {10.5281/zenodo.20287197},
  url       = {https://zenodo.org/records/20287390},
  note      = {Novel architecture combining mesh attention, temporal decay
               encoding, and per-token adaptive depth routing}
}

Links


Built from scratch. Every attention head. Every graph edge. Every exit gate.

Vigneshwar LK β€” Takshashila University, CSE 2022–26