| |
| """Generate MoE layer forward reference for Qwen3-235B layer 0. |
| |
| Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic). |
| Output: hidden_states after MoE + residual. |
| """ |
| import os, json, math, torch, torch_npu |
| from safetensors.torch import load_file |
|
|
| torch.npu.set_device(0) |
| torch.set_grad_enabled(False) |
|
|
| MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16' |
| OUT_DIR = 'tests/moe_data' |
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json'))) |
| D = cfg['hidden_size'] |
| I = cfg['moe_intermediate_size'] |
| E = cfg['num_experts'] |
| TK = cfg['num_experts_per_tok'] |
| eps = cfg['rms_norm_eps'] |
| norm_topk = cfg.get('norm_topk_prob', True) |
|
|
| |
| attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read() |
| S = 5 |
| x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu() |
| print(f"x_in: {x_in.shape}") |
|
|
| |
| idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json'))) |
| wm = idx['weight_map'] |
|
|
| needed = [f'model.layers.0.post_attention_layernorm.weight', |
| f'model.layers.0.mlp.gate.weight'] |
| for e in range(E): |
| for p in ['gate_proj', 'up_proj', 'down_proj']: |
| needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight') |
|
|
| shards = sorted({wm[n] for n in needed}) |
| weights = {} |
| for sh in shards: |
| t = load_file(os.path.join(MODEL_DIR, sh)) |
| for n in needed: |
| if n in t: |
| weights[n] = t[n].to('npu') |
| print("loaded %d tensors from %d shards" % (len(weights), len(shards))) |
|
|
| |
| residual = x_in |
|
|
| |
| xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps) |
| xn_flat = xn.view(S, D) |
|
|
| |
| W_router = weights['model.layers.0.mlp.gate.weight'] |
| logits = xn_flat @ W_router.t() |
|
|
| |
| topk_logits, topk_idx = logits.topk(TK, dim=-1) |
| topk_weights = torch.softmax(topk_logits.float(), dim=-1) |
| if norm_topk: |
| topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20) |
| topk_weights = topk_weights.to(torch.bfloat16) |
| topk_idx = topk_idx.to(torch.int32) |
|
|
| print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}") |
| print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}") |
|
|
| |
| out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu') |
| for s in range(S): |
| token = xn_flat[s] |
| acc = torch.zeros(D, dtype=torch.bfloat16, device='npu') |
| for k in range(TK): |
| e = int(topk_idx[s, k].item()) |
| w = topk_weights[s, k] |
| Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight'] |
| Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight'] |
| Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight'] |
| gate = token @ Wg.t() |
| up = token @ Wu.t() |
| act = torch.nn.functional.silu(gate) * up |
| down = act @ Wd.t() |
| acc = acc + w * down |
| out_flat[s] = acc |
|
|
| moe_out = out_flat.view(1, S, D) |
| final_out = residual + moe_out |
| print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}") |
|
|
| |
| def dump(name, t): |
| p = os.path.join(OUT_DIR, name + '.bin') |
| a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') |
| open(p, 'wb').write(a.tobytes()) |
|
|
| dump('x_in', x_in) |
| dump('final_out', final_out) |
| dump('moe_out', moe_out) |
| dump('router', W_router) |
| dump('xn', xn) |
| dump('topk_w', topk_weights) |
| dump('out_flat', out_flat) |
|
|
| |
| topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes() |
| open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes) |
|
|
| with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f: |
| f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n") |
|
|
| print(f"\nDumps in {OUT_DIR}") |
|
|