| |
| """Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb).""" |
| import os, math, torch, torch_npu |
| torch.npu.set_device(0) |
| torch.manual_seed(42) |
|
|
| S = 5; Hq = 64; Hkv = 4; Dh = 128 |
| theta = 5e6 |
| data = 'tests/attn_data' |
|
|
| def load_bf16(name, shape): |
| raw = open(os.path.join(data, name + '.bin'), 'rb').read() |
| a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16) |
| return a.npu() |
|
|
| q = load_bf16('q_normed', [1, S, Hq, Dh]) |
| k = load_bf16('k_normed', [1, S, Hkv, Dh]) |
|
|
| |
| inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh)) |
| pos = torch.arange(S, device='npu').float().unsqueeze(-1) |
| freqs = pos * inv_freq |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos().to(torch.bfloat16) |
| sin = emb.sin().to(torch.bfloat16) |
|
|
| |
| def rotate_half(x): |
| h = x.shape[-1] // 2 |
| x1 = x[..., :h] |
| x2 = x[..., h:] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
| |
| cos_b = cos.unsqueeze(0).unsqueeze(2) |
| sin_b = sin.unsqueeze(0).unsqueeze(2) |
|
|
| q_roped_hf = q * cos_b + rotate_half(q) * sin_b |
| k_roped_hf = k * cos_b + rotate_half(k) * sin_b |
|
|
| print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist()) |
| print("cos[0,:4]:", cos[0,:4].float().cpu().tolist()) |
| print("sin[0,:4]:", sin[0,:4].float().cpu().tolist()) |
| print("cos[1,:4]:", cos[1,:4].float().cpu().tolist()) |
|
|
| |
| old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh]) |
| diff = (q_roped_hf - old_q_roped).float().abs().max().item() |
| print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}") |
|
|
| |
| def dump(name, t): |
| p = os.path.join(data, name + '.bin') |
| a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') |
| open(p, 'wb').write(a.tobytes()) |
| dump('q_roped', q_roped_hf) |
| dump('k_roped', k_roped_hf) |
| |
| dump('cos', cos.unsqueeze(0)) |
| dump('sin', sin.unsqueeze(0)) |
|
|
| print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.") |
|
|