Spaces:
Running on Zero
Running on Zero
| """ | |
| model.py -- SpikeWhaleLM: combined architecture from SpikeTransformer (My Project) + NanoWhale. | |
| Architecture flow: | |
| Embedding | |
| -> Engram delta (N-gram memory, My Project) | |
| -> [expand to hc_mult copies if HC enabled] | |
| -> N x TransformerBlock: | |
| HC pre-op (NanoWhale) -> RMSNorm -> MLA+DERF+XSA Attention (combined) | |
| -> HC post-op | |
| HC pre-op -> RMSNorm -> MoE FFN w/ shared expert (NanoWhale) | |
| -> HC post-op | |
| -> [mean-pool hc_mult copies if HC enabled] | |
| -> RMSNorm | |
| -> LM head + MTP heads (NanoWhale) | |
| Component origins: | |
| RMSNorm, RotaryEmbedding -- both (standard) | |
| Engram / DERFContextGate -- My Project | |
| MLADerfXSAAttention -- MLA from NanoWhale + DERF+XSA from My Project | |
| SparseMoEFFN w/ shared expert -- NanoWhale MoE structure + My Project aux loss | |
| HyperConnectionLayer -- NanoWhale | |
| SpikeWhaleLM + MTP heads -- NanoWhale | |
| """ | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, List | |
| from transformers import PreTrainedModel | |
| from transformers.modeling_outputs import CausalLMOutputWithPast | |
| from torch.utils.checkpoint import checkpoint as gradient_checkpoint | |
| from config import SpikeWhaleConfig | |
| # --------------------------------------------------------------------------- | |
| # Primitives | |
| # --------------------------------------------------------------------------- | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight | |
| class RotaryEmbedding(nn.Module): | |
| """RoPE for the rope partition of Q and K (qk_rope_head_dim dims only).""" | |
| def __init__(self, dim: int, max_positions: int = 4096, theta: float = 10000.0): | |
| super().__init__() | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| t = torch.arange(max_positions).float() | |
| freqs = torch.outer(t, inv_freq) | |
| self.register_buffer("cos_cache", freqs.cos()) | |
| self.register_buffer("sin_cache", freqs.sin()) | |
| def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> torch.Tensor: | |
| """ | |
| x: [B, H, S, rope_dim] | |
| position_ids: [B, S] | |
| """ | |
| cos = self.cos_cache[position_ids].unsqueeze(1) # [B, 1, S, rope_dim//2] | |
| sin = self.sin_cache[position_ids].unsqueeze(1) | |
| d = cos.shape[-1] | |
| x1, x2 = x[..., :d], x[..., d:] | |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) | |
| # --------------------------------------------------------------------------- | |
| # Engram: N-gram hash lookup + DERF gate (My Project, preserved) | |
| # --------------------------------------------------------------------------- | |
| class TokenCompressor(nn.Module): | |
| def __init__(self, embed_dim: int, compress_dim: int): | |
| super().__init__() | |
| self.proj = nn.Linear(embed_dim, compress_dim, bias=False) | |
| nn.init.normal_(self.proj.weight, std=0.02) | |
| # BUGFIX: this projection feeds ONLY the integer hash index | |
| # (idx = h.abs().long() % table_size) in MultiHeadHashLookup. The .long() | |
| # cast is non-differentiable, so no gradient ever reaches this weight -- | |
| # it can never learn. Worse, _classify_params put it in the weight-decay | |
| # group, so AdamW was steadily shrinking it toward zero and degrading the | |
| # hash projection over a long run. Freeze it: a fixed random projection is | |
| # exactly the right behavior for an LSH-style hash, and freezing drops it | |
| # from the optimizer (saves state) and from weight decay. Checkpoint-safe: | |
| # the parameter still exists and is still saved/loaded in state_dict. | |
| self.proj.weight.requires_grad_(False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.proj(x) | |
| class MultiHeadHashLookup(nn.Module): | |
| def __init__(self, num_heads: int, table_size: int, | |
| compress_dim: int, out_dim: int, max_ngram: int = 3): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.table_size = table_size | |
| self.max_ngram = max_ngram | |
| self.out_dim = out_dim | |
| self.tables = nn.ModuleList([ | |
| nn.Embedding(table_size, out_dim) for _ in range(num_heads) | |
| ]) | |
| for t in self.tables: | |
| nn.init.normal_(t.weight, std=0.01) | |
| for n in range(1, max_ngram + 1): | |
| for k in range(n): | |
| proj = torch.randn(num_heads, compress_dim) | |
| proj = proj / (proj.norm(dim=1, keepdim=True) + 1e-8) | |
| self.register_buffer(f"hash_proj_n{n}_p{k}", proj) | |
| def forward(self, compressed: torch.Tensor) -> torch.Tensor: | |
| """ | |
| compressed: [B, S, compress_dim] | |
| returns: [B, S, out_dim] | |
| All positions are processed in parallel. The outer loop runs max_ngram | |
| times (≤3), not S times (≤2048). Each iteration is a single matmul + | |
| embedding lookup across the whole sequence, making this GPU-friendly | |
| and compatible with torch.compile. | |
| """ | |
| B, S, _ = compressed.shape | |
| device = compressed.device | |
| out = torch.zeros(B, S, self.out_dim, device=device, dtype=compressed.dtype) | |
| # Per-position normalization: tracks how many (n-gram × head) contributions | |
| # each position receives. Positions near the start get fewer contributions | |
| # because shorter n-grams don't exist yet (matches original causal behavior). | |
| norm = torch.zeros(S, device=device) | |
| for n in range(1, self.max_ngram + 1): | |
| if S < n: | |
| continue | |
| valid_len = S - n + 1 # positions [n-1 .. S-1] are valid for order-n | |
| start = n - 1 | |
| # Accumulate position-k contribution to the order-n hash. | |
| # compressed[:, k : k+valid_len, :] is the k-th token of every n-gram | |
| # window simultaneously → [B, valid_len, num_heads] after projection. | |
| h = torch.zeros(B, valid_len, self.num_heads, device=device) | |
| for k in range(n): | |
| proj = getattr(self, f"hash_proj_n{n}_p{k}") # [num_heads, compress_dim] | |
| h = h + torch.matmul(compressed[:, k:k + valid_len, :].float(), proj.t()) | |
| idx = h.abs().long() % self.table_size # [B, valid_len, num_heads] | |
| for head_idx, table in enumerate(self.tables): | |
| out[:, start:, :] = out[:, start:, :] + table(idx[:, :, head_idx]) | |
| norm[start:] += self.num_heads | |
| # Cast back to input dtype: the norm division promotes bf16→float32 under autocast. | |
| # Keeping the output in the same dtype as the input avoids a silent dtype mismatch | |
| # when EngramModule adds this result back onto the (bf16) embedding tensor. | |
| return (out / norm.view(1, -1, 1).clamp(min=1)).to(compressed.dtype) | |
| class DERFContextGate(nn.Module): | |
| """ | |
| DERF gate: gate = gamma * erf(alpha * proj([retrieved, x]) + bias) | |
| Positive probability = (gate + 1) / 2 applied to retrieved embedding. | |
| Large negative init_bias keeps gate closed at start of training. | |
| """ | |
| def __init__(self, obs_size: int, init_bias: float = -4.0): | |
| super().__init__() | |
| self.proj = nn.Linear(obs_size * 2, obs_size) | |
| self.alpha = nn.Parameter(torch.ones(obs_size)) | |
| self.bias = nn.Parameter(torch.full((obs_size,), init_bias)) | |
| self.gamma = nn.Parameter(torch.ones(obs_size)) | |
| def forward(self, retrieved: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| logits = self.proj(torch.cat([retrieved, x], dim=-1)) | |
| gate = self.gamma * ((torch.erf(self.alpha * logits + self.bias) + 1.0) / 2.0) | |
| return retrieved * gate | |
| class EngramModule(nn.Module): | |
| """ | |
| N-gram hash lookup with DERF gate (My Project), fully vectorized. | |
| All S positions are processed in parallel — the sequential Python loop | |
| over sequence positions has been eliminated. The lookup now accepts the | |
| full [B, S, compress_dim] compressed tensor and returns [B, S, H] in one pass. | |
| """ | |
| def __init__(self, cfg: SpikeWhaleConfig): | |
| super().__init__() | |
| self.compressor = TokenCompressor(cfg.hidden_size, cfg.engram_compress_dim) | |
| self.lookup = MultiHeadHashLookup( | |
| cfg.engram_num_heads, cfg.engram_table_size, | |
| cfg.engram_compress_dim, cfg.hidden_size, cfg.engram_max_ngram, | |
| ) | |
| self.gate = DERFContextGate(cfg.hidden_size, cfg.engram_gate_init_bias) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """x: [B, S, H] -> engram_delta: [B, S, H]""" | |
| compressed = self.compressor(x.detach()) # [B, S, compress_dim] | |
| retrieved = self.lookup(compressed) # [B, S, H] | |
| return self.gate(retrieved, x) # [B, S, H] | |
| # --------------------------------------------------------------------------- | |
| # Hyper-Connections (NanoWhale, simplified) | |
| # --------------------------------------------------------------------------- | |
| class HyperConnectionLayer(nn.Module): | |
| """ | |
| Simplified Hyper-Connections for one sublayer (attention or FFN). | |
| Maintains hc_mult parallel residual streams. | |
| Pre-op: learned weighted average of hc_mult copies -> single hidden state for sublayer. | |
| Post-op: sublayer output added to each copy with learned per-stream weights. | |
| Full HC uses Sinkhorn-normalized 2D routing matrices; this uses softmax-normalized | |
| 1D weights for pre/post routing -- captures the same multi-stream routing spirit. | |
| """ | |
| def __init__(self, hidden_size: int, hc_mult: int, | |
| sinkhorn_iters: int = 20, eps: float = 1e-6): | |
| super().__init__() | |
| self.hc_mult = hc_mult | |
| # pre_weight: how to mix hc_mult copies into one sublayer input | |
| # post_weight: how to distribute the sublayer delta to each copy | |
| # | |
| # BUGFIX: these must NOT be initialized identically across streams. | |
| # The model expands the hidden state into hc_mult *identical* copies. | |
| # With uniform pre/post weights, pre_op produces sum_i copy_i * w_i = | |
| # copy * sum(softmax)=copy (all copies equal), and post_op adds the same | |
| # delta to every copy -- so the streams stay byte-for-byte identical at | |
| # every layer. When all streams are equal, the softmax Jacobian applied | |
| # to the (equal) per-stream gradients is exactly zero, so pre_weight and | |
| # post_weight receive ZERO gradient and never move off 1/hc_mult. The HC | |
| # routing then learns nothing and just burns hc_mult x memory/compute. | |
| # | |
| # Breaking the post_weight symmetry at init makes the streams diverge | |
| # after the first sublayer, which restores gradient flow to all HC | |
| # weights. We center post_weight so softmax starts near-uniform (keeps | |
| # the residual baseline ~unchanged) but with a distinct value per stream. | |
| self.pre_weight = nn.Parameter( | |
| torch.linspace(0.5, -0.5, hc_mult) / max(hc_mult, 1) | |
| ) | |
| self.post_weight = nn.Parameter( | |
| torch.linspace(-0.5, 0.5, hc_mult) / max(hc_mult, 1) | |
| ) | |
| def pre_op(self, copies: torch.Tensor) -> torch.Tensor: | |
| """copies: [B, hc_mult, S, H] -> [B, S, H]""" | |
| w = F.softmax(self.pre_weight, dim=0) # [hc_mult] | |
| return (copies * w.view(1, -1, 1, 1)).sum(dim=1) | |
| def post_op(self, copies: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: | |
| """ | |
| copies: [B, hc_mult, S, H] | |
| delta: [B, S, H] | |
| Returns updated copies: [B, hc_mult, S, H] | |
| """ | |
| w = F.softmax(self.post_weight, dim=0) # [hc_mult] | |
| return copies + delta.unsqueeze(1) * w.view(1, -1, 1, 1) | |
| # --------------------------------------------------------------------------- | |
| # MLA + DERF + XSA Attention (combined) | |
| # --------------------------------------------------------------------------- | |
| class MLADerfXSAAttention(nn.Module): | |
| """ | |
| Multi-Head Latent Attention (NanoWhale) with DERF scores + XSA correction (My Project). | |
| MLA (from NanoWhale): | |
| Q: hidden -> q_lora_rank (RMSNorm) -> num_heads * head_dim (low-rank projection) | |
| K, V: hidden -> num_kv_heads * head_dim (direct, MQA by default with num_kv_heads=1) | |
| Output: num_heads * head_dim -> o_lora_rank -> hidden (low-rank output) | |
| Partial RoPE: applied only to the last qk_rope_head_dim dims of Q and K | |
| DERF (from My Project): | |
| Replaces softmax: erf(alpha * scores + bias) * gamma, shifted to [0,1] then normalized. | |
| Per-head learnable alpha, bias, gamma. | |
| XSA (from My Project): | |
| After computing the weighted value sum y, subtract the component of y that | |
| projects onto each position's own value vector. Forces the output to carry | |
| only cross-position information, not echo the current token back. | |
| """ | |
| def __init__(self, cfg: SpikeWhaleConfig): | |
| super().__init__() | |
| self.num_heads = cfg.num_attention_heads | |
| self.num_kv_heads = cfg.num_key_value_heads | |
| self.head_dim = cfg.head_dim | |
| self.qk_rope_head_dim = cfg.qk_rope_head_dim | |
| self.nope_head_dim = cfg.nope_head_dim | |
| self.hidden_size = cfg.hidden_size | |
| self.use_derf = cfg.use_derf | |
| self.use_xsa = cfg.use_xsa | |
| self.dropout_p = cfg.attention_dropout | |
| self.kv_groups = self.num_heads // self.num_kv_heads | |
| # Low-rank Q projection (MLA) | |
| self.q_a_proj = nn.Linear(cfg.hidden_size, cfg.q_lora_rank, bias=False) | |
| self.q_a_norm = RMSNorm(cfg.q_lora_rank, cfg.rms_norm_eps) | |
| self.q_b_proj = nn.Linear(cfg.q_lora_rank, self.num_heads * self.head_dim, bias=False) | |
| # Direct K, V projections (MQA/GQA) | |
| self.k_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(cfg.hidden_size, self.num_kv_heads * self.head_dim, bias=False) | |
| # Low-rank output projection (MLA) | |
| self.o_a_proj = nn.Linear(self.num_heads * self.head_dim, cfg.o_lora_rank, bias=False) | |
| self.o_b_proj = nn.Linear(cfg.o_lora_rank, cfg.hidden_size, bias=False) | |
| # Partial RoPE: applied to qk_rope_head_dim dims only | |
| self.rope = RotaryEmbedding( | |
| self.qk_rope_head_dim, | |
| max_positions=cfg.max_position_embeddings, | |
| theta=cfg.rope_theta, | |
| ) | |
| # DERF parameters: one per query head (My Project) | |
| if self.use_derf: | |
| self.derf_alpha = nn.Parameter(torch.ones(self.num_heads)) | |
| self.derf_bias = nn.Parameter(torch.zeros(self.num_heads)) | |
| self.derf_gamma = nn.Parameter(torch.ones(self.num_heads)) | |
| nn.init.normal_(self.q_a_proj.weight, std=cfg.initializer_range) | |
| nn.init.normal_(self.q_b_proj.weight, std=cfg.initializer_range) | |
| nn.init.normal_(self.k_proj.weight, std=cfg.initializer_range) | |
| nn.init.normal_(self.v_proj.weight, std=cfg.initializer_range) | |
| nn.init.normal_(self.o_a_proj.weight, std=cfg.initializer_range) | |
| nn.init.normal_(self.o_b_proj.weight, std=cfg.initializer_range) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| position_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: | |
| B, S, _ = x.shape | |
| # Q via low-rank projection with intermediate norm (MLA) | |
| q = self.q_a_norm(self.q_a_proj(x)) | |
| q = self.q_b_proj(q).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) | |
| # [B, num_heads, S, head_dim] | |
| # K, V direct projections | |
| k = self.k_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) | |
| # Partial RoPE: split into nope and rope partitions, rotate only the rope part | |
| q_nope = q[..., :self.nope_head_dim] | |
| q_rope = q[..., self.nope_head_dim:] # qk_rope_head_dim dims | |
| k_nope = k[..., :self.nope_head_dim] | |
| k_rope = k[..., self.nope_head_dim:] | |
| q_rope = self.rope(q_rope, position_ids) | |
| k_rope = self.rope(k_rope, position_ids) | |
| q = torch.cat([q_nope, q_rope], dim=-1) | |
| k = torch.cat([k_nope, k_rope], dim=-1) | |
| # KV cache for inference | |
| if past_key_value is not None: | |
| k = torch.cat([past_key_value[0], k], dim=2) | |
| v = torch.cat([past_key_value[1], v], dim=2) | |
| present = (k, v) if use_cache else None | |
| N = k.shape[2] # total key positions (past + current) | |
| # Expand KV heads for MQA/GQA | |
| if self.kv_groups > 1: | |
| k = k.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( | |
| B, self.num_heads, N, self.head_dim) | |
| v = v.unsqueeze(2).expand(-1, -1, self.kv_groups, -1, -1).reshape( | |
| B, self.num_heads, N, self.head_dim) | |
| # Scaled dot-product attention. | |
| if self.use_derf: | |
| # DERF replaces softmax with a custom erf nonlinearity, so it cannot | |
| # use the fused kernel and must materialize scores explicitly. | |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) | |
| # Build boolean mask for causality (this avoids the -inf math errors) | |
| if attention_mask is None and past_key_value is None: | |
| is_masked = torch.triu(torch.ones(S, N, dtype=torch.bool, device=scores.device), diagonal=N - S + 1).unsqueeze(0).unsqueeze(0) | |
| else: | |
| is_masked = (attention_mask < -1.0) if attention_mask is not None else torch.zeros_like(scores, dtype=torch.bool) | |
| # FIX 2: Do NOT use float('-inf'). If alpha ever hits 0.0, 0.0 * -inf = NaN. | |
| # Use a safe negative scalar (-10000.0) for masked positions. | |
| safe_scores = scores.masked_fill(is_masked, -10000.0) | |
| a = self.derf_alpha.view(1, -1, 1, 1) | |
| b = self.derf_bias.view(1, -1, 1, 1) | |
| g = self.derf_gamma.view(1, -1, 1, 1) | |
| attn_weights = g * torch.erf(a * safe_scores + b) # [-gamma, gamma] | |
| attn_weights = (attn_weights + g) / 2.0 # shift to [0, gamma] | |
| attn_weights = attn_weights.masked_fill(is_masked, 0.0) # enforce causal mask safely | |
| attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-8) | |
| if self.dropout_p > 0 and self.training: | |
| attn_weights = F.dropout(attn_weights, p=self.dropout_p) | |
| y = torch.matmul(attn_weights, v) # [B, num_heads, S, head_dim] | |
| else: | |
| # OPTIMIZATION: standard (softmax) attention goes through the fused | |
| # scaled_dot_product_attention kernel (FlashAttention / mem-efficient | |
| # backends). This is the hot path during pretraining (use_derf=False) | |
| # and is much faster + lower memory than materializing [B,H,S,N] | |
| # scores and a softmax. SDPA already scales by 1/sqrt(head_dim). | |
| # | |
| # CONTIGUITY FIX: with MQA/GQA, k and v above are built via | |
| # .unsqueeze(2).expand(...).reshape(...). Under torch.compile, inductor | |
| # can trace the broadcasted (zero-stride) view through to the fused | |
| # flash-attention BACKWARD kernel, whose meta-kernel then asserts on the | |
| # mismatched stride (e.g. "stride 120==245760 at dim=1") and aborts. | |
| # Forcing contiguity guarantees standard strides into the fused kernel. | |
| q = q.contiguous() | |
| k = k.contiguous() | |
| v = v.contiguous() | |
| drop = self.dropout_p if self.training else 0.0 | |
| if past_key_value is None and attention_mask is None: | |
| # Prefill / training: pure causal mask, no materialization needed. | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True, dropout_p=drop) | |
| else: | |
| # Incremental decode or a provided mask: pass an explicit boolean | |
| # keep-mask (True = attend). SDPA fills masked positions with -inf. | |
| if attention_mask is not None: | |
| is_masked = (attention_mask < -1.0) | |
| else: | |
| is_masked = torch.triu( | |
| torch.ones(S, N, dtype=torch.bool, device=q.device), | |
| diagonal=N - S + 1, | |
| ).unsqueeze(0).unsqueeze(0) | |
| y = F.scaled_dot_product_attention( | |
| q, k, v, attn_mask=~is_masked, dropout_p=drop) | |
| # XSA: remove self-projection from output (My Project) | |
| # For each query position s, subtract the component of y[:,:,s,:] that | |
| # projects onto the normalized value vector at the same position. | |
| if self.use_xsa: | |
| past_len = N - S | |
| v_self = v[:, :, past_len:past_len + S, :] # [B, H, S, D] | |
| vn = v_self / (v_self.norm(dim=-1, keepdim=True) + 1e-8) | |
| projection = (y * vn).sum(dim=-1, keepdim=True) * vn | |
| y = y - projection | |
| # Low-rank output projection (MLA) | |
| y = y.transpose(1, 2).contiguous().view(B, S, self.num_heads * self.head_dim) | |
| y = self.o_b_proj(self.o_a_proj(y)) | |
| return y, present | |
| # --------------------------------------------------------------------------- | |
| # MoE FFN: shared expert + sqrtsoftplus + hash routing (NanoWhale) + aux loss (My Project) | |
| # --------------------------------------------------------------------------- | |
| class ExpertFFN(nn.Module): | |
| """Single SwiGLU expert.""" | |
| def __init__(self, hidden_size: int, intermediate_size: int): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| def sqrtsoftplus(x: torch.Tensor) -> torch.Tensor: | |
| """sqrt(softplus(x)) = sqrt(log(1+exp(x))). NanoWhale expert scoring.""" | |
| # FIX 1: Added 1e-8. If F.softplus(x) evaluates to 0.0, torch.sqrt(0) produces NaN gradients on backward pass. | |
| return torch.sqrt(F.softplus(x) + 1e-8) | |
| class SparseMoEFFN(nn.Module): | |
| """ | |
| Combines NanoWhale MoE structure with My Project aux loss: | |
| - n_shared_experts always-active experts (NanoWhale) | |
| - n_routed_experts sparse routed experts, top-k activation | |
| - sqrtsoftplus scoring (NanoWhale) vs softmax | |
| - hash routing for early layers (NanoWhale) | |
| - norm_topk_prob + routed_scaling_factor (NanoWhale) | |
| - load-balancing aux loss (My Project) | |
| """ | |
| def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int = 0): | |
| super().__init__() | |
| self.n_routed_experts = cfg.n_routed_experts | |
| self.n_shared_experts = cfg.n_shared_experts | |
| self.num_experts_per_tok = cfg.num_experts_per_tok | |
| self.norm_topk_prob = cfg.norm_topk_prob | |
| self.scoring_func = cfg.scoring_func | |
| self.routed_scaling_factor = cfg.routed_scaling_factor | |
| self.use_hash_routing = layer_idx < cfg.num_hash_layers | |
| self.aux_loss_coef = cfg.moe_aux_loss_coef | |
| self.router = nn.Linear(cfg.hidden_size, cfg.n_routed_experts, bias=False) | |
| self.experts = nn.ModuleList([ | |
| ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) | |
| for _ in range(cfg.n_routed_experts) | |
| ]) | |
| self.shared_experts = nn.ModuleList([ | |
| ExpertFFN(cfg.hidden_size, cfg.moe_intermediate_size) | |
| for _ in range(cfg.n_shared_experts) | |
| ]) if cfg.n_shared_experts > 0 else None | |
| self._last_aux_loss: Optional[torch.Tensor] = None | |
| def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| B, S, H = x.shape | |
| x_flat = x.view(B * S, H) | |
| T = B * S | |
| # Shared experts: always active (NanoWhale) | |
| shared_out = torch.zeros_like(x_flat) | |
| if self.shared_experts: | |
| for expert in self.shared_experts: | |
| shared_out = shared_out + expert(x_flat) | |
| if len(self.shared_experts) > 1: | |
| shared_out = shared_out / len(self.shared_experts) | |
| # Router | |
| if self.use_hash_routing: | |
| # Hash routing: deterministic assignment without learned router (NanoWhale). | |
| # Assign each of the num_experts_per_tok slots a DISTINCT expert by cycling: | |
| # token at absolute position p -> experts [p%n, (p+1)%n, ..., (p+k-1)%n]. | |
| # | |
| # BUGFIX: the assignment must key off the token's ABSOLUTE sequence | |
| # position, not torch.arange(T) (its index in the current flattened | |
| # batch). With arange(T), incremental KV-cache decoding (S=1) always | |
| # sees index 0 and routes every token to expert 0, so generation used | |
| # a different expert assignment than training and silently diverged. | |
| # Using position_ids makes prefill, full-sequence training, and | |
| # step-by-step generation all agree. (For S divisible by n_experts, | |
| # this matches the previous training-time behavior exactly, so existing | |
| # checkpoints stay valid.) | |
| if position_ids is not None: | |
| base = (position_ids.reshape(T, 1) % self.n_routed_experts).long() | |
| else: | |
| base = (torch.arange(T, device=x.device) % self.n_routed_experts).unsqueeze(1) | |
| offsets = torch.arange(self.num_experts_per_tok, device=x.device) # [k] | |
| top_k_indices = (base + offsets.unsqueeze(0)) % self.n_routed_experts # [T, k] | |
| top_k_weights = torch.ones(T, self.num_experts_per_tok, device=x.device) / self.num_experts_per_tok | |
| self._last_aux_loss = None | |
| else: | |
| router_logits = self.router(x_flat) | |
| if self.scoring_func == "sqrtsoftplus": | |
| routing_scores = sqrtsoftplus(router_logits) | |
| else: | |
| routing_scores = F.softmax(router_logits, dim=-1) | |
| top_k_scores, top_k_indices = torch.topk(routing_scores, self.num_experts_per_tok, dim=-1) | |
| if self.norm_topk_prob: | |
| top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-8) | |
| else: | |
| top_k_weights = top_k_scores | |
| top_k_weights = top_k_weights * self.routed_scaling_factor | |
| # Load-balancing aux loss (My Project) | |
| softmax_probs = F.softmax(router_logits, dim=-1) | |
| expert_mask = torch.zeros_like(softmax_probs) | |
| expert_mask.scatter_(1, top_k_indices, 1.0) | |
| f_e = expert_mask.mean(0) | |
| p_e = softmax_probs.mean(0) | |
| self._last_aux_loss = self.n_routed_experts * (f_e * p_e).sum() * self.aux_loss_coef | |
| # Dispatch tokens to routed experts | |
| out = torch.zeros_like(x_flat) | |
| for expert_idx, expert in enumerate(self.experts): | |
| token_mask = (top_k_indices == expert_idx).any(dim=-1) | |
| if not token_mask.any(): | |
| continue | |
| expert_input = x_flat[token_mask] | |
| expert_output = expert(expert_input) | |
| k_pos = (top_k_indices[token_mask] == expert_idx).nonzero(as_tuple=False) | |
| weights = top_k_weights[token_mask][k_pos[:, 0], k_pos[:, 1]].unsqueeze(-1) | |
| out[token_mask] = out[token_mask] + expert_output * weights | |
| out = out + shared_out | |
| return out.view(B, S, H) | |
| def get_aux_loss(self) -> Optional[torch.Tensor]: | |
| # Return None when hash routing (no aux loss) or when forward hasn't run yet. | |
| # Returning torch.tensor(0.0) here would be a CPU tensor and cause a device | |
| # mismatch when added to the CUDA total_aux_loss in SpikeWhaleModel. | |
| return self._last_aux_loss | |
| class DenseFFN(nn.Module): | |
| """Dense SwiGLU FFN for non-MoE layers.""" | |
| def __init__(self, cfg: SpikeWhaleConfig): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(cfg.hidden_size, cfg.moe_intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(cfg.moe_intermediate_size, cfg.hidden_size, bias=False) | |
| def forward(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| def get_aux_loss(self) -> Optional[torch.Tensor]: | |
| return None # dense layers have no aux loss; None avoids CPU-tensor device mismatch | |
| # --------------------------------------------------------------------------- | |
| # Transformer block with Hyper-Connections | |
| # --------------------------------------------------------------------------- | |
| class TransformerBlock(nn.Module): | |
| """ | |
| Transformer block combining all features: | |
| - Hyper-Connections: pre/post routing through hc_mult streams (NanoWhale) | |
| - MLA + DERF + XSA attention (combined) | |
| - MoE FFN with shared expert (NanoWhale) + aux loss (My Project) | |
| """ | |
| def __init__(self, cfg: SpikeWhaleConfig, layer_idx: int): | |
| super().__init__() | |
| self.use_hc = cfg.use_hyper_connections | |
| self.hidden_dropout = cfg.hidden_dropout | |
| self.attn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| self.attn = MLADerfXSAAttention(cfg) | |
| self.ffn_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| if cfg.use_moe and layer_idx in cfg.moe_layers: | |
| self.ffn = SparseMoEFFN(cfg, layer_idx) | |
| self.is_moe = True | |
| else: | |
| self.ffn = DenseFFN(cfg) | |
| self.is_moe = False | |
| if self.use_hc: | |
| self.hc_attn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, | |
| cfg.hc_sinkhorn_iters, cfg.hc_eps) | |
| self.hc_ffn = HyperConnectionLayer(cfg.hidden_size, cfg.hc_mult, | |
| cfg.hc_sinkhorn_iters, cfg.hc_eps) | |
| def forward( | |
| self, | |
| x: torch.Tensor, # [B, hc_mult, S, H] if HC else [B, S, H] | |
| position_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| past_key_value: Optional[Tuple] = None, | |
| use_cache: bool = False, | |
| ) -> Tuple[torch.Tensor, Optional[Tuple], Optional[torch.Tensor]]: | |
| # --- Attention sub-layer --- | |
| if self.use_hc: | |
| h = self.hc_attn.pre_op(x) # [B, S, H] | |
| else: | |
| h = x | |
| attn_out, present = self.attn( | |
| self.attn_norm(h), position_ids, attention_mask, past_key_value, use_cache | |
| ) | |
| attn_out = F.dropout(attn_out, p=self.hidden_dropout, training=self.training) | |
| if self.use_hc: | |
| x = self.hc_attn.post_op(x, attn_out) | |
| h = self.hc_ffn.pre_op(x) # [B, S, H] | |
| else: | |
| h = h + attn_out | |
| # --- FFN sub-layer --- | |
| ffn_out = self.ffn(self.ffn_norm(h), position_ids) | |
| ffn_out = F.dropout(ffn_out, p=self.hidden_dropout, training=self.training) | |
| if self.use_hc: | |
| x = self.hc_ffn.post_op(x, ffn_out) | |
| else: | |
| x = h + ffn_out | |
| return x, present, self.ffn.get_aux_loss() | |
| # --------------------------------------------------------------------------- | |
| # Full model | |
| # --------------------------------------------------------------------------- | |
| class HRMRefinementBlock(nn.Module): | |
| """ | |
| HRM-INSPIRED iterative refinement (EXPERIMENTAL, off by default). NOT the full | |
| Hierarchical Reasoning Model -- only the iterative-refinement mechanism that the | |
| independent ARC-Prize ablation found carried most of HRM's benefit, adapted to a | |
| causal LM's final hidden state. | |
| Runs N inner steps; each computes a small gated update conditioned on the current | |
| state AND the original ('anchor') input. Per-step gate inits at 0 and up.weight is | |
| zero-init -> the block is an EXACT identity at init, so enabling it cannot hurt a | |
| fresh model; it only contributes if training opens the gate. Pointwise over | |
| positions -> causal-safe (no future-token leakage). In/out [B,S,H]. | |
| """ | |
| def __init__(self, hidden_size: int, refine_dim: int, steps: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.steps = steps | |
| self.norm = RMSNorm(hidden_size, eps) | |
| self.down = nn.Linear(hidden_size * 2, refine_dim, bias=False) | |
| self.up = nn.Linear(refine_dim, hidden_size, bias=False) | |
| self.gate = nn.Parameter(torch.zeros(steps)) | |
| nn.init.normal_(self.down.weight, std=0.02) | |
| nn.init.zeros_(self.up.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| anchor = x | |
| h = x | |
| for t in range(self.steps): | |
| inp = torch.cat([self.norm(h), anchor], dim=-1) | |
| update = self.up(F.silu(self.down(inp))) | |
| h = h + torch.tanh(self.gate[t]) * update | |
| return h | |
| class LatentProjection(nn.Module): | |
| """ModularMind-on-V2: pool final hidden state -> d_latent output vector. | |
| Mirrors ModularMind's contract: mean-pool over sequence, ReLU^2 activation | |
| (sparse latent codes), Xavier init (NOT zero) so the latent carries signal | |
| from step 1 — zero-init would make the chain unable to bootstrap.""" | |
| def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.proj1 = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.proj2 = nn.Linear(hidden_size, d_latent, bias=False) | |
| self.norm = RMSNorm(d_latent, eps) | |
| nn.init.xavier_uniform_(self.proj1.weight) | |
| nn.init.xavier_uniform_(self.proj2.weight) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| pooled = x.mean(dim=1) # [B, S, H] -> [B, H] | |
| h = torch.relu(self.proj1(pooled)) ** 2 | |
| return self.norm(self.proj2(h)) # [B, d_latent] | |
| class LatentInjection(nn.Module): | |
| """ModularMind-on-V2: fold an incoming d_latent vector into embeddings. | |
| Broadcast across positions, ReGLU-gated add. Gate starts SMALL (not exactly | |
| zero): the injection is near-identity at init (stable) while still passing a | |
| little gradient, so the upstream RecursiveLink + specialist can bootstrap from | |
| step 1. (Exact-zero gate would block all gradient to the link -- the | |
| bootstrapping problem ModularMind's LatentProjection docstring warns about.) | |
| This is the INPUT side of RecursiveLink (the prev specialist's latent).""" | |
| def __init__(self, hidden_size: int, d_latent: int, eps: float = 1e-6, | |
| gate_init: float = 1e-3): | |
| super().__init__() | |
| self.up = nn.Linear(d_latent, hidden_size, bias=False) | |
| self.norm = RMSNorm(hidden_size, eps) | |
| self.value_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.gate_proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.gate_init = gate_init | |
| nn.init.xavier_uniform_(self.up.weight) | |
| nn.init.xavier_uniform_(self.value_proj.weight) | |
| nn.init.normal_(self.gate_proj.weight, std=gate_init) # small, not zero | |
| def forward(self, x: torch.Tensor, latent: torch.Tensor) -> torch.Tensor: | |
| # x: [B, S, H], latent: [B, d_latent] | |
| inj = self.norm(self.up(latent)).unsqueeze(1) # [B, 1, H] broadcast over S | |
| value = self.value_proj(inj) | |
| gate = torch.relu(self.gate_proj(inj)) | |
| return x + value * gate | |
| class RecursiveLink(nn.Module): | |
| """ModularMind cross-specialist bridge, V2 build. Converts one specialist's | |
| output latent into the next specialist's input latent. ReGLU + residual, | |
| single shared module reused for every hop. Fully differentiable.""" | |
| def __init__(self, d_latent: int = 256, expansion: float = 2.0): | |
| super().__init__() | |
| d_hidden = int(d_latent * expansion) | |
| self.norm = nn.LayerNorm(d_latent) | |
| self.value_proj = nn.Linear(d_latent, d_hidden, bias=False) | |
| self.gate_proj = nn.Linear(d_latent, d_hidden, bias=False) | |
| self.down = nn.Linear(d_hidden, d_latent, bias=False) | |
| self.residual_gate = nn.Parameter(torch.ones(1)) | |
| nn.init.xavier_uniform_(self.value_proj.weight) | |
| nn.init.xavier_uniform_(self.gate_proj.weight) | |
| nn.init.xavier_uniform_(self.down.weight) | |
| def forward(self, z: torch.Tensor) -> torch.Tensor: | |
| n = self.norm(z) | |
| h = self.value_proj(n) * torch.relu(self.gate_proj(n)) | |
| return z + self.residual_gate * self.down(h) | |
| class SpikeWhaleModel(nn.Module): | |
| """Decoder stack without LM head.""" | |
| def __init__(self, cfg: SpikeWhaleConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) | |
| nn.init.normal_(self.embed_tokens.weight, std=cfg.initializer_range) | |
| self.engram = EngramModule(cfg) if cfg.use_engram else None | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(cfg, layer_idx=i) | |
| for i in range(cfg.num_hidden_layers) | |
| ]) | |
| self.norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| self.hrm_refine = ( | |
| HRMRefinementBlock(cfg.hidden_size, cfg.hrm_refine_dim, | |
| cfg.hrm_refine_steps, cfg.rms_norm_eps) | |
| if getattr(cfg, "use_hrm_refine", False) else None | |
| ) | |
| # ModularMind-on-V2: latent input/output (off unless use_latent_io) | |
| if getattr(cfg, "use_latent_io", False): | |
| self.latent_inject = LatentInjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) | |
| self.latent_out = LatentProjection(cfg.hidden_size, cfg.d_latent, cfg.rms_norm_eps) | |
| else: | |
| self.latent_inject = None | |
| self.latent_out = None | |
| self.gradient_checkpointing = False | |
| def reset_latent_gate(self): | |
| """Re-init the injection gate SMALL (not zero). Must be called AFTER any HF | |
| post_init/_init_weights pass, which otherwise re-randomizes the gate to full | |
| scale. Small-but-nonzero keeps injection near-identity at start while letting | |
| gradient reach the upstream RecursiveLink (so the chain can bootstrap).""" | |
| if self.latent_inject is not None: | |
| nn.init.normal_(self.latent_inject.gate_proj.weight, | |
| std=self.latent_inject.gate_init) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[Tuple]] = None, | |
| use_cache: bool = False, | |
| inject_latent: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[List[Tuple]], torch.Tensor]: | |
| B, S = input_ids.shape | |
| device = input_ids.device | |
| if position_ids is None: | |
| past_len = past_key_values[0][0].shape[2] if past_key_values else 0 | |
| position_ids = torch.arange( | |
| past_len, past_len + S, device=device | |
| ).unsqueeze(0).expand(B, -1) | |
| # Token embedding | |
| x = self.embed_tokens(input_ids) # [B, S, H] | |
| # Engram N-gram delta (My Project) | |
| if self.engram is not None: | |
| x = x + self.engram(x) | |
| # ModularMind-on-V2: inject the previous specialist's latent (broadcast | |
| # across positions, ReGLU-gated). No-op at init (gate zero) and skipped | |
| # entirely if no latent is passed. | |
| if self.latent_inject is not None and inject_latent is not None: | |
| x = self.latent_inject(x, inject_latent) | |
| # Expand to hc_mult streams for Hyper-Connections (NanoWhale) | |
| if self.cfg.use_hyper_connections: | |
| x = x.unsqueeze(1).expand(-1, self.cfg.hc_mult, -1, -1).clone() | |
| # [B, hc_mult, S, H] | |
| present_key_values = [] if use_cache else None | |
| total_aux_loss = torch.tensor(0.0, device=device) | |
| for layer_idx, layer in enumerate(self.layers): | |
| pkv = past_key_values[layer_idx] if past_key_values else None | |
| if self.gradient_checkpointing and self.training: | |
| # Gradient checkpointing with use_reentrant=False (NanoWhale) | |
| x, present, aux_loss = gradient_checkpoint( | |
| layer, x, position_ids, attention_mask, None, False, | |
| use_reentrant=False, | |
| ) | |
| else: | |
| x, present, aux_loss = layer(x, position_ids, attention_mask, pkv, use_cache) | |
| if use_cache: | |
| present_key_values.append(present) | |
| if aux_loss is not None: | |
| total_aux_loss = total_aux_loss + aux_loss | |
| # Reduce HC streams to single hidden state | |
| if self.cfg.use_hyper_connections: | |
| x = x.mean(dim=1) # [B, S, H] | |
| if self.hrm_refine is not None: | |
| x = self.hrm_refine(x) | |
| x = self.norm(x) | |
| # ModularMind-on-V2: emit this specialist's output latent (for RecursiveLink). | |
| out_latent = self.latent_out(x) if self.latent_out is not None else None | |
| return x, present_key_values, total_aux_loss, out_latent | |
| class SpikeWhaleLM(PreTrainedModel): | |
| """ | |
| Full causal LM combining all SpikeTransformer + NanoWhale features. | |
| Training (forward with labels): | |
| out = model(input_ids=ids, labels=ids) | |
| loss = out.loss # CE + MTP loss + MoE aux loss | |
| Generation: | |
| out = model(input_ids=ids, use_cache=True) | |
| past = out.past_key_values | |
| out2 = model(input_ids=next_id, past_key_values=past, use_cache=True) | |
| """ | |
| config_class = SpikeWhaleConfig | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["TransformerBlock"] | |
| def __init__(self, cfg: SpikeWhaleConfig): | |
| super().__init__(cfg) | |
| self.model = SpikeWhaleModel(cfg) | |
| self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) | |
| nn.init.normal_(self.lm_head.weight, std=cfg.initializer_range) | |
| if cfg.tie_word_embeddings: | |
| self.lm_head.weight = self.model.embed_tokens.weight | |
| # Multi-Token Prediction heads (NanoWhale): predict token at position+k | |
| self.mtp_heads = nn.ModuleList([ | |
| nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) | |
| for _ in range(cfg.num_nextn_predict_layers) | |
| ]) if cfg.num_nextn_predict_layers > 0 else None | |
| self.post_init() | |
| # HF post_init re-randomizes Linear weights, clobbering the zero-init | |
| # injection gate. Restore it so the latent injection is identity-at-start. | |
| self.model.reset_latent_gate() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def _set_gradient_checkpointing(self, module, value=False): | |
| if isinstance(module, SpikeWhaleModel): | |
| module.gradient_checkpointing = value | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[List[Tuple]] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| inject_latent: Optional[torch.Tensor] = None, | |
| **kwargs, | |
| ) -> CausalLMOutputWithPast: | |
| hidden, present_kvs, aux_loss, out_latent = self.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| use_cache=use_cache, | |
| inject_latent=inject_latent, | |
| ) | |
| logits = self.lm_head(hidden) | |
| loss = None | |
| if labels is not None: | |
| # Standard next-token CE loss (shifted by 1) | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| shift_labels = labels[..., 1:].contiguous() | |
| loss = F.cross_entropy( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| # Multi-Token Prediction loss (NanoWhale) | |
| # Each MTP head k predicts token at position + k+1 (beyond the standard +1) | |
| if self.mtp_heads is not None: | |
| mtp_total = torch.tensor(0.0, device=loss.device) | |
| for k, head in enumerate(self.mtp_heads, start=1): | |
| offset = k + 1 # predicts position + offset | |
| if hidden.size(1) > offset: | |
| mtp_logits = head(hidden[..., :-offset, :].contiguous()) | |
| mtp_labels = labels[..., offset:].contiguous() | |
| mtp_total = mtp_total + F.cross_entropy( | |
| mtp_logits.view(-1, mtp_logits.size(-1)), | |
| mtp_labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| loss = loss + mtp_total / max(len(self.mtp_heads), 1) | |
| # MoE load-balancing aux loss (My Project) | |
| loss = loss + aux_loss | |
| out = CausalLMOutputWithPast( | |
| loss=loss, | |
| logits=logits, | |
| past_key_values=present_kvs, | |
| ) | |
| out.latent = out_latent # ModularMind-on-V2: this specialist's output latent | |
| return out | |
| def count_parameters(self) -> int: | |
| return sum(p.numel() for p in self.parameters()) |