File size: 1,109 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | #!/usr/bin/env python3
"""Generate a RmsNorm reference using PyTorch."""
import os, struct
import numpy as np
import torch
import torch_npu
torch.npu.set_device(0)
torch.manual_seed(123)
N, D = 5, 4096 # 5 tokens, Qwen3 hidden_size
eps = 1e-6
x = torch.randn(N, D, dtype=torch.bfloat16).npu()
gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0
# Use torch_npu's npu_rms_norm if available, else do it manually
y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps)
out_dir = 'tests/rms_norm_data'
os.makedirs(out_dir, exist_ok=True)
def dump_bf16(name, t):
path = os.path.join(out_dir, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
with open(path, 'wb') as f:
f.write(a.tobytes())
return path
dump_bf16('x', x)
dump_bf16('gamma', gamma)
dump_bf16('y_ref', y_ref)
with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
f.write(f"N={N}\nD={D}\neps={eps}\n")
print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}")
print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist())
print("saved in", out_dir)
|