| """Quasar hybrid transformer — HuggingFace compatible. |
| |
| """ |
|
|
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from transformers import GenerationMixin |
| from transformers.cache_utils import Cache |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import logging |
|
|
| from .configuration_quasar import QuasarConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| from fla.layers.quasar import QuasarAttention |
| from fla.layers.gla import GatedLinearAttention |
| from fla.models.utils import Cache as FlaCache, FLAGenerationMixin |
|
|
|
|
| |
| |
| |
| class RMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
|
|
| |
| |
| |
| class RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings=4096, base=100000, device=None): |
| super().__init__() |
| self.dim = dim |
| self.max_position_embeddings = max_position_embeddings |
| self.base = base |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
| t = torch.arange(max_position_embeddings + 1, device=device, dtype=inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", t, inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("_cos_cached", emb.cos()[None, None, :, :], persistent=False) |
| self.register_buffer("_sin_cached", emb.sin()[None, None, :, :], persistent=False) |
|
|
| def forward(self, x, seq_len=None): |
| if seq_len is not None and seq_len > self._cos_cached.shape[2]: |
| t = torch.arange(seq_len + 1024, device=x.device, dtype=self.inv_freq.dtype) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| self.register_buffer("_cos_cached", emb.cos()[None, None, :, :].to(self._cos_cached.dtype), persistent=False) |
| self.register_buffer("_sin_cached", emb.sin()[None, None, :, :].to(self._sin_cached.dtype), persistent=False) |
| return ( |
| self._cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| self._sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), |
| ) |
|
|
|
|
| |
| |
| |
| class LatentMemoryModule(nn.Module): |
| """Persistent Latent Parameter Memory — weight names match checkpoint.""" |
|
|
| def __init__(self, hidden_size, memory_slots=128, memory_dim=128, use_triton=False): |
| super().__init__() |
| self.K = memory_slots |
| self.D = memory_dim |
|
|
| self.W_eta = nn.Linear(hidden_size, 1, bias=True) |
| nn.init.zeros_(self.W_eta.weight) |
| nn.init.constant_(self.W_eta.bias, -5.0) |
|
|
| self.segment_len = 64 |
| self.summary_query = nn.Parameter(torch.randn(1, 1, memory_dim)) |
| self.summary_proj = nn.Linear(hidden_size, memory_dim, bias=True) |
| self.eta_channels = nn.Parameter(torch.ones(1, 1, memory_dim)) |
| self.temperature = nn.Parameter(torch.ones(1)) |
| self.hidden_size = hidden_size |
| self.use_triton = False |
| self.input_norm = nn.LayerNorm(hidden_size) |
| self.compress_z = nn.Sequential( |
| nn.Linear(hidden_size, memory_dim * 2, bias=False), |
| nn.SiLU(), |
| nn.Linear(memory_dim * 2, memory_dim, bias=False), |
| ) |
| self.W_qkv_mem = nn.Linear(hidden_size, memory_dim * 3, bias=False) |
| self.scale = 1.0 / math.sqrt(memory_dim) |
|
|
| def get_diversity_loss(self, M): |
| B, K, D = M.shape |
| M_norm = F.normalize(M, p=2, dim=-1) |
| sim = torch.bmm(M_norm, M_norm.transpose(1, 2)) |
| mask = torch.eye(K, device=M.device).unsqueeze(0) |
| sim = sim * (1 - mask) |
| return sim.pow(2).mean() |
|
|
| def write_memory(self, H, M, chunk_idx=0): |
| H = self.input_norm(H) |
| B, T, _ = H.shape |
| H_mem = self.summary_proj(H) |
| eta_tokens = self.W_eta(H).squeeze(-1) |
|
|
| L = self.segment_len |
| if T % L != 0: |
| pad_len = L - (T % L) |
| H_padded = F.pad(H_mem, (0, 0, 0, pad_len)) |
| eta_padded = F.pad(eta_tokens, (0, pad_len), value=-10.0) |
| else: |
| H_padded = H_mem |
| eta_padded = eta_tokens |
|
|
| T_pad = H_padded.shape[1] |
| num_segments = T_pad // L |
| H_segs = H_padded.view(B * num_segments, L, self.D) |
|
|
| summary_scores = torch.bmm( |
| self.summary_query.expand(B * num_segments, -1, -1), |
| H_segs.transpose(1, 2), |
| ) |
| summary_weights = F.softmax(summary_scores * self.scale, dim=-1) |
| Z_seg = torch.bmm(summary_weights, H_segs).view(B, num_segments, self.D) |
|
|
| eta_raw_sig = torch.sigmoid(eta_tokens) |
| eta_seg_sig = torch.max( |
| torch.sigmoid(eta_padded.view(B, num_segments, L)), dim=-1, keepdim=True |
| )[0] |
|
|
| scores = torch.bmm(Z_seg, M.transpose(-1, -2)) * self.scale * torch.exp(self.temperature) |
| A = F.softmax(scores, dim=-1) |
| DeltaM_seg = torch.bmm(A.transpose(1, 2), Z_seg * eta_seg_sig) |
| eta_avg = eta_seg_sig.mean(dim=1, keepdim=True) |
| gate = eta_avg * torch.sigmoid(self.eta_channels) |
| M_new = (1.0 - gate) * M + DeltaM_seg / num_segments |
| norm_sq = torch.sum(DeltaM_seg ** 2) / num_segments |
| div_loss = self.get_diversity_loss(M_new) |
| return M_new, norm_sq * 0.01 + div_loss * 0.1, eta_raw_sig |
|
|
| def read_memory(self, H, M, memory_scale=1.0): |
| H = self.input_norm(H) |
| qkv_mem = self.W_qkv_mem(H) |
| _, _, Q_r = torch.split(qkv_mem, [self.D, self.D, self.D], dim=-1) |
| scores = torch.bmm(Q_r, M.transpose(-1, -2)) |
| if M.shape[1] > 1024: |
| top_k = 64 |
| top_vals, top_idx = torch.topk(scores, top_k, dim=-1) |
| mask = torch.full_like(scores, float('-inf')) |
| mask.scatter_(-1, top_idx, top_vals) |
| scores = mask |
| A = F.softmax(scores * 2.0, dim=-1) |
| C = torch.bmm(A, M) |
| return C * memory_scale |
|
|
|
|
| |
| |
| |
| class SwiGLUBlock(nn.Module): |
| """Dense FFN — weight names: gate.weight, up.weight, down.weight""" |
|
|
| def __init__(self, d_model, d_ff): |
| super().__init__() |
| self.gate = nn.Linear(d_model, d_ff, bias=False) |
| self.up = nn.Linear(d_model, d_ff, bias=False) |
| self.down = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.down(F.silu(self.gate(x)) * self.up(x)) |
|
|
|
|
| class SigmoidRouter(nn.Module): |
| """Router with router_weights Parameter — weight name: router.router_weights""" |
|
|
| def __init__(self, d_model, num_experts): |
| super().__init__() |
| self.router_weights = nn.Parameter(torch.zeros(num_experts, d_model)) |
| nn.init.kaiming_uniform_(self.router_weights, a=math.sqrt(5)) |
|
|
| def forward(self, x): |
| logits = F.linear(x, self.router_weights) |
| scores = torch.sigmoid(logits) |
| return scores, logits |
|
|
|
|
| class BigMacMoE(nn.Module): |
| """BigMac MoE with DCCA bottleneck — matches checkpoint weight names exactly. |
| |
| Weights: w_down_proj, w_up_proj, experts_w12, experts_w3, |
| router.router_weights, shared_experts.{i}.{gate,up,down}.weight, |
| max_vio |
| """ |
|
|
| def __init__(self, config, layer_idx=None): |
| super().__init__() |
| self.d_model = config.d_model |
| self.bigmac_r = getattr(config, 'bigmac_r', 0.25) |
| self.bottle_dim = int(self.d_model * self.bigmac_r) |
|
|
| self.num_shared_experts = getattr(config, 'num_shared_experts', 1) |
| self.num_routed_experts = getattr(config, 'num_routed_experts', 64) |
| self.top_k = getattr(config, 'top_k', 4) |
|
|
| default_routed_size = int(getattr(config, 'routed_expert_size', 768) / self.bigmac_r) |
| self.routed_expert_size = getattr(config, 'bigmac_expert_size', default_routed_size) |
| self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) |
| self.layer_idx = layer_idx |
|
|
| self.shared_experts = nn.ModuleList([ |
| SwiGLUBlock(self.d_model, self.shared_expert_size) |
| for _ in range(self.num_shared_experts) |
| ]) |
|
|
| |
| self.w_down_proj = nn.Linear(self.d_model, self.bottle_dim, bias=False) |
| self.w_up_proj = nn.Linear(self.bottle_dim, self.d_model, bias=False) |
|
|
| |
| self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.bottle_dim, 2 * self.routed_expert_size)) |
| self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.bottle_dim)) |
|
|
| self.router = SigmoidRouter(self.d_model, self.num_routed_experts) |
|
|
| self.expert_bias = None |
| self.expert_momentum = None |
| self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) |
| self.smebu_lambda = getattr(config, 'smebu_lambda', 2e-3) |
| self.smebu_beta = getattr(config, 'smebu_beta', 0.5) |
|
|
| self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-4) |
| self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) |
| self.register_buffer("max_vio", torch.tensor(0.0)) |
| self.route_scale = math.sqrt(self.top_k) |
| self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) |
|
|
| |
| self.register_buffer("_dummy_token", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) |
| self.register_buffer("_dummy_out", torch.zeros(1, self.bottle_dim, dtype=torch.bfloat16), persistent=False) |
| self._cached_N = -1 |
| self._cached_K = -1 |
| self._cached_indices = None |
|
|
| def _init_weights(self, std=0.011): |
| nn.init.normal_(self.w_down_proj.weight, std=std) |
| nn.init.normal_(self.w_up_proj.weight, std=std) |
| nn.init.normal_(self.experts_w12, std=std) |
| nn.init.normal_(self.experts_w3, std=std) |
| for expert in self.shared_experts: |
| nn.init.normal_(expert.gate.weight, std=std) |
| nn.init.normal_(expert.up.weight, std=std) |
| nn.init.normal_(expert.down.weight, std=std) |
|
|
| def forward(self, x, expert_bias=None): |
| batch_size, seq_len, d_model = x.shape |
| hidden_states = x.view(-1, d_model) |
| N, D = hidden_states.shape |
| K = self.top_k |
| num_tokens_total = N * K |
|
|
| |
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): |
| scores, logits = self.router(hidden_states) |
| z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight |
|
|
| bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) |
| selection_scores = scores + bias |
| _, topk_indices = torch.topk(selection_scores, K, dim=-1) |
| topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) |
|
|
| topk_logits = torch.gather(logits, 1, topk_indices) |
| gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) |
|
|
| |
| if self.training: |
| flat_topk_idx = topk_indices.view(-1) |
| expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) |
| fi = expert_counts.float() / num_tokens_total |
| Pi = scores.nan_to_num().mean(dim=0) |
| aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight |
| else: |
| aux_loss = torch.tensor(0.0, device=x.device) |
| expert_counts = None |
|
|
| |
| shared_out = 0 |
| if self.num_shared_experts > 0: |
| for expert in self.shared_experts: |
| shared_out = shared_out + expert(hidden_states) |
|
|
| |
| down_proj_hidden = self.w_down_proj(hidden_states) |
|
|
| |
| flat_topk_idx = topk_indices.view(-1).clamp(0, self.num_routed_experts - 1) |
| sorted_experts, permutation = torch.sort(flat_topk_idx) |
|
|
| if self._cached_N == N and self._cached_K == K: |
| token_indices, global_rel_idx = self._cached_indices |
| else: |
| token_indices = torch.arange(N, device=x.device).repeat_interleave(K) |
| global_rel_idx = torch.arange(num_tokens_total, device=x.device) |
| self._cached_N, self._cached_K = N, K |
| self._cached_indices = (token_indices, global_rel_idx) |
|
|
| max_load = ((num_tokens_total // self.num_routed_experts) // 8 + 6) * 8 |
| used_counts = expert_counts if expert_counts is not None else torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) |
| expert_ptr = torch.cumsum(used_counts, dim=0) - used_counts |
|
|
| local_idx = global_rel_idx - expert_ptr.index_select(0, sorted_experts) |
| capacity_mask = local_idx < max_load |
| valid_slots = sorted_experts[capacity_mask] * max_load + local_idx[capacity_mask] |
| num_slots = self.num_routed_experts * max_load |
|
|
| hidden_with_dummy = torch.cat([down_proj_hidden, self._dummy_token], dim=0) |
| reverse_map = torch.full((num_slots,), N, device=x.device, dtype=torch.long) |
| reverse_map.scatter_(0, valid_slots.long(), token_indices[permutation][capacity_mask]) |
|
|
| padding = hidden_with_dummy.index_select(0, reverse_map).view(self.num_routed_experts, max_load, self.bottle_dim) |
|
|
| h12 = torch.bmm(padding, self.experts_w12) |
| h1, h2 = h12.chunk(2, dim=-1) |
| padded_out = torch.bmm(F.silu(h1) * h2, self.experts_w3) |
|
|
| padded_out_flat = padded_out.view(-1, self.bottle_dim) |
| padded_out_with_dummy = torch.cat([padded_out_flat, self._dummy_out], dim=0) |
|
|
| gather_map = torch.full((num_tokens_total,), num_slots, device=x.device, dtype=torch.long) |
| gather_map.scatter_(0, permutation[capacity_mask], valid_slots) |
|
|
| gathered_out = padded_out_with_dummy.index_select(0, gather_map).view(N, K, self.bottle_dim) |
|
|
| routed_out_bottle = torch.bmm(gating_scores.to(gathered_out.dtype).unsqueeze(1), gathered_out).squeeze(1) |
| routed_out = self.w_up_proj(routed_out_bottle) |
|
|
| if self.training: |
| mean_load = num_tokens_total / self.num_routed_experts |
| self._pending_violation = (mean_load - used_counts.float()) / (mean_load + 1e-6) |
|
|
| route_scale = math.sqrt(self.top_k) if self.training else 1.0 |
| out = (shared_out + routed_out * route_scale) * self.moe_scale |
| out = out.view(batch_size, seq_len, d_model).to(x.dtype) |
|
|
| return out, z_loss + aux_loss |
|
|
| def update_bias(self, counts, num_tokens): |
| expert_counts = counts.float() |
| mean_load = num_tokens * self.top_k / self.num_routed_experts |
| violation = (mean_load - expert_counts) / (mean_load + 1e-6) |
| clamped_update = torch.tanh(self.smebu_kappa * violation) |
| delta_bi = self.smebu_lambda * clamped_update |
| delta_bi = delta_bi - delta_bi.mean() |
| self.expert_momentum.data = self.smebu_beta * self.expert_momentum.data + (1 - self.smebu_beta) * delta_bi |
| self.expert_bias.data = (self.expert_bias.data + self.expert_momentum.data).nan_to_num_().clamp(-10.0, 10.0) |
| self.expert_bias.data -= self.expert_bias.data.mean() |
| current_max_vio = -violation.min() |
| self.max_vio.copy_(0.99 * self.max_vio + 0.01 * current_max_vio) |
|
|
|
|
| class GroupedMoE(nn.Module): |
| """Grouped MoE fallback — for non-BigMac configs.""" |
|
|
| def __init__(self, config, layer_idx=None): |
| super().__init__() |
| self.d_model = config.d_model |
| self.num_shared_experts = getattr(config, 'num_shared_experts', 1) |
| self.num_routed_experts = getattr(config, 'num_routed_experts', 64) |
| self.top_k = getattr(config, 'top_k', 6) |
| self.shared_expert_size = getattr(config, 'shared_expert_size', config.d_ff) |
| self.routed_expert_size = getattr(config, 'routed_expert_size', 1408) |
| self.layer_idx = layer_idx |
|
|
| self.shared_experts = nn.ModuleList([ |
| SwiGLUBlock(self.d_model, self.shared_expert_size) |
| for _ in range(self.num_shared_experts) |
| ]) |
| self.experts_w12 = nn.Parameter(torch.zeros(self.num_routed_experts, self.d_model, 2 * self.routed_expert_size)) |
| self.experts_w3 = nn.Parameter(torch.zeros(self.num_routed_experts, self.routed_expert_size, self.d_model)) |
| self.router = nn.Linear(config.d_model, config.num_routed_experts, bias=False) |
| with torch.no_grad(): |
| nn.init.normal_(self.router.weight, std=0.01) |
| self.z_loss_weight = getattr(config, 'moe_z_loss_coeff', 1e-6) |
| self.aux_loss_weight = getattr(config, 'moe_aux_loss_coeff', 1e-4) |
| self.smebu_kappa = getattr(config, 'smebu_kappa', 2.0) |
| self.smebu_lambda = getattr(config, 'smebu_lambda', 5e-4) |
| self.smebu_beta = getattr(config, 'smebu_beta', 0.5) |
| self.register_buffer("max_vio", torch.tensor(0.0)) |
| self.moe_scale = 1.0 / (1.0 + float(self.num_shared_experts > 0)) |
|
|
| def _init_weights(self, std=0.011): |
| nn.init.normal_(self.experts_w12, std=std) |
| nn.init.normal_(self.experts_w3, std=std) |
| for expert in self.shared_experts: |
| nn.init.normal_(expert.gate.weight, std=std) |
| nn.init.normal_(expert.up.weight, std=std) |
| nn.init.normal_(expert.down.weight, std=std) |
|
|
| def forward(self, x, expert_bias=None): |
| batch_size, seq_len, d_model = x.shape |
| hidden_states = x.view(-1, d_model) |
| N, D = hidden_states.shape |
| K = self.top_k |
|
|
| with torch.autocast(device_type=x.device.type, dtype=torch.float32): |
| logits = self.router(hidden_states) |
| scores = torch.sigmoid(logits) |
| z_loss = torch.mean(logits.nan_to_num() ** 2) * self.z_loss_weight |
| bias = expert_bias if expert_bias is not None else torch.zeros(self.num_routed_experts, device=x.device) |
| selection_scores = scores + bias |
| _, topk_indices = torch.topk(selection_scores, K, dim=-1) |
| topk_indices = topk_indices.clamp(0, logits.shape[1] - 1) |
| topk_logits = torch.gather(logits, 1, topk_indices) |
| gating_scores = F.softmax(topk_logits, dim=-1).to(torch.bfloat16) |
|
|
| if self.training: |
| flat_topk_idx = topk_indices.view(-1) |
| expert_counts = torch.bincount(flat_topk_idx, minlength=self.num_routed_experts) |
| fi = expert_counts.float() / (N * K) |
| Pi = scores.nan_to_num().mean(dim=0) |
| aux_loss = torch.sum(fi * Pi) * self.aux_loss_weight |
| self._pending_violation = fi.detach() - (1.0 / self.num_routed_experts) |
| else: |
| aux_loss = torch.tensor(0.0, device=x.device) |
| expert_counts = None |
| self._pending_violation = torch.zeros(self.num_routed_experts, device=x.device) |
|
|
| shared_out = 0 |
| if self.num_shared_experts > 0: |
| for expert in self.shared_experts: |
| shared_out = shared_out + expert(hidden_states) |
|
|
| |
| num_experts = self.num_routed_experts |
| flat_topk_idx = topk_indices.view(-1) |
| tokens_per_expert = torch.bincount(flat_topk_idx, minlength=num_experts) |
| max_tokens = tokens_per_expert.max().item() |
|
|
| if max_tokens == 0: |
| out = shared_out * self.moe_scale |
| return out.view(batch_size, seq_len, d_model).to(x.dtype), aux_loss |
|
|
| sorted_indices = torch.argsort(flat_topk_idx) |
| token_indices = torch.arange(N, device=x.device).repeat_interleave(K)[sorted_indices] |
| grouped_x = hidden_states[token_indices] |
| padded_x = torch.zeros(num_experts, max_tokens, D, device=x.device, dtype=x.dtype) |
| expert_starts = torch.cat([torch.tensor([0], device=x.device), tokens_per_expert[:-1].cumsum(0)]) |
| intra_offsets = torch.arange(N * K, device=x.device) - expert_starts.repeat_interleave(tokens_per_expert) |
| expert_idx = flat_topk_idx[sorted_indices] |
| padded_x_flat = padded_x.view(-1, D) |
| flat_dest_indices = expert_idx * max_tokens + intra_offsets |
| padded_x_flat.index_put_((flat_dest_indices,), grouped_x) |
| h12 = torch.bmm(padded_x, self.experts_w12) |
| h1, h2 = h12.chunk(2, dim=-1) |
| h = F.silu(h1) * h2 |
| expert_out_padded = torch.bmm(h, self.experts_w3) |
| full_expert_out = expert_out_padded.view(-1, D)[flat_dest_indices] |
| gating_flat = gating_scores.view(-1) |
| sorted_gating = gating_flat[sorted_indices].unsqueeze(1) |
| weighted_out = full_expert_out * sorted_gating |
| routed_out = torch.zeros_like(hidden_states) |
| routed_out.index_add_(0, token_indices, weighted_out) |
|
|
| route_scale = math.sqrt(self.top_k) if self.training else 1.0 |
| out = (shared_out + routed_out * route_scale) * self.moe_scale |
| out = out.view(batch_size, seq_len, d_model).to(x.dtype) |
| return out, z_loss + aux_loss |
|
|
|
|
| |
| |
| |
| |
| |
| |
| class HybridBlock(nn.Module): |
| def __init__(self, config: QuasarConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.d_model |
| self.layer_idx = layer_idx |
| self.n_layers = config.n_layers |
| self.config = config |
| self.gradient_checkpointing = False |
|
|
| |
| self.use_looped_injection = config.use_looped_injection |
| self.injection_gate = nn.Parameter(torch.tensor([-2.197])) |
|
|
| |
| self.layer_type = config.hybrid_layer_types[layer_idx] |
|
|
| |
| if self.layer_type == "quasar": |
| self.attn = QuasarAttention( |
| mode=config.attn_mode, |
| hidden_size=config.d_model, |
| expand_v=config.expand_v, |
| head_dim=config.head_dim, |
| num_heads=config.n_heads, |
| num_v_heads=config.num_v_heads, |
| use_short_conv=config.use_short_conv, |
| allow_neg_eigval=config.allow_neg_eigval, |
| conv_size=config.conv_size, |
| norm_eps=config.rms_norm_eps, |
| layer_idx=layer_idx, |
| ) |
| elif self.layer_type == "gla": |
| self.attn = GatedLinearAttention( |
| mode=config.gla_mode, |
| hidden_size=config.d_model, |
| expand_k=config.expand_k, |
| expand_v=config.expand_v, |
| num_heads=config.n_heads, |
| layer_idx=layer_idx, |
| ) |
| |
| self.memory = LatentMemoryModule( |
| hidden_size=config.d_model, |
| memory_slots=config.memory_slots, |
| memory_dim=config.memory_dim, |
| use_triton=False, |
| ) |
| nn.init.constant_(self.memory.W_eta.bias, -1.0) |
| self.W_alpha = nn.Linear(config.d_model, 1) |
| self.C_to_hidden = nn.Linear(config.memory_dim, config.d_model, bias=False) |
| else: |
| raise ValueError(f"Unknown layer_type: {self.layer_type}") |
|
|
| |
| self.ln1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.ln1_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.ln2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.ln2_out = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
|
|
| |
| dense_layers = config.dense_input_layers |
| num_routed = config.num_routed_experts |
|
|
| if layer_idx < dense_layers or num_routed == 0: |
| self.is_moe = False |
| self.ffn = SwiGLUBlock(config.d_model, config.d_ff) |
| else: |
| self.is_moe = True |
| if config.moe_type == "bigmac": |
| self.ffn = BigMacMoE(config, layer_idx=layer_idx) |
| elif config.moe_type == "deepseek": |
| |
| self.ffn = BigMacMoE(config, layer_idx=layer_idx) |
| else: |
| self.ffn = GroupedMoE(config, layer_idx=layer_idx) |
|
|
| self.dropout = nn.Dropout(config.dropout) |
| self.scale_factor = 1.0 / math.sqrt(2 * self.n_layers) |
| self.residual_scale = config.residual_scale |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| trinity_std = 0.5 / math.sqrt(self.hidden_size) |
|
|
| if self.layer_type == "gla": |
| nn.init.constant_(self.W_alpha.bias, -10.0) |
| nn.init.zeros_(self.W_alpha.weight) |
| nn.init.normal_(self.C_to_hidden.weight, std=trinity_std) |
|
|
| def apply_deep_init(m): |
| if hasattr(m, 'down') and isinstance(m.down, nn.Linear): |
| nn.init.normal_(m.down.weight, mean=0.0, std=trinity_std * self.scale_factor) |
|
|
| if not self.is_moe: |
| nn.init.normal_(self.ffn.gate.weight, mean=0.0, std=trinity_std) |
| nn.init.normal_(self.ffn.up.weight, mean=0.0, std=trinity_std) |
| apply_deep_init(self.ffn) |
| else: |
| self.ffn._init_weights(std=trinity_std) |
| for expert in self.ffn.shared_experts: |
| apply_deep_init(expert) |
| nn.init.normal_(self.ffn.experts_w3, mean=0.0, std=trinity_std) |
|
|
| nn.init.constant_(self.ln1_out.weight, 1.0) |
| nn.init.constant_(self.ln2_out.weight, 1.0) |
|
|
| if hasattr(self.attn, 'o_proj') and isinstance(self.attn.o_proj, nn.Linear): |
| nn.init.normal_(self.attn.o_proj.weight, mean=0.0, std=trinity_std * self.scale_factor) |
| for proj_name in ['q_proj', 'k_proj', 'v_proj', 'g_proj']: |
| if hasattr(self.attn, proj_name): |
| m = getattr(self.attn, proj_name) |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=trinity_std) |
| elif isinstance(m, nn.Sequential): |
| for subm in m: |
| if isinstance(subm, nn.Linear): |
| nn.init.normal_(subm.weight, mean=0.0, std=trinity_std) |
|
|
| def forward(self, x, cos=None, sin=None, expert_bias=None, |
| memory_state=None, lambda_reg=0.01, **kwargs): |
| if self.use_looped_injection: |
| P = kwargs.get('P') |
| if P is not None: |
| x = x + (torch.sigmoid(self.injection_gate) * P) |
|
|
| if self.gradient_checkpointing and self.training: |
| return torch.utils.checkpoint.checkpoint( |
| self._forward, x, cos, sin, expert_bias, memory_state, lambda_reg, |
| use_reentrant=False, **kwargs, |
| ) |
| return self._forward(x, cos, sin, expert_bias, memory_state, lambda_reg, **kwargs) |
|
|
| def _forward(self, x, cos=None, sin=None, expert_bias=None, |
| memory_state=None, lambda_reg=0.01, **kwargs): |
| |
| residual = x |
| x = self.ln1(x) |
|
|
| |
| attn_kwargs = {} |
| if cos is not None and sin is not None: |
| attn_kwargs['cos'] = cos |
| attn_kwargs['sin'] = sin |
|
|
| |
| if 'past_key_values' in kwargs and kwargs['past_key_values'] is not None: |
| attn_kwargs['past_key_values'] = kwargs['past_key_values'] |
| if 'use_cache' in kwargs: |
| attn_kwargs['use_cache'] = kwargs['use_cache'] |
|
|
| attn_out = self.attn(x, **attn_kwargs) |
| if isinstance(attn_out, tuple): |
| attn_out = attn_out[0] |
|
|
| new_memory_state = None |
| mem_loss = torch.tensor(0.0, device=x.device) |
|
|
| |
| if self.layer_type == "gla" and memory_state is not None: |
| new_memory_state, total_mem_loss, _ = self.memory.write_memory(x, memory_state) |
| C = self.memory.read_memory(x, new_memory_state) |
| alpha = torch.sigmoid(self.W_alpha(x)) |
| C_proj = self.C_to_hidden(C) |
| attn_out = attn_out + (alpha * C_proj) |
| mem_loss = total_mem_loss |
|
|
| |
| x = residual + self.residual_scale * self.dropout(self.ln1_out(attn_out)) |
|
|
| |
| residual = x |
| x = self.ln2(x) |
| if self.is_moe: |
| block_out, aux_loss = self.ffn(x, expert_bias=expert_bias) |
| else: |
| block_out = self.ffn(x) |
| aux_loss = torch.tensor(0.0, device=x.device) |
|
|
| x = residual + self.residual_scale * self.dropout(self.ln2_out(block_out)) |
| return x, aux_loss, new_memory_state, mem_loss |
|
|
|
|
| |
| |
| |
| @dataclass |
| class QuasarModelOutputWithPast(BaseModelOutputWithPast): |
| memory_states: dict | None = None |
| memory_loss: torch.Tensor | None = None |
|
|
|
|
| @dataclass |
| class QuasarCausalLMOutputWithPast(CausalLMOutputWithPast): |
| memory_states: dict | None = None |
| memory_loss: torch.Tensor | None = None |
| aux_loss: torch.Tensor | None = None |
|
|
|
|
| |
| |
| |
| class QuasarPreTrainedModel(PreTrainedModel): |
| config_class = QuasarConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["HybridBlock"] |
| _supports_cache_class = True |
|
|
| def _init_weights(self, module): |
| std = getattr(self.config, "initializer_range", 0.02) |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| |
| |
| |
| |
| class QuasarModel(QuasarPreTrainedModel): |
| config: QuasarConfig |
|
|
| def __init__(self, config: QuasarConfig): |
| super().__init__(config) |
| self.config = config |
| d_model = config.d_model |
| n_heads = config.n_heads |
| n_layers = config.n_layers |
| vocab_size = config.vocab_size |
| max_seq_len = config.max_seq_len |
|
|
| self.embed_tokens = nn.Embedding(vocab_size, d_model) |
| self.embed_norm = RMSNorm(d_model, eps=config.rms_norm_eps) |
| self.layers = nn.ModuleList([ |
| HybridBlock(config, i) for i in range(n_layers) |
| ]) |
| self.norm = RMSNorm(d_model, eps=config.rms_norm_eps) |
| self.rotary_emb = RotaryEmbedding( |
| d_model // n_heads, max_seq_len, base=config.rope_theta, |
| ) |
|
|
| |
| self.moe_layer_ffns = [l.ffn for l in self.layers if getattr(l, 'is_moe', False)] |
| self.num_moe = len(self.moe_layer_ffns) |
| num_experts = config.num_routed_experts |
| if self.num_moe > 0 and num_experts > 0: |
| self.register_buffer("all_moe_bias", torch.zeros(self.num_moe, num_experts)) |
| self.register_buffer("all_moe_momentum", torch.zeros(self.num_moe, num_experts)) |
| self.register_buffer("all_moe_max_vio", torch.zeros(self.num_moe)) |
|
|
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def init_memory(self, batch_size, device, dtype=torch.float32): |
| memory_states = {} |
| for i, layer in enumerate(self.layers): |
| if layer.layer_type == "gla": |
| m = torch.zeros(batch_size, layer.memory.K, layer.memory.D, device=device, dtype=dtype) |
| memory_states[i] = m |
| return memory_states |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| memory_states: dict | None = None, |
| lambda_reg: float = 0.01, |
| **kwargs, |
| ): |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("Specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| hidden_states = self.embed_norm(inputs_embeds) |
| batch_size, seq_len, _ = hidden_states.shape |
|
|
| |
| if position_ids is None: |
| past_seen_tokens = 0 |
| if past_key_values is not None: |
| try: |
| past_seen_tokens = past_key_values.get_seq_length() |
| except Exception: |
| past_seen_tokens = 0 |
| position_ids = torch.arange(past_seen_tokens, past_seen_tokens + seq_len, device=hidden_states.device) |
|
|
| |
| max_pos = int(position_ids.max().item() + 1) if position_ids.numel() > 0 else seq_len |
| cos_full, sin_full = self.rotary_emb(hidden_states, seq_len=max_pos) |
| if position_ids.dim() == 1: |
| cos = cos_full[:, :, position_ids] |
| sin = sin_full[:, :, position_ids] |
| else: |
| cos = cos_full[:, :, position_ids[0]] |
| sin = sin_full[:, :, position_ids[0]] |
|
|
| |
| if memory_states is None: |
| memory_states = self.init_memory(batch_size, hidden_states.device, hidden_states.dtype) |
|
|
| all_hidden_states = () if output_hidden_states else None |
| aux_losses = [] |
| mem_losses = [] |
| new_memory_states = {} |
|
|
| |
| P = hidden_states |
| num_loops = self.config.num_loops |
| current_memory_states = memory_states |
|
|
| |
| if self.num_moe > 0: |
| bias_snapshot = self.all_moe_bias.detach().clone() |
| else: |
| bias_snapshot = None |
|
|
| for loop_idx in range(num_loops): |
| moe_idx = 0 |
| iteration_new_memory_states = {} |
| for layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| bias = bias_snapshot[moe_idx] if (getattr(layer, 'is_moe', False) and bias_snapshot is not None) else None |
|
|
| layer_out = layer( |
| hidden_states, |
| cos=cos, sin=sin, |
| expert_bias=bias, |
| memory_state=current_memory_states.get(layer.layer_idx), |
| lambda_reg=lambda_reg, |
| P=P if self.config.use_looped_injection else None, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| hidden_states, aux_loss, new_m, m_loss = layer_out |
| if new_m is not None: |
| iteration_new_memory_states[layer.layer_idx] = new_m |
| mem_losses.append(m_loss) |
| if bias is not None: |
| moe_idx += 1 |
| aux_losses.append(aux_loss) |
|
|
| current_memory_states = iteration_new_memory_states |
| new_memory_states = iteration_new_memory_states |
|
|
| |
| if self.training and self.num_moe > 0: |
| with torch.no_grad(): |
| self._update_all_moe_biases() |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| total_aux = torch.stack(aux_losses).sum() if aux_losses else torch.tensor(0.0, device=hidden_states.device) |
| total_mem = torch.stack(mem_losses).sum() if mem_losses else torch.tensor(0.0, device=hidden_states.device) |
|
|
| return QuasarModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=all_hidden_states, |
| memory_states=new_memory_states, |
| memory_loss=total_mem, |
| ), total_aux |
|
|
| def _update_all_moe_biases(self): |
| violations = torch.stack([m._pending_violation for m in self.moe_layer_ffns]) |
| m0 = self.moe_layer_ffns[0] |
| kappa, lamb, beta = m0.smebu_kappa, m0.smebu_lambda, m0.smebu_beta |
| clamped_update = torch.tanh(kappa * violations) |
| delta_bi = lamb * clamped_update |
| delta_bi = delta_bi - delta_bi.mean(dim=-1, keepdim=True) |
| self.all_moe_momentum.mul_(beta).add_(delta_bi, alpha=1 - beta) |
| self.all_moe_bias.add_(self.all_moe_momentum).nan_to_num_().clamp_(-10.0, 10.0) |
| self.all_moe_bias.sub_(self.all_moe_bias.mean(dim=-1, keepdim=True)) |
| current_max_vios = -violations.min(dim=-1).values |
| self.all_moe_max_vio.mul_(0.99).add_(current_max_vios, alpha=0.01) |
| for i, moe in enumerate(self.moe_layer_ffns): |
| moe.max_vio.copy_(self.all_moe_max_vio[i]) |
| del moe._pending_violation |
|
|
|
|
| |
| |
| |
| |
| class QuasarForCausalLM(QuasarPreTrainedModel, FLAGenerationMixin): |
| config: QuasarConfig |
| _tied_weights_keys = {} |
|
|
| def __init__(self, config: QuasarConfig): |
| super().__init__(config) |
| self.model = QuasarModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def tie_weights(self, missing_keys=None, recompute_mapping=False): |
| pass |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor | None = None, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.LongTensor | None = None, |
| past_key_values: Cache | None = None, |
| inputs_embeds: torch.FloatTensor | None = None, |
| labels: torch.LongTensor | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| memory_states: dict | None = None, |
| lambda_reg: float = 0.01, |
| return_dict: bool | None = None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| model_outputs, total_aux = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_hidden_states=output_hidden_states, |
| memory_states=memory_states, |
| lambda_reg=lambda_reg, |
| **kwargs, |
| ) |
|
|
| hidden_states = model_outputs.last_hidden_state |
|
|
| loss = None |
| if labels is not None: |
| shift_hidden = hidden_states[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| flat_hidden = shift_hidden.view(-1, self.config.d_model) |
| flat_labels = shift_labels.view(-1) |
| mask = flat_labels != -100 |
| if mask.any(): |
| active_hidden = flat_hidden[mask] |
| active_labels = flat_labels[mask] |
| chunk_size = 256 |
| total_loss = 0.0 |
| total_tokens = active_labels.numel() |
| for i in range(0, total_tokens, chunk_size): |
| end = min(i + chunk_size, total_tokens) |
| chunk_logits = self.lm_head(active_hidden[i:end]) |
| chunk_loss = F.cross_entropy(chunk_logits.float(), active_labels[i:end], reduction='sum') |
| total_loss += chunk_loss |
| loss = total_loss / total_tokens |
| loss = loss + total_aux + model_outputs.memory_loss |
| else: |
| loss = torch.tensor(0.0, device=hidden_states.device, requires_grad=True) |
| logits = None |
| else: |
| logits = self.lm_head(hidden_states) |
|
|
| if not return_dict: |
| output = (logits,) + model_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return QuasarCausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=model_outputs.past_key_values, |
| hidden_states=model_outputs.hidden_states, |
| memory_states=model_outputs.memory_states, |
| memory_loss=model_outputs.memory_loss, |
| aux_loss=total_aux, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| memory_states=None, |
| cache_position=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| if past_key_values is not None: |
| if input_ids is not None: |
| input_ids = input_ids[:, -1:] |
| if inputs_embeds is not None: |
| inputs_embeds = inputs_embeds[:, -1:] |
|
|
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| if memory_states is None and past_key_values is not None: |
| memory_states = getattr(past_key_values, "memory_states", None) |
|
|
| model_inputs.update({ |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "memory_states": memory_states, |
| }) |
| return model_inputs |
|
|
| def update_model_kwargs_for_generation(self, outputs, model_kwargs, is_seq2seq=False, num_new_tokens=1): |
| model_kwargs = super().update_model_kwargs_for_generation( |
| outputs=outputs, model_kwargs=model_kwargs, |
| is_seq2seq=is_seq2seq, num_new_tokens=num_new_tokens, |
| ) |
| if getattr(outputs, "memory_states", None) is not None: |
| model_kwargs["memory_states"] = outputs.memory_states |
| return model_kwargs |
|
|
| def _reorder_cache(self, past_key_values, beam_idx): |
| if past_key_values is None: |
| return None |
| return past_key_values.reorder_cache(beam_idx) |
|
|
|
|
| __all__ = [ |
| "QuasarConfig", |
| "QuasarPreTrainedModel", |
| "QuasarModel", |
| "QuasarForCausalLM", |
| "QuasarModelOutputWithPast", |
| "QuasarCausalLMOutputWithPast", |
| ] |
|
|