llm_mutil_npu / scripts /gen_rms_norm_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Generate a RmsNorm reference using PyTorch."""
import os, struct
import numpy as np
import torch
import torch_npu
torch.npu.set_device(0)
torch.manual_seed(123)
N, D = 5, 4096 # 5 tokens, Qwen3 hidden_size
eps = 1e-6
x = torch.randn(N, D, dtype=torch.bfloat16).npu()
gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0
# Use torch_npu's npu_rms_norm if available, else do it manually
y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps)
out_dir = 'tests/rms_norm_data'
os.makedirs(out_dir, exist_ok=True)
def dump_bf16(name, t):
path = os.path.join(out_dir, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
with open(path, 'wb') as f:
f.write(a.tobytes())
return path
dump_bf16('x', x)
dump_bf16('gamma', gamma)
dump_bf16('y_ref', y_ref)
with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
f.write(f"N={N}\nD={D}\neps={eps}\n")
print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}")
print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist())
print("saved in", out_dir)