| import torch |
| import torch.nn.functional as F |
| from collections import namedtuple |
|
|
| from kernels.benchmark import Benchmark |
|
|
|
|
| def moe_mlp_reference( |
| x: torch.Tensor, |
| router_weight: torch.Tensor, |
| router_bias: torch.Tensor, |
| gate_up_proj: torch.Tensor, |
| gate_up_proj_bias: torch.Tensor, |
| down_proj: torch.Tensor, |
| down_proj_bias: torch.Tensor, |
| top_k: int = 4, |
| alpha: float = 1.702, |
| limit: float = 7.0, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| in_shape = x.shape |
| num_experts = router_weight.shape[0] |
| hidden_size = x.shape[-1] |
|
|
| |
| hidden_states = x.view(-1, hidden_size) |
| num_tokens = hidden_states.shape[0] |
|
|
| |
| logits = F.linear(hidden_states, router_weight, router_bias) |
| expert_weights, router_indices = torch.topk(logits, top_k, dim=-1) |
| routing_weights = F.softmax(expert_weights, dim=-1) |
|
|
| |
| next_states = torch.zeros_like(hidden_states) |
|
|
| |
| with torch.no_grad(): |
| expert_mask = F.one_hot(router_indices, num_classes=num_experts) |
| expert_mask = expert_mask.permute(2, 1, 0) |
| |
| expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() |
|
|
| |
| for expert_idx in expert_hit: |
| expert_idx = expert_idx[0] |
| with torch.no_grad(): |
| top_k_idx, token_idx = torch.where(expert_mask[expert_idx]) |
|
|
| current_state = hidden_states[token_idx] |
|
|
| |
| gate_up = ( |
| current_state @ gate_up_proj[expert_idx] + gate_up_proj_bias[expert_idx] |
| ) |
|
|
| |
| gate, up = gate_up[..., ::2], gate_up[..., 1::2] |
|
|
| |
| gate = gate.clamp(min=None, max=limit) |
| up = up.clamp(min=-limit, max=limit) |
|
|
| |
| glu = gate * torch.sigmoid(gate * alpha) |
| gated_output = (up + 1) * glu |
|
|
| |
| out = gated_output @ down_proj[expert_idx] + down_proj_bias[expert_idx] |
|
|
| |
| weights_for_expert = routing_weights[token_idx, top_k_idx] |
| weighted_output = out * weights_for_expert[:, None] |
| next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) |
|
|
| return next_states.view(in_shape), routing_weights |
|
|
|
|
| class MegaBlocksMoeBenchmark(Benchmark): |
| seed: int = 42 |
|
|
| def setup(self): |
| |
| ne, hs, isz = 128, 1152, 3072 |
| batch, seq = 8, 1 |
|
|
| |
| self.router_weight = torch.randn( |
| ne, hs, device=self.device, dtype=torch.float32 |
| ) |
| torch.nn.init.kaiming_uniform_(self.router_weight) |
| self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32) |
|
|
| |
| self.gate_up_proj = ( |
| torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02 |
| ) |
| self.gate_up_proj_bias = torch.zeros( |
| ne, isz, device=self.device, dtype=torch.float32 |
| ) |
| self.down_proj = ( |
| torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32) |
| * 0.02 |
| ) |
| self.down_proj_bias = torch.zeros( |
| ne, hs, device=self.device, dtype=torch.float32 |
| ) |
|
|
| |
| self.x = ( |
| torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1 |
| ) |
|
|
| |
| self.model = self.kernel.layers.MegaBlocksMoeMLP() |
| self.model.router = torch.nn.Linear(hs, ne, device=self.device) |
| self.model.router.weight.data = self.router_weight.clone() |
| self.model.router.bias.data = self.router_bias.clone() |
|
|
| Experts = namedtuple( |
| "Experts", |
| [ |
| "gate_up_proj", |
| "gate_up_proj_bias", |
| "down_proj", |
| "down_proj_bias", |
| "hidden_size", |
| "num_experts", |
| ], |
| ) |
| self.model.experts = Experts( |
| gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()), |
| gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()), |
| down_proj=torch.nn.Parameter(self.down_proj.clone()), |
| down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()), |
| hidden_size=hs, |
| num_experts=ne, |
| ) |
|
|
| self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_base(self): |
| self.out, self.expert_weights = self.model(self.x) |
|
|
| def verify_base(self) -> torch.Tensor: |
| ref_out, _ = moe_mlp_reference( |
| self.x, |
| self.router_weight, |
| self.router_bias, |
| self.gate_up_proj, |
| self.gate_up_proj_bias, |
| self.down_proj, |
| self.down_proj_bias, |
| top_k=4, |
| ) |
| return ref_out |
|
|
| def setup_large(self): |
| |
| ne, hs, isz = 128, 1152, 3072 |
| batch, seq = 32, 16 |
|
|
| |
| self.router_weight = torch.randn( |
| ne, hs, device=self.device, dtype=torch.float32 |
| ) |
| torch.nn.init.kaiming_uniform_(self.router_weight) |
| self.router_bias = torch.zeros(ne, device=self.device, dtype=torch.float32) |
|
|
| |
| self.gate_up_proj = ( |
| torch.randn(ne, hs, isz, device=self.device, dtype=torch.float32) * 0.02 |
| ) |
| self.gate_up_proj_bias = torch.zeros( |
| ne, isz, device=self.device, dtype=torch.float32 |
| ) |
| self.down_proj = ( |
| torch.randn(ne, isz // 2, hs, device=self.device, dtype=torch.float32) |
| * 0.02 |
| ) |
| self.down_proj_bias = torch.zeros( |
| ne, hs, device=self.device, dtype=torch.float32 |
| ) |
|
|
| |
| self.x = ( |
| torch.randn(seq, batch, hs, device=self.device, dtype=torch.float32) * 0.1 |
| ) |
|
|
| |
| self.model = self.kernel.layers.MegaBlocksMoeMLP() |
| self.model.router = torch.nn.Linear(hs, ne, device=self.device) |
| self.model.router.weight.data = self.router_weight.clone() |
| self.model.router.bias.data = self.router_bias.clone() |
|
|
| Experts = namedtuple( |
| "Experts", |
| [ |
| "gate_up_proj", |
| "gate_up_proj_bias", |
| "down_proj", |
| "down_proj_bias", |
| "hidden_size", |
| "num_experts", |
| "capacity_factor", |
| ], |
| ) |
| self.model.experts = Experts( |
| gate_up_proj=torch.nn.Parameter(self.gate_up_proj.clone()), |
| gate_up_proj_bias=torch.nn.Parameter(self.gate_up_proj_bias.clone()), |
| down_proj=torch.nn.Parameter(self.down_proj.clone()), |
| down_proj_bias=torch.nn.Parameter(self.down_proj_bias.clone()), |
| hidden_size=hs, |
| num_experts=ne, |
| capacity_factor=4.0, |
| ) |
|
|
| self.out = torch.empty(seq, batch, hs, device=self.device, dtype=torch.float32) |
|
|
| def benchmark_large(self): |
| self.out, self.expert_weights = self.model(self.x) |
|
|
| def verify_large(self) -> torch.Tensor: |
| ref_out, _ = moe_mlp_reference( |
| self.x, |
| self.router_weight, |
| self.router_bias, |
| self.gate_up_proj, |
| self.gate_up_proj_bias, |
| self.down_proj, |
| self.down_proj_bias, |
| top_k=4, |
| ) |
| return ref_out |
|
|