llm_mutil_npu / scripts /gen_mm_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Generate a linear (y = x @ W.T) reference for a realistic Qwen3 attention shape."""
import os, struct, torch, torch_npu
torch.npu.set_device(0)
torch.manual_seed(7)
N, D, OUT = 5, 4096, 8192 # prompt len, hidden, q_dim
x = torch.randn(N, D, dtype=torch.bfloat16).npu()
W = torch.randn(OUT, D, dtype=torch.bfloat16).npu() # HF layout [out, in]
# y = x @ W.T, shape [N, OUT]
y_ref = torch.matmul(x, W.t())
out_dir = 'tests/mm_data'
os.makedirs(out_dir, exist_ok=True)
def dump(name, t):
p = os.path.join(out_dir, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
open(p, 'wb').write(a.tobytes())
dump('x', x); dump('W', W); dump('y_ref', y_ref)
with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
f.write(f"N={N}\nD={D}\nOUT={OUT}\n")
print(f"N={N} D={D} OUT={OUT}, y_ref[0, :4] = {y_ref[0, :4].float().cpu().tolist()}")