llm_mutil_npu / scripts /gen_moe_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Generate MoE layer forward reference for Qwen3-235B layer 0.
Input: hidden_states from attention output (use attn_data/final_out.bin as input — realistic).
Output: hidden_states after MoE + residual.
"""
import os, json, math, torch, torch_npu
from safetensors.torch import load_file
torch.npu.set_device(0)
torch.set_grad_enabled(False)
MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
OUT_DIR = 'tests/moe_data'
os.makedirs(OUT_DIR, exist_ok=True)
cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
D = cfg['hidden_size'] # 4096
I = cfg['moe_intermediate_size'] # 1536
E = cfg['num_experts'] # 128
TK = cfg['num_experts_per_tok'] # 8
eps = cfg['rms_norm_eps']
norm_topk = cfg.get('norm_topk_prob', True)
# Use attention output as input (more realistic than random)
attn_out_raw = open('tests/attn_data/final_out.bin', 'rb').read()
S = 5
x_in = torch.frombuffer(bytearray(attn_out_raw), dtype=torch.int16).view(1, S, D).view(torch.bfloat16).npu()
print(f"x_in: {x_in.shape}")
# Load required weights for layer 0
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
wm = idx['weight_map']
needed = [f'model.layers.0.post_attention_layernorm.weight',
f'model.layers.0.mlp.gate.weight']
for e in range(E):
for p in ['gate_proj', 'up_proj', 'down_proj']:
needed.append(f'model.layers.0.mlp.experts.{e}.{p}.weight')
shards = sorted({wm[n] for n in needed})
weights = {}
for sh in shards:
t = load_file(os.path.join(MODEL_DIR, sh))
for n in needed:
if n in t:
weights[n] = t[n].to('npu')
print("loaded %d tensors from %d shards" % (len(weights), len(shards)))
# Residual = input
residual = x_in
# Post-attention RmsNorm
xn, _ = torch_npu.npu_rms_norm(x_in, weights['model.layers.0.post_attention_layernorm.weight'], epsilon=eps)
xn_flat = xn.view(S, D) # flatten batch
# Router: logits [S, E]
W_router = weights['model.layers.0.mlp.gate.weight'] # [E, D]
logits = xn_flat @ W_router.t() # [S, E]
# Top-k softmax
topk_logits, topk_idx = logits.topk(TK, dim=-1) # both [S, TK]
topk_weights = torch.softmax(topk_logits.float(), dim=-1) # [S, TK] F32
if norm_topk:
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-20)
topk_weights = topk_weights.to(torch.bfloat16)
topk_idx = topk_idx.to(torch.int32)
print(f"topk_idx[0]: {topk_idx[0].cpu().tolist()}")
print(f"topk_weights[0]: {topk_weights[0].cpu().float().tolist()}")
# MoE forward — loop over tokens (simple reference, not optimized)
out_flat = torch.zeros(S, D, dtype=torch.bfloat16, device='npu')
for s in range(S):
token = xn_flat[s] # [D]
acc = torch.zeros(D, dtype=torch.bfloat16, device='npu')
for k in range(TK):
e = int(topk_idx[s, k].item())
w = topk_weights[s, k]
Wg = weights[f'model.layers.0.mlp.experts.{e}.gate_proj.weight'] # [I, D]
Wu = weights[f'model.layers.0.mlp.experts.{e}.up_proj.weight'] # [I, D]
Wd = weights[f'model.layers.0.mlp.experts.{e}.down_proj.weight'] # [D, I]
gate = token @ Wg.t() # [I]
up = token @ Wu.t()
act = torch.nn.functional.silu(gate) * up
down = act @ Wd.t() # [D]
acc = acc + w * down
out_flat[s] = acc
moe_out = out_flat.view(1, S, D)
final_out = residual + moe_out
print(f"final_out[0,0,:4] = {final_out[0,0,:4].float().cpu().tolist()}")
# Dump
def dump(name, t):
p = os.path.join(OUT_DIR, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
open(p, 'wb').write(a.tobytes())
dump('x_in', x_in)
dump('final_out', final_out)
dump('moe_out', moe_out)
dump('router', W_router)
dump('xn', xn)
dump('topk_w', topk_weights) # [S, TK] BF16 (normalized)
dump('out_flat', out_flat) # [S, D] BF16 — moe contrib before residual
# expert_idx as int32 dump (raw bytes)
topk_idx_bytes = topk_idx.contiguous().cpu().numpy().astype('int32').tobytes()
open(os.path.join(OUT_DIR, 'topk_idx.bin'), 'wb').write(topk_idx_bytes)
with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
f.write(f"S={S}\nD={D}\nI={I}\nE={E}\nTK={TK}\n")
print(f"\nDumps in {OUT_DIR}")