File size: 2,308 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#!/usr/bin/env python3
"""Re-generate RoPE reference using explicit HF formula (not torch_npu.npu_apply_rotary_pos_emb)."""
import os, math, torch, torch_npu
torch.npu.set_device(0)
torch.manual_seed(42)

S = 5; Hq = 64; Hkv = 4; Dh = 128
theta = 5e6
data = 'tests/attn_data'

def load_bf16(name, shape):
    raw = open(os.path.join(data, name + '.bin'), 'rb').read()
    a = torch.frombuffer(bytearray(raw), dtype=torch.int16).view(*shape).view(torch.bfloat16)
    return a.npu()

q = load_bf16('q_normed', [1, S, Hq, Dh])
k = load_bf16('k_normed', [1, S, Hkv, Dh])

# Compute cos/sin identical to HF (rope_theta=5e6, 0..S positions)
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
pos = torch.arange(S, device='npu').float().unsqueeze(-1)
freqs = pos * inv_freq
emb = torch.cat([freqs, freqs], dim=-1)    # [S, Dh]
cos = emb.cos().to(torch.bfloat16)          # [S, Dh]
sin = emb.sin().to(torch.bfloat16)

# HF (Qwen3) style RoPE: q_rot = q * cos + rotate_half(q) * sin
def rotate_half(x):
    h = x.shape[-1] // 2
    x1 = x[..., :h]
    x2 = x[..., h:]
    return torch.cat([-x2, x1], dim=-1)

# Broadcast cos/sin from [S, Dh] to [1, S, 1, Dh]
cos_b = cos.unsqueeze(0).unsqueeze(2)
sin_b = sin.unsqueeze(0).unsqueeze(2)

q_roped_hf = q * cos_b + rotate_half(q) * sin_b
k_roped_hf = k * cos_b + rotate_half(k) * sin_b

print("HF-style q_roped[0,0,:4]:", q_roped_hf[0,0,0,:4].float().cpu().tolist())
print("cos[0,:4]:", cos[0,:4].float().cpu().tolist())
print("sin[0,:4]:", sin[0,:4].float().cpu().tolist())
print("cos[1,:4]:", cos[1,:4].float().cpu().tolist())

# Compare with existing q_roped (from torch_npu.npu_apply_rotary_pos_emb)
old_q_roped = load_bf16('q_roped', [1, S, Hq, Dh])
diff = (q_roped_hf - old_q_roped).float().abs().max().item()
print(f"\nDiff between HF formula and npu_apply: max={diff:.4f}")

# Save HF version as ground truth
def dump(name, t):
    p = os.path.join(data, name + '.bin')
    a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
    open(p, 'wb').write(a.tobytes())
dump('q_roped', q_roped_hf)
dump('k_roped', k_roped_hf)
# Overwrite cos, sin to [1, S, Dh] layout
dump('cos', cos.unsqueeze(0))  # [1, S, Dh]
dump('sin', sin.unsqueeze(0))

print("\nOverwrote q_roped, k_roped, cos, sin with HF-formula ground truth.")