llm_mutil_npu / scripts /gen_attention_reference.py
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#!/usr/bin/env python3
"""Generate a single-layer attention forward reference for Qwen3-235B layer 0.
Input: token ids (representing "The capital of France is")
Output: hidden_states after layer 0 attention (residual already added).
Also dumps all intermediate tensors for step-wise debugging.
"""
import os, json, math, struct
import torch
import torch_npu
from safetensors.torch import load_file
torch.npu.set_device(0)
torch.set_grad_enabled(False)
MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
OUT_DIR = 'tests/attn_data'
os.makedirs(OUT_DIR, exist_ok=True)
cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
D = cfg['hidden_size'] # 4096
Hq = cfg['num_attention_heads'] # 64
Hkv = cfg['num_key_value_heads'] # 4
Dh = cfg['head_dim'] # 128
Q_DIM = Hq * Dh # 8192
KV_DIM = Hkv * Dh # 512
eps = cfg['rms_norm_eps']
theta = cfg['rope_theta'] # 5e6 for Qwen3-235B
# ---- Find which safetensors shard contains layer 0 attention + input_layernorm ----
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
wm = idx['weight_map']
needed = [
'model.embed_tokens.weight',
'model.layers.0.input_layernorm.weight',
'model.layers.0.self_attn.q_proj.weight',
'model.layers.0.self_attn.k_proj.weight',
'model.layers.0.self_attn.v_proj.weight',
'model.layers.0.self_attn.o_proj.weight',
'model.layers.0.self_attn.q_norm.weight',
'model.layers.0.self_attn.k_norm.weight',
]
shards = sorted({wm[n] for n in needed})
print("Need to load shards:", shards)
weights = {}
for sh in shards:
t = load_file(os.path.join(MODEL_DIR, sh))
for n in needed:
if n in t:
weights[n] = t[n].to('npu')
print("loaded:", list(weights.keys()))
# ---- Forward ----
# Input tokens (from tokenizer: "The capital of France is")
token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu()
S = token_ids.shape[0]
print(f"S = {S}")
# Embedding lookup
x = weights['model.embed_tokens.weight'][token_ids] # [S, D]
x = x.unsqueeze(0) # [1, S, D]
print("embed x:", x.shape, x.dtype)
# Residual
residual = x
# Input layernorm (RMSNorm)
ln = weights['model.layers.0.input_layernorm.weight']
xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps)
print("after_input_norm xn:", xn.shape)
# Q/K/V projections
Wq = weights['model.layers.0.self_attn.q_proj.weight']
Wk = weights['model.layers.0.self_attn.k_proj.weight']
Wv = weights['model.layers.0.self_attn.v_proj.weight']
q = torch.matmul(xn, Wq.t()) # [1, S, Q_DIM]
k = torch.matmul(xn, Wk.t()) # [1, S, KV_DIM]
v = torch.matmul(xn, Wv.t())
# Reshape to heads
q = q.view(1, S, Hq, Dh)
k = k.view(1, S, Hkv, Dh)
v = v.view(1, S, Hkv, Dh)
# Per-head RMSNorm on head_dim (Qwen3 specific)
qn_w = weights['model.layers.0.self_attn.q_norm.weight'] # [Dh]
kn_w = weights['model.layers.0.self_attn.k_norm.weight']
q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps)
k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps)
# RoPE: compute cos/sin for positions [0, S)
position_ids = torch.arange(S, device='npu').unsqueeze(0) # [1, S]
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # [1, S, Dh/2]
# Concat (half, half) to get [1, S, Dh]
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().to(torch.bfloat16) # [1, S, Dh]
sin = emb.sin().to(torch.bfloat16)
# Apply RoPE — npu_apply_rotary_pos_emb expects BSND layout
# cos/sin shape: [1, S, 1, Dh] for broadcast over heads
cos_b = cos.unsqueeze(2)
sin_b = sin.unsqueeze(2)
q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b)
# Flatten for FIAS (BSH layout)
q_bsh = q_roped.reshape(1, S, Q_DIM)
k_bsh = k_roped.reshape(1, S, KV_DIM)
v_bsh = v.reshape(1, S, KV_DIM)
# FIAS with causal mask for prefill
scale = 1.0 / math.sqrt(Dh)
# sparse_mode=3 requires fixed 2048×2048 mask
MASK_SIZE = 2048
mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1)
mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE)
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
q_bsh, k_bsh, v_bsh,
num_heads=Hq,
num_key_value_heads=Hkv,
scale=scale,
input_layout="BSH",
sparse_mode=3,
atten_mask=mask,
actual_seq_lengths=[S],
actual_seq_lengths_kv=[S],
)
print("attn_out:", attn_out.shape) # [1, S, Q_DIM]
# Output projection
Wo = weights['model.layers.0.self_attn.o_proj.weight']
o = torch.matmul(attn_out, Wo.t()) # [1, S, D]
# Residual add
out = residual + o
print("out:", out.shape, out[0, 0, :4].float().tolist())
# ---- Dump ----
def dump(name, t):
p = os.path.join(OUT_DIR, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
open(p, 'wb').write(a.tobytes())
# Save token_ids
with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f:
f.write(struct.pack('<i', S))
for tid in token_ids.cpu().tolist():
f.write(struct.pack('<i', tid))
# Save inputs
dump('x_input', x) # embed result
dump('x_normed', xn)
dump('q_normed', q_normed)
dump('k_normed', k_normed)
dump('q_roped', q_roped)
dump('k_roped', k_roped)
dump('cos', cos)
dump('sin', sin)
dump('attn_out', attn_out)
dump('final_out', out)
# Save weights used (dtype=BF16)
for name, path_name in [
('model.layers.0.input_layernorm.weight', 'w_input_norm'),
('model.layers.0.self_attn.q_proj.weight', 'w_q_proj'),
('model.layers.0.self_attn.k_proj.weight', 'w_k_proj'),
('model.layers.0.self_attn.v_proj.weight', 'w_v_proj'),
('model.layers.0.self_attn.o_proj.weight', 'w_o_proj'),
('model.layers.0.self_attn.q_norm.weight', 'w_q_norm'),
('model.layers.0.self_attn.k_norm.weight', 'w_k_norm'),
]:
dump(path_name, weights[name])
with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
f.write(f"S={S}\nD={D}\nHq={Hq}\nHkv={Hkv}\nDh={Dh}\nQ_DIM={Q_DIM}\nKV_DIM={KV_DIM}\neps={eps}\ntheta={theta}\n")
print("\nAll dumps in:", OUT_DIR)
print("Final output first 4:", out[0, 0, :4].float().cpu().tolist())