"""Self-contained modeling file for trust_remote_code use. This file merges mup_models.py and hf_wrapper.py into a single module with no imports from looped_scaling.*. It is intended to be placed alongside a config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it without requiring the looped_scaling package to be installed. Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer), "moe" (MoETransformer), "looped-moe" (LoopedMoETransformer). """ import torch import math import sys import torch.nn as nn import torch.nn.functional as F from collections.abc import Callable, Iterable from einops import rearrange, einsum, reduce, repeat from typing import IO, Any, BinaryIO, Optional from torch import Tensor from collections import Counter, defaultdict from torch.nn.functional import scaled_dot_product_attention as sdpa # for flash attention from torch.nn.functional import grouped_mm, silu from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast BASE_D_MODEL = 128 BASE_D_FF = 384 """ Standard Transformer and Components implemented with muP """ # --------------------------------------------------------------------------- # Numerically stable softmax (inlined from looped_scaling/utils.py) # --------------------------------------------------------------------------- def softmax(logits: Tensor, dim: int) -> Tensor: logits = logits.float() # get max values over specified dimension max_values = torch.max(logits, dim=dim, keepdim=True).values # subtract max_values from x so max element is 0 shifted = logits - max_values # broadcast should work # get exp of shifted terms shifted_exps = torch.exp(shifted) # get sum of shifted terms shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True) # calculate product product = shifted_exps / shifted_exp_sums return product # y = Wx (no bias terms!) class Linear(nn.Module): def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None): super().__init__() # Register parameter first so shape is always stored (required for HF meta-device loading) self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device)) # for muP, derive initial std deviation from given base model's std_deviation and width ratio std_scaled = std_base / math.sqrt(width_ratio) nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) def forward(self, x: Tensor) -> Tensor: # Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in" # on output side of einsum expression, so "... d_out" follows convention # to put the output dim last return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out") class Embedding(nn.Module): def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): super().__init__() # Register parameter first so shape is always stored (required for HF meta-device loading) self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)) # normalize the embeddings to spec nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3) def forward(self, token_ids: Tensor) -> Tensor: # for every id, we need to pull the row vector associated return self.weight[token_ids] class RMSNorm(nn.Module): def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None): super().__init__() # for muP no gain parameter on the rms self.d_model = d_model self.eps = eps def forward(self, x: Tensor) -> Tensor: # upcast input to torch.float32 in_dtype = x.dtype x = x.to(torch.float32) # calculate the RMS scalar # scalar for every ex. in batch, for every emb in sequence mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq") rms = torch.sqrt(mean_squared_sum + self.eps) # for muP, no gain on rms norm as is normally applied. rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d") # return result to original dtype return rms_norm.to(in_dtype) class PositionwiseFeedforward(nn.Module): # SwiGLU(x) = W2(SiLU(W1x)⊙W3x) def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None): super().__init__() # for muP, calculate the base model's standard deviation w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) # same for all W because d_model+d_ff = d_ff+d_model # initialize parameters of SWiGLU FFN self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype) self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) def forward(self, x: Tensor) -> Tensor: # FFN = W2*(SiLU(W1*X) dot W3X) silu_in = self.w1(x) silu_out = silu(silu_in) # silu_in * torch.sigmoid(silu_in) gate = self.w3(x) gated_prod = silu_out * gate final_prod = self.w2(gated_prod) return final_prod class RotaryPositionalEmbedding(nn.Module): def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None): """ theta: float Θ value for the RoPE d_k: int dimension of query and key vectors max_seq_len: int Maximum sequence length that will be inputted device: torch.device | None = None Device to store the buffer on """ super().__init__() rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype) # initialize rotation matrix for i in range(max_seq_len): for k in range(d_k//2): angle = i/(theta**(2*k/d_k)) rot = Tensor([[math.cos(angle), -math.sin(angle)], [math.sin(angle), math.cos(angle)]]) rotations[i, k, :] = rot self.register_buffer("rotations", rotations, persistent=True) def forward(self, x: Tensor, token_positions: Tensor) -> Tensor: """ self.rotations shape: (seq_dim, feature_dim, 2, 2) x: tensor of shape (..., seq_dim, feature_dim) token_positions: tensor of shape (..., seq_dim) """ # get the correct rotation matrices # by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim rot = self.rotations[token_positions].to(dtype=x.dtype) # match activation dtype (buffer is float32, activations may be bfloat16) # rearrange by every two elements along feature dim of input x x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2) # apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,) y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i") # reshape y_pairs back to original shape y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)") return y def scaled_dot_product_attention( Q: Tensor, K: Tensor, V: Tensor, mask: Optional[Tensor] = None, ) -> Tensor: """ Given key (K), query (Q), and value (V) tensors, return the output of your scaled dot product attention implementation. Args: let m be seq length of inputs, n be seq length of outputs d_k is look-up dim, d_v is value dim Q (Float[Tensor, "batch ... n d_k"]): Query tensor K (Float[Tensor, "batch ... m d_k"]): Key tensor V (Float[Tensor, "batch ... m d_v"]): Values tensor mask (Float[Tensor, " ... n m"] | None): Mask tensor Returns: Float[Tensor, " ... n d_v"]: Output of SDPA """ # get the key feature dim (should be last dim of Q and K) d_k = Q.shape[-1] assert d_k == K.shape[-1] # calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k) scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k # apply the mask if there is one if mask is not None: bool_mask = mask.bool() # compatible if somehow, input is mask bool or if float attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype) scores = scores + attn_mask # calculate the weighted weights = softmax(scores, dim=-1) # the softmax should be taken over the m inputs at an i'th output pos. # return weights@V return einsum(weights, V, "... n m, ... m d_v -> ... n d_v") class MultiheadSelfAttention(nn.Module): """ Args: d_model (int): Dimensionality of the feedforward input and output. num_heads (int): Number of heads to use in multi-headed attention. max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that. q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on. Returns: Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention implementation with the given QKV projection weights and input features. """ def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None): super().__init__() # initialize the multi-head self attention weights as 1 large matrix (which will be sliced) assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})" self.d_model = d_model self.num_heads = num_heads # for muP, calculate standard deviation of base model attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL)) # for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) # # Removed for torch sdpa, uncomment if using normal code # if max_seq_len: # causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device)) # self.register_buffer("causal_mask", causal_mask, persistent=False) # else: # self.register_buffer("causal_mask", None, persistent=False) assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE." if theta: d_k = d_model//num_heads self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype) else: self.rope = None def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: # get Q, K, V matrices Q = self.q_proj(x) # output shape is [batch seq d_model] K = self.k_proj(x) V = self.v_proj(x) # #create causal mask intepreting the second to last dim as seq dim # if self.causal_mask is None: # seq_dim = x.shape[-2] # cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device)) # else: # # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len) # seq_dim = x.shape[-2] # cmask = self.causal_mask[:seq_dim, :seq_dim] # get slice size for multi-head self attention d_k = self.d_model // self.num_heads d_v = self.d_model // self.num_heads q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) # apply RoPE to q_heads and k_heads if self.rope: seq_dim = x.shape[-2] # x is (b,s,d) if token_positions is None: token_positions = torch.arange(seq_dim, device=x.device) token_positions = rearrange(token_positions, "seq -> 1 seq") # 1 seq allows broadcast across batch dim q_heads = self.rope(q_heads, token_positions) k_heads = self.rope(k_heads, token_positions) v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v) #mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask) mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k) mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)") # apply o_proj_weight to the concatenated multi-head attention product out = self.output_proj(mha) return out class PrenormBlock(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, max_seq_len: int, theta: float, width_ratio: float, device=None, dtype=None): super().__init__() # norm layer self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) # mhsa with rope self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) # add step # norm layer self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) # positionwise feed forward self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype) # add to output def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: # first Tx operation, Norm + MHSA w/ RoPE norm1_out = self.ln1(x) # we may have to define token_positions if it is not given attn_out = self.attn(norm1_out, token_positions) # ensure no broadcasting, elementwise addition on [batch seq d_model] assert(x.shape == attn_out.shape) resid1_out = attn_out + x # second Tx operation, Norm + SwiGLU norm2_out = self.ln2(resid1_out) ffn_out = self.ffn(norm2_out) # ensure no broadcasting, elementwise addition assert(ffn_out.shape == resid1_out.shape) final_out = resid1_out + ffn_out return final_out class MuTransformer(nn.Module): def __init__( self, vocab_size: int, context_length: int, d_model: int, num_layers: int, num_heads: int, d_ff: int, rope_theta: float, width_ratio: float = 1.0, weight_tying: bool = False, device=None, dtype=None): super().__init__() self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) self.weight_tying = weight_tying if weight_tying: self.lm_head = self.token_embeddings.weight else: std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) self.width_ratio = width_ratio def forward(self, x: Tensor) -> Tensor: # 1. token embed step, no muP alpha_in x = self.token_embeddings(x) # 2. prenorm blocks step for layer in self.layers: x = layer(x) # 3. Final norm x = self.ln_final(x) # 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling if self.weight_tying: x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio else: x = self.lm_head(x) # 5. return output, no muP alpha_out return x """ Looped Language Models implemented with MuP """ class LoopedStack(nn.Module): def __init__( self, context_length: int, d_model: int, num_layers_in_stack: int, num_heads: int, d_ff: int, rope_theta: float, width_ratio: float = 1.0, mixture_of_experts: bool = False, num_experts: Optional[int] = None, num_active: Optional[int] = None, device=None, dtype=None): super().__init__() if mixture_of_experts: # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active, # context_length,rope_theta,width_ratio,device,dtype) # for _ in range(num_layers_in_stack)]) self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers_in_stack)]) else: self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers_in_stack)]) self.mixture_of_experts = mixture_of_experts def forward(self, x: Tensor) -> Tensor: # prenorm blocks step if self.mixture_of_experts: lb_total = 0 lz_total = 0 # sum up load balancing and z-losses across each layer for layer in self.layers: x, lb, lz = layer(x) lb_total += lb lz_total += lz return x, lb_total, lz_total else: for layer in self.layers: x = layer(x) return x class LoopedTransformer(nn.Module): def __init__( self, vocab_size: int, context_length: int, d_model: int, num_layers_in_stack: int, num_stacks: int, num_heads: int, d_ff: int, rope_theta: float, width_ratio: float = 1.0, weight_tying: bool = False, device=None, dtype=None): super().__init__() self.num_stacks = num_stacks self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype) self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) self.weight_tying = weight_tying self.width_ratio = width_ratio if weight_tying: self.lm_head = self.token_embeddings.weight else: std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype) def forward(self, x: Tensor) -> Tensor: # token embed step x = self.token_embeddings(x) # repeated calls to stack for i in range(self.num_stacks): x = self.stack(x) # final norm x = self.ln_final(x) # Vocab projection or lm_head if self.weight_tying: x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio else: x = self.lm_head(x) return x """ Mixture-of-Experts Implementation in muP """ # Router Class class Router(nn.Module): def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None): super().__init__() # router is simply a linear layer. we initialize (d_in, d_out) according to my code std_base = math.sqrt(2/(BASE_D_MODEL+num_experts)) self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) # adjusted for muP self.num_active = num_active def forward(self, x: Tensor): # returns scores, top_k_scores, top_k_indices logits = self.gate(x) # should be shape (batch, seq, n_routers) # probs probs = softmax(logits, dim=-1) # get top_k top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1) # renormalize the top scores so weighted sum of expert products can be taken score_sums = torch.sum(top_scores, dim=-1, keepdim=True) # (batch, seq) top_scores = top_scores/score_sums return logits, probs, top_scores, top_experts class MoEPrenormBlock(nn.Module): def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): super().__init__() # norm layer before mHSA+RoPE self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) # mhsa with rope self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) # norm layer before position-wise feedfoward self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) # router self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) # save MoE hyperparams self.num_experts = num_experts self.num_active = num_active # initialize MoE FFNs as a module list d_ff_expert = d_ff // num_active self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) # adjusted for muP def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: # input dims batch, seq, dim = x.shape # first Tx operation, Norm + MHSA w/ RoPE norm1_out = self.ln1(x) # we may have to define token_positions if it is not given attn_out = self.attn(norm1_out, token_positions) # ensure no broadcasting, elementwise addition on [batch seq d_model] assert(x.shape == attn_out.shape) resid1_out = attn_out + x # prenorm before position-wise feedforward norm2_out = self.ln2(resid1_out) # get scores from Router. returns shape (batch,seq,k) logits, probs, top_scores, top_experts = self.router(norm2_out) # logits and probs are (batch, seq, n_routers) expert_mean_probs = torch.mean(probs, dim=(0, 1)) # take mean across batch and seq dims # apply mixture of experts experts_out = torch.zeros_like(norm2_out) # copies shape, device and dtype total_tokens_assigned = batch*seq*self.num_active lb_sum = 0 for expert_idx in range(self.num_experts): # get masks for expert selection expert_mask = (top_experts == expert_idx) embed_mask = expert_mask.any(dim=-1) # if any of the k is expert, we want to transform embed if not embed_mask.any(): continue pi = expert_mean_probs[expert_idx].item() fi = (expert_mask.sum().item())/total_tokens_assigned # num embeds assigned to expert in batch lb_sum += fi*pi # extract embeds and weights for activated experts weights = top_scores[expert_mask] # (num_embeds) expert_embeds = norm2_out[embed_mask] # (num_embeds, hidden_dim) # forward for the correct experts expert_out = self.experts[expert_idx](expert_embeds) # Vanilla Implementation # map back to experts output experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out # broadcast elementwise multiply by hidden dim # calculate batch's load balancing loss lb = self.num_experts*lb_sum # calculate batch's router z loss logsumexp = torch.logsumexp(logits.float(), dim=-1) lz = torch.mean(logsumexp ** 2) # ensure no broadcasting, elementwise addition assert(experts_out.shape == resid1_out.shape) final_out = resid1_out + experts_out return final_out, lb, lz class GroupedMoEPrenormBlock(nn.Module): @staticmethod def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter: w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) # (batch, in, out) std_scaled = std_base / math.sqrt(width_ratio) nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) return nn.Parameter(w) def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): super().__init__() # norm layer before mHSA+RoPE self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) # mhsa with rope self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) # norm layer before position-wise feedfoward self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) # router self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) # save MoE hyperparams self.num_experts = num_experts self.num_active = num_active # initialize MoE FFNs as a module list d_ff_expert = d_ff // num_active # expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF)) self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype) self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: batch, seq, dim = x.shape total_tokens = batch * seq # first Tx operation, Norm + MHSA w/ RoPE norm1_out = self.ln1(x) attn_out = self.attn(norm1_out, token_positions) assert(x.shape == attn_out.shape) resid1_out = attn_out + x # prenorm before position-wise feedforward norm2_out = self.ln2(resid1_out) # get scores from Router. returns shape (batch, seq, k) logits, probs, top_scores, top_experts = self.router(norm2_out) # flatten to 2D for grouped_mm x_flat = rearrange(norm2_out, 'b s d -> (b s) d') # (total_tokens, d_model) flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') # (total_tokens * k,) flat_scores = rearrange(top_scores, 'b s k -> (b s k)') # (total_tokens * k,) flat_positions = torch.arange(total_tokens, device=x.device) # (total_tokens) flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) # (total_tokens * k) # sort by expert sort_indices = flat_expert_ids.argsort(stable=True) sorted_expert_ids = flat_expert_ids[sort_indices] sorted_token_ids = flat_token_ids[sort_indices] sorted_scores = flat_scores[sort_indices] sorted_x = x_flat[sorted_token_ids] # (total_tokens * k, d_model) # build offs (cumulative token counts per expert) counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts) offs = counts.cumsum(0).to(torch.int32) # (num_experts,) # grouped SwiGLU: W2(SiLU(W1 x) dot W3 x) h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs) h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs) gated = silu(h1) * h3 expert_out = grouped_mm(gated, self.experts_w2, offs=offs) # (total_tokens * k, d_model) # weight by router scores and scatter-add back expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d') output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype) output_flat.index_add_(0, sorted_token_ids, expert_out) # reshape back to (batch, seq, d_model) experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq) # aux losses fi = counts.float() / (total_tokens * self.num_active) pi = reduce(probs, 'b s e -> e', 'mean') lb = self.num_experts * einsum(fi, pi, 'e, e ->') logsumexp = torch.logsumexp(logits.float(), dim=-1) lz = reduce(logsumexp ** 2, '... -> ', 'mean') # residual connection assert(experts_out.shape == resid1_out.shape) final_out = resid1_out + experts_out return final_out, lb, lz # MoE Implementation class MoETransformer(nn.Module): def __init__( self, vocab_size: int, context_length: int, d_model: int, num_layers: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, rope_theta: float, width_ratio: float = 1.0, device=None, dtype=None): super().__init__() self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) self.num_layers = num_layers # self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active, # context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)]) self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) # only non-tied embeddings now std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) def forward(self, x: Tensor) -> Tensor: # collect aux losses lb_total = 0 lz_total = 0 # 1. token embed step x = self.token_embeddings(x) # 2. prenorm blocks step for layer in self.layers: x, lb, lz = layer(x) lb_total += lb lz_total += lz # 3. Final norm x = self.ln_final(x) # 4. Vocab projection or lm_head x = self.lm_head(x) # calculate average layer aux loss lb_avg = lb_total / self.num_layers lz_avg = lz_total / self.num_layers return x, lb_avg, lz_avg class LoopedMoETransformer(nn.Module): def __init__( self, vocab_size: int, context_length: int, d_model: int, num_layers_in_stack: int, num_stacks: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, rope_theta: float, width_ratio: float, device=None, dtype=None): super().__init__() self.stack_depth = num_stacks self.total_layers = num_stacks*num_layers_in_stack self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, mixture_of_experts=True, num_experts=num_experts, num_active=num_active, device=device, dtype=dtype) # parameters for loop with MoE self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) # scale lm head std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) def forward(self, x: Tensor) -> Tensor: # collect aux losses lb_total = 0 lz_total = 0 # token embed step x = self.token_embeddings(x) # repeated calls to stack for i in range(self.stack_depth): x, lb, lz = self.stack(x) lb_total += lb lz_total += lz # final norm x = self.ln_final(x) # Vocab projection or lm_head x = self.lm_head(x) # calculate aux loss averages lb_avg = lb_total / self.total_layers lz_avg = lz_total / self.total_layers return x, lb_avg, lz_avg # --------------------------------------------------------------------------- # HuggingFace wrapper (from hf_wrapper.py) # --------------------------------------------------------------------------- class LoopLMConfig(PretrainedConfig): """Config for all four loop-lm model variants.""" model_type = "loop-lm" def __init__( self, # which of the four architectures to use model_variant: str = "base", # "base" | "looped" | "moe" | "looped-moe" # shared vocab_size: int = 50257, context_length: int = 1024, d_model: int = 1024, num_heads: int = 16, d_ff: int = 2752, rope_theta: float = 10000.0, width_ratio: float = 8.0, # d_model / base_d_model (128); set at training time # base + moe only num_layers: int = 16, # base + looped only weight_tying: bool = False, # looped + looped-moe only num_layers_in_stack: int = 8, num_stacks: int = 2, # moe + looped-moe only num_experts: int = 8, num_active: int = 2, # aux loss weights — used when forward() is called with labels lb_loss_factor: float = 0.01, lz_loss_factor: float = 0.001, **kwargs, ): super().__init__(**kwargs) self.model_variant = model_variant self.vocab_size = vocab_size self.context_length = context_length self.d_model = d_model self.num_heads = num_heads self.d_ff = d_ff self.rope_theta = rope_theta self.width_ratio = width_ratio self.num_layers = num_layers self.weight_tying = weight_tying self.num_layers_in_stack = num_layers_in_stack self.num_stacks = num_stacks self.num_experts = num_experts self.num_active = num_active self.lb_loss_factor = lb_loss_factor self.lz_loss_factor = lz_loss_factor # lm-evaluation-harness looks for this attribute to cap sequence length self.max_length = context_length class LoopLMForCausalLM(PreTrainedModel, GenerationMixin): """Causal LM wrapper over all four looped-scaling variants. Implements the HuggingFace PreTrainedModel interface so you can: - Upload/download via push_to_hub / from_pretrained - Run lm-evaluation-harness evals - Fine-tune with TRL's SFTTrainer / DPOTrainer """ config_class = LoopLMConfig # tell HF which parameter holds the output logits for generation _keys_to_ignore_on_load_missing = [] def __init__(self, config: LoopLMConfig): super().__init__(config) self.model = self._build_inner_model(config) self.post_init() # ------------------------------------------------------------------ # Model construction # ------------------------------------------------------------------ def _build_inner_model(self, config: LoopLMConfig): kw = dict( vocab_size=config.vocab_size, context_length=config.context_length, d_model=config.d_model, num_heads=config.num_heads, d_ff=config.d_ff, rope_theta=config.rope_theta, width_ratio=config.width_ratio, # device=None so weights are placed on CPU; caller uses .to(device) ) v = config.model_variant if v == "base": return MuTransformer( **kw, num_layers=config.num_layers, weight_tying=config.weight_tying, ) elif v == "looped": return LoopedTransformer( **kw, num_layers_in_stack=config.num_layers_in_stack, num_stacks=config.num_stacks, weight_tying=config.weight_tying, ) elif v == "moe": return MoETransformer( **kw, num_layers=config.num_layers, num_experts=config.num_experts, num_active=config.num_active, ) elif v == "looped-moe": return LoopedMoETransformer( **kw, num_layers_in_stack=config.num_layers_in_stack, num_stacks=config.num_stacks, num_experts=config.num_experts, num_active=config.num_active, ) else: raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe") # ------------------------------------------------------------------ # Embedding access (required by some HF utilities) # ------------------------------------------------------------------ def get_input_embeddings(self): return self.model.token_embeddings def set_input_embeddings(self, value): self.model.token_embeddings = value # ------------------------------------------------------------------ # Forward # ------------------------------------------------------------------ def forward( self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, # causal mask is handled internally labels: Optional[torch.LongTensor] = None, **kwargs, ) -> CausalLMOutputWithPast: """ Args: input_ids: (batch, seq) attention_mask: ignored — models use a built-in causal mask labels: (batch, seq) token ids; if provided, returns cross-entropy loss. For MoE variants, aux losses (lb + lz) are added to the CE loss. """ is_moe = self.config.model_variant in ("moe", "looped-moe") if is_moe: logits, lb, lz = self.model(input_ids) else: logits = self.model(input_ids) lb = lz = 0.0 loss = None if labels is not None: ce_loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ) aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz loss = ce_loss + aux if self.training else ce_loss return CausalLMOutputWithPast( loss=loss, logits=logits, ) # ------------------------------------------------------------------ # Generation support (no KV cache — generation is correct but slow) # ------------------------------------------------------------------ def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}