File size: 1,109 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#!/usr/bin/env python3
"""Generate a RmsNorm reference using PyTorch."""
import os, struct
import numpy as np
import torch
import torch_npu

torch.npu.set_device(0)
torch.manual_seed(123)

N, D = 5, 4096  # 5 tokens, Qwen3 hidden_size
eps = 1e-6

x = torch.randn(N, D, dtype=torch.bfloat16).npu()
gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0

# Use torch_npu's npu_rms_norm if available, else do it manually
y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps)

out_dir = 'tests/rms_norm_data'
os.makedirs(out_dir, exist_ok=True)

def dump_bf16(name, t):
    path = os.path.join(out_dir, name + '.bin')
    a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
    with open(path, 'wb') as f:
        f.write(a.tobytes())
    return path

dump_bf16('x',     x)
dump_bf16('gamma', gamma)
dump_bf16('y_ref', y_ref)

with open(os.path.join(out_dir, 'shape.txt'), 'w') as f:
    f.write(f"N={N}\nD={D}\neps={eps}\n")

print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}")
print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist())
print("saved in", out_dir)