ModuleMind / agents /modmind /model.py
Quazim0t0's picture
Add files using upload-large-folder tool
45e7dfb verified
Raw
History Blame Contribute Delete
46.3 kB
"""
model.py -- SpikeWhaleLM: combined architecture from SpikeTransformer (My Project) + NanoWhale.
Architecture flow:
Embedding
-> Engram delta (N-gram memory, My Project)
-> [expand to hc_mult copies if HC enabled]
-> N x TransformerBlock:
HC pre-op (NanoWhale) -> RMSNorm -> MLA+DERF+XSA Attention (combined)
-> HC post-op
HC pre-op -> RMSNorm -> MoE FFN w/ shared expert (NanoWhale)
-> HC post-op
-> [mean-pool hc_mult copies if HC enabled]
-> RMSNorm
-> LM head + MTP heads (NanoWhale)
Component origins:
RMSNorm, RotaryEmbedding -- both (standard)
Engram / DERFContextGate -- My Project
MLADerfXSAAttention -- MLA from NanoWhale + DERF+XSA from My Project
SparseMoEFFN w/ shared expert -- NanoWhale MoE structure + My Project aux loss
HyperConnectionLayer -- NanoWhale
SpikeWhaleLM + MTP heads -- NanoWhale
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.utils.checkpoint import checkpoint as gradient_checkpoint
from config import SpikeWhaleConfig
# ---------------------------------------------------------------------------
# Primitives
# ---------------------------------------------------------------------------
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class RotaryEmbedding(nn.Module):
"""RoPE for the rope partition of Q and K (qk_rope_head_dim dims only)."""
def __init__(self, dim: int, max_positions: int = 4096, theta: float = 10000.0):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(max_positions).float()
freqs = torch.outer(t, inv_freq)
self.register_buffer("cos_cache", freqs.cos())
self.register_buffer("sin_cache", freqs.sin())
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor:
"""
x: [B, H, S, rope_dim]
position_ids: [B, S]
"""
cos = self.cos_cache[position_ids].unsqueeze(1) # [B, 1, S, rope_dim//2]
sin = self.sin_cache[position_ids].unsqueeze(1)
d = cos.shape[-1]
x1, x2 = x[..., :d], x[..., d:]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
# ---------------------------------------------------------------------------
# Engram: N-gram hash lookup + DERF gate (My Project, preserved)
# ---------------------------------------------------------------------------
class TokenCompressor(nn.Module):
def __init__(self, embed_dim: int, compress_dim: int):
super().__init__()
self.proj = nn.Linear(embed_dim, compress_dim, bias=False)
nn.init.normal_(self.proj.weight, std=0.02)
# BUGFIX: this projection feeds ONLY the integer hash index
# (idx = h.abs().long() % table_size) in MultiHeadHashLookup. The .long()
# cast is non-differentiable, so no gradient ever reaches this weight --
# it can never learn. Worse, _classify_params put it in the weight-decay
# group, so AdamW was steadily shrinking it toward zero and degrading the
# hash projection over a long run. Freeze it: a fixed random projection is
# exactly the right behavior for an LSH-style hash, and freezing drops it
# from the optimizer (saves state) and from weight decay. Checkpoint-safe:
# the parameter still exists and is still saved/loaded in state_dict.
self.proj.weight.requires_grad_(False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class MultiHeadHashLookup(nn.Module):
def __init__(self, num_heads: int, table_size: int,
compress_dim: int, out_dim: int, max_ngram: int = 3):
super().__init__()
self.num_heads = num_heads
self.table_size = table_size
self.max_ngram = max_ngram
self.out_dim = out_dim
self.tables = nn.ModuleList([
nn.Embedding(table_size, out_dim) for _ in range(num_heads)
])
for t in self.tables:
nn.init.normal_(t.weight, std=0.01)
for n in range(1, max_ngram + 1):
for k in range(n):
proj = torch.randn(num_heads, compress_dim)
proj = proj / (proj.norm(dim=1, keepdim=True) + 1e-8)
self.register_buffer(f"hash_proj_n{n}_p{k}", proj)
def forward(self, compressed: torch.Tensor) -> torch.Tensor:
"""
compressed: [B, S, compress_dim]
returns: [B, S, out_dim]
All positions are processed in parallel. The outer loop runs max_ngram
times (≤3), not S times (≤2048). Each iteration is a single matmul +
embedding lookup across the whole sequence, making this GPU-friendly
and compatible with torch.compile.
"""
B, S, _ = compressed.shape
device = compressed.device
out = torch.zeros(B, S, self.out_dim, device=device, dtype=compressed.dtype)
# Per-position normalization: tracks how many (n-gram × head) contributions
# each position receives. Positions near the start get fewer contributions
# because shorter n-grams don't exist yet (matches original causal behavior).
norm = torch.zeros(S, device=device)
for n in range(1, self.max_ngram + 1):
if S < n:
continue
valid_len = S - n + 1 # positions [n-1 .. S-1] are valid for order-n
start = n - 1
# Accumulate position-k contribution to the order-n hash.
# compressed[:, k : k+valid_len, :] is the k-th token of every n-gram
# window simultaneously → [B, valid_len, num_heads] after projection.
h = torch.zeros(B, valid_len, self.num_heads, device=device)
for k in range(n):
proj = getattr(self, f"hash_proj_n{n}_p{k}") # [num_heads, compress_dim]
h = h + torch.matmul(compressed[:, k:k + valid_len, :].float(), proj.t())
idx = h.abs().long() % self.table_size # [B, valid_len, num_heads]
for head_idx, table in enumerate(self.tables):
out[:, start:, :] = out[:, start:, :] + table(idx[:, :, head_idx])
norm[start:] += self.num_heads
# Cast back to input dtype: the norm division promotes bf16→float32 under autocast.
# Keeping the output in the same dtype as the input avoids a silent dtype mismatch
# when EngramModule adds this result back onto the (bf16) embedding tensor.
return (out / norm.view(1, -1, 1).clamp(min=1)).to(compressed.dtype)
class DERFContextGate(nn.Module):
"""
DERF gate: gate = gamma * erf(alpha * proj([retrieved, x]) + bias)
Positive probability = (gate + 1) / 2 applied to retrieved embedding.
Large negative init_bias keeps gate closed at start of training.
"""
def __init__(self, obs_size: int, init_bias: float = -4.0):
super().__init__()
self.proj = nn.Linear(obs_size * 2, obs_size)
self.alpha = nn.Parameter(torch.ones(obs_size))
self.bias = nn.Parameter(torch.full((obs_size,), init_bias))
self.gamma = nn.Parameter(torch.ones(obs_size))
def forward(self, retrieved: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
logits = self.proj(torch.cat([retrieved, x], dim=-1))
gate = self.gamma * ((torch.erf(self.alpha * logits + self.bias) + 1.0) / 2.0)
return retrieved * gate
class EngramModule(nn.Module):
"""
N-gram hash lookup with DERF gate (My Project), fully vectorized.
All S positions are processed in parallel — the sequential Python loop
over sequence positions has been eliminated. The lookup now accepts the
full [B, S, compress_dim] compressed tensor and returns [B, S, H] in one pass.
"""
def __init__(self, cfg: SpikeWhaleConfig):
super().__init__()
self.compressor = TokenCompressor(cfg.hidden_size, cfg.engram_compress_dim)
self.lookup = MultiHeadHashLookup(
cfg.engram_num_heads, cfg.engram_table_size,
cfg.engram_compress_dim, cfg.hidden_size, cfg.engram_max_ngram,
)
self.gate = DERFContextGate(cfg.hidden_size, cfg.engram_gate_init_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: [B, S, H] -> engram_delta: [B, S, H]"""
compressed = self.compressor(x.detach()) # [B, S, compress_dim]
retrieved = self.lookup(compressed) # [B, S, H]
return self.gate(retrieved, x) # [B, S, H]
# ---------------------------------------------------------------------------
# Hyper-Connections (NanoWhale, simplified)
# ---------------------------------------------------------------------------
class HyperConnectionLayer(nn.Module):
"""
Simplified Hyper-Connections for one sublayer (attention or FFN).
Maintains hc_mult parallel residual streams.
Pre-op: learned weighted average of hc_mult copies -> single hidden state for sublayer.
Post-op: sublayer output added to each copy with learned per-stream weights.
Full HC uses Sinkhorn-normalized 2D routing matrices; this uses softmax-normalized
1D weights for pre/post routing -- captures the same multi-stream routing spirit.
"""
def __init__(self, hidden_size: int, hc_mult: int,
sinkhorn_iters: int = 20, eps: float = 1e-6):
super().__init__()
self.hc_mult = hc_mult
# pre_weight: how to mix hc_mult copies into one sublayer input
# post_weight: how to distribute the sublayer delta to each copy
#
# BUGFIX: these must NOT be initialized identically across streams.
# The model expands the hidden state into hc_mult *identical* copies.
# With uniform pre/post weights, pre_op produces sum_i copy_i * w_i =
# copy * sum(softmax)=copy (all copies equal), and post_op adds the same
# delta to every copy -- so the streams stay byte-for-byte identical at
# every layer. When all streams are equal, the softmax Jacobian applied
# to the (equal) per-stream gradients is exactly zero, so pre_weight and
# post_weight receive ZERO gradient and never move off 1/hc_mult. The HC
# routing then learns nothing and just burns hc_mult x memory/compute.
#
# Breaking the post_weight symmetry at init makes the streams diverge
# after the first sublayer, which restores gradient flow to all HC
# weights. We center post_weight so softmax starts near-uniform (keeps
# the residual baseline ~unchanged) but with a distinct value per stream.
self.pre_weight = nn.Parameter(
torch.linspace(0.5, -0.5, hc_mult) / max(hc_mult, 1)
)
self.post_weight = nn.Parameter(
torch.linspace(-0.5, 0.5, hc_mult) / max(hc_mult, 1)
)
def pre_op(self, copies: torch.Tensor) -> torch.Tensor:
"""copies: [B, hc_mult, S, H] -> [B, S, H]"""
w = F.softmax(self.pre_weight, dim=0) # [hc_mult]
return (copies * w.view(1, -1, 1, 1)).sum(dim=1)
def post_op(self, copies: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
"""
copies: [B, hc_mult, S, H]
delta: [B, S, H]
Returns updated copies: [B, hc_mult, S, H]
"""
w = F.softmax(self.post_weight, dim=0) # [hc_mult]
return copies + delta.unsqueeze(1) * w.view(1, -1, 1, 1)
# ---------------------------------------------------------------------------
# MLA + DERF + XSA Attention (combined)
# ---------------------------------------------------------------------------
class MLADerfXSAAttention(nn.Module):
"""
Multi-Head Latent Attention (NanoWhale) with DERF scores + XSA correction (My Project).
MLA (from NanoWhale):
Q: hidden -> q_lora_rank (RMSNorm) -> num_heads * head_dim (low-rank projection)
K, V: hidden -> num_kv_heads * head_dim (direct, MQA by default with num_kv_heads=1)
Output: num_heads * head_dim -> o_lora_rank -> hidden (low-rank output)
Partial RoPE: applied only to the last qk_rope_head_dim dims of Q and K
DERF (from My Project):
Replaces softmax: erf(alpha * scores + bias) * gamma, shifted to [0,1] then normalized.
Per-head learnable alpha, bias, gamma.
XSA (from My Project):
After computing the weighted value sum y, subtract the component of y that
projects onto each position's own value vector. Forces the output to carry
only cross-position information, not echo the current token back.
"""
def __init__(self, cfg: SpikeWhaleConfig):
super().__init__()
self.num_heads = cfg.num_attention_heads
self.num_kv_heads = cfg.num_key_value_heads
self.head_dim = cfg.head_dim
self.qk_rope_head_dim = cfg.qk_rope_head_dim
self.nope_head_dim = cfg.nope_head_dim
self.hidden_size = cfg.hidden_size
self.use_derf = cfg.use_derf
self.use_xsa = cfg.use_xsa
self.dropout_p = cfg.attention_dropout
self.kv_groups = self.num_heads // self.num_kv_heads
# Low-rank Q projection (MLA)
self.q_a_proj = nn.Linear(cfg.hidden_size, cfg.q_lora_rank, bias=False)
self.q_a_norm = RMSNorm(cfg.q_lora_rank, cfg.rms_norm_eps)
self.q_b_proj = nn.Linear(cfg.q_lora_rank, self.num_heads * self.head_dim, bias=False)
# Direct K, V projections (MQA/GQA)
self.k_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
# Low-rank output projection (MLA)
self.o_a_proj = nn.Linear(self.num_heads * self.head_dim, cfg.o_lora_rank, bias=False)
self.o_b_proj = nn.Linear(cfg.o_lora_rank, cfg.hidden_size, bias=False)
# Partial RoPE: applied to qk_rope_head_dim dims only
self.rope = RotaryEmbedding(
self.qk_rope_head_dim,
max_positions=cfg.max_position_embeddings,
theta=cfg.rope_theta,
)
# DERF parameters: one per query head (My Project)
if self.use_derf:
self.derf_alpha = nn.Parameter(torch.ones(self.num_heads))
self.derf_bias = nn.Parameter(torch.zeros(self.num_heads))
self.derf_gamma = nn.Parameter(torch.ones(self.num_heads))
nn.init.normal_(self.q_a_proj.weight, std=cfg.initializer_range)
nn.init.normal_(self.q_b_proj.weight, std=cfg.initializer_range)
nn.init.normal_(self.k_proj.weight, std=cfg.initializer_range)
nn.init.normal_(self.v_proj.weight, std=cfg.initializer_range)
nn.init.normal_(self.o_a_proj.weight, std=cfg.initializer_range)
nn.init.normal_(self.o_b_proj.weight, std=cfg.initializer_range)
def forward(
self,
x: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, S, _ = x.shape
# Q via low-rank projection with intermediate norm (MLA)
q = self.q_a_norm(self.q_a_proj(x))
q = self.q_b_proj(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
# [B, num_heads, S, head_dim]
# K, V direct projections
k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
# Partial RoPE: split into nope and rope partitions, rotate only the rope part
q_nope = q[..., :self.nope_head_dim]
q_rope = q[..., self.nope_head_dim:] # qk_rope_head_dim dims
k_nope = k[..., :self.nope_head_dim]
k_rope = k[..., self.nope_head_dim:]
q_rope = self.rope(q_rope, position_ids)
k_rope = self.rope(k_rope, position_ids)
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
# KV cache for inference
if past_key_value is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
present = (k, v) if use_cache else None
N = k.shape[2] # total key positions (past + current)
# Expand KV heads for MQA/GQA
if self.kv_groups > 1:
k = k.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape(
B, self.num_heads, N, self.head_dim)
v = v.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape(
B, self.num_heads, N, self.head_dim)
# Scaled dot-product attention.
if self.use_derf:
# DERF replaces softmax with a custom erf nonlinearity, so it cannot
# use the fused kernel and must materialize scores explicitly.
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Build boolean mask for causality (this avoids the -inf math errors)
if attention_mask is None and past_key_value is None:
is_masked = torch.triu(torch.ones(S, N, dtype=torch.bool, device=scores.device), diagonal=N - S + 1).unsqueeze(0).unsqueeze(0)
else:
is_masked = (attention_mask < -1.0) if attention_mask is not None else torch.zeros_like(scores, dtype=torch.bool)
# FIX 2: Do NOT use float('-inf'). If alpha ever hits 0.0, 0.0 * -inf = NaN.
# Use a safe negative scalar (-10000.0) for masked positions.
safe_scores = scores.masked_fill(is_masked, -10000.0)
a = self.derf_alpha.view(1, -1, 1, 1)
b = self.derf_bias.view(1, -1, 1, 1)
g = self.derf_gamma.view(1, -1, 1, 1)
attn_weights = g * torch.erf(a * safe_scores + b) # [-gamma, gamma]
attn_weights = (attn_weights + g) / 2.0 # shift to [0, gamma]
attn_weights = attn_weights.masked_fill(is_masked, 0.0) # enforce causal mask safely
attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-8)
if self.dropout_p > 0 and self.training:
attn_weights = F.dropout(attn_weights, p=self.dropout_p)
y = torch.matmul(attn_weights, v) # [B, num_heads, S, head_dim]
else:
# OPTIMIZATION: standard (softmax) attention goes through the fused
# scaled_dot_product_attention kernel (FlashAttention / mem-efficient
# backends). This is the hot path during pretraining (use_derf=False)
# and is much faster + lower memory than materializing [B,H,S,N]
# scores and a softmax. SDPA already scales by 1/sqrt(head_dim).
#
# CONTIGUITY FIX: with MQA/GQA, k and v above are built via
# .unsqueeze(2).expand(...).reshape(...). Under torch.compile, inductor
# can trace the broadcasted (zero-stride) view through to the fused
# flash-attention BACKWARD kernel, whose meta-kernel then asserts on the
# mismatched stride (e.g. "stride 120==245760 at dim=1") and aborts.
# Forcing contiguity guarantees standard strides into the fused kernel.
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
drop = self.dropout_p if self.training else 0.0
if past_key_value is None and attention_mask is None:
# Prefill / training: pure causal mask, no materialization needed.
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=drop)
else:
# Incremental decode or a provided mask: pass an explicit boolean
# keep-mask (True = attend). SDPA fills masked positions with -inf.
if attention_mask is not None:
is_masked = (attention_mask < -1.0)
else:
is_masked = torch.triu(
torch.ones(S, N, dtype=torch.bool, device=q.device),
diagonal=N - S + 1,
).unsqueeze(0).unsqueeze(0)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=~is_masked, dropout_p=drop)
# XSA: remove self-projection from output (My Project)
# For each query position s, subtract the component of y[:,:,s,:] that
# projects onto the normalized value vector at the same position.
if self.use_xsa:
past_len = N - S
v_self = v[:, :, past_len:past_len + S, :] # [B, H, S, D]
vn = v_self / (v_self.norm(dim=-1, keepdim=True) + 1e-8)
projection = (y * vn).sum(dim=-1, keepdim=True) * vn
y = y - projection
# Low-rank output projection (MLA)
y = y.transpose(1, 2).contiguous().view(B, S, self.num_heads * self.head_dim)
y = self.o_b_proj(self.o_a_proj(y))
return y, present
# ---------------------------------------------------------------------------
# MoE FFN: shared expert + sqrtsoftplus + hash routing (NanoWhale) + aux loss (My Project)
# ---------------------------------------------------------------------------
class ExpertFFN(nn.Module):
"""Single SwiGLU expert."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
def sqrtsoftplus(x: torch.Tensor) -> torch.Tensor:
"""sqrt(softplus(x)) = sqrt(log(1+exp(x))). NanoWhale expert scoring."""
# FIX 1: Added 1e-8. If F.softplus(x) evaluates to 0.0, torch.sqrt(0) produces NaN gradients on backward pass.
return torch.sqrt(F.softplus(x) + 1e-8)
class SparseMoEFFN(nn.Module):
"""
Combines NanoWhale MoE structure with My Project aux loss:
- n_shared_experts always-active experts (NanoWhale)
- n_routed_experts sparse routed experts, top-k activation
- sqrtsoftplus scoring (NanoWhale) vs softmax
- hash routing for early layers (NanoWhale)
- norm_topk_prob + routed_scaling_factor (NanoWhale)
- load-balancing aux loss (My Project)
"""
def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int = 0):
super().__init__()
self.n_routed_experts = cfg.n_routed_experts
self.n_shared_experts = cfg.n_shared_experts
self.num_experts_per_tok = cfg.num_experts_per_tok
self.norm_topk_prob = cfg.norm_topk_prob
self.scoring_func = cfg.scoring_func
self.routed_scaling_factor = cfg.routed_scaling_factor
self.use_hash_routing = layer_idx < cfg.num_hash_layers
self.aux_loss_coef = cfg.moe_aux_loss_coef
self.router = nn.Linear(cfg.hidden_size, cfg.n_routed_experts, bias=False)
self.experts = nn.ModuleList([
ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size)
for _ in range(cfg.n_routed_experts)
])
self.shared_experts = nn.ModuleList([
ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size)
for _ in range(cfg.n_shared_experts)
]) if cfg.n_shared_experts > 0 else None
self._last_aux_loss: Optional[torch.Tensor] = None
def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
B, S, H = x.shape
x_flat = x.view(B * S, H)
T = B * S
# Shared experts: always active (NanoWhale)
shared_out = torch.zeros_like(x_flat)
if self.shared_experts:
for expert in self.shared_experts:
shared_out = shared_out + expert(x_flat)
if len(self.shared_experts) > 1:
shared_out = shared_out / len(self.shared_experts)
# Router
if self.use_hash_routing:
# Hash routing: deterministic assignment without learned router (NanoWhale).
# Assign each of the num_experts_per_tok slots a DISTINCT expert by cycling:
# token at absolute position p -> experts [p%n, (p+1)%n, ..., (p+k-1)%n].
#
# BUGFIX: the assignment must key off the token's ABSOLUTE sequence
# position, not torch.arange(T) (its index in the current flattened
# batch). With arange(T), incremental KV-cache decoding (S=1) always
# sees index 0 and routes every token to expert 0, so generation used
# a different expert assignment than training and silently diverged.
# Using position_ids makes prefill, full-sequence training, and
# step-by-step generation all agree. (For S divisible by n_experts,
# this matches the previous training-time behavior exactly, so existing
# checkpoints stay valid.)
if position_ids is not None:
base = (position_ids.reshape(T, 1) % self.n_routed_experts).long()
else:
base = (torch.arange(T, device=x.device) % self.n_routed_experts).unsqueeze(1)
offsets = torch.arange(self.num_experts_per_tok, device=x.device) # [k]
top_k_indices = (base + offsets.unsqueeze(0)) % self.n_routed_experts # [T, k]
top_k_weights = torch.ones(T, self.num_experts_per_tok, device=x.device) / self.num_experts_per_tok
self._last_aux_loss = None
else:
router_logits = self.router(x_flat)
if self.scoring_func == "sqrtsoftplus":
routing_scores = sqrtsoftplus(router_logits)
else:
routing_scores = F.softmax(router_logits, dim=-1)
top_k_scores, top_k_indices = torch.topk(routing_scores, self.num_experts_per_tok, dim=-1)
if self.norm_topk_prob:
top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8)
else:
top_k_weights = top_k_scores
top_k_weights = top_k_weights * self.routed_scaling_factor
# Load-balancing aux loss (My Project)
softmax_probs = F.softmax(router_logits, dim=-1)
expert_mask = torch.zeros_like(softmax_probs)
expert_mask.scatter_(1, top_k_indices, 1.0)
f_e = expert_mask.mean(0)
p_e = softmax_probs.mean(0)
self._last_aux_loss = self.n_routed_experts * (f_e * p_e).sum() * self.aux_loss_coef
# Dispatch tokens to routed experts
out = torch.zeros_like(x_flat)
for expert_idx, expert in enumerate(self.experts):
token_mask = (top_k_indices == expert_idx).any(dim=-1)
if not token_mask.any():
continue
expert_input = x_flat[token_mask]
expert_output = expert(expert_input)
k_pos = (top_k_indices[token_mask] == expert_idx).nonzero(as_tuple=False)
weights = top_k_weights[token_mask][k_pos[:, 0], k_pos[:, 1]].unsqueeze(-1)
out[token_mask] = out[token_mask] + expert_output * weights
out = out + shared_out
return out.view(B, S, H)
def get_aux_loss(self) -> Optional[torch.Tensor]:
# Return None when hash routing (no aux loss) or when forward hasn't run yet.
# Returning torch.tensor(0.0) here would be a CPU tensor and cause a device
# mismatch when added to the CUDA total_aux_loss in SpikeWhaleModel.
return self._last_aux_loss
class DenseFFN(nn.Module):
"""Dense SwiGLU FFN for non-MoE layers."""
def __init__(self, cfg: SpikeWhaleConfig):
super().__init__()
self.gate_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False)
self.up_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False)
self.down_proj = nn.Linear(cfg.moe_intermediate_size, cfg.hidden_size, bias=False)
def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
def get_aux_loss(self) -> Optional[torch.Tensor]:
return None # dense layers have no aux loss; None avoids CPU-tensor device mismatch
# ---------------------------------------------------------------------------
# Transformer block with Hyper-Connections
# ---------------------------------------------------------------------------
class TransformerBlock(nn.Module):
"""
Transformer block combining all features:
- Hyper-Connections: pre/post routing through hc_mult streams (NanoWhale)
- MLA + DERF + XSA attention (combined)
- MoE FFN with shared expert (NanoWhale) + aux loss (My Project)
"""
def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int):
super().__init__()
self.use_hc = cfg.use_hyper_connections
self.hidden_dropout = cfg.hidden_dropout
self.attn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.attn = MLADerfXSAAttention(cfg)
self.ffn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
if cfg.use_moe and layer_idx in cfg.moe_layers:
self.ffn = SparseMoEFFN(cfg, layer_idx)
self.is_moe = True
else:
self.ffn = DenseFFN(cfg)
self.is_moe = False
if self.use_hc:
self.hc_attn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult,
cfg.hc_sinkhorn_iters, cfg.hc_eps)
self.hc_ffn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult,
cfg.hc_sinkhorn_iters, cfg.hc_eps)
def forward(
self,
x: torch.Tensor, # [B, hc_mult, S, H] if HC else [B, S, H]
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple], Optional[torch.Tensor]]:
# --- Attention sub-layer ---
if self.use_hc:
h = self.hc_attn.pre_op(x) # [B, S, H]
else:
h = x
attn_out, present = self.attn(
self.attn_norm(h), position_ids, attention_mask, past_key_value, use_cache
)
attn_out = F.dropout(attn_out, p=self.hidden_dropout, training=self.training)
if self.use_hc:
x = self.hc_attn.post_op(x, attn_out)
h = self.hc_ffn.pre_op(x) # [B, S, H]
else:
h = h + attn_out
# --- FFN sub-layer ---
ffn_out = self.ffn(self.ffn_norm(h), position_ids)
ffn_out = F.dropout(ffn_out, p=self.hidden_dropout, training=self.training)
if self.use_hc:
x = self.hc_ffn.post_op(x, ffn_out)
else:
x = h + ffn_out
return x, present, self.ffn.get_aux_loss()
# ---------------------------------------------------------------------------
# Full model
# ---------------------------------------------------------------------------
class HRMRefinementBlock(nn.Module):
"""
HRM-INSPIRED iterative refinement (EXPERIMENTAL, off by default). NOT the full
Hierarchical Reasoning Model -- only the iterative-refinement mechanism that the
independent ARC-Prize ablation found carried most of HRM's benefit, adapted to a
causal LM's final hidden state.
Runs N inner steps; each computes a small gated update conditioned on the current
state AND the original ('anchor') input. Per-step gate inits at 0 and up.weight is
zero-init -> the block is an EXACT identity at init, so enabling it cannot hurt a
fresh model; it only contributes if training opens the gate. Pointwise over
positions -> causal-safe (no future-token leakage). In/out [B,S,H].
"""
def __init__(self, hidden_size: int, refine_dim: int, steps: int, eps: float = 1e-6):
super().__init__()
self.steps = steps
self.norm = RMSNorm(hidden_size, eps)
self.down = nn.Linear(hidden_size * 2, refine_dim, bias=False)
self.up = nn.Linear(refine_dim, hidden_size, bias=False)
self.gate = nn.Parameter(torch.zeros(steps))
nn.init.normal_(self.down.weight, std=0.02)
nn.init.zeros_(self.up.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
anchor = x
h = x
for t in range(self.steps):
inp = torch.cat([self.norm(h), anchor], dim=-1)
update = self.up(F.silu(self.down(inp)))
h = h + torch.tanh(self.gate[t]) * update
return h
class LatentProjection(nn.Module):
"""ModularMind-on-V2: pool final hidden state -> d_latent output vector.
Mirrors ModularMind's contract: mean-pool over sequence, ReLU^2 activation
(sparse latent codes), Xavier init (NOT zero) so the latent carries signal
from step 1 — zero-init would make the chain unable to bootstrap."""
def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6):
super().__init__()
self.proj1 = nn.Linear(hidden_size, hidden_size, bias=False)
self.proj2 = nn.Linear(hidden_size, d_latent, bias=False)
self.norm = RMSNorm(d_latent, eps)
nn.init.xavier_uniform_(self.proj1.weight)
nn.init.xavier_uniform_(self.proj2.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
pooled = x.mean(dim=1) # [B, S, H] -> [B, H]
h = torch.relu(self.proj1(pooled)) ** 2
return self.norm(self.proj2(h)) # [B, d_latent]
class LatentInjection(nn.Module):
"""ModularMind-on-V2: fold an incoming d_latent vector into embeddings.
Broadcast across positions, ReGLU-gated add. Gate starts SMALL (not exactly
zero): the injection is near-identity at init (stable) while still passing a
little gradient, so the upstream RecursiveLink + specialist can bootstrap from
step 1. (Exact-zero gate would block all gradient to the link -- the
bootstrapping problem ModularMind's LatentProjection docstring warns about.)
This is the INPUT side of RecursiveLink (the prev specialist's latent)."""
def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6,
gate_init: float = 1e-3):
super().__init__()
self.up = nn.Linear(d_latent, hidden_size, bias=False)
self.norm = RMSNorm(hidden_size, eps)
self.value_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.gate_init = gate_init
nn.init.xavier_uniform_(self.up.weight)
nn.init.xavier_uniform_(self.value_proj.weight)
nn.init.normal_(self.gate_proj.weight, std=gate_init) # small, not zero
def forward(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor:
# x: [B, S, H], latent: [B, d_latent]
inj = self.norm(self.up(latent)).unsqueeze(1) # [B, 1, H] broadcast over S
value = self.value_proj(inj)
gate = torch.relu(self.gate_proj(inj))
return x + value * gate
class RecursiveLink(nn.Module):
"""ModularMind cross-specialist bridge, V2 build. Converts one specialist's
output latent into the next specialist's input latent. ReGLU + residual,
single shared module reused for every hop. Fully differentiable."""
def __init__(self, d_latent: int = 256, expansion: float = 2.0):
super().__init__()
d_hidden = int(d_latent * expansion)
self.norm = nn.LayerNorm(d_latent)
self.value_proj = nn.Linear(d_latent, d_hidden, bias=False)
self.gate_proj = nn.Linear(d_latent, d_hidden, bias=False)
self.down = nn.Linear(d_hidden, d_latent, bias=False)
self.residual_gate = nn.Parameter(torch.ones(1))
nn.init.xavier_uniform_(self.value_proj.weight)
nn.init.xavier_uniform_(self.gate_proj.weight)
nn.init.xavier_uniform_(self.down.weight)
def forward(self, z: torch.Tensor) -> torch.Tensor:
n = self.norm(z)
h = self.value_proj(n) * torch.relu(self.gate_proj(n))
return z + self.residual_gate * self.down(h)
class SpikeWhaleModel(nn.Module):
"""Decoder stack without LM head."""
def __init__(self, cfg: SpikeWhaleConfig):
super().__init__()
self.cfg = cfg
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
nn.init.normal_(self.embed_tokens.weight, std=cfg.initializer_range)
self.engram = EngramModule(cfg) if cfg.use_engram else None
self.layers = nn.ModuleList([
TransformerBlock(cfg, layer_idx=i)
for i in range(cfg.num_hidden_layers)
])
self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.hrm_refine = (
HRMRefinementBlock(cfg.hidden_size, cfg.hrm_refine_dim,
cfg.hrm_refine_steps, cfg.rms_norm_eps)
if getattr(cfg, "use_hrm_refine", False) else None
)
# ModularMind-on-V2: latent input/output (off unless use_latent_io)
if getattr(cfg, "use_latent_io", False):
self.latent_inject = LatentInjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps)
self.latent_out = LatentProjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps)
else:
self.latent_inject = None
self.latent_out = None
self.gradient_checkpointing = False
def reset_latent_gate(self):
"""Re-init the injection gate SMALL (not zero). Must be called AFTER any HF
post_init/_init_weights pass, which otherwise re-randomizes the gate to full
scale. Small-but-nonzero keeps injection near-identity at start while letting
gradient reach the upstream RecursiveLink (so the chain can bootstrap)."""
if self.latent_inject is not None:
nn.init.normal_(self.latent_inject.gate_proj.weight,
std=self.latent_inject.gate_init)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple]] = None,
use_cache: bool = False,
inject_latent: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[List[Tuple]], torch.Tensor]:
B, S = input_ids.shape
device = input_ids.device
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(
past_len, past_len + S, device=device
).unsqueeze(0).expand(B, -1)
# Token embedding
x = self.embed_tokens(input_ids) # [B, S, H]
# Engram N-gram delta (My Project)
if self.engram is not None:
x = x + self.engram(x)
# ModularMind-on-V2: inject the previous specialist's latent (broadcast
# across positions, ReGLU-gated). No-op at init (gate zero) and skipped
# entirely if no latent is passed.
if self.latent_inject is not None and inject_latent is not None:
x = self.latent_inject(x, inject_latent)
# Expand to hc_mult streams for Hyper-Connections (NanoWhale)
if self.cfg.use_hyper_connections:
x = x.unsqueeze(1).expand(-1, self.cfg.hc_mult, -1, -1).clone()
# [B, hc_mult, S, H]
present_key_values = [] if use_cache else None
total_aux_loss = torch.tensor(0.0, device=device)
for layer_idx, layer in enumerate(self.layers):
pkv = past_key_values[layer_idx] if past_key_values else None
if self.gradient_checkpointing and self.training:
# Gradient checkpointing with use_reentrant=False (NanoWhale)
x, present, aux_loss = gradient_checkpoint(
layer, x, position_ids, attention_mask, None, False,
use_reentrant=False,
)
else:
x, present, aux_loss = layer(x, position_ids, attention_mask, pkv, use_cache)
if use_cache:
present_key_values.append(present)
if aux_loss is not None:
total_aux_loss = total_aux_loss + aux_loss
# Reduce HC streams to single hidden state
if self.cfg.use_hyper_connections:
x = x.mean(dim=1) # [B, S, H]
if self.hrm_refine is not None:
x = self.hrm_refine(x)
x = self.norm(x)
# ModularMind-on-V2: emit this specialist's output latent (for RecursiveLink).
out_latent = self.latent_out(x) if self.latent_out is not None else None
return x, present_key_values, total_aux_loss, out_latent
class SpikeWhaleLM(PreTrainedModel):
"""
Full causal LM combining all SpikeTransformer + NanoWhale features.
Training (forward with labels):
out = model(input_ids=ids, labels=ids)
loss = out.loss # CE + MTP loss + MoE aux loss
Generation:
out = model(input_ids=ids, use_cache=True)
past = out.past_key_values
out2 = model(input_ids=next_id, past_key_values=past, use_cache=True)
"""
config_class = SpikeWhaleConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["TransformerBlock"]
def __init__(self, cfg: SpikeWhaleConfig):
super().__init__(cfg)
self.model = SpikeWhaleModel(cfg)
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
nn.init.normal_(self.lm_head.weight, std=cfg.initializer_range)
if cfg.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# Multi-Token Prediction heads (NanoWhale): predict token at position+k
self.mtp_heads = nn.ModuleList([
nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
for _ in range(cfg.num_nextn_predict_layers)
]) if cfg.num_nextn_predict_layers > 0 else None
self.post_init()
# HF post_init re-randomizes Linear weights, clobbering the zero-init
# injection gate. Restore it so the latent injection is identity-at-start.
self.model.reset_latent_gate()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, SpikeWhaleModel):
module.gradient_checkpointing = value
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple]] = None,
labels: Optional[torch.Tensor] = None,
use_cache: bool = False,
inject_latent: Optional[torch.Tensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
hidden, present_kvs, aux_loss, out_latent = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
inject_latent=inject_latent,
)
logits = self.lm_head(hidden)
loss = None
if labels is not None:
# Standard next-token CE loss (shifted by 1)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
# Multi-Token Prediction loss (NanoWhale)
# Each MTP head k predicts token at position + k+1 (beyond the standard +1)
if self.mtp_heads is not None:
mtp_total = torch.tensor(0.0, device=loss.device)
for k, head in enumerate(self.mtp_heads, start=1):
offset = k + 1 # predicts position + offset
if hidden.size(1) > offset:
mtp_logits = head(hidden[..., :-offset, :].contiguous())
mtp_labels = labels[..., offset:].contiguous()
mtp_total = mtp_total + F.cross_entropy(
mtp_logits.view(-1, mtp_logits.size(-1)),
mtp_labels.view(-1),
ignore_index=-100,
)
loss = loss + mtp_total / max(len(self.mtp_heads), 1)
# MoE load-balancing aux loss (My Project)
loss = loss + aux_loss
out = CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=present_kvs,
)
out.latent = out_latent # ModularMind-on-V2: this specialist's output latent
return out
def count_parameters(self) -> int:
return sum(p.numel() for p in self.parameters())