#!/usr/bin/env python3 """Generate a RmsNorm reference using PyTorch.""" import os, struct import numpy as np import torch import torch_npu torch.npu.set_device(0) torch.manual_seed(123) N, D = 5, 4096 # 5 tokens, Qwen3 hidden_size eps = 1e-6 x = torch.randn(N, D, dtype=torch.bfloat16).npu() gamma = torch.randn(D, dtype=torch.bfloat16).npu() * 0.1 + 1.0 # Use torch_npu's npu_rms_norm if available, else do it manually y_ref, _ = torch_npu.npu_rms_norm(x, gamma, epsilon=eps) out_dir = 'tests/rms_norm_data' os.makedirs(out_dir, exist_ok=True) def dump_bf16(name, t): path = os.path.join(out_dir, name + '.bin') a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') with open(path, 'wb') as f: f.write(a.tobytes()) return path dump_bf16('x', x) dump_bf16('gamma', gamma) dump_bf16('y_ref', y_ref) with open(os.path.join(out_dir, 'shape.txt'), 'w') as f: f.write(f"N={N}\nD={D}\neps={eps}\n") print(f"x shape: {x.shape}, gamma: {gamma.shape}, y_ref: {y_ref.shape}") print("y_ref[0, :8]:", y_ref[0, :8].float().cpu().tolist()) print("saved in", out_dir)