| | """Self-contained modeling file for trust_remote_code use. |
| | |
| | This file merges mup_models.py and hf_wrapper.py into a single module with no |
| | imports from looped_scaling.*. It is intended to be placed alongside a |
| | config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that |
| | HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it |
| | without requiring the looped_scaling package to be installed. |
| | |
| | Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer), |
| | "moe" (MoETransformer), "looped-moe" (LoopedMoETransformer). |
| | """ |
| |
|
| | import torch |
| | import math |
| | import sys |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from collections.abc import Callable, Iterable |
| | from einops import rearrange, einsum, reduce, repeat |
| | from typing import IO, Any, BinaryIO, Optional |
| | from torch import Tensor |
| | from collections import Counter, defaultdict |
| | from torch.nn.functional import scaled_dot_product_attention as sdpa |
| | from torch.nn.functional import grouped_mm, silu |
| | from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | BASE_D_MODEL = 128 |
| | BASE_D_FF = 384 |
| |
|
| | """ Standard Transformer and Components implemented with muP """ |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def softmax(logits: Tensor, dim: int) -> Tensor: |
| | logits = logits.float() |
| | |
| | max_values = torch.max(logits, dim=dim, keepdim=True).values |
| |
|
| | |
| | shifted = logits - max_values |
| |
|
| | |
| | shifted_exps = torch.exp(shifted) |
| |
|
| | |
| | shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True) |
| |
|
| | |
| | product = shifted_exps / shifted_exp_sums |
| |
|
| | return product |
| |
|
| |
|
| | |
| | class Linear(nn.Module): |
| | def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None): |
| | super().__init__() |
| |
|
| | |
| | self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device)) |
| |
|
| | |
| | std_scaled = std_base / math.sqrt(width_ratio) |
| | nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | |
| | |
| | return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out") |
| |
|
| | class Embedding(nn.Module): |
| | def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None): |
| | super().__init__() |
| |
|
| | |
| | self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device)) |
| |
|
| | |
| | nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3) |
| |
|
| | def forward(self, token_ids: Tensor) -> Tensor: |
| | |
| | return self.weight[token_ids] |
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None): |
| | super().__init__() |
| |
|
| | |
| | self.d_model = d_model |
| | self.eps = eps |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | in_dtype = x.dtype |
| | x = x.to(torch.float32) |
| |
|
| | |
| | |
| | mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq") |
| | rms = torch.sqrt(mean_squared_sum + self.eps) |
| |
|
| | |
| | rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d") |
| |
|
| | |
| | return rms_norm.to(in_dtype) |
| |
|
| | class PositionwiseFeedforward(nn.Module): |
| | |
| | def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None): |
| | super().__init__() |
| |
|
| | |
| | w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) |
| |
|
| | |
| | self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) |
| | self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype) |
| | self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | silu_in = self.w1(x) |
| | silu_out = silu(silu_in) |
| | gate = self.w3(x) |
| | gated_prod = silu_out * gate |
| | final_prod = self.w2(gated_prod) |
| | return final_prod |
| |
|
| | class RotaryPositionalEmbedding(nn.Module): |
| | def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None): |
| | """ |
| | theta: float Θ value for the RoPE |
| | d_k: int dimension of query and key vectors |
| | max_seq_len: int Maximum sequence length that will be inputted |
| | device: torch.device | None = None Device to store the buffer on |
| | """ |
| | super().__init__() |
| | rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype) |
| |
|
| | |
| | for i in range(max_seq_len): |
| | for k in range(d_k//2): |
| | angle = i/(theta**(2*k/d_k)) |
| | rot = Tensor([[math.cos(angle), -math.sin(angle)], |
| | [math.sin(angle), math.cos(angle)]]) |
| | rotations[i, k, :] = rot |
| |
|
| | self.register_buffer("rotations", rotations, persistent=True) |
| |
|
| |
|
| | def forward(self, x: Tensor, token_positions: Tensor) -> Tensor: |
| | """ |
| | self.rotations shape: (seq_dim, feature_dim, 2, 2) |
| | x: tensor of shape (..., seq_dim, feature_dim) |
| | token_positions: tensor of shape (..., seq_dim) |
| | """ |
| | |
| | |
| | rot = self.rotations[token_positions].to(dtype=x.dtype) |
| |
|
| | |
| | x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2) |
| |
|
| | |
| | y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i") |
| |
|
| | |
| | y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)") |
| |
|
| | return y |
| |
|
| | def scaled_dot_product_attention( |
| | Q: Tensor, |
| | K: Tensor, |
| | V: Tensor, |
| | mask: Optional[Tensor] = None, |
| | ) -> Tensor: |
| | """ |
| | Given key (K), query (Q), and value (V) tensors, return |
| | the output of your scaled dot product attention implementation. |
| | |
| | Args: |
| | let m be seq length of inputs, n be seq length of outputs |
| | d_k is look-up dim, d_v is value dim |
| | Q (Float[Tensor, "batch ... n d_k"]): Query tensor |
| | K (Float[Tensor, "batch ... m d_k"]): Key tensor |
| | V (Float[Tensor, "batch ... m d_v"]): Values tensor |
| | mask (Float[Tensor, " ... n m"] | None): Mask tensor |
| | Returns: |
| | Float[Tensor, " ... n d_v"]: Output of SDPA |
| | """ |
| |
|
| | |
| | d_k = Q.shape[-1] |
| | assert d_k == K.shape[-1] |
| |
|
| | |
| | scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k |
| |
|
| | |
| | if mask is not None: |
| | bool_mask = mask.bool() |
| | attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype) |
| | scores = scores + attn_mask |
| |
|
| | |
| | weights = softmax(scores, dim=-1) |
| |
|
| | |
| | return einsum(weights, V, "... n m, ... m d_v -> ... n d_v") |
| |
|
| | class MultiheadSelfAttention(nn.Module): |
| | """ |
| | Args: |
| | d_model (int): Dimensionality of the feedforward input and output. |
| | num_heads (int): Number of heads to use in multi-headed attention. |
| | max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that. |
| | q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection |
| | k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection |
| | v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection |
| | o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection |
| | in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on. |
| | |
| | Returns: |
| | Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention |
| | implementation with the given QKV projection weights and input features. |
| | """ |
| | def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None): |
| | super().__init__() |
| |
|
| | |
| | assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})" |
| |
|
| | self.d_model = d_model |
| | self.num_heads = num_heads |
| |
|
| | |
| | attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL)) |
| |
|
| | |
| | self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| | self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| | self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| | self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE." |
| |
|
| | if theta: |
| | d_k = d_model//num_heads |
| | self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype) |
| | else: |
| | self.rope = None |
| |
|
| | def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| | |
| | Q = self.q_proj(x) |
| | K = self.k_proj(x) |
| | V = self.v_proj(x) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | d_k = self.d_model // self.num_heads |
| | d_v = self.d_model // self.num_heads |
| |
|
| | q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) |
| | k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k) |
| |
|
| | |
| | if self.rope: |
| | seq_dim = x.shape[-2] |
| | if token_positions is None: |
| | token_positions = torch.arange(seq_dim, device=x.device) |
| | token_positions = rearrange(token_positions, "seq -> 1 seq") |
| |
|
| | q_heads = self.rope(q_heads, token_positions) |
| | k_heads = self.rope(k_heads, token_positions) |
| |
|
| | v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v) |
| |
|
| | |
| | mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k) |
| | mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)") |
| |
|
| | |
| | out = self.output_proj(mha) |
| |
|
| | return out |
| |
|
| | class PrenormBlock(nn.Module): |
| | def __init__(self, |
| | d_model: int, |
| | num_heads: int, |
| | d_ff: int, |
| | max_seq_len: int, |
| | theta: float, |
| | width_ratio: float, |
| | device=None, |
| | dtype=None): |
| | super().__init__() |
| | |
| | self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
| | |
| | self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
| | |
| | |
| | self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
| | |
| | self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype) |
| | |
| |
|
| | def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| |
|
| | |
| | norm1_out = self.ln1(x) |
| | |
| | attn_out = self.attn(norm1_out, token_positions) |
| |
|
| | |
| | assert(x.shape == attn_out.shape) |
| | resid1_out = attn_out + x |
| |
|
| | |
| | norm2_out = self.ln2(resid1_out) |
| | ffn_out = self.ffn(norm2_out) |
| |
|
| | |
| | assert(ffn_out.shape == resid1_out.shape) |
| | final_out = resid1_out + ffn_out |
| | return final_out |
| |
|
| | class MuTransformer(nn.Module): |
| | def __init__( |
| | self, vocab_size: int, |
| | context_length: int, |
| | d_model: int, |
| | num_layers: int, |
| | num_heads: int, |
| | d_ff: int, |
| | rope_theta: float, |
| | width_ratio: float = 1.0, |
| | weight_tying: bool = False, |
| | device=None, dtype=None): |
| | super().__init__() |
| | self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| | self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) |
| | self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| | self.weight_tying = weight_tying |
| | if weight_tying: |
| | self.lm_head = self.token_embeddings.weight |
| | else: |
| | std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| | self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
| | self.width_ratio = width_ratio |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | x = self.token_embeddings(x) |
| |
|
| | |
| | for layer in self.layers: |
| | x = layer(x) |
| |
|
| | |
| | x = self.ln_final(x) |
| |
|
| | |
| | if self.weight_tying: |
| | x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio |
| | else: |
| | x = self.lm_head(x) |
| |
|
| | |
| | return x |
| |
|
| | """ Looped Language Models implemented with MuP """ |
| |
|
| | class LoopedStack(nn.Module): |
| | def __init__( |
| | self, |
| | context_length: int, |
| | d_model: int, |
| | num_layers_in_stack: int, |
| | num_heads: int, |
| | d_ff: int, |
| | rope_theta: float, |
| | width_ratio: float = 1.0, |
| | mixture_of_experts: bool = False, |
| | num_experts: Optional[int] = None, |
| | num_active: Optional[int] = None, |
| | device=None, dtype=None): |
| | super().__init__() |
| | if mixture_of_experts: |
| | |
| | |
| | |
| | self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, |
| | context_length, rope_theta, width_ratio, device, dtype) |
| | for _ in range(num_layers_in_stack)]) |
| | else: |
| | self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, |
| | width_ratio, device, dtype) for _ in range(num_layers_in_stack)]) |
| | self.mixture_of_experts = mixture_of_experts |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | if self.mixture_of_experts: |
| | lb_total = 0 |
| | lz_total = 0 |
| | |
| | for layer in self.layers: |
| | x, lb, lz = layer(x) |
| | lb_total += lb |
| | lz_total += lz |
| | return x, lb_total, lz_total |
| | else: |
| | for layer in self.layers: |
| | x = layer(x) |
| | return x |
| |
|
| | class LoopedTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | vocab_size: int, |
| | context_length: int, |
| | d_model: int, |
| | num_layers_in_stack: int, |
| | num_stacks: int, |
| | num_heads: int, |
| | d_ff: int, |
| | rope_theta: float, |
| | width_ratio: float = 1.0, |
| | weight_tying: bool = False, |
| | device=None, dtype=None): |
| | super().__init__() |
| | self.num_stacks = num_stacks |
| |
|
| | self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| | self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype) |
| | self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| | self.weight_tying = weight_tying |
| | self.width_ratio = width_ratio |
| |
|
| | if weight_tying: |
| | self.lm_head = self.token_embeddings.weight |
| | else: |
| | std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| | self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | x = self.token_embeddings(x) |
| |
|
| | |
| | for i in range(self.num_stacks): |
| | x = self.stack(x) |
| |
|
| | |
| | x = self.ln_final(x) |
| |
|
| | |
| | if self.weight_tying: |
| | x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio |
| | else: |
| | x = self.lm_head(x) |
| |
|
| | return x |
| |
|
| | """ Mixture-of-Experts Implementation in muP """ |
| |
|
| | |
| | class Router(nn.Module): |
| | def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None): |
| | super().__init__() |
| | |
| | std_base = math.sqrt(2/(BASE_D_MODEL+num_experts)) |
| | self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) |
| | self.num_active = num_active |
| |
|
| | def forward(self, x: Tensor): |
| | |
| | logits = self.gate(x) |
| |
|
| | |
| | probs = softmax(logits, dim=-1) |
| |
|
| | |
| | top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1) |
| |
|
| | |
| | score_sums = torch.sum(top_scores, dim=-1, keepdim=True) |
| | top_scores = top_scores/score_sums |
| |
|
| | return logits, probs, top_scores, top_experts |
| |
|
| | class MoEPrenormBlock(nn.Module): |
| | def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, |
| | max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): |
| | super().__init__() |
| | |
| | self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
| |
|
| | |
| | self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) |
| |
|
| | |
| | self.num_experts = num_experts |
| | self.num_active = num_active |
| |
|
| | |
| | d_ff_expert = d_ff // num_active |
| | self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) |
| |
|
| | def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| | |
| | batch, seq, dim = x.shape |
| |
|
| | |
| | norm1_out = self.ln1(x) |
| | |
| | attn_out = self.attn(norm1_out, token_positions) |
| |
|
| | |
| | assert(x.shape == attn_out.shape) |
| | resid1_out = attn_out + x |
| |
|
| | |
| | norm2_out = self.ln2(resid1_out) |
| |
|
| | |
| | logits, probs, top_scores, top_experts = self.router(norm2_out) |
| | expert_mean_probs = torch.mean(probs, dim=(0, 1)) |
| |
|
| | |
| | experts_out = torch.zeros_like(norm2_out) |
| | total_tokens_assigned = batch*seq*self.num_active |
| | lb_sum = 0 |
| |
|
| | for expert_idx in range(self.num_experts): |
| | |
| | expert_mask = (top_experts == expert_idx) |
| | embed_mask = expert_mask.any(dim=-1) |
| | if not embed_mask.any(): continue |
| | pi = expert_mean_probs[expert_idx].item() |
| | fi = (expert_mask.sum().item())/total_tokens_assigned |
| | lb_sum += fi*pi |
| |
|
| | |
| | weights = top_scores[expert_mask] |
| | expert_embeds = norm2_out[embed_mask] |
| |
|
| | |
| | expert_out = self.experts[expert_idx](expert_embeds) |
| |
|
| | |
| | experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out |
| |
|
| | |
| | lb = self.num_experts*lb_sum |
| |
|
| | |
| | logsumexp = torch.logsumexp(logits.float(), dim=-1) |
| | lz = torch.mean(logsumexp ** 2) |
| |
|
| | |
| | assert(experts_out.shape == resid1_out.shape) |
| | final_out = resid1_out + experts_out |
| | return final_out, lb, lz |
| |
|
| |
|
| | class GroupedMoEPrenormBlock(nn.Module): |
| | @staticmethod |
| | def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter: |
| | w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) |
| | std_scaled = std_base / math.sqrt(width_ratio) |
| | nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled) |
| | return nn.Parameter(w) |
| |
|
| | def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int, |
| | max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None): |
| | super().__init__() |
| | |
| | self.ln1 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype) |
| |
|
| | |
| | self.ln2 = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype) |
| |
|
| | |
| | self.num_experts = num_experts |
| | self.num_active = num_active |
| |
|
| | |
| | d_ff_expert = d_ff // num_active |
| |
|
| | |
| | w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF)) |
| | self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) |
| | self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype) |
| | self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype) |
| |
|
| | def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor: |
| | batch, seq, dim = x.shape |
| | total_tokens = batch * seq |
| |
|
| | |
| | norm1_out = self.ln1(x) |
| | attn_out = self.attn(norm1_out, token_positions) |
| |
|
| | assert(x.shape == attn_out.shape) |
| | resid1_out = attn_out + x |
| |
|
| | |
| | norm2_out = self.ln2(resid1_out) |
| |
|
| | |
| | logits, probs, top_scores, top_experts = self.router(norm2_out) |
| |
|
| | |
| | x_flat = rearrange(norm2_out, 'b s d -> (b s) d') |
| | flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') |
| | flat_scores = rearrange(top_scores, 'b s k -> (b s k)') |
| | flat_positions = torch.arange(total_tokens, device=x.device) |
| | flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) |
| |
|
| | |
| | sort_indices = flat_expert_ids.argsort(stable=True) |
| | sorted_expert_ids = flat_expert_ids[sort_indices] |
| | sorted_token_ids = flat_token_ids[sort_indices] |
| | sorted_scores = flat_scores[sort_indices] |
| | sorted_x = x_flat[sorted_token_ids] |
| |
|
| | |
| | counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts) |
| | offs = counts.cumsum(0).to(torch.int32) |
| |
|
| | |
| | h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs) |
| | h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs) |
| | gated = silu(h1) * h3 |
| | expert_out = grouped_mm(gated, self.experts_w2, offs=offs) |
| |
|
| | |
| | expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d') |
| | output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype) |
| | output_flat.index_add_(0, sorted_token_ids, expert_out) |
| |
|
| | |
| | experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq) |
| |
|
| | |
| | fi = counts.float() / (total_tokens * self.num_active) |
| | pi = reduce(probs, 'b s e -> e', 'mean') |
| | lb = self.num_experts * einsum(fi, pi, 'e, e ->') |
| |
|
| | logsumexp = torch.logsumexp(logits.float(), dim=-1) |
| | lz = reduce(logsumexp ** 2, '... -> ', 'mean') |
| |
|
| | |
| | assert(experts_out.shape == resid1_out.shape) |
| | final_out = resid1_out + experts_out |
| | return final_out, lb, lz |
| |
|
| |
|
| | |
| | class MoETransformer(nn.Module): |
| | def __init__( |
| | self, vocab_size: int, |
| | context_length: int, |
| | d_model: int, |
| | num_layers: int, |
| | num_heads: int, |
| | d_ff: int, |
| | num_experts: int, |
| | num_active: int, |
| | rope_theta: float, |
| | width_ratio: float = 1.0, |
| | device=None, dtype=None): |
| | super().__init__() |
| | self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| | self.num_layers = num_layers |
| | |
| | |
| | self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active, |
| | context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)]) |
| | self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| | self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | lb_total = 0 |
| | lz_total = 0 |
| |
|
| | |
| | x = self.token_embeddings(x) |
| |
|
| | |
| | for layer in self.layers: |
| | x, lb, lz = layer(x) |
| | lb_total += lb |
| | lz_total += lz |
| |
|
| | |
| | x = self.ln_final(x) |
| |
|
| | |
| | x = self.lm_head(x) |
| |
|
| | |
| | lb_avg = lb_total / self.num_layers |
| | lz_avg = lz_total / self.num_layers |
| |
|
| | return x, lb_avg, lz_avg |
| |
|
| | class LoopedMoETransformer(nn.Module): |
| | def __init__( |
| | self, vocab_size: int, |
| | context_length: int, |
| | d_model: int, |
| | num_layers_in_stack: int, |
| | num_stacks: int, |
| | num_heads: int, |
| | d_ff: int, |
| | num_experts: int, |
| | num_active: int, |
| | rope_theta: float, |
| | width_ratio: float, |
| | device=None, dtype=None): |
| | super().__init__() |
| | self.stack_depth = num_stacks |
| | self.total_layers = num_stacks*num_layers_in_stack |
| | self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype) |
| | self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, |
| | d_ff, rope_theta, width_ratio, mixture_of_experts=True, |
| | num_experts=num_experts, num_active=num_active, |
| | device=device, dtype=dtype) |
| | self.ln_final = RMSNorm(d_model, device=device, dtype=dtype) |
| |
|
| | |
| | std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size)) |
| | self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype) |
| |
|
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | |
| | lb_total = 0 |
| | lz_total = 0 |
| |
|
| | |
| | x = self.token_embeddings(x) |
| |
|
| | |
| | for i in range(self.stack_depth): |
| | x, lb, lz = self.stack(x) |
| | lb_total += lb |
| | lz_total += lz |
| |
|
| | |
| | x = self.ln_final(x) |
| |
|
| | |
| | x = self.lm_head(x) |
| |
|
| | |
| | lb_avg = lb_total / self.total_layers |
| | lz_avg = lz_total / self.total_layers |
| |
|
| | return x, lb_avg, lz_avg |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class LoopLMConfig(PretrainedConfig): |
| | """Config for all four loop-lm model variants.""" |
| |
|
| | model_type = "loop-lm" |
| |
|
| | def __init__( |
| | self, |
| | |
| | model_variant: str = "base", |
| | |
| | vocab_size: int = 50257, |
| | context_length: int = 1024, |
| | d_model: int = 1024, |
| | num_heads: int = 16, |
| | d_ff: int = 2752, |
| | rope_theta: float = 10000.0, |
| | width_ratio: float = 8.0, |
| | |
| | num_layers: int = 16, |
| | |
| | weight_tying: bool = False, |
| | |
| | num_layers_in_stack: int = 8, |
| | num_stacks: int = 2, |
| | |
| | num_experts: int = 8, |
| | num_active: int = 2, |
| | |
| | lb_loss_factor: float = 0.01, |
| | lz_loss_factor: float = 0.001, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.model_variant = model_variant |
| | self.vocab_size = vocab_size |
| | self.context_length = context_length |
| | self.d_model = d_model |
| | self.num_heads = num_heads |
| | self.d_ff = d_ff |
| | self.rope_theta = rope_theta |
| | self.width_ratio = width_ratio |
| | self.num_layers = num_layers |
| | self.weight_tying = weight_tying |
| | self.num_layers_in_stack = num_layers_in_stack |
| | self.num_stacks = num_stacks |
| | self.num_experts = num_experts |
| | self.num_active = num_active |
| | self.lb_loss_factor = lb_loss_factor |
| | self.lz_loss_factor = lz_loss_factor |
| | |
| | self.max_length = context_length |
| |
|
| |
|
| | class LoopLMForCausalLM(PreTrainedModel, GenerationMixin): |
| | """Causal LM wrapper over all four looped-scaling variants. |
| | |
| | Implements the HuggingFace PreTrainedModel interface so you can: |
| | - Upload/download via push_to_hub / from_pretrained |
| | - Run lm-evaluation-harness evals |
| | - Fine-tune with TRL's SFTTrainer / DPOTrainer |
| | """ |
| |
|
| | config_class = LoopLMConfig |
| | |
| | _keys_to_ignore_on_load_missing = [] |
| |
|
| | def __init__(self, config: LoopLMConfig): |
| | super().__init__(config) |
| | self.model = self._build_inner_model(config) |
| | self.post_init() |
| |
|
| | |
| | |
| | |
| |
|
| | def _build_inner_model(self, config: LoopLMConfig): |
| | kw = dict( |
| | vocab_size=config.vocab_size, |
| | context_length=config.context_length, |
| | d_model=config.d_model, |
| | num_heads=config.num_heads, |
| | d_ff=config.d_ff, |
| | rope_theta=config.rope_theta, |
| | width_ratio=config.width_ratio, |
| | |
| | ) |
| | v = config.model_variant |
| | if v == "base": |
| | return MuTransformer( |
| | **kw, |
| | num_layers=config.num_layers, |
| | weight_tying=config.weight_tying, |
| | ) |
| | elif v == "looped": |
| | return LoopedTransformer( |
| | **kw, |
| | num_layers_in_stack=config.num_layers_in_stack, |
| | num_stacks=config.num_stacks, |
| | weight_tying=config.weight_tying, |
| | ) |
| | elif v == "moe": |
| | return MoETransformer( |
| | **kw, |
| | num_layers=config.num_layers, |
| | num_experts=config.num_experts, |
| | num_active=config.num_active, |
| | ) |
| | elif v == "looped-moe": |
| | return LoopedMoETransformer( |
| | **kw, |
| | num_layers_in_stack=config.num_layers_in_stack, |
| | num_stacks=config.num_stacks, |
| | num_experts=config.num_experts, |
| | num_active=config.num_active, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe") |
| |
|
| | |
| | |
| | |
| |
|
| | def get_input_embeddings(self): |
| | return self.model.token_embeddings |
| |
|
| | def set_input_embeddings(self, value): |
| | self.model.token_embeddings = value |
| |
|
| | |
| | |
| | |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> CausalLMOutputWithPast: |
| | """ |
| | Args: |
| | input_ids: (batch, seq) |
| | attention_mask: ignored — models use a built-in causal mask |
| | labels: (batch, seq) token ids; if provided, returns cross-entropy loss. |
| | For MoE variants, aux losses (lb + lz) are added to the CE loss. |
| | """ |
| | is_moe = self.config.model_variant in ("moe", "looped-moe") |
| |
|
| | if is_moe: |
| | logits, lb, lz = self.model(input_ids) |
| | else: |
| | logits = self.model(input_ids) |
| | lb = lz = 0.0 |
| |
|
| | loss = None |
| | if labels is not None: |
| | ce_loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1), |
| | ) |
| | aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz |
| | loss = ce_loss + aux if self.training else ce_loss |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | return {"input_ids": input_ids} |
| |
|