llm_mutil_npu / scripts /gen_gmm_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Generate a GMM reference case using torch_npu.npu_grouped_matmul.
Dumps: x, w (unpermuted ggml-like layout), group_list, and y_ref to binary files
for the C++ POC to load and validate against.
"""
import os
import sys
import numpy as np
import struct
# Enable torch_npu
os.environ.setdefault('LD_LIBRARY_PATH', '')
import torch
import torch_npu
torch.npu.set_device(0)
torch.manual_seed(42)
# Toy Qwen3-like MoE shape: D=64 hidden, I=32 intermediate, E=8 experts, N*K=16 expanded tokens
# (small enough to eyeball; large enough to catch layout bugs)
D, I, E, TOTAL = 64, 32, 8, 16
# Input x: [TOTAL, D] BF16 — expanded routed tokens
x = torch.randn(TOTAL, D, dtype=torch.bfloat16).npu()
# Weight w: per-expert [I, D] BF16 — gate/up has this shape in HF
# We will stack into [E, I, D] and also provide [E, D, I] permuted for comparison
w_per_expert = [torch.randn(I, D, dtype=torch.bfloat16).npu() for _ in range(E)]
w_stacked_IDL = torch.stack(w_per_expert, dim=0) # [E, I, D]
# group_list: counts of tokens per expert, sum = TOTAL
group_list = torch.tensor([3, 2, 1, 2, 1, 3, 2, 2], dtype=torch.int64).npu()
assert group_list.sum().item() == TOTAL
# Reference: use torch_npu.npu_grouped_matmul
# Per cann-recipes: weight needs to be in [E, D, I] for matmul y = x @ w (y shape [total, I])
# i.e. per-expert w is transposed from HF's [I, D] to [D, I]
w_transposed = w_stacked_IDL.transpose(1, 2).contiguous() # [E, D, I]
# Call GMM: y = x @ w, result [TOTAL, I]
y_ref = torch_npu.npu_grouped_matmul(
[x], # x list
[w_transposed], # weight list (transposed)
group_list=group_list,
group_type=0,
group_list_type=1, # counts
split_item=3 # single-in single-out
)[0] # unwrap tensor list
print("x shape:", x.shape, x.dtype)
print("w_stacked_IDL shape:", w_stacked_IDL.shape, w_stacked_IDL.dtype)
print("w_transposed shape:", w_transposed.shape)
print("group_list:", group_list.cpu().tolist())
print("y_ref shape:", y_ref.shape)
print("y_ref[0, 0:4]:", y_ref[0, 0:4].cpu().float().tolist())
# Save binary dumps
out_dir = 'tests/poc_data'
os.makedirs(out_dir, exist_ok=True)
def dump_bf16(name, tensor):
path = os.path.join(out_dir, name + '.bin')
arr = tensor.contiguous().cpu().view(torch.int16).numpy().astype('int16')
with open(path, 'wb') as f:
f.write(arr.tobytes())
print(f" wrote {name}.bin: {arr.shape} int16 = BF16 raw, {arr.nbytes} bytes")
def dump_int64(name, tensor):
path = os.path.join(out_dir, name + '.bin')
arr = tensor.contiguous().cpu().numpy().astype('int64')
with open(path, 'wb') as f:
f.write(arr.tobytes())
print(f" wrote {name}.bin: {arr.shape} int64, {arr.nbytes} bytes")
# HF-style weight layout (ggml stores similar): [E, I, D] = what C++ gets from safetensors after stack
dump_bf16('x', x)
dump_bf16('w_hf_EID', w_stacked_IDL) # C++ input weight (HF layout)
dump_bf16('w_ref_EDI', w_transposed) # Already-permuted reference (for debug)
dump_int64('group_list', group_list)
dump_bf16('y_ref', y_ref)
# Also dump shapes header
with open(os.path.join(out_dir, 'shapes.txt'), 'w') as f:
f.write(f"D={D}\nI={I}\nE={E}\nTOTAL={TOTAL}\n")
print("\nAll dumps in:", out_dir)
print("\nTo validate: C++ loads w_hf_EID, permutes [0,2,1] to [E,D,I], NZ-casts, calls GMMV4, "
"compares output to y_ref.")