llm_mutil_npu / scripts /regen_rope_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb)."""
import os, math, torch, torch_npu
torch.npu.set_device(0)
torch.manual_seed(42)
S = 5; Hq = 64; Hkv = 4; Dh = 128
theta = 5e6
data = 'tests/attn_data'
def load_bf16(name, shape):
raw = open(os.path.join(data, name + '.bin'), 'rb').read()
a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16)
return a.npu()
q = load_bf16('q_normed', [1, S, Hq, Dh])
k = load_bf16('k_normed', [1, S, Hkv, Dh])
# Compute cos/sin identical to HF (rope_theta=5e6, 0..S positions)
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
pos = torch.arange(S, device='npu').float().unsqueeze(-1)
freqs = pos * inv_freq
emb = torch.cat([freqs, freqs], dim=-1) # [S, Dh]
cos = emb.cos().to(torch.bfloat16) # [S, Dh]
sin = emb.sin().to(torch.bfloat16)
# HF (Qwen3) style RoPE: q_rot = q * cos + rotate_half(q) * sin
def rotate_half(x):
h = x.shape[-1] // 2
x1 = x[..., :h]
x2 = x[..., h:]
return torch.cat([-x2, x1], dim=-1)
# Broadcast cos/sin from [S, Dh] to [1, S, 1, Dh]
cos_b = cos.unsqueeze(0).unsqueeze(2)
sin_b = sin.unsqueeze(0).unsqueeze(2)
q_roped_hf = q * cos_b + rotate_half(q) * sin_b
k_roped_hf = k * cos_b + rotate_half(k) * sin_b
print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist())
print("cos[0,:4]:", cos[0,:4].float().cpu().tolist())
print("sin[0,:4]:", sin[0,:4].float().cpu().tolist())
print("cos[1,:4]:", cos[1,:4].float().cpu().tolist())
# Compare with existing q_roped (from torch_npu.npu_apply_rotary_pos_emb)
old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh])
diff = (q_roped_hf - old_q_roped).float().abs().max().item()
print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}")
# Save HF version as ground truth
def dump(name, t):
p = os.path.join(data, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
open(p, 'wb').write(a.tobytes())
dump('q_roped', q_roped_hf)
dump('k_roped', k_roped_hf)
# Overwrite cos, sin to [1, S, Dh] layout
dump('cos', cos.unsqueeze(0)) # [1, S, Dh]
dump('sin', sin.unsqueeze(0))
print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.")