scaled-base / modeling_loop_lm.py
ml-ryanlee's picture
Update modeling_loop_lm.py
e57525e verified
"""Self-contained modeling file for trust_remote_code use.
This file merges mup_models.py and hf_wrapper.py into a single module with no
imports from looped_scaling.*. It is intended to be placed alongside a
config.json that sets ``auto_map`` / ``model_type = "loop-lm"`` so that
HuggingFace's ``from_pretrained(..., trust_remote_code=True)`` can load it
without requiring the looped_scaling package to be installed.
Supported model variants: "base" (MuTransformer), "looped" (LoopedTransformer),
"moe" (MoETransformer), "looped-moe" (LoopedMoETransformer).
"""
import torch
import math
import sys
import torch.nn as nn
import torch.nn.functional as F
from collections.abc import Callable, Iterable
from einops import rearrange, einsum, reduce, repeat
from typing import IO, Any, BinaryIO, Optional
from torch import Tensor
from collections import Counter, defaultdict
from torch.nn.functional import scaled_dot_product_attention as sdpa # for flash attention
from torch.nn.functional import grouped_mm, silu
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
BASE_D_MODEL = 128
BASE_D_FF = 384
""" Standard Transformer and Components implemented with muP """
# ---------------------------------------------------------------------------
# Numerically stable softmax (inlined from looped_scaling/utils.py)
# ---------------------------------------------------------------------------
def softmax(logits: Tensor, dim: int) -> Tensor:
logits = logits.float()
# get max values over specified dimension
max_values = torch.max(logits, dim=dim, keepdim=True).values
# subtract max_values from x so max element is 0
shifted = logits - max_values # broadcast should work
# get exp of shifted terms
shifted_exps = torch.exp(shifted)
# get sum of shifted terms
shifted_exp_sums = torch.sum(shifted_exps, dim=dim, keepdim=True)
# calculate product
product = shifted_exps / shifted_exp_sums
return product
# y = Wx (no bias terms!)
class Linear(nn.Module):
def __init__(self, in_features, out_features, width_ratio, std_base, device=None, dtype=None):
super().__init__()
# Register parameter first so shape is always stored (required for HF meta-device loading)
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device=device))
# for muP, derive initial std deviation from given base model's std_deviation and width ratio
std_scaled = std_base / math.sqrt(width_ratio)
nn.init.trunc_normal_(self.weight, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
def forward(self, x: Tensor) -> Tensor:
# Pytorch standard: on input side of expression, d_in is last dim of x so "... d_in"
# on output side of einsum expression, so "... d_out" follows convention
# to put the output dim last
return einsum(self.weight, x, "d_out d_in, ... d_in -> ... d_out")
class Embedding(nn.Module):
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__()
# Register parameter first so shape is always stored (required for HF meta-device loading)
self.weight = nn.Parameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device=device))
# normalize the embeddings to spec
nn.init.trunc_normal_(self.weight, mean=0.0, std=1.0, a=-3, b=3)
def forward(self, token_ids: Tensor) -> Tensor:
# for every id, we need to pull the row vector associated
return self.weight[token_ids]
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
super().__init__()
# for muP no gain parameter on the rms
self.d_model = d_model
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
# upcast input to torch.float32
in_dtype = x.dtype
x = x.to(torch.float32)
# calculate the RMS scalar
# scalar for every ex. in batch, for every emb in sequence
mean_squared_sum = (1/self.d_model)*einsum(x, x, "... seq d, ... seq d -> ... seq")
rms = torch.sqrt(mean_squared_sum + self.eps)
# for muP, no gain on rms norm as is normally applied.
rms_norm = einsum(x, 1/rms, "... seq d, ... seq -> ... seq d")
# return result to original dtype
return rms_norm.to(in_dtype)
class PositionwiseFeedforward(nn.Module):
# SwiGLU(x) = W2(SiLU(W1x)⊙W3x)
def __init__(self, d_model: int, d_ff: int, width_ratio: float, device=None, dtype=None):
super().__init__()
# for muP, calculate the base model's standard deviation
w_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_FF)) # same for all W because d_model+d_ff = d_ff+d_model
# initialize parameters of SWiGLU FFN
self.w1 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
self.w2 = Linear(d_ff, d_model, width_ratio, w_std_base, device=device, dtype=dtype)
self.w3 = Linear(d_model, d_ff, width_ratio, w_std_base, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# FFN = W2*(SiLU(W1*X) dot W3X)
silu_in = self.w1(x)
silu_out = silu(silu_in) # silu_in * torch.sigmoid(silu_in)
gate = self.w3(x)
gated_prod = silu_out * gate
final_prod = self.w2(gated_prod)
return final_prod
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None, dtype=None):
"""
theta: float Θ value for the RoPE
d_k: int dimension of query and key vectors
max_seq_len: int Maximum sequence length that will be inputted
device: torch.device | None = None Device to store the buffer on
"""
super().__init__()
rotations = torch.empty(max_seq_len, d_k//2, 2, 2, device=device, dtype=dtype)
# initialize rotation matrix
for i in range(max_seq_len):
for k in range(d_k//2):
angle = i/(theta**(2*k/d_k))
rot = Tensor([[math.cos(angle), -math.sin(angle)],
[math.sin(angle), math.cos(angle)]])
rotations[i, k, :] = rot
self.register_buffer("rotations", rotations, persistent=True)
def forward(self, x: Tensor, token_positions: Tensor) -> Tensor:
"""
self.rotations shape: (seq_dim, feature_dim, 2, 2)
x: tensor of shape (..., seq_dim, feature_dim)
token_positions: tensor of shape (..., seq_dim)
"""
# get the correct rotation matrices
# by default, 0'th dim of array_indexed is index dim, last dim of indices is feature dim
rot = self.rotations[token_positions].to(dtype=x.dtype) # match activation dtype (buffer is float32, activations may be bfloat16)
# rearrange by every two elements along feature dim of input x
x_pairs = rearrange(x, "... seq_dim (feature_dim i) -> ... seq_dim feature_dim i", i=2)
# apply rotations to these. for each pairwise position is A@x->y : (ixj)@(j,)->(i,)
y_pairs = einsum(rot, x_pairs, "... seq_dim feature_dim i j, ... seq_dim feature_dim j -> ... seq_dim feature_dim i")
# reshape y_pairs back to original shape
y = rearrange(y_pairs, "... seq_dim feature_dim i -> ... seq_dim (feature_dim i)")
return y
def scaled_dot_product_attention(
Q: Tensor,
K: Tensor,
V: Tensor,
mask: Optional[Tensor] = None,
) -> Tensor:
"""
Given key (K), query (Q), and value (V) tensors, return
the output of your scaled dot product attention implementation.
Args:
let m be seq length of inputs, n be seq length of outputs
d_k is look-up dim, d_v is value dim
Q (Float[Tensor, "batch ... n d_k"]): Query tensor
K (Float[Tensor, "batch ... m d_k"]): Key tensor
V (Float[Tensor, "batch ... m d_v"]): Values tensor
mask (Float[Tensor, " ... n m"] | None): Mask tensor
Returns:
Float[Tensor, " ... n d_v"]: Output of SDPA
"""
# get the key feature dim (should be last dim of Q and K)
d_k = Q.shape[-1]
assert d_k == K.shape[-1]
# calculate the weighted scores (similarity product). for muP, scale by d_k not sqrt(d_k)
scores = einsum(Q, K, "... n d_k, ... m d_k -> ... n m") / d_k
# apply the mask if there is one
if mask is not None:
bool_mask = mask.bool() # compatible if somehow, input is mask bool or if float
attn_mask = torch.where(bool_mask, 0.0, float('-inf')).to(scores.dtype)
scores = scores + attn_mask
# calculate the weighted
weights = softmax(scores, dim=-1) # the softmax should be taken over the m inputs at an i'th output pos.
# return weights@V
return einsum(weights, V, "... n m, ... m d_v -> ... n d_v")
class MultiheadSelfAttention(nn.Module):
"""
Args:
d_model (int): Dimensionality of the feedforward input and output.
num_heads (int): Number of heads to use in multi-headed attention.
max_seq_len (int): Maximum sequence length to pre-cache if your implementation does that.
q_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the Q projection
k_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the K projection
v_proj_weight (Float[Tensor, "d_k d_in"]): Weights for the V projection
o_proj_weight (Float[Tensor, "d_model d_v"]): Weights for the output projection
in_features (Float[Tensor, "... sequence_length d_in"]): Tensor to run your implementation on.
Returns:
Float[Tensor, " ... sequence_length d_out"]: Tensor with the output of running your optimized, batched multi-headed attention
implementation with the given QKV projection weights and input features.
"""
def __init__(self, d_model: int, num_heads: int, max_seq_len: int = None, theta: float = None, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# initialize the multi-head self attention weights as 1 large matrix (which will be sliced)
assert d_model % num_heads == 0, f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
self.d_model = d_model
self.num_heads = num_heads
# for muP, calculate standard deviation of base model
attn_std_base = math.sqrt(2/(BASE_D_MODEL+BASE_D_MODEL))
# for muP, initialize the Wq,Wk,Wv,Wo linear weights with width_ratio and base model's stddev
self.q_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.k_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.v_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
self.output_proj = Linear(d_model, d_model, width_ratio, attn_std_base, device=device, dtype=dtype)
# # Removed for torch sdpa, uncomment if using normal code
# if max_seq_len:
# causal_mask = torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=dtype, device=device))
# self.register_buffer("causal_mask", causal_mask, persistent=False)
# else:
# self.register_buffer("causal_mask", None, persistent=False)
assert theta is None or max_seq_len is not None, "max_seq_len must be provided when theta is given for multi-head self attention with RoPE."
if theta:
d_k = d_model//num_heads
self.rope = RotaryPositionalEmbedding(theta, d_k, max_seq_len, device, dtype)
else:
self.rope = None
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# get Q, K, V matrices
Q = self.q_proj(x) # output shape is [batch seq d_model]
K = self.k_proj(x)
V = self.v_proj(x)
# #create causal mask intepreting the second to last dim as seq dim
# if self.causal_mask is None:
# seq_dim = x.shape[-2]
# cmask = torch.tril(torch.ones(seq_dim, seq_dim, dtype=x.dtype, device=x.device))
# else:
# # Slice the pre-computed mask to match actual sequence length (could be < than max_seq_len)
# seq_dim = x.shape[-2]
# cmask = self.causal_mask[:seq_dim, :seq_dim]
# get slice size for multi-head self attention
d_k = self.d_model // self.num_heads
d_v = self.d_model // self.num_heads
q_heads = rearrange(Q, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
k_heads = rearrange(K, "... seq (heads d_k) -> ... heads seq d_k", d_k=d_k)
# apply RoPE to q_heads and k_heads
if self.rope:
seq_dim = x.shape[-2] # x is (b,s,d)
if token_positions is None:
token_positions = torch.arange(seq_dim, device=x.device)
token_positions = rearrange(token_positions, "seq -> 1 seq") # 1 seq allows broadcast across batch dim
q_heads = self.rope(q_heads, token_positions)
k_heads = self.rope(k_heads, token_positions)
v_heads = rearrange(V, "... seq (heads d_v) -> ... heads seq d_v", d_v=d_v)
#mha_heads = scaled_dot_product_attention(q_heads, k_heads, v_heads, cmask)
mha_heads = sdpa(q_heads, k_heads, v_heads, is_causal=True, scale=1.0/d_k)
mha = rearrange(mha_heads, "... heads seq d_v -> ... seq (heads d_v)")
# apply o_proj_weight to the concatenated multi-head attention product
out = self.output_proj(mha)
return out
class PrenormBlock(nn.Module):
def __init__(self,
d_model: int,
num_heads: int,
d_ff: int,
max_seq_len: int,
theta: float,
width_ratio: float,
device=None,
dtype=None):
super().__init__()
# norm layer
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# add step
# norm layer
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# positionwise feed forward
self.ffn = PositionwiseFeedforward(d_model, d_ff, width_ratio, device, dtype)
# add to output
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
# we may have to define token_positions if it is not given
attn_out = self.attn(norm1_out, token_positions)
# ensure no broadcasting, elementwise addition on [batch seq d_model]
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# second Tx operation, Norm + SwiGLU
norm2_out = self.ln2(resid1_out)
ffn_out = self.ffn(norm2_out)
# ensure no broadcasting, elementwise addition
assert(ffn_out.shape == resid1_out.shape)
final_out = resid1_out + ffn_out
return final_out
class MuTransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
weight_tying: bool = False,
device=None, dtype=None):
super().__init__()
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
self.weight_tying = weight_tying
if weight_tying:
self.lm_head = self.token_embeddings.weight
else:
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
self.width_ratio = width_ratio
def forward(self, x: Tensor) -> Tensor:
# 1. token embed step, no muP alpha_in
x = self.token_embeddings(x)
# 2. prenorm blocks step
for layer in self.layers:
x = layer(x)
# 3. Final norm
x = self.ln_final(x)
# 4. unembed layer, muP implemented as scaling on init variance and lr of lm_head, not output scaling
if self.weight_tying:
x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
else:
x = self.lm_head(x)
# 5. return output, no muP alpha_out
return x
""" Looped Language Models implemented with MuP """
class LoopedStack(nn.Module):
def __init__(
self,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
mixture_of_experts: bool = False,
num_experts: Optional[int] = None,
num_active: Optional[int] = None,
device=None, dtype=None):
super().__init__()
if mixture_of_experts:
# self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
# context_length,rope_theta,width_ratio,device,dtype)
# for _ in range(num_layers_in_stack)])
self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
context_length, rope_theta, width_ratio, device, dtype)
for _ in range(num_layers_in_stack)])
else:
self.layers = nn.ModuleList([PrenormBlock(d_model, num_heads, d_ff, context_length, rope_theta,
width_ratio, device, dtype) for _ in range(num_layers_in_stack)])
self.mixture_of_experts = mixture_of_experts
def forward(self, x: Tensor) -> Tensor:
# prenorm blocks step
if self.mixture_of_experts:
lb_total = 0
lz_total = 0
# sum up load balancing and z-losses across each layer
for layer in self.layers:
x, lb, lz = layer(x)
lb_total += lb
lz_total += lz
return x, lb_total, lz_total
else:
for layer in self.layers:
x = layer(x)
return x
class LoopedTransformer(nn.Module):
def __init__(
self,
vocab_size: int,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_stacks: int,
num_heads: int,
d_ff: int,
rope_theta: float,
width_ratio: float = 1.0,
weight_tying: bool = False,
device=None, dtype=None):
super().__init__()
self.num_stacks = num_stacks
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads, d_ff, rope_theta, width_ratio, device=device, dtype=dtype)
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
self.weight_tying = weight_tying
self.width_ratio = width_ratio
if weight_tying:
self.lm_head = self.token_embeddings.weight
else:
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio, std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# token embed step
x = self.token_embeddings(x)
# repeated calls to stack
for i in range(self.num_stacks):
x = self.stack(x)
# final norm
x = self.ln_final(x)
# Vocab projection or lm_head
if self.weight_tying:
x = einsum(x, self.lm_head, "... s d, v d -> ... s v")/self.width_ratio
else:
x = self.lm_head(x)
return x
""" Mixture-of-Experts Implementation in muP """
# Router Class
class Router(nn.Module):
def __init__(self, d_model: int, num_experts: int, num_active=None, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# router is simply a linear layer. we initialize (d_in, d_out) according to my code
std_base = math.sqrt(2/(BASE_D_MODEL+num_experts))
self.gate = Linear(d_model, num_experts, width_ratio, std_base, device=device, dtype=dtype) # adjusted for muP
self.num_active = num_active
def forward(self, x: Tensor):
# returns scores, top_k_scores, top_k_indices
logits = self.gate(x) # should be shape (batch, seq, n_routers)
# probs
probs = softmax(logits, dim=-1)
# get top_k
top_scores, top_experts = torch.topk(probs, k=self.num_active, dim=-1)
# renormalize the top scores so weighted sum of expert products can be taken
score_sums = torch.sum(top_scores, dim=-1, keepdim=True) # (batch, seq)
top_scores = top_scores/score_sums
return logits, probs, top_scores, top_experts
class MoEPrenormBlock(nn.Module):
def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# norm layer before mHSA+RoPE
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# norm layer before position-wise feedfoward
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# router
self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
# save MoE hyperparams
self.num_experts = num_experts
self.num_active = num_active
# initialize MoE FFNs as a module list
d_ff_expert = d_ff // num_active
self.experts = nn.ModuleList([PositionwiseFeedforward(d_model, d_ff_expert, width_ratio, device, dtype) for _ in range(num_experts)]) # adjusted for muP
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
# input dims
batch, seq, dim = x.shape
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
# we may have to define token_positions if it is not given
attn_out = self.attn(norm1_out, token_positions)
# ensure no broadcasting, elementwise addition on [batch seq d_model]
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# prenorm before position-wise feedforward
norm2_out = self.ln2(resid1_out)
# get scores from Router. returns shape (batch,seq,k)
logits, probs, top_scores, top_experts = self.router(norm2_out) # logits and probs are (batch, seq, n_routers)
expert_mean_probs = torch.mean(probs, dim=(0, 1)) # take mean across batch and seq dims
# apply mixture of experts
experts_out = torch.zeros_like(norm2_out) # copies shape, device and dtype
total_tokens_assigned = batch*seq*self.num_active
lb_sum = 0
for expert_idx in range(self.num_experts):
# get masks for expert selection
expert_mask = (top_experts == expert_idx)
embed_mask = expert_mask.any(dim=-1) # if any of the k is expert, we want to transform embed
if not embed_mask.any(): continue
pi = expert_mean_probs[expert_idx].item()
fi = (expert_mask.sum().item())/total_tokens_assigned # num embeds assigned to expert in batch
lb_sum += fi*pi
# extract embeds and weights for activated experts
weights = top_scores[expert_mask] # (num_embeds)
expert_embeds = norm2_out[embed_mask] # (num_embeds, hidden_dim)
# forward for the correct experts
expert_out = self.experts[expert_idx](expert_embeds) # Vanilla Implementation
# map back to experts output
experts_out[embed_mask] += weights.unsqueeze(-1)*expert_out # broadcast elementwise multiply by hidden dim
# calculate batch's load balancing loss
lb = self.num_experts*lb_sum
# calculate batch's router z loss
logsumexp = torch.logsumexp(logits.float(), dim=-1)
lz = torch.mean(logsumexp ** 2)
# ensure no broadcasting, elementwise addition
assert(experts_out.shape == resid1_out.shape)
final_out = resid1_out + experts_out
return final_out, lb, lz
class GroupedMoEPrenormBlock(nn.Module):
@staticmethod
def _init_expert_weights(num_experts, in_features, out_features, width_ratio, std_base, device, dtype) -> nn.Parameter:
w = torch.empty(num_experts, in_features, out_features, device=device, dtype=dtype) # (batch, in, out)
std_scaled = std_base / math.sqrt(width_ratio)
nn.init.trunc_normal_(w, mean=0.0, std=std_scaled, a=-3*std_scaled, b=3*std_scaled)
return nn.Parameter(w)
def __init__(self, d_model: int, num_heads: int, d_ff: int, num_experts: int, num_active: int,
max_seq_len: int, theta: float, width_ratio: float = 1.0, device=None, dtype=None):
super().__init__()
# norm layer before mHSA+RoPE
self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)
# mhsa with rope
self.attn = MultiheadSelfAttention(d_model, num_heads, max_seq_len, theta, width_ratio, device, dtype)
# norm layer before position-wise feedfoward
self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)
# router
self.router = Router(d_model, num_experts, num_active, width_ratio=width_ratio, device=device, dtype=dtype)
# save MoE hyperparams
self.num_experts = num_experts
self.num_active = num_active
# initialize MoE FFNs as a module list
d_ff_expert = d_ff // num_active
# expose and stack the MoE SwiGLU weights for all experts. with experts in string, optimizer scales weights by width_ratio
w_std_base = math.sqrt(2 / (BASE_D_MODEL + BASE_D_FF))
self.experts_w1 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
self.experts_w2 = self._init_expert_weights(num_experts, d_ff_expert, d_model, width_ratio, w_std_base, device, dtype)
self.experts_w3 = self._init_expert_weights(num_experts, d_model, d_ff_expert, width_ratio, w_std_base, device, dtype)
def forward(self, x: Tensor, token_positions: Optional[Tensor] = None) -> Tensor:
batch, seq, dim = x.shape
total_tokens = batch * seq
# first Tx operation, Norm + MHSA w/ RoPE
norm1_out = self.ln1(x)
attn_out = self.attn(norm1_out, token_positions)
assert(x.shape == attn_out.shape)
resid1_out = attn_out + x
# prenorm before position-wise feedforward
norm2_out = self.ln2(resid1_out)
# get scores from Router. returns shape (batch, seq, k)
logits, probs, top_scores, top_experts = self.router(norm2_out)
# flatten to 2D for grouped_mm
x_flat = rearrange(norm2_out, 'b s d -> (b s) d') # (total_tokens, d_model)
flat_expert_ids = rearrange(top_experts, 'b s k -> (b s k)') # (total_tokens * k,)
flat_scores = rearrange(top_scores, 'b s k -> (b s k)') # (total_tokens * k,)
flat_positions = torch.arange(total_tokens, device=x.device) # (total_tokens)
flat_token_ids = repeat(flat_positions, 'n -> (n k)', k=self.num_active) # (total_tokens * k)
# sort by expert
sort_indices = flat_expert_ids.argsort(stable=True)
sorted_expert_ids = flat_expert_ids[sort_indices]
sorted_token_ids = flat_token_ids[sort_indices]
sorted_scores = flat_scores[sort_indices]
sorted_x = x_flat[sorted_token_ids] # (total_tokens * k, d_model)
# build offs (cumulative token counts per expert)
counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts)
offs = counts.cumsum(0).to(torch.int32) # (num_experts,)
# grouped SwiGLU: W2(SiLU(W1 x) dot W3 x)
h1 = grouped_mm(sorted_x, self.experts_w1, offs=offs)
h3 = grouped_mm(sorted_x, self.experts_w3, offs=offs)
gated = silu(h1) * h3
expert_out = grouped_mm(gated, self.experts_w2, offs=offs) # (total_tokens * k, d_model)
# weight by router scores and scatter-add back
expert_out = einsum(expert_out, sorted_scores, 'n d, n -> n d')
output_flat = torch.zeros(total_tokens, dim, device=x.device, dtype=expert_out.dtype)
output_flat.index_add_(0, sorted_token_ids, expert_out)
# reshape back to (batch, seq, d_model)
experts_out = rearrange(output_flat, '(b s) d -> b s d', b=batch, s=seq)
# aux losses
fi = counts.float() / (total_tokens * self.num_active)
pi = reduce(probs, 'b s e -> e', 'mean')
lb = self.num_experts * einsum(fi, pi, 'e, e ->')
logsumexp = torch.logsumexp(logits.float(), dim=-1)
lz = reduce(logsumexp ** 2, '... -> ', 'mean')
# residual connection
assert(experts_out.shape == resid1_out.shape)
final_out = resid1_out + experts_out
return final_out, lb, lz
# MoE Implementation
class MoETransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers: int,
num_heads: int,
d_ff: int,
num_experts: int,
num_active: int,
rope_theta: float,
width_ratio: float = 1.0,
device=None, dtype=None):
super().__init__()
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.num_layers = num_layers
# self.layers = nn.ModuleList([MoEPrenormBlock(d_model,num_heads,d_ff,num_experts,num_active,
# context_length,rope_theta,width_ratio,device,dtype) for _ in range(num_layers)])
self.layers = nn.ModuleList([GroupedMoEPrenormBlock(d_model, num_heads, d_ff, num_experts, num_active,
context_length, rope_theta, width_ratio, device, dtype) for _ in range(num_layers)])
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
# only non-tied embeddings now
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# collect aux losses
lb_total = 0
lz_total = 0
# 1. token embed step
x = self.token_embeddings(x)
# 2. prenorm blocks step
for layer in self.layers:
x, lb, lz = layer(x)
lb_total += lb
lz_total += lz
# 3. Final norm
x = self.ln_final(x)
# 4. Vocab projection or lm_head
x = self.lm_head(x)
# calculate average layer aux loss
lb_avg = lb_total / self.num_layers
lz_avg = lz_total / self.num_layers
return x, lb_avg, lz_avg
class LoopedMoETransformer(nn.Module):
def __init__(
self, vocab_size: int,
context_length: int,
d_model: int,
num_layers_in_stack: int,
num_stacks: int,
num_heads: int,
d_ff: int,
num_experts: int,
num_active: int,
rope_theta: float,
width_ratio: float,
device=None, dtype=None):
super().__init__()
self.stack_depth = num_stacks
self.total_layers = num_stacks*num_layers_in_stack
self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=dtype)
self.stack = LoopedStack(context_length, d_model, num_layers_in_stack, num_heads,
d_ff, rope_theta, width_ratio, mixture_of_experts=True,
num_experts=num_experts, num_active=num_active,
device=device, dtype=dtype) # parameters for loop with MoE
self.ln_final = RMSNorm(d_model, device=device, dtype=dtype)
# scale lm head
std_base_lm_head = math.sqrt(2/(BASE_D_MODEL+vocab_size))
self.lm_head = Linear(d_model, vocab_size, width_ratio=width_ratio, std_base=std_base_lm_head, device=device, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
# collect aux losses
lb_total = 0
lz_total = 0
# token embed step
x = self.token_embeddings(x)
# repeated calls to stack
for i in range(self.stack_depth):
x, lb, lz = self.stack(x)
lb_total += lb
lz_total += lz
# final norm
x = self.ln_final(x)
# Vocab projection or lm_head
x = self.lm_head(x)
# calculate aux loss averages
lb_avg = lb_total / self.total_layers
lz_avg = lz_total / self.total_layers
return x, lb_avg, lz_avg
# ---------------------------------------------------------------------------
# HuggingFace wrapper (from hf_wrapper.py)
# ---------------------------------------------------------------------------
class LoopLMConfig(PretrainedConfig):
"""Config for all four loop-lm model variants."""
model_type = "loop-lm"
def __init__(
self,
# which of the four architectures to use
model_variant: str = "base", # "base" | "looped" | "moe" | "looped-moe"
# shared
vocab_size: int = 50257,
context_length: int = 1024,
d_model: int = 1024,
num_heads: int = 16,
d_ff: int = 2752,
rope_theta: float = 10000.0,
width_ratio: float = 8.0, # d_model / base_d_model (128); set at training time
# base + moe only
num_layers: int = 16,
# base + looped only
weight_tying: bool = False,
# looped + looped-moe only
num_layers_in_stack: int = 8,
num_stacks: int = 2,
# moe + looped-moe only
num_experts: int = 8,
num_active: int = 2,
# aux loss weights — used when forward() is called with labels
lb_loss_factor: float = 0.01,
lz_loss_factor: float = 0.001,
**kwargs,
):
super().__init__(**kwargs)
self.model_variant = model_variant
self.vocab_size = vocab_size
self.context_length = context_length
self.d_model = d_model
self.num_heads = num_heads
self.d_ff = d_ff
self.rope_theta = rope_theta
self.width_ratio = width_ratio
self.num_layers = num_layers
self.weight_tying = weight_tying
self.num_layers_in_stack = num_layers_in_stack
self.num_stacks = num_stacks
self.num_experts = num_experts
self.num_active = num_active
self.lb_loss_factor = lb_loss_factor
self.lz_loss_factor = lz_loss_factor
# lm-evaluation-harness looks for this attribute to cap sequence length
self.max_length = context_length
class LoopLMForCausalLM(PreTrainedModel, GenerationMixin):
"""Causal LM wrapper over all four looped-scaling variants.
Implements the HuggingFace PreTrainedModel interface so you can:
- Upload/download via push_to_hub / from_pretrained
- Run lm-evaluation-harness evals
- Fine-tune with TRL's SFTTrainer / DPOTrainer
"""
config_class = LoopLMConfig
# tell HF which parameter holds the output logits for generation
_keys_to_ignore_on_load_missing = []
def __init__(self, config: LoopLMConfig):
super().__init__(config)
self.model = self._build_inner_model(config)
self.post_init()
# ------------------------------------------------------------------
# Model construction
# ------------------------------------------------------------------
def _build_inner_model(self, config: LoopLMConfig):
kw = dict(
vocab_size=config.vocab_size,
context_length=config.context_length,
d_model=config.d_model,
num_heads=config.num_heads,
d_ff=config.d_ff,
rope_theta=config.rope_theta,
width_ratio=config.width_ratio,
# device=None so weights are placed on CPU; caller uses .to(device)
)
v = config.model_variant
if v == "base":
return MuTransformer(
**kw,
num_layers=config.num_layers,
weight_tying=config.weight_tying,
)
elif v == "looped":
return LoopedTransformer(
**kw,
num_layers_in_stack=config.num_layers_in_stack,
num_stacks=config.num_stacks,
weight_tying=config.weight_tying,
)
elif v == "moe":
return MoETransformer(
**kw,
num_layers=config.num_layers,
num_experts=config.num_experts,
num_active=config.num_active,
)
elif v == "looped-moe":
return LoopedMoETransformer(
**kw,
num_layers_in_stack=config.num_layers_in_stack,
num_stacks=config.num_stacks,
num_experts=config.num_experts,
num_active=config.num_active,
)
else:
raise ValueError(f"Unknown model_variant: {v!r}. Choose from: base, looped, moe, looped-moe")
# ------------------------------------------------------------------
# Embedding access (required by some HF utilities)
# ------------------------------------------------------------------
def get_input_embeddings(self):
return self.model.token_embeddings
def set_input_embeddings(self, value):
self.model.token_embeddings = value
# ------------------------------------------------------------------
# Forward
# ------------------------------------------------------------------
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None, # causal mask is handled internally
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Args:
input_ids: (batch, seq)
attention_mask: ignored — models use a built-in causal mask
labels: (batch, seq) token ids; if provided, returns cross-entropy loss.
For MoE variants, aux losses (lb + lz) are added to the CE loss.
"""
is_moe = self.config.model_variant in ("moe", "looped-moe")
if is_moe:
logits, lb, lz = self.model(input_ids)
else:
logits = self.model(input_ids)
lb = lz = 0.0
loss = None
if labels is not None:
ce_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
)
aux = self.config.lb_loss_factor * lb + self.config.lz_loss_factor * lz
loss = ce_loss + aux if self.training else ce_loss
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
)
# ------------------------------------------------------------------
# Generation support (no KV cache — generation is correct but slow)
# ------------------------------------------------------------------
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}