"""Quasar hybrid transformer — HuggingFace compatible. """ import math from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint from transformers import GenerationMixin from transformers.cache_utils import Cache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from .configuration_quasar import QuasarConfig logger = logging.get_logger(__name__) # FLA layer imports — required from fla.layers.quasar import QuasarAttention from fla.layers.gla import GatedLinearAttention from fla.models.utils import Cache as FlaCache, FLAGenerationMixin # =================================================================== # RMSNorm (standalone — weight name: .weight, no bias) # =================================================================== class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) # =================================================================== # Rotary Embedding (persistent inv_freq to match checkpoint) # =================================================================== class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=4096, base=100000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Pre-compute cos/sin cache t = torch.arange(max_position_embeddings + 1, device=device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len=None): if seq_len is not None and seq_len > self._cos_cached.shape[2]: t = torch.arange(seq_len + 1024, device=x.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("_cos_cached", emb.cos()[None, None, :, :].to(self._cos_cached.dtype), persistent=False) self.register_buffer("_sin_cached", emb.sin()[None, None, :, :].to(self._sin_cached.dtype), persistent=False) return ( self._cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self._sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) # =================================================================== # Latent Memory Module (use_triton=False — PyTorch bmm is faster) # =================================================================== class LatentMemoryModule(nn.Module): """Persistent Latent Parameter Memory — weight names match checkpoint.""" def __init__(self, hidden_size, memory_slots=128, memory_dim=128, use_triton=False): super().__init__() self.K = memory_slots self.D = memory_dim self.W_eta = nn.Linear(hidden_size, 1, bias=True) nn.init.zeros_(self.W_eta.weight) nn.init.constant_(self.W_eta.bias, -5.0) self.segment_len = 64 self.summary_query = nn.Parameter(torch.randn(1, 1, memory_dim)) self.summary_proj = nn.Linear(hidden_size, memory_dim, bias=True) self.eta_channels = nn.Parameter(torch.ones(1, 1, memory_dim)) self.temperature = nn.Parameter(torch.ones(1)) self.hidden_size = hidden_size self.use_triton = False self.input_norm = nn.LayerNorm(hidden_size) self.compress_z = nn.Sequential( nn.Linear(hidden_size, memory_dim * 2, bias=False), nn.SiLU(), nn.Linear(memory_dim * 2, memory_dim, bias=False), ) self.W_qkv_mem = nn.Linear(hidden_size, memory_dim * 3, bias=False) self.scale = 1.0 / math.sqrt(memory_dim) def get_diversity_loss(self, M): B, K, D = M.shape M_norm = F.normalize(M, p=2, dim=-1) sim = torch.bmm(M_norm, M_norm.transpose(1, 2)) mask = torch.eye(K, device=M.device).unsqueeze(0) sim = sim * (1 - mask) return sim.pow(2).mean() def write_memory(self, H, M, chunk_idx=0): H = self.input_norm(H) B, T, _ = H.shape H_mem = self.summary_proj(H) eta_tokens = self.W_eta(H).squeeze(-1) L = self.segment_len if T % L != 0: pad_len = L - (T % L) H_padded = F.pad(H_mem, (0, 0, 0, pad_len)) eta_padded = F.pad(eta_tokens, (0, pad_len), value=-10.0) else: H_padded = H_mem eta_padded = eta_tokens T_pad = H_padded.shape[1] num_segments = T_pad // L H_segs = H_padded.view(B * num_segments, L, self.D) summary_scores = torch.bmm( self.summary_query.expand(B * num_segments, -1, -1), H_segs.transpose(1, 2), ) summary_weights = F.softmax(summary_scores * self.scale, dim=-1) Z_seg = torch.bmm(summary_weights, H_segs).view(B, num_segments, self.D) eta_raw_sig = torch.sigmoid(eta_tokens) eta_seg_sig = torch.max( torch.sigmoid(eta_padded.view(B, num_segments, L)), dim=-1, keepdim=True )[0] scores = torch.bmm(Z_seg, M.transpose(-1, -2)) * self.scale * torch.exp(self.temperature) A = F.softmax(scores, dim=-1) DeltaM_seg = torch.bmm(A.transpose(1, 2), Z_seg * eta_seg_sig) eta_avg = eta_seg_sig.mean(dim=1, keepdim=True) gate = eta_avg * torch.sigmoid(self.eta_channels) M_new = (1.0 - gate) * M + DeltaM_seg / num_segments norm_sq = torch.sum(DeltaM_seg ** 2) / num_segments div_loss = self.get_diversity_loss(M_new) return M_new, norm_sq * 0.01 + div_loss * 0.1, eta_raw_sig def read_memory(self, H, M, memory_scale=1.0): H = self.input_norm(H) qkv_mem = self.W_qkv_mem(H) _, _, Q_r = torch.split(qkv_mem, [self.D, self.D, self.D], dim=-1) scores = torch.bmm(Q_r, M.transpose(-1, -2)) if M.shape[1] > 1024: top_k = 64 top_vals, top_idx = torch.topk(scores, top_k, dim=-1) mask = torch.full_like(scores, float('-inf')) mask.scatter_(-1, top_idx, top_vals) scores = mask A = F.softmax(scores * 2.0, dim=-1) C = torch.bmm(A, M) return C * memory_scale # =================================================================== # FFN Components # =================================================================== class SwiGLUBlock(nn.Module): """Dense FFN — weight names: gate.weight, up.weight, down.weight""" def __init__(self, d_model, d_ff): super().__init__() self.gate = nn.Linear(d_model, d_ff, bias=False) self.up = nn.Linear(d_model, d_ff, bias=False) self.down = nn.Linear(d_ff, d_model, bias=False) def forward(self, x): return self.down(F.silu(self.gate(x)) * self.up(x)) class SigmoidRouter(nn.Module): """Router with router_weights Parameter — weight name: router.router_weights""" def __init__(self, d_model, num_experts): super().__init__() self.router_weights = nn.Parameter(torch.zeros(num_experts, d_model)) nn.init.kaiming_uniform_(self.router_weights, a=math.sqrt(5)) def forward(self, x): logits = F.linear(x, self.router_weights) scores = torch.sigmoid(logits) return scores, logits class BigMacMoE(nn.Module): """BigMac MoE with DCCA bottleneck — matches checkpoint weight names exactly. Weights: w_down_proj, w_up_proj, experts_w12, experts_w3, router.router_weights, shared_experts.{i}.{gate,up,down}.weight, max_vio """ def __init__(self, config, layer_idx=None): super().__init__() self.d_model = config.d_model self.bigmac_r = getattr(config, 'bigmac_r', 0.25) self.bottle_dim = int(self.d_model * self.bigmac_r) self.num_shared_experts = getattr(config, 'num_shared_experts', 1) self.num_routed_experts = getattr(config, 'num_routed_experts', 64) self.top_k = getattr(config, 'top_k', 4) default_routed_size = int(getattr(config, 'routed_expert_size', 768) / self.bigmac_r) self.routed_expert_size = getattr(config, 'bigmac_expert_size', default_routed_size) self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) self.layer_idx = layer_idx self.shared_experts = nn.ModuleList([ SwiGLUBlock(self.d_model, self.shared_expert_size) for _ in range(self.num_shared_experts) ]) # BigMac DCCA Projections self.w_down_proj = nn.Linear(self.d_model, self.bottle_dim, bias=False) self.w_up_proj = nn.Linear(self.bottle_dim, self.d_model, bias=False) # BigMac Experts (fused gate+up W12, down W3) self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.bottle_dim, 2 * self.routed_expert_size)) self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.bottle_dim)) self.router = SigmoidRouter(self.d_model, self.num_routed_experts) self.expert_bias = None self.expert_momentum = None self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) self.smebu_lambda = getattr(config, 'smebu_lambda', 2e-3) self.smebu_beta = getattr(config, 'smebu_beta', 0.5) self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-4) self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) self.register_buffer("max_vio", torch.tensor(0.0)) self.route_scale = math.sqrt(self.top_k) self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) # Buffers for padded BMM dispatch self.register_buffer("_dummy_token", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) self.register_buffer("_dummy_out", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) self._cached_N = -1 self._cached_K = -1 self._cached_indices = None def _init_weights(self, std=0.011): nn.init.normal_(self.w_down_proj.weight, std=std) nn.init.normal_(self.w_up_proj.weight, std=std) nn.init.normal_(self.experts_w12, std=std) nn.init.normal_(self.experts_w3, std=std) for expert in self.shared_experts: nn.init.normal_(expert.gate.weight, std=std) nn.init.normal_(expert.up.weight, std=std) nn.init.normal_(expert.down.weight, std=std) def forward(self, x, expert_bias=None): batch_size, seq_len, d_model = x.shape hidden_states = x.view(-1, d_model) N, D = hidden_states.shape K = self.top_k num_tokens_total = N * K # 1. Routing & Gating with torch.autocast(device_type=x.device.type, dtype=torch.float32): scores, logits = self.router(hidden_states) z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) selection_scores = scores + bias _, topk_indices = torch.topk(selection_scores, K, dim=-1) topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) topk_logits = torch.gather(logits, 1, topk_indices) gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) # 2. Aux loss if self.training: flat_topk_idx = topk_indices.view(-1) expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) fi = expert_counts.float() / num_tokens_total Pi = scores.nan_to_num().mean(dim=0) aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight else: aux_loss = torch.tensor(0.0, device=x.device) expert_counts = None # 3. Shared experts shared_out = 0 if self.num_shared_experts > 0: for expert in self.shared_experts: shared_out = shared_out + expert(hidden_states) # 4. Bottleneck projection down_proj_hidden = self.w_down_proj(hidden_states) # 5. Routed experts (padded BMM dispatch) flat_topk_idx = topk_indices.view(-1).clamp(0, self.num_routed_experts - 1) sorted_experts, permutation = torch.sort(flat_topk_idx) if self._cached_N == N and self._cached_K == K: token_indices, global_rel_idx = self._cached_indices else: token_indices = torch.arange(N, device=x.device).repeat_interleave(K) global_rel_idx = torch.arange(num_tokens_total, device=x.device) self._cached_N, self._cached_K = N, K self._cached_indices = (token_indices, global_rel_idx) max_load = ((num_tokens_total // self.num_routed_experts) // 8 + 6) * 8 used_counts = expert_counts if expert_counts is not None else torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) expert_ptr = torch.cumsum(used_counts, dim=0) - used_counts local_idx = global_rel_idx - expert_ptr.index_select(0, sorted_experts) capacity_mask = local_idx < max_load valid_slots = sorted_experts[capacity_mask] * max_load + local_idx[capacity_mask] num_slots = self.num_routed_experts * max_load hidden_with_dummy = torch.cat([down_proj_hidden, self._dummy_token], dim=0) reverse_map = torch.full((num_slots,), N, device=x.device, dtype=torch.long) reverse_map.scatter_(0, valid_slots.long(), token_indices[permutation][capacity_mask]) padding = hidden_with_dummy.index_select(0, reverse_map).view(self.num_routed_experts, max_load, self.bottle_dim) h12 = torch.bmm(padding, self.experts_w12) h1, h2 = h12.chunk(2, dim=-1) padded_out = torch.bmm(F.silu(h1) * h2, self.experts_w3) padded_out_flat = padded_out.view(-1, self.bottle_dim) padded_out_with_dummy = torch.cat([padded_out_flat, self._dummy_out], dim=0) gather_map = torch.full((num_tokens_total,), num_slots, device=x.device, dtype=torch.long) gather_map.scatter_(0, permutation[capacity_mask], valid_slots) gathered_out = padded_out_with_dummy.index_select(0, gather_map).view(N, K, self.bottle_dim) routed_out_bottle = torch.bmm(gating_scores.to(gathered_out.dtype).unsqueeze(1), gathered_out).squeeze(1) routed_out = self.w_up_proj(routed_out_bottle) if self.training: mean_load = num_tokens_total / self.num_routed_experts self._pending_violation = (mean_load - used_counts.float()) / (mean_load + 1e-6) route_scale = math.sqrt(self.top_k) if self.training else 1.0 out = (shared_out + routed_out * route_scale) * self.moe_scale out = out.view(batch_size, seq_len, d_model).to(x.dtype) return out, z_loss + aux_loss def update_bias(self, counts, num_tokens): expert_counts = counts.float() mean_load = num_tokens * self.top_k / self.num_routed_experts violation = (mean_load - expert_counts) / (mean_load + 1e-6) clamped_update = torch.tanh(self.smebu_kappa * violation) delta_bi = self.smebu_lambda * clamped_update delta_bi = delta_bi - delta_bi.mean() self.expert_momentum.data = self.smebu_beta * self.expert_momentum.data + (1 - self.smebu_beta) * delta_bi self.expert_bias.data = (self.expert_bias.data + self.expert_momentum.data).nan_to_num_().clamp(-10.0, 10.0) self.expert_bias.data -= self.expert_bias.data.mean() current_max_vio = -violation.min() self.max_vio.copy_(0.99 * self.max_vio + 0.01 * current_max_vio) class GroupedMoE(nn.Module): """Grouped MoE fallback — for non-BigMac configs.""" def __init__(self, config, layer_idx=None): super().__init__() self.d_model = config.d_model self.num_shared_experts = getattr(config, 'num_shared_experts', 1) self.num_routed_experts = getattr(config, 'num_routed_experts', 64) self.top_k = getattr(config, 'top_k', 6) self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) self.routed_expert_size = getattr(config, 'routed_expert_size', 1408) self.layer_idx = layer_idx self.shared_experts = nn.ModuleList([ SwiGLUBlock(self.d_model, self.shared_expert_size) for _ in range(self.num_shared_experts) ]) self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.d_model, 2 * self.routed_expert_size)) self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.d_model)) self.router = nn.Linear(config.d_model, config.num_routed_experts, bias=False) with torch.no_grad(): nn.init.normal_(self.router.weight, std=0.01) self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-6) self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) self.smebu_lambda = getattr(config, 'smebu_lambda', 5e-4) self.smebu_beta = getattr(config, 'smebu_beta', 0.5) self.register_buffer("max_vio", torch.tensor(0.0)) self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) def _init_weights(self, std=0.011): nn.init.normal_(self.experts_w12, std=std) nn.init.normal_(self.experts_w3, std=std) for expert in self.shared_experts: nn.init.normal_(expert.gate.weight, std=std) nn.init.normal_(expert.up.weight, std=std) nn.init.normal_(expert.down.weight, std=std) def forward(self, x, expert_bias=None): batch_size, seq_len, d_model = x.shape hidden_states = x.view(-1, d_model) N, D = hidden_states.shape K = self.top_k with torch.autocast(device_type=x.device.type, dtype=torch.float32): logits = self.router(hidden_states) scores = torch.sigmoid(logits) z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) selection_scores = scores + bias _, topk_indices = torch.topk(selection_scores, K, dim=-1) topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) topk_logits = torch.gather(logits, 1, topk_indices) gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) if self.training: flat_topk_idx = topk_indices.view(-1) expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) fi = expert_counts.float() / (N * K) Pi = scores.nan_to_num().mean(dim=0) aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight self._pending_violation = fi.detach() - (1.0 / self.num_routed_experts) else: aux_loss = torch.tensor(0.0, device=x.device) expert_counts = None self._pending_violation = torch.zeros(self.num_routed_experts, device=x.device) shared_out = 0 if self.num_shared_experts > 0: for expert in self.shared_experts: shared_out = shared_out + expert(hidden_states) # Padded BMM dispatch num_experts = self.num_routed_experts flat_topk_idx = topk_indices.view(-1) tokens_per_expert = torch.bincount(flat_topk_idx, minlength=num_experts) max_tokens = tokens_per_expert.max().item() if max_tokens == 0: out = shared_out * self.moe_scale return out.view(batch_size, seq_len, d_model).to(x.dtype), aux_loss sorted_indices = torch.argsort(flat_topk_idx) token_indices = torch.arange(N, device=x.device).repeat_interleave(K)[sorted_indices] grouped_x = hidden_states[token_indices] padded_x = torch.zeros(num_experts, max_tokens, D, device=x.device, dtype=x.dtype) expert_starts = torch.cat([torch.tensor([0], device=x.device), tokens_per_expert[:-1].cumsum(0)]) intra_offsets = torch.arange(N * K, device=x.device) - expert_starts.repeat_interleave(tokens_per_expert) expert_idx = flat_topk_idx[sorted_indices] padded_x_flat = padded_x.view(-1, D) flat_dest_indices = expert_idx * max_tokens + intra_offsets padded_x_flat.index_put_((flat_dest_indices,), grouped_x) h12 = torch.bmm(padded_x, self.experts_w12) h1, h2 = h12.chunk(2, dim=-1) h = F.silu(h1) * h2 expert_out_padded = torch.bmm(h, self.experts_w3) full_expert_out = expert_out_padded.view(-1, D)[flat_dest_indices] gating_flat = gating_scores.view(-1) sorted_gating = gating_flat[sorted_indices].unsqueeze(1) weighted_out = full_expert_out * sorted_gating routed_out = torch.zeros_like(hidden_states) routed_out.index_add_(0, token_indices, weighted_out) route_scale = math.sqrt(self.top_k) if self.training else 1.0 out = (shared_out + routed_out * route_scale) * self.moe_scale out = out.view(batch_size, seq_len, d_model).to(x.dtype) return out, z_loss + aux_loss # =================================================================== # HybridBlock — one transformer layer # Weight names: ln1.weight, ln1_out.weight, ln2.weight, ln2_out.weight, # attn.*, memory.*, W_alpha.*, C_to_hidden.*, # ffn.*, injection_gate # =================================================================== class HybridBlock(nn.Module): def __init__(self, config: QuasarConfig, layer_idx: int): super().__init__() self.hidden_size = config.d_model self.layer_idx = layer_idx self.n_layers = config.n_layers self.config = config self.gradient_checkpointing = False # Looped Transformer injection gate (checkpoint always has it) self.use_looped_injection = config.use_looped_injection self.injection_gate = nn.Parameter(torch.tensor([-2.197])) # Determine layer type (use hybrid_layer_types for quasar/gla distinction) self.layer_type = config.hybrid_layer_types[layer_idx] # Attention layer if self.layer_type == "quasar": self.attn = QuasarAttention( mode=config.attn_mode, hidden_size=config.d_model, expand_v=config.expand_v, head_dim=config.head_dim, num_heads=config.n_heads, num_v_heads=config.num_v_heads, use_short_conv=config.use_short_conv, allow_neg_eigval=config.allow_neg_eigval, conv_size=config.conv_size, norm_eps=config.rms_norm_eps, layer_idx=layer_idx, ) elif self.layer_type == "gla": self.attn = GatedLinearAttention( mode=config.gla_mode, hidden_size=config.d_model, expand_k=config.expand_k, expand_v=config.expand_v, num_heads=config.n_heads, layer_idx=layer_idx, ) # Latent Memory Module self.memory = LatentMemoryModule( hidden_size=config.d_model, memory_slots=config.memory_slots, memory_dim=config.memory_dim, use_triton=False, ) nn.init.constant_(self.memory.W_eta.bias, -1.0) self.W_alpha = nn.Linear(config.d_model, 1) self.C_to_hidden = nn.Linear(config.memory_dim, config.d_model, bias=False) else: raise ValueError(f"Unknown layer_type: {self.layer_type}") # Sandwich norms self.ln1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.ln1_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.ln2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.ln2_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) # FFN vs MoE dense_layers = config.dense_input_layers num_routed = config.num_routed_experts if layer_idx < dense_layers or num_routed == 0: self.is_moe = False self.ffn = SwiGLUBlock(config.d_model, config.d_ff) else: self.is_moe = True if config.moe_type == "bigmac": self.ffn = BigMacMoE(config, layer_idx=layer_idx) elif config.moe_type == "deepseek": # DeepSeekMoE could be added here if needed self.ffn = BigMacMoE(config, layer_idx=layer_idx) else: self.ffn = GroupedMoE(config, layer_idx=layer_idx) self.dropout = nn.Dropout(config.dropout) self.scale_factor = 1.0 / math.sqrt(2 * self.n_layers) self.residual_scale = config.residual_scale self._init_weights() def _init_weights(self): trinity_std = 0.5 / math.sqrt(self.hidden_size) if self.layer_type == "gla": nn.init.constant_(self.W_alpha.bias, -10.0) nn.init.zeros_(self.W_alpha.weight) nn.init.normal_(self.C_to_hidden.weight, std=trinity_std) def apply_deep_init(m): if hasattr(m, 'down') and isinstance(m.down, nn.Linear): nn.init.normal_(m.down.weight, mean=0.0, std=trinity_std * self.scale_factor) if not self.is_moe: nn.init.normal_(self.ffn.gate.weight, mean=0.0, std=trinity_std) nn.init.normal_(self.ffn.up.weight, mean=0.0, std=trinity_std) apply_deep_init(self.ffn) else: self.ffn._init_weights(std=trinity_std) for expert in self.ffn.shared_experts: apply_deep_init(expert) nn.init.normal_(self.ffn.experts_w3, mean=0.0, std=trinity_std) nn.init.constant_(self.ln1_out.weight, 1.0) nn.init.constant_(self.ln2_out.weight, 1.0) if hasattr(self.attn, 'o_proj') and isinstance(self.attn.o_proj, nn.Linear): nn.init.normal_(self.attn.o_proj.weight, mean=0.0, std=trinity_std * self.scale_factor) for proj_name in ['q_proj', 'k_proj', 'v_proj', 'g_proj']: if hasattr(self.attn, proj_name): m = getattr(self.attn, proj_name) if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=trinity_std) elif isinstance(m, nn.Sequential): for subm in m: if isinstance(subm, nn.Linear): nn.init.normal_(subm.weight, mean=0.0, std=trinity_std) def forward(self, x, cos=None, sin=None, expert_bias=None, memory_state=None, lambda_reg=0.01, **kwargs): if self.use_looped_injection: P = kwargs.get('P') if P is not None: x = x + (torch.sigmoid(self.injection_gate) * P) if self.gradient_checkpointing and self.training: return torch.utils.checkpoint.checkpoint( self._forward, x, cos, sin, expert_bias, memory_state, lambda_reg, use_reentrant=False, **kwargs, ) return self._forward(x, cos, sin, expert_bias, memory_state, lambda_reg, **kwargs) def _forward(self, x, cos=None, sin=None, expert_bias=None, memory_state=None, lambda_reg=0.01, **kwargs): # 1. Attention block residual = x x = self.ln1(x) # Build attention kwargs attn_kwargs = {} if cos is not None and sin is not None: attn_kwargs['cos'] = cos attn_kwargs['sin'] = sin # Pass past_key_values for FLA cache support if 'past_key_values' in kwargs and kwargs['past_key_values'] is not None: attn_kwargs['past_key_values'] = kwargs['past_key_values'] if 'use_cache' in kwargs: attn_kwargs['use_cache'] = kwargs['use_cache'] attn_out = self.attn(x, **attn_kwargs) if isinstance(attn_out, tuple): attn_out = attn_out[0] new_memory_state = None mem_loss = torch.tensor(0.0, device=x.device) # GLA layers: read/write latent memory if self.layer_type == "gla" and memory_state is not None: new_memory_state, total_mem_loss, _ = self.memory.write_memory(x, memory_state) C = self.memory.read_memory(x, new_memory_state) alpha = torch.sigmoid(self.W_alpha(x)) C_proj = self.C_to_hidden(C) attn_out = attn_out + (alpha * C_proj) mem_loss = total_mem_loss # Sandwich norm + residual scaling x = residual + self.residual_scale * self.dropout(self.ln1_out(attn_out)) # 2. FFN / MoE block residual = x x = self.ln2(x) if self.is_moe: block_out, aux_loss = self.ffn(x, expert_bias=expert_bias) else: block_out = self.ffn(x) aux_loss = torch.tensor(0.0, device=x.device) x = residual + self.residual_scale * self.dropout(self.ln2_out(block_out)) return x, aux_loss, new_memory_state, mem_loss # =================================================================== # Output dataclasses # =================================================================== @dataclass class QuasarModelOutputWithPast(BaseModelOutputWithPast): memory_states: dict | None = None memory_loss: torch.Tensor | None = None @dataclass class QuasarCausalLMOutputWithPast(CausalLMOutputWithPast): memory_states: dict | None = None memory_loss: torch.Tensor | None = None aux_loss: torch.Tensor | None = None # =================================================================== # PreTrainedModel base # =================================================================== class QuasarPreTrainedModel(PreTrainedModel): config_class = QuasarConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["HybridBlock"] _supports_cache_class = True def _init_weights(self, module): std = getattr(self.config, "initializer_range", 0.02) if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # =================================================================== # QuasarModel — base transformer (no LM head) # Weight prefix: model.* (embed_tokens, embed_norm, layers, norm, rotary_emb, all_moe_*) # =================================================================== class QuasarModel(QuasarPreTrainedModel): config: QuasarConfig def __init__(self, config: QuasarConfig): super().__init__(config) self.config = config d_model = config.d_model n_heads = config.n_heads n_layers = config.n_layers vocab_size = config.vocab_size max_seq_len = config.max_seq_len self.embed_tokens = nn.Embedding(vocab_size, d_model) self.embed_norm = RMSNorm(d_model, eps=config.rms_norm_eps) self.layers = nn.ModuleList([ HybridBlock(config, i) for i in range(n_layers) ]) self.norm = RMSNorm(d_model, eps=config.rms_norm_eps) self.rotary_emb = RotaryEmbedding( d_model // n_heads, max_seq_len, base=config.rope_theta, ) # SMEBU global buffers — sized [num_moe, num_experts] to match checkpoint self.moe_layer_ffns = [l.ffn for l in self.layers if getattr(l, 'is_moe', False)] self.num_moe = len(self.moe_layer_ffns) num_experts = config.num_routed_experts if self.num_moe > 0 and num_experts > 0: self.register_buffer("all_moe_bias", torch.zeros(self.num_moe, num_experts)) self.register_buffer("all_moe_momentum", torch.zeros(self.num_moe, num_experts)) self.register_buffer("all_moe_max_vio", torch.zeros(self.num_moe)) self.gradient_checkpointing = False self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def init_memory(self, batch_size, device, dtype=torch.float32): memory_states = {} for i, layer in enumerate(self.layers): if layer.layer_type == "gla": m = torch.zeros(batch_size, layer.memory.K, layer.memory.D, device=device, dtype=dtype) memory_states[i] = m return memory_states def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, memory_states: dict | None = None, lambda_reg: float = 0.01, **kwargs, ): output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("Specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # Embed norm for stability hidden_states = self.embed_norm(inputs_embeds) batch_size, seq_len, _ = hidden_states.shape # Position ids if position_ids is None: past_seen_tokens = 0 if past_key_values is not None: try: past_seen_tokens = past_key_values.get_seq_length() except Exception: past_seen_tokens = 0 position_ids = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=hidden_states.device) # RoPE max_pos = int(position_ids.max().item() + 1) if position_ids.numel() > 0 else seq_len cos_full, sin_full = self.rotary_emb(hidden_states, seq_len=max_pos) if position_ids.dim() == 1: cos = cos_full[:, :, position_ids] sin = sin_full[:, :, position_ids] else: cos = cos_full[:, :, position_ids[0]] sin = sin_full[:, :, position_ids[0]] # Memory states if memory_states is None: memory_states = self.init_memory(batch_size, hidden_states.device, hidden_states.dtype) all_hidden_states = () if output_hidden_states else None aux_losses = [] mem_losses = [] new_memory_states = {} # Looped transformer anchor P = hidden_states num_loops = self.config.num_loops current_memory_states = memory_states # Snapshot expert bias for gradient checkpointing consistency if self.num_moe > 0: bias_snapshot = self.all_moe_bias.detach().clone() else: bias_snapshot = None for loop_idx in range(num_loops): moe_idx = 0 iteration_new_memory_states = {} for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) bias = bias_snapshot[moe_idx] if (getattr(layer, 'is_moe', False) and bias_snapshot is not None) else None layer_out = layer( hidden_states, cos=cos, sin=sin, expert_bias=bias, memory_state=current_memory_states.get(layer.layer_idx), lambda_reg=lambda_reg, P=P if self.config.use_looped_injection else None, past_key_values=past_key_values, use_cache=use_cache, **kwargs, ) hidden_states, aux_loss, new_m, m_loss = layer_out if new_m is not None: iteration_new_memory_states[layer.layer_idx] = new_m mem_losses.append(m_loss) if bias is not None: moe_idx += 1 aux_losses.append(aux_loss) current_memory_states = iteration_new_memory_states new_memory_states = iteration_new_memory_states # SMEBU bias update (no_grad to avoid checkpointing issues) if self.training and self.num_moe > 0: with torch.no_grad(): self._update_all_moe_biases() hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) total_aux = torch.stack(aux_losses).sum() if aux_losses else torch.tensor(0.0, device=hidden_states.device) total_mem = torch.stack(mem_losses).sum() if mem_losses else torch.tensor(0.0, device=hidden_states.device) return QuasarModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, memory_states=new_memory_states, memory_loss=total_mem, ), total_aux def _update_all_moe_biases(self): violations = torch.stack([m._pending_violation for m in self.moe_layer_ffns]) m0 = self.moe_layer_ffns[0] kappa, lamb, beta = m0.smebu_kappa, m0.smebu_lambda, m0.smebu_beta clamped_update = torch.tanh(kappa * violations) delta_bi = lamb * clamped_update delta_bi = delta_bi - delta_bi.mean(dim=-1, keepdim=True) self.all_moe_momentum.mul_(beta).add_(delta_bi, alpha=1 - beta) self.all_moe_bias.add_(self.all_moe_momentum).nan_to_num_().clamp_(-10.0, 10.0) self.all_moe_bias.sub_(self.all_moe_bias.mean(dim=-1, keepdim=True)) current_max_vios = -violations.min(dim=-1).values self.all_moe_max_vio.mul_(0.99).add_(current_max_vios, alpha=0.01) for i, moe in enumerate(self.moe_layer_ffns): moe.max_vio.copy_(self.all_moe_max_vio[i]) del moe._pending_violation # =================================================================== # QuasarForCausalLM — with LM head + generation support # Weight prefix: lm_head.* (top-level), model.* (from QuasarModel) # =================================================================== class QuasarForCausalLM(QuasarPreTrainedModel, FLAGenerationMixin): config: QuasarConfig _tied_weights_keys = {} def __init__(self, config: QuasarConfig): super().__init__(config) self.model = QuasarModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self, missing_keys=None, recompute_mapping=False): pass # Don't tie — crashes FSDP def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_hidden_states: bool | None = None, memory_states: dict | None = None, lambda_reg: float = 0.01, return_dict: bool | None = None, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict model_outputs, total_aux = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, memory_states=memory_states, lambda_reg=lambda_reg, **kwargs, ) hidden_states = model_outputs.last_hidden_state loss = None if labels is not None: shift_hidden = hidden_states[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() flat_hidden = shift_hidden.view(-1, self.config.d_model) flat_labels = shift_labels.view(-1) mask = flat_labels != -100 if mask.any(): active_hidden = flat_hidden[mask] active_labels = flat_labels[mask] chunk_size = 256 total_loss = 0.0 total_tokens = active_labels.numel() for i in range(0, total_tokens, chunk_size): end = min(i + chunk_size, total_tokens) chunk_logits = self.lm_head(active_hidden[i:end]) chunk_loss = F.cross_entropy(chunk_logits.float(), active_labels[i:end], reduction='sum') total_loss += chunk_loss loss = total_loss / total_tokens loss = loss + total_aux + model_outputs.memory_loss else: loss = torch.tensor(0.0, device=hidden_states.device, requires_grad=True) logits = None else: logits = self.lm_head(hidden_states) if not return_dict: output = (logits,) + model_outputs[1:] return ((loss,) + output) if loss is not None else output return QuasarCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=model_outputs.past_key_values, hidden_states=model_outputs.hidden_states, memory_states=model_outputs.memory_states, memory_loss=model_outputs.memory_loss, aux_loss=total_aux, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, memory_states=None, cache_position=None, use_cache=True, **kwargs, ): if past_key_values is not None: if input_ids is not None: input_ids = input_ids[:, -1:] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:, -1:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} if memory_states is None and past_key_values is not None: memory_states = getattr(past_key_values, "memory_states", None) model_inputs.update({ "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cache_position": cache_position, "memory_states": memory_states, }) return model_inputs def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_seq2seq=False, num_new_tokens=1): model_kwargs = super().update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_seq2seq=is_seq2seq, num_new_tokens=num_new_tokens, ) if getattr(outputs, "memory_states", None) is not None: model_kwargs["memory_states"] = outputs.memory_states return model_kwargs def _reorder_cache(self, past_key_values, beam_idx): if past_key_values is None: return None return past_key_values.reorder_cache(beam_idx) __all__ = [ "QuasarConfig", "QuasarPreTrainedModel", "QuasarModel", "QuasarForCausalLM", "QuasarModelOutputWithPast", "QuasarCausalLMOutputWithPast", ]