| |
| """Generate a RmsNorm reference using PyTorch.""" |
| import os, struct |
| import numpy as np |
| import torch |
| import torch_npu |
|
|
| torch.npu.set_device(0) |
| torch.manual_seed(123) |
|
|
| N, D = 5, 4096 |
| eps = 1e-6 |
|
|
| x = torch.randn(N, D, dtype=torch.bfloat16).npu() |
| gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0 |
|
|
| |
| y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps) |
|
|
| out_dir = 'tests/rms_norm_data' |
| os.makedirs(out_dir, exist_ok=True) |
|
|
| def dump_bf16(name, t): |
| path = os.path.join(out_dir, name + '.bin') |
| a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') |
| with open(path, 'wb') as f: |
| f.write(a.tobytes()) |
| return path |
|
|
| dump_bf16('x', x) |
| dump_bf16('gamma', gamma) |
| dump_bf16('y_ref', y_ref) |
|
|
| with open(os.path.join(out_dir, 'shape.txt'), 'w') as f: |
| f.write(f"N={N}\nD={D}\neps={eps}\n") |
|
|
| print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}") |
| print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist()) |
| print("saved in", out_dir) |
|
|