""" model.py -- SpikeWhaleLM: combined architecture from SpikeTransformer (My Project) + NanoWhale. Architecture flow: Embedding -> Engram delta (N-gram memory, My Project) -> [expand to hc_mult copies if HC enabled] -> N x TransformerBlock: HC pre-op (NanoWhale) -> RMSNorm -> MLA+DERF+XSA Attention (combined) -> HC post-op HC pre-op -> RMSNorm -> MoE FFN w/ shared expert (NanoWhale) -> HC post-op -> [mean-pool hc_mult copies if HC enabled] -> RMSNorm -> LM head + MTP heads (NanoWhale) Component origins: RMSNorm, RotaryEmbedding -- both (standard) Engram / DERFContextGate -- My Project MLADerfXSAAttention -- MLA from NanoWhale + DERF+XSA from My Project SparseMoEFFN w/ shared expert -- NanoWhale MoE structure + My Project aux loss HyperConnectionLayer -- NanoWhale SpikeWhaleLM + MTP heads -- NanoWhale """ import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List from transformers import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from torch.utils.checkpoint import checkpoint as gradient_checkpoint from config import SpikeWhaleConfig # --------------------------------------------------------------------------- # Primitives # --------------------------------------------------------------------------- class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight class RotaryEmbedding(nn.Module): """RoPE for the rope partition of Q and K (qk_rope_head_dim dims only).""" def __init__(self, dim: int, max_positions: int = 4096, theta: float = 10000.0): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(max_positions).float() freqs = torch.outer(t, inv_freq) self.register_buffer("cos_cache", freqs.cos()) self.register_buffer("sin_cache", freqs.sin()) def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: """ x: [B, H, S, rope_dim] position_ids: [B, S] """ cos = self.cos_cache[position_ids].unsqueeze(1) # [B, 1, S, rope_dim//2] sin = self.sin_cache[position_ids].unsqueeze(1) d = cos.shape[-1] x1, x2 = x[..., :d], x[..., d:] return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) # --------------------------------------------------------------------------- # Engram: N-gram hash lookup + DERF gate (My Project, preserved) # --------------------------------------------------------------------------- class TokenCompressor(nn.Module): def __init__(self, embed_dim: int, compress_dim: int): super().__init__() self.proj = nn.Linear(embed_dim, compress_dim, bias=False) nn.init.normal_(self.proj.weight, std=0.02) # BUGFIX: this projection feeds ONLY the integer hash index # (idx = h.abs().long() % table_size) in MultiHeadHashLookup. The .long() # cast is non-differentiable, so no gradient ever reaches this weight -- # it can never learn. Worse, _classify_params put it in the weight-decay # group, so AdamW was steadily shrinking it toward zero and degrading the # hash projection over a long run. Freeze it: a fixed random projection is # exactly the right behavior for an LSH-style hash, and freezing drops it # from the optimizer (saves state) and from weight decay. Checkpoint-safe: # the parameter still exists and is still saved/loaded in state_dict. self.proj.weight.requires_grad_(False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) class MultiHeadHashLookup(nn.Module): def __init__(self, num_heads: int, table_size: int, compress_dim: int, out_dim: int, max_ngram: int = 3): super().__init__() self.num_heads = num_heads self.table_size = table_size self.max_ngram = max_ngram self.out_dim = out_dim self.tables = nn.ModuleList([ nn.Embedding(table_size, out_dim) for _ in range(num_heads) ]) for t in self.tables: nn.init.normal_(t.weight, std=0.01) for n in range(1, max_ngram + 1): for k in range(n): proj = torch.randn(num_heads, compress_dim) proj = proj / (proj.norm(dim=1, keepdim=True) + 1e-8) self.register_buffer(f"hash_proj_n{n}_p{k}", proj) def forward(self, compressed: torch.Tensor) -> torch.Tensor: """ compressed: [B, S, compress_dim] returns: [B, S, out_dim] All positions are processed in parallel. The outer loop runs max_ngram times (≤3), not S times (≤2048). Each iteration is a single matmul + embedding lookup across the whole sequence, making this GPU-friendly and compatible with torch.compile. """ B, S, _ = compressed.shape device = compressed.device out = torch.zeros(B, S, self.out_dim, device=device, dtype=compressed.dtype) # Per-position normalization: tracks how many (n-gram × head) contributions # each position receives. Positions near the start get fewer contributions # because shorter n-grams don't exist yet (matches original causal behavior). norm = torch.zeros(S, device=device) for n in range(1, self.max_ngram + 1): if S < n: continue valid_len = S - n + 1 # positions [n-1 .. S-1] are valid for order-n start = n - 1 # Accumulate position-k contribution to the order-n hash. # compressed[:, k : k+valid_len, :] is the k-th token of every n-gram # window simultaneously → [B, valid_len, num_heads] after projection. h = torch.zeros(B, valid_len, self.num_heads, device=device) for k in range(n): proj = getattr(self, f"hash_proj_n{n}_p{k}") # [num_heads, compress_dim] h = h + torch.matmul(compressed[:, k:k + valid_len, :].float(), proj.t()) idx = h.abs().long() % self.table_size # [B, valid_len, num_heads] for head_idx, table in enumerate(self.tables): out[:, start:, :] = out[:, start:, :] + table(idx[:, :, head_idx]) norm[start:] += self.num_heads # Cast back to input dtype: the norm division promotes bf16→float32 under autocast. # Keeping the output in the same dtype as the input avoids a silent dtype mismatch # when EngramModule adds this result back onto the (bf16) embedding tensor. return (out / norm.view(1, -1, 1).clamp(min=1)).to(compressed.dtype) class DERFContextGate(nn.Module): """ DERF gate: gate = gamma * erf(alpha * proj([retrieved, x]) + bias) Positive probability = (gate + 1) / 2 applied to retrieved embedding. Large negative init_bias keeps gate closed at start of training. """ def __init__(self, obs_size: int, init_bias: float = -4.0): super().__init__() self.proj = nn.Linear(obs_size * 2, obs_size) self.alpha = nn.Parameter(torch.ones(obs_size)) self.bias = nn.Parameter(torch.full((obs_size,), init_bias)) self.gamma = nn.Parameter(torch.ones(obs_size)) def forward(self, retrieved: torch.Tensor, x: torch.Tensor) -> torch.Tensor: logits = self.proj(torch.cat([retrieved, x], dim=-1)) gate = self.gamma * ((torch.erf(self.alpha * logits + self.bias) + 1.0) / 2.0) return retrieved * gate class EngramModule(nn.Module): """ N-gram hash lookup with DERF gate (My Project), fully vectorized. All S positions are processed in parallel — the sequential Python loop over sequence positions has been eliminated. The lookup now accepts the full [B, S, compress_dim] compressed tensor and returns [B, S, H] in one pass. """ def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.compressor = TokenCompressor(cfg.hidden_size, cfg.engram_compress_dim) self.lookup = MultiHeadHashLookup( cfg.engram_num_heads, cfg.engram_table_size, cfg.engram_compress_dim, cfg.hidden_size, cfg.engram_max_ngram, ) self.gate = DERFContextGate(cfg.hidden_size, cfg.engram_gate_init_bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """x: [B, S, H] -> engram_delta: [B, S, H]""" compressed = self.compressor(x.detach()) # [B, S, compress_dim] retrieved = self.lookup(compressed) # [B, S, H] return self.gate(retrieved, x) # [B, S, H] # --------------------------------------------------------------------------- # Hyper-Connections (NanoWhale, simplified) # --------------------------------------------------------------------------- class HyperConnectionLayer(nn.Module): """ Simplified Hyper-Connections for one sublayer (attention or FFN). Maintains hc_mult parallel residual streams. Pre-op: learned weighted average of hc_mult copies -> single hidden state for sublayer. Post-op: sublayer output added to each copy with learned per-stream weights. Full HC uses Sinkhorn-normalized 2D routing matrices; this uses softmax-normalized 1D weights for pre/post routing -- captures the same multi-stream routing spirit. """ def __init__(self, hidden_size: int, hc_mult: int, sinkhorn_iters: int = 20, eps: float = 1e-6): super().__init__() self.hc_mult = hc_mult # pre_weight: how to mix hc_mult copies into one sublayer input # post_weight: how to distribute the sublayer delta to each copy # # BUGFIX: these must NOT be initialized identically across streams. # The model expands the hidden state into hc_mult *identical* copies. # With uniform pre/post weights, pre_op produces sum_i copy_i * w_i = # copy * sum(softmax)=copy (all copies equal), and post_op adds the same # delta to every copy -- so the streams stay byte-for-byte identical at # every layer. When all streams are equal, the softmax Jacobian applied # to the (equal) per-stream gradients is exactly zero, so pre_weight and # post_weight receive ZERO gradient and never move off 1/hc_mult. The HC # routing then learns nothing and just burns hc_mult x memory/compute. # # Breaking the post_weight symmetry at init makes the streams diverge # after the first sublayer, which restores gradient flow to all HC # weights. We center post_weight so softmax starts near-uniform (keeps # the residual baseline ~unchanged) but with a distinct value per stream. self.pre_weight = nn.Parameter( torch.linspace(0.5, -0.5, hc_mult) / max(hc_mult, 1) ) self.post_weight = nn.Parameter( torch.linspace(-0.5, 0.5, hc_mult) / max(hc_mult, 1) ) def pre_op(self, copies: torch.Tensor) -> torch.Tensor: """copies: [B, hc_mult, S, H] -> [B, S, H]""" w = F.softmax(self.pre_weight, dim=0) # [hc_mult] return (copies * w.view(1, -1, 1, 1)).sum(dim=1) def post_op(self, copies: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: """ copies: [B, hc_mult, S, H] delta: [B, S, H] Returns updated copies: [B, hc_mult, S, H] """ w = F.softmax(self.post_weight, dim=0) # [hc_mult] return copies + delta.unsqueeze(1) * w.view(1, -1, 1, 1) # --------------------------------------------------------------------------- # MLA + DERF + XSA Attention (combined) # --------------------------------------------------------------------------- class MLADerfXSAAttention(nn.Module): """ Multi-Head Latent Attention (NanoWhale) with DERF scores + XSA correction (My Project). MLA (from NanoWhale): Q: hidden -> q_lora_rank (RMSNorm) -> num_heads * head_dim (low-rank projection) K, V: hidden -> num_kv_heads * head_dim (direct, MQA by default with num_kv_heads=1) Output: num_heads * head_dim -> o_lora_rank -> hidden (low-rank output) Partial RoPE: applied only to the last qk_rope_head_dim dims of Q and K DERF (from My Project): Replaces softmax: erf(alpha * scores + bias) * gamma, shifted to [0,1] then normalized. Per-head learnable alpha, bias, gamma. XSA (from My Project): After computing the weighted value sum y, subtract the component of y that projects onto each position's own value vector. Forces the output to carry only cross-position information, not echo the current token back. """ def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.num_heads = cfg.num_attention_heads self.num_kv_heads = cfg.num_key_value_heads self.head_dim = cfg.head_dim self.qk_rope_head_dim = cfg.qk_rope_head_dim self.nope_head_dim = cfg.nope_head_dim self.hidden_size = cfg.hidden_size self.use_derf = cfg.use_derf self.use_xsa = cfg.use_xsa self.dropout_p = cfg.attention_dropout self.kv_groups = self.num_heads // self.num_kv_heads # Low-rank Q projection (MLA) self.q_a_proj = nn.Linear(cfg.hidden_size, cfg.q_lora_rank, bias=False) self.q_a_norm = RMSNorm(cfg.q_lora_rank, cfg.rms_norm_eps) self.q_b_proj = nn.Linear(cfg.q_lora_rank, self.num_heads * self.head_dim, bias=False) # Direct K, V projections (MQA/GQA) self.k_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) # Low-rank output projection (MLA) self.o_a_proj = nn.Linear(self.num_heads * self.head_dim, cfg.o_lora_rank, bias=False) self.o_b_proj = nn.Linear(cfg.o_lora_rank, cfg.hidden_size, bias=False) # Partial RoPE: applied to qk_rope_head_dim dims only self.rope = RotaryEmbedding( self.qk_rope_head_dim, max_positions=cfg.max_position_embeddings, theta=cfg.rope_theta, ) # DERF parameters: one per query head (My Project) if self.use_derf: self.derf_alpha = nn.Parameter(torch.ones(self.num_heads)) self.derf_bias = nn.Parameter(torch.zeros(self.num_heads)) self.derf_gamma = nn.Parameter(torch.ones(self.num_heads)) nn.init.normal_(self.q_a_proj.weight, std=cfg.initializer_range) nn.init.normal_(self.q_b_proj.weight, std=cfg.initializer_range) nn.init.normal_(self.k_proj.weight, std=cfg.initializer_range) nn.init.normal_(self.v_proj.weight, std=cfg.initializer_range) nn.init.normal_(self.o_a_proj.weight, std=cfg.initializer_range) nn.init.normal_(self.o_b_proj.weight, std=cfg.initializer_range) def forward( self, x: torch.Tensor, position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: B, S, _ = x.shape # Q via low-rank projection with intermediate norm (MLA) q = self.q_a_norm(self.q_a_proj(x)) q = self.q_b_proj(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) # [B, num_heads, S, head_dim] # K, V direct projections k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) # Partial RoPE: split into nope and rope partitions, rotate only the rope part q_nope = q[..., :self.nope_head_dim] q_rope = q[..., self.nope_head_dim:] # qk_rope_head_dim dims k_nope = k[..., :self.nope_head_dim] k_rope = k[..., self.nope_head_dim:] q_rope = self.rope(q_rope, position_ids) k_rope = self.rope(k_rope, position_ids) q = torch.cat([q_nope, q_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1) # KV cache for inference if past_key_value is not None: k = torch.cat([past_key_value[0], k], dim=2) v = torch.cat([past_key_value[1], v], dim=2) present = (k, v) if use_cache else None N = k.shape[2] # total key positions (past + current) # Expand KV heads for MQA/GQA if self.kv_groups > 1: k = k.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( B, self.num_heads, N, self.head_dim) v = v.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( B, self.num_heads, N, self.head_dim) # Scaled dot-product attention. if self.use_derf: # DERF replaces softmax with a custom erf nonlinearity, so it cannot # use the fused kernel and must materialize scores explicitly. scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) # Build boolean mask for causality (this avoids the -inf math errors) if attention_mask is None and past_key_value is None: is_masked = torch.triu(torch.ones(S, N, dtype=torch.bool, device=scores.device), diagonal=N - S + 1).unsqueeze(0).unsqueeze(0) else: is_masked = (attention_mask < -1.0) if attention_mask is not None else torch.zeros_like(scores, dtype=torch.bool) # FIX 2: Do NOT use float('-inf'). If alpha ever hits 0.0, 0.0 * -inf = NaN. # Use a safe negative scalar (-10000.0) for masked positions. safe_scores = scores.masked_fill(is_masked, -10000.0) a = self.derf_alpha.view(1, -1, 1, 1) b = self.derf_bias.view(1, -1, 1, 1) g = self.derf_gamma.view(1, -1, 1, 1) attn_weights = g * torch.erf(a * safe_scores + b) # [-gamma, gamma] attn_weights = (attn_weights + g) / 2.0 # shift to [0, gamma] attn_weights = attn_weights.masked_fill(is_masked, 0.0) # enforce causal mask safely attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-8) if self.dropout_p > 0 and self.training: attn_weights = F.dropout(attn_weights, p=self.dropout_p) y = torch.matmul(attn_weights, v) # [B, num_heads, S, head_dim] else: # OPTIMIZATION: standard (softmax) attention goes through the fused # scaled_dot_product_attention kernel (FlashAttention / mem-efficient # backends). This is the hot path during pretraining (use_derf=False) # and is much faster + lower memory than materializing [B,H,S,N] # scores and a softmax. SDPA already scales by 1/sqrt(head_dim). # # CONTIGUITY FIX: with MQA/GQA, k and v above are built via # .unsqueeze(2).expand(...).reshape(...). Under torch.compile, inductor # can trace the broadcasted (zero-stride) view through to the fused # flash-attention BACKWARD kernel, whose meta-kernel then asserts on the # mismatched stride (e.g. "stride 120==245760 at dim=1") and aborts. # Forcing contiguity guarantees standard strides into the fused kernel. q = q.contiguous() k = k.contiguous() v = v.contiguous() drop = self.dropout_p if self.training else 0.0 if past_key_value is None and attention_mask is None: # Prefill / training: pure causal mask, no materialization needed. y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=drop) else: # Incremental decode or a provided mask: pass an explicit boolean # keep-mask (True = attend). SDPA fills masked positions with -inf. if attention_mask is not None: is_masked = (attention_mask < -1.0) else: is_masked = torch.triu( torch.ones(S, N, dtype=torch.bool, device=q.device), diagonal=N - S + 1, ).unsqueeze(0).unsqueeze(0) y = F.scaled_dot_product_attention( q, k, v, attn_mask=~is_masked, dropout_p=drop) # XSA: remove self-projection from output (My Project) # For each query position s, subtract the component of y[:,:,s,:] that # projects onto the normalized value vector at the same position. if self.use_xsa: past_len = N - S v_self = v[:, :, past_len:past_len + S, :] # [B, H, S, D] vn = v_self / (v_self.norm(dim=-1, keepdim=True) + 1e-8) projection = (y * vn).sum(dim=-1, keepdim=True) * vn y = y - projection # Low-rank output projection (MLA) y = y.transpose(1, 2).contiguous().view(B, S, self.num_heads * self.head_dim) y = self.o_b_proj(self.o_a_proj(y)) return y, present # --------------------------------------------------------------------------- # MoE FFN: shared expert + sqrtsoftplus + hash routing (NanoWhale) + aux loss (My Project) # --------------------------------------------------------------------------- class ExpertFFN(nn.Module): """Single SwiGLU expert.""" def __init__(self, hidden_size: int, intermediate_size: int): super().__init__() self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) def sqrtsoftplus(x: torch.Tensor) -> torch.Tensor: """sqrt(softplus(x)) = sqrt(log(1+exp(x))). NanoWhale expert scoring.""" # FIX 1: Added 1e-8. If F.softplus(x) evaluates to 0.0, torch.sqrt(0) produces NaN gradients on backward pass. return torch.sqrt(F.softplus(x) + 1e-8) class SparseMoEFFN(nn.Module): """ Combines NanoWhale MoE structure with My Project aux loss: - n_shared_experts always-active experts (NanoWhale) - n_routed_experts sparse routed experts, top-k activation - sqrtsoftplus scoring (NanoWhale) vs softmax - hash routing for early layers (NanoWhale) - norm_topk_prob + routed_scaling_factor (NanoWhale) - load-balancing aux loss (My Project) """ def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int = 0): super().__init__() self.n_routed_experts = cfg.n_routed_experts self.n_shared_experts = cfg.n_shared_experts self.num_experts_per_tok = cfg.num_experts_per_tok self.norm_topk_prob = cfg.norm_topk_prob self.scoring_func = cfg.scoring_func self.routed_scaling_factor = cfg.routed_scaling_factor self.use_hash_routing = layer_idx < cfg.num_hash_layers self.aux_loss_coef = cfg.moe_aux_loss_coef self.router = nn.Linear(cfg.hidden_size, cfg.n_routed_experts, bias=False) self.experts = nn.ModuleList([ ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) for _ in range(cfg.n_routed_experts) ]) self.shared_experts = nn.ModuleList([ ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) for _ in range(cfg.n_shared_experts) ]) if cfg.n_shared_experts > 0 else None self._last_aux_loss: Optional[torch.Tensor] = None def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: B, S, H = x.shape x_flat = x.view(B * S, H) T = B * S # Shared experts: always active (NanoWhale) shared_out = torch.zeros_like(x_flat) if self.shared_experts: for expert in self.shared_experts: shared_out = shared_out + expert(x_flat) if len(self.shared_experts) > 1: shared_out = shared_out / len(self.shared_experts) # Router if self.use_hash_routing: # Hash routing: deterministic assignment without learned router (NanoWhale). # Assign each of the num_experts_per_tok slots a DISTINCT expert by cycling: # token at absolute position p -> experts [p%n, (p+1)%n, ..., (p+k-1)%n]. # # BUGFIX: the assignment must key off the token's ABSOLUTE sequence # position, not torch.arange(T) (its index in the current flattened # batch). With arange(T), incremental KV-cache decoding (S=1) always # sees index 0 and routes every token to expert 0, so generation used # a different expert assignment than training and silently diverged. # Using position_ids makes prefill, full-sequence training, and # step-by-step generation all agree. (For S divisible by n_experts, # this matches the previous training-time behavior exactly, so existing # checkpoints stay valid.) if position_ids is not None: base = (position_ids.reshape(T, 1) % self.n_routed_experts).long() else: base = (torch.arange(T, device=x.device) % self.n_routed_experts).unsqueeze(1) offsets = torch.arange(self.num_experts_per_tok, device=x.device) # [k] top_k_indices = (base + offsets.unsqueeze(0)) % self.n_routed_experts # [T, k] top_k_weights = torch.ones(T, self.num_experts_per_tok, device=x.device) / self.num_experts_per_tok self._last_aux_loss = None else: router_logits = self.router(x_flat) if self.scoring_func == "sqrtsoftplus": routing_scores = sqrtsoftplus(router_logits) else: routing_scores = F.softmax(router_logits, dim=-1) top_k_scores, top_k_indices = torch.topk(routing_scores, self.num_experts_per_tok, dim=-1) if self.norm_topk_prob: top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8) else: top_k_weights = top_k_scores top_k_weights = top_k_weights * self.routed_scaling_factor # Load-balancing aux loss (My Project) softmax_probs = F.softmax(router_logits, dim=-1) expert_mask = torch.zeros_like(softmax_probs) expert_mask.scatter_(1, top_k_indices, 1.0) f_e = expert_mask.mean(0) p_e = softmax_probs.mean(0) self._last_aux_loss = self.n_routed_experts * (f_e * p_e).sum() * self.aux_loss_coef # Dispatch tokens to routed experts out = torch.zeros_like(x_flat) for expert_idx, expert in enumerate(self.experts): token_mask = (top_k_indices == expert_idx).any(dim=-1) if not token_mask.any(): continue expert_input = x_flat[token_mask] expert_output = expert(expert_input) k_pos = (top_k_indices[token_mask] == expert_idx).nonzero(as_tuple=False) weights = top_k_weights[token_mask][k_pos[:, 0], k_pos[:, 1]].unsqueeze(-1) out[token_mask] = out[token_mask] + expert_output * weights out = out + shared_out return out.view(B, S, H) def get_aux_loss(self) -> Optional[torch.Tensor]: # Return None when hash routing (no aux loss) or when forward hasn't run yet. # Returning torch.tensor(0.0) here would be a CPU tensor and cause a device # mismatch when added to the CUDA total_aux_loss in SpikeWhaleModel. return self._last_aux_loss class DenseFFN(nn.Module): """Dense SwiGLU FFN for non-MoE layers.""" def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.gate_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) self.up_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) self.down_proj = nn.Linear(cfg.moe_intermediate_size, cfg.hidden_size, bias=False) def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) def get_aux_loss(self) -> Optional[torch.Tensor]: return None # dense layers have no aux loss; None avoids CPU-tensor device mismatch # --------------------------------------------------------------------------- # Transformer block with Hyper-Connections # --------------------------------------------------------------------------- class TransformerBlock(nn.Module): """ Transformer block combining all features: - Hyper-Connections: pre/post routing through hc_mult streams (NanoWhale) - MLA + DERF + XSA attention (combined) - MoE FFN with shared expert (NanoWhale) + aux loss (My Project) """ def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int): super().__init__() self.use_hc = cfg.use_hyper_connections self.hidden_dropout = cfg.hidden_dropout self.attn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.attn = MLADerfXSAAttention(cfg) self.ffn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) if cfg.use_moe and layer_idx in cfg.moe_layers: self.ffn = SparseMoEFFN(cfg, layer_idx) self.is_moe = True else: self.ffn = DenseFFN(cfg) self.is_moe = False if self.use_hc: self.hc_attn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, cfg.hc_sinkhorn_iters, cfg.hc_eps) self.hc_ffn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, cfg.hc_sinkhorn_iters, cfg.hc_eps) def forward( self, x: torch.Tensor, # [B, hc_mult, S, H] if HC else [B, S, H] position_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[torch.Tensor]]: # --- Attention sub-layer --- if self.use_hc: h = self.hc_attn.pre_op(x) # [B, S, H] else: h = x attn_out, present = self.attn( self.attn_norm(h), position_ids, attention_mask, past_key_value, use_cache ) attn_out = F.dropout(attn_out, p=self.hidden_dropout, training=self.training) if self.use_hc: x = self.hc_attn.post_op(x, attn_out) h = self.hc_ffn.pre_op(x) # [B, S, H] else: h = h + attn_out # --- FFN sub-layer --- ffn_out = self.ffn(self.ffn_norm(h), position_ids) ffn_out = F.dropout(ffn_out, p=self.hidden_dropout, training=self.training) if self.use_hc: x = self.hc_ffn.post_op(x, ffn_out) else: x = h + ffn_out return x, present, self.ffn.get_aux_loss() # --------------------------------------------------------------------------- # Full model # --------------------------------------------------------------------------- class HRMRefinementBlock(nn.Module): """ HRM-INSPIRED iterative refinement (EXPERIMENTAL, off by default). NOT the full Hierarchical Reasoning Model -- only the iterative-refinement mechanism that the independent ARC-Prize ablation found carried most of HRM's benefit, adapted to a causal LM's final hidden state. Runs N inner steps; each computes a small gated update conditioned on the current state AND the original ('anchor') input. Per-step gate inits at 0 and up.weight is zero-init -> the block is an EXACT identity at init, so enabling it cannot hurt a fresh model; it only contributes if training opens the gate. Pointwise over positions -> causal-safe (no future-token leakage). In/out [B,S,H]. """ def __init__(self, hidden_size: int, refine_dim: int, steps: int, eps: float = 1e-6): super().__init__() self.steps = steps self.norm = RMSNorm(hidden_size, eps) self.down = nn.Linear(hidden_size * 2, refine_dim, bias=False) self.up = nn.Linear(refine_dim, hidden_size, bias=False) self.gate = nn.Parameter(torch.zeros(steps)) nn.init.normal_(self.down.weight, std=0.02) nn.init.zeros_(self.up.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: anchor = x h = x for t in range(self.steps): inp = torch.cat([self.norm(h), anchor], dim=-1) update = self.up(F.silu(self.down(inp))) h = h + torch.tanh(self.gate[t]) * update return h class LatentProjection(nn.Module): """ModularMind-on-V2: pool final hidden state -> d_latent output vector. Mirrors ModularMind's contract: mean-pool over sequence, ReLU^2 activation (sparse latent codes), Xavier init (NOT zero) so the latent carries signal from step 1 — zero-init would make the chain unable to bootstrap.""" def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6): super().__init__() self.proj1 = nn.Linear(hidden_size, hidden_size, bias=False) self.proj2 = nn.Linear(hidden_size, d_latent, bias=False) self.norm = RMSNorm(d_latent, eps) nn.init.xavier_uniform_(self.proj1.weight) nn.init.xavier_uniform_(self.proj2.weight) def forward(self, x: torch.Tensor) -> torch.Tensor: pooled = x.mean(dim=1) # [B, S, H] -> [B, H] h = torch.relu(self.proj1(pooled)) ** 2 return self.norm(self.proj2(h)) # [B, d_latent] class LatentInjection(nn.Module): """ModularMind-on-V2: fold an incoming d_latent vector into embeddings. Broadcast across positions, ReGLU-gated add. Gate starts SMALL (not exactly zero): the injection is near-identity at init (stable) while still passing a little gradient, so the upstream RecursiveLink + specialist can bootstrap from step 1. (Exact-zero gate would block all gradient to the link -- the bootstrapping problem ModularMind's LatentProjection docstring warns about.) This is the INPUT side of RecursiveLink (the prev specialist's latent).""" def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6, gate_init: float = 1e-3): super().__init__() self.up = nn.Linear(d_latent, hidden_size, bias=False) self.norm = RMSNorm(hidden_size, eps) self.value_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False) self.gate_init = gate_init nn.init.xavier_uniform_(self.up.weight) nn.init.xavier_uniform_(self.value_proj.weight) nn.init.normal_(self.gate_proj.weight, std=gate_init) # small, not zero def forward(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: # x: [B, S, H], latent: [B, d_latent] inj = self.norm(self.up(latent)).unsqueeze(1) # [B, 1, H] broadcast over S value = self.value_proj(inj) gate = torch.relu(self.gate_proj(inj)) return x + value * gate class RecursiveLink(nn.Module): """ModularMind cross-specialist bridge, V2 build. Converts one specialist's output latent into the next specialist's input latent. ReGLU + residual, single shared module reused for every hop. Fully differentiable.""" def __init__(self, d_latent: int = 256, expansion: float = 2.0): super().__init__() d_hidden = int(d_latent * expansion) self.norm = nn.LayerNorm(d_latent) self.value_proj = nn.Linear(d_latent, d_hidden, bias=False) self.gate_proj = nn.Linear(d_latent, d_hidden, bias=False) self.down = nn.Linear(d_hidden, d_latent, bias=False) self.residual_gate = nn.Parameter(torch.ones(1)) nn.init.xavier_uniform_(self.value_proj.weight) nn.init.xavier_uniform_(self.gate_proj.weight) nn.init.xavier_uniform_(self.down.weight) def forward(self, z: torch.Tensor) -> torch.Tensor: n = self.norm(z) h = self.value_proj(n) * torch.relu(self.gate_proj(n)) return z + self.residual_gate * self.down(h) class SpikeWhaleModel(nn.Module): """Decoder stack without LM head.""" def __init__(self, cfg: SpikeWhaleConfig): super().__init__() self.cfg = cfg self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) nn.init.normal_(self.embed_tokens.weight, std=cfg.initializer_range) self.engram = EngramModule(cfg) if cfg.use_engram else None self.layers = nn.ModuleList([ TransformerBlock(cfg, layer_idx=i) for i in range(cfg.num_hidden_layers) ]) self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.hrm_refine = ( HRMRefinementBlock(cfg.hidden_size, cfg.hrm_refine_dim, cfg.hrm_refine_steps, cfg.rms_norm_eps) if getattr(cfg, "use_hrm_refine", False) else None ) # ModularMind-on-V2: latent input/output (off unless use_latent_io) if getattr(cfg, "use_latent_io", False): self.latent_inject = LatentInjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) self.latent_out = LatentProjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) else: self.latent_inject = None self.latent_out = None self.gradient_checkpointing = False def reset_latent_gate(self): """Re-init the injection gate SMALL (not zero). Must be called AFTER any HF post_init/_init_weights pass, which otherwise re-randomizes the gate to full scale. Small-but-nonzero keeps injection near-identity at start while letting gradient reach the upstream RecursiveLink (so the chain can bootstrap).""" if self.latent_inject is not None: nn.init.normal_(self.latent_inject.gate_proj.weight, std=self.latent_inject.gate_init) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple]] = None, use_cache: bool = False, inject_latent: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[List[Tuple]], torch.Tensor]: B, S = input_ids.shape device = input_ids.device if position_ids is None: past_len = past_key_values[0][0].shape[2] if past_key_values else 0 position_ids = torch.arange( past_len, past_len + S, device=device ).unsqueeze(0).expand(B, -1) # Token embedding x = self.embed_tokens(input_ids) # [B, S, H] # Engram N-gram delta (My Project) if self.engram is not None: x = x + self.engram(x) # ModularMind-on-V2: inject the previous specialist's latent (broadcast # across positions, ReGLU-gated). No-op at init (gate zero) and skipped # entirely if no latent is passed. if self.latent_inject is not None and inject_latent is not None: x = self.latent_inject(x, inject_latent) # Expand to hc_mult streams for Hyper-Connections (NanoWhale) if self.cfg.use_hyper_connections: x = x.unsqueeze(1).expand(-1, self.cfg.hc_mult, -1, -1).clone() # [B, hc_mult, S, H] present_key_values = [] if use_cache else None total_aux_loss = torch.tensor(0.0, device=device) for layer_idx, layer in enumerate(self.layers): pkv = past_key_values[layer_idx] if past_key_values else None if self.gradient_checkpointing and self.training: # Gradient checkpointing with use_reentrant=False (NanoWhale) x, present, aux_loss = gradient_checkpoint( layer, x, position_ids, attention_mask, None, False, use_reentrant=False, ) else: x, present, aux_loss = layer(x, position_ids, attention_mask, pkv, use_cache) if use_cache: present_key_values.append(present) if aux_loss is not None: total_aux_loss = total_aux_loss + aux_loss # Reduce HC streams to single hidden state if self.cfg.use_hyper_connections: x = x.mean(dim=1) # [B, S, H] if self.hrm_refine is not None: x = self.hrm_refine(x) x = self.norm(x) # ModularMind-on-V2: emit this specialist's output latent (for RecursiveLink). out_latent = self.latent_out(x) if self.latent_out is not None else None return x, present_key_values, total_aux_loss, out_latent class SpikeWhaleLM(PreTrainedModel): """ Full causal LM combining all SpikeTransformer + NanoWhale features. Training (forward with labels): out = model(input_ids=ids, labels=ids) loss = out.loss # CE + MTP loss + MoE aux loss Generation: out = model(input_ids=ids, use_cache=True) past = out.past_key_values out2 = model(input_ids=next_id, past_key_values=past, use_cache=True) """ config_class = SpikeWhaleConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["TransformerBlock"] def __init__(self, cfg: SpikeWhaleConfig): super().__init__(cfg) self.model = SpikeWhaleModel(cfg) self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) nn.init.normal_(self.lm_head.weight, std=cfg.initializer_range) if cfg.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight # Multi-Token Prediction heads (NanoWhale): predict token at position+k self.mtp_heads = nn.ModuleList([ nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) for _ in range(cfg.num_nextn_predict_layers) ]) if cfg.num_nextn_predict_layers > 0 else None self.post_init() # HF post_init re-randomizes Linear weights, clobbering the zero-init # injection gate. Restore it so the latent injection is identity-at-start. self.model.reset_latent_gate() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, SpikeWhaleModel): module.gradient_checkpointing = value def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple]] = None, labels: Optional[torch.Tensor] = None, use_cache: bool = False, inject_latent: Optional[torch.Tensor] = None, **kwargs, ) -> CausalLMOutputWithPast: hidden, present_kvs, aux_loss, out_latent = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, inject_latent=inject_latent, ) logits = self.lm_head(hidden) loss = None if labels is not None: # Standard next-token CE loss (shifted by 1) shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) # Multi-Token Prediction loss (NanoWhale) # Each MTP head k predicts token at position + k+1 (beyond the standard +1) if self.mtp_heads is not None: mtp_total = torch.tensor(0.0, device=loss.device) for k, head in enumerate(self.mtp_heads, start=1): offset = k + 1 # predicts position + offset if hidden.size(1) > offset: mtp_logits = head(hidden[..., :-offset, :].contiguous()) mtp_labels = labels[..., offset:].contiguous() mtp_total = mtp_total + F.cross_entropy( mtp_logits.view(-1, mtp_logits.size(-1)), mtp_labels.view(-1), ignore_index=-100, ) loss = loss + mtp_total / max(len(self.mtp_heads), 1) # MoE load-balancing aux loss (My Project) loss = loss + aux_loss out = CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=present_kvs, ) out.latent = out_latent # ModularMind-on-V2: this specialist's output latent return out def count_parameters(self) -> int: return sum(p.numel() for p in self.parameters())