#!/usr/bin/env python3 """Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb).""" import os, math, torch, torch_npu torch.npu.set_device(0) torch.manual_seed(42) S = 5; Hq = 64; Hkv = 4; Dh = 128 theta = 5e6 data = 'tests/attn_data' def load_bf16(name, shape): raw = open(os.path.join(data, name + '.bin'), 'rb').read() a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16) return a.npu() q = load_bf16('q_normed', [1, S, Hq, Dh]) k = load_bf16('k_normed', [1, S, Hkv, Dh]) # Compute cos/sin identical to HF (rope_theta=5e6, 0..S positions) inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh)) pos = torch.arange(S, device='npu').float().unsqueeze(-1) freqs = pos * inv_freq emb = torch.cat([freqs, freqs], dim=-1) # [S, Dh] cos = emb.cos().to(torch.bfloat16) # [S, Dh] sin = emb.sin().to(torch.bfloat16) # HF (Qwen3) style RoPE: q_rot = q * cos + rotate_half(q) * sin def rotate_half(x): h = x.shape[-1] // 2 x1 = x[..., :h] x2 = x[..., h:] return torch.cat([-x2, x1], dim=-1) # Broadcast cos/sin from [S, Dh] to [1, S, 1, Dh] cos_b = cos.unsqueeze(0).unsqueeze(2) sin_b = sin.unsqueeze(0).unsqueeze(2) q_roped_hf = q * cos_b + rotate_half(q) * sin_b k_roped_hf = k * cos_b + rotate_half(k) * sin_b print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist()) print("cos[0,:4]:", cos[0,:4].float().cpu().tolist()) print("sin[0,:4]:", sin[0,:4].float().cpu().tolist()) print("cos[1,:4]:", cos[1,:4].float().cpu().tolist()) # Compare with existing q_roped (from torch_npu.npu_apply_rotary_pos_emb) old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh]) diff = (q_roped_hf - old_q_roped).float().abs().max().item() print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}") # Save HF version as ground truth def dump(name, t): p = os.path.join(data, name + '.bin') a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') open(p, 'wb').write(a.tobytes()) dump('q_roped', q_roped_hf) dump('k_roped', k_roped_hf) # Overwrite cos, sin to [1, S, Dh] layout dump('cos', cos.unsqueeze(0)) # [1, S, Dh] dump('sin', sin.unsqueeze(0)) print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.")