File size: 6,300 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | #!/usr/bin/env python3
"""Generate a single-layer attention forward reference for Qwen3-235B layer 0.
Input: token ids (representing "The capital of France is")
Output: hidden_states after layer 0 attention (residual already added).
Also dumps all intermediate tensors for step-wise debugging.
"""
import os, json, math, struct
import torch
import torch_npu
from safetensors.torch import load_file
torch.npu.set_device(0)
torch.set_grad_enabled(False)
MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16'
OUT_DIR = 'tests/attn_data'
os.makedirs(OUT_DIR, exist_ok=True)
cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json')))
D = cfg['hidden_size'] # 4096
Hq = cfg['num_attention_heads'] # 64
Hkv = cfg['num_key_value_heads'] # 4
Dh = cfg['head_dim'] # 128
Q_DIM = Hq * Dh # 8192
KV_DIM = Hkv * Dh # 512
eps = cfg['rms_norm_eps']
theta = cfg['rope_theta'] # 5e6 for Qwen3-235B
# ---- Find which safetensors shard contains layer 0 attention + input_layernorm ----
idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json')))
wm = idx['weight_map']
needed = [
'model.embed_tokens.weight',
'model.layers.0.input_layernorm.weight',
'model.layers.0.self_attn.q_proj.weight',
'model.layers.0.self_attn.k_proj.weight',
'model.layers.0.self_attn.v_proj.weight',
'model.layers.0.self_attn.o_proj.weight',
'model.layers.0.self_attn.q_norm.weight',
'model.layers.0.self_attn.k_norm.weight',
]
shards = sorted({wm[n] for n in needed})
print("Need to load shards:", shards)
weights = {}
for sh in shards:
t = load_file(os.path.join(MODEL_DIR, sh))
for n in needed:
if n in t:
weights[n] = t[n].to('npu')
print("loaded:", list(weights.keys()))
# ---- Forward ----
# Input tokens (from tokenizer: "The capital of France is")
token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu()
S = token_ids.shape[0]
print(f"S = {S}")
# Embedding lookup
x = weights['model.embed_tokens.weight'][token_ids] # [S, D]
x = x.unsqueeze(0) # [1, S, D]
print("embed x:", x.shape, x.dtype)
# Residual
residual = x
# Input layernorm (RMSNorm)
ln = weights['model.layers.0.input_layernorm.weight']
xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps)
print("after_input_norm xn:", xn.shape)
# Q/K/V projections
Wq = weights['model.layers.0.self_attn.q_proj.weight']
Wk = weights['model.layers.0.self_attn.k_proj.weight']
Wv = weights['model.layers.0.self_attn.v_proj.weight']
q = torch.matmul(xn, Wq.t()) # [1, S, Q_DIM]
k = torch.matmul(xn, Wk.t()) # [1, S, KV_DIM]
v = torch.matmul(xn, Wv.t())
# Reshape to heads
q = q.view(1, S, Hq, Dh)
k = k.view(1, S, Hkv, Dh)
v = v.view(1, S, Hkv, Dh)
# Per-head RMSNorm on head_dim (Qwen3 specific)
qn_w = weights['model.layers.0.self_attn.q_norm.weight'] # [Dh]
kn_w = weights['model.layers.0.self_attn.k_norm.weight']
q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps)
k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps)
# RoPE: compute cos/sin for positions [0, S)
position_ids = torch.arange(S, device='npu').unsqueeze(0) # [1, S]
inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh))
freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # [1, S, Dh/2]
# Concat (half, half) to get [1, S, Dh]
emb = torch.cat([freqs, freqs], dim=-1)
cos = emb.cos().to(torch.bfloat16) # [1, S, Dh]
sin = emb.sin().to(torch.bfloat16)
# Apply RoPE — npu_apply_rotary_pos_emb expects BSND layout
# cos/sin shape: [1, S, 1, Dh] for broadcast over heads
cos_b = cos.unsqueeze(2)
sin_b = sin.unsqueeze(2)
q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b)
# Flatten for FIAS (BSH layout)
q_bsh = q_roped.reshape(1, S, Q_DIM)
k_bsh = k_roped.reshape(1, S, KV_DIM)
v_bsh = v.reshape(1, S, KV_DIM)
# FIAS with causal mask for prefill
scale = 1.0 / math.sqrt(Dh)
# sparse_mode=3 requires fixed 2048×2048 mask
MASK_SIZE = 2048
mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1)
mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE)
attn_out, _ = torch_npu.npu_fused_infer_attention_score(
q_bsh, k_bsh, v_bsh,
num_heads=Hq,
num_key_value_heads=Hkv,
scale=scale,
input_layout="BSH",
sparse_mode=3,
atten_mask=mask,
actual_seq_lengths=[S],
actual_seq_lengths_kv=[S],
)
print("attn_out:", attn_out.shape) # [1, S, Q_DIM]
# Output projection
Wo = weights['model.layers.0.self_attn.o_proj.weight']
o = torch.matmul(attn_out, Wo.t()) # [1, S, D]
# Residual add
out = residual + o
print("out:", out.shape, out[0, 0, :4].float().tolist())
# ---- Dump ----
def dump(name, t):
p = os.path.join(OUT_DIR, name + '.bin')
a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16')
open(p, 'wb').write(a.tobytes())
# Save token_ids
with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f:
f.write(struct.pack('<i', S))
for tid in token_ids.cpu().tolist():
f.write(struct.pack('<i', tid))
# Save inputs
dump('x_input', x) # embed result
dump('x_normed', xn)
dump('q_normed', q_normed)
dump('k_normed', k_normed)
dump('q_roped', q_roped)
dump('k_roped', k_roped)
dump('cos', cos)
dump('sin', sin)
dump('attn_out', attn_out)
dump('final_out', out)
# Save weights used (dtype=BF16)
for name, path_name in [
('model.layers.0.input_layernorm.weight', 'w_input_norm'),
('model.layers.0.self_attn.q_proj.weight', 'w_q_proj'),
('model.layers.0.self_attn.k_proj.weight', 'w_k_proj'),
('model.layers.0.self_attn.v_proj.weight', 'w_v_proj'),
('model.layers.0.self_attn.o_proj.weight', 'w_o_proj'),
('model.layers.0.self_attn.q_norm.weight', 'w_q_norm'),
('model.layers.0.self_attn.k_norm.weight', 'w_k_norm'),
]:
dump(path_name, weights[name])
with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f:
f.write(f"S={S}\nD={D}\nHq={Hq}\nHkv={Hkv}\nDh={Dh}\nQ_DIM={Q_DIM}\nKV_DIM={KV_DIM}\neps={eps}\ntheta={theta}\n")
print("\nAll dumps in:", OUT_DIR)
print("Final output first 4:", out[0, 0, :4].float().cpu().tolist())
|