#!/usr/bin/env python3 """Generate MoE layer forward reference for Qwen3-235B layer 0. Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic). Output: hidden_states after MoE + residual. """ import os, json, math, torch, torch_npu from safetensors.torch import load_file torch.npu.set_device(0) torch.set_grad_enabled(False) MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16' OUT_DIR = 'tests/moe_data' os.makedirs(OUT_DIR, exist_ok=True) cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json'))) D = cfg['hidden_size'] # 4096 I = cfg['moe_intermediate_size'] # 1536 E = cfg['num_experts'] # 128 TK = cfg['num_experts_per_tok'] # 8 eps = cfg['rms_norm_eps'] norm_topk = cfg.get('norm_topk_prob', True) # Use attention output as input (more realistic than random) attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read() S = 5 x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu() print(f"x_in: {x_in.shape}") # Load required weights for layer 0 idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json'))) wm = idx['weight_map'] needed = [f'model.layers.0.post_attention_layernorm.weight', f'model.layers.0.mlp.gate.weight'] for e in range(E): for p in ['gate_proj', 'up_proj', 'down_proj']: needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight') shards = sorted({wm[n] for n in needed}) weights = {} for sh in shards: t = load_file(os.path.join(MODEL_DIR, sh)) for n in needed: if n in t: weights[n] = t[n].to('npu') print("loaded %d tensors from %d shards" % (len(weights), len(shards))) # Residual = input residual = x_in # Post-attention RmsNorm xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps) xn_flat = xn.view(S, D) # flatten batch # Router: logits [S, E] W_router = weights['model.layers.0.mlp.gate.weight'] # [E, D] logits = xn_flat @ W_router.t() # [S, E] # Top-k softmax topk_logits, topk_idx = logits.topk(TK, dim=-1) # both [S, TK] topk_weights = torch.softmax(topk_logits.float(), dim=-1) # [S, TK] F32 if norm_topk: topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) topk_weights = topk_weights.to(torch.bfloat16) topk_idx = topk_idx.to(torch.int32) print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}") print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}") # MoE forward — loop over tokens (simple reference, not optimized) out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu') for s in range(S): token = xn_flat[s] # [D] acc = torch.zeros(D, dtype=torch.bfloat16, device='npu') for k in range(TK): e = int(topk_idx[s, k].item()) w = topk_weights[s, k] Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight'] # [I, D] Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight'] # [I, D] Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight'] # [D, I] gate = token @ Wg.t() # [I] up = token @ Wu.t() act = torch.nn.functional.silu(gate) * up down = act @ Wd.t() # [D] acc = acc + w * down out_flat[s] = acc moe_out = out_flat.view(1, S, D) final_out = residual + moe_out print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}") # Dump def dump(name, t): p = os.path.join(OUT_DIR, name + '.bin') a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') open(p, 'wb').write(a.tobytes()) dump('x_in', x_in) dump('final_out', final_out) dump('moe_out', moe_out) dump('router', W_router) dump('xn', xn) dump('topk_w', topk_weights) # [S, TK] BF16 (normalized) dump('out_flat', out_flat) # [S, D] BF16 — moe contrib before residual # expert_idx as int32 dump (raw bytes) topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes() open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes) with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f: f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n") print(f"\nDumps in {OUT_DIR}")