| |
| """Generate a single-layer attention forward reference for Qwen3-235B layer 0. |
| |
| Input: token ids (representing "The capital of France is") |
| Output: hidden_states after layer 0 attention (residual already added). |
| Also dumps all intermediate tensors for step-wise debugging. |
| """ |
| import os, json, math, struct |
| import torch |
| import torch_npu |
| from safetensors.torch import load_file |
|
|
| torch.npu.set_device(0) |
| torch.set_grad_enabled(False) |
|
|
| MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16' |
| OUT_DIR = 'tests/attn_data' |
| os.makedirs(OUT_DIR, exist_ok=True) |
|
|
| cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json'))) |
| D = cfg['hidden_size'] |
| Hq = cfg['num_attention_heads'] |
| Hkv = cfg['num_key_value_heads'] |
| Dh = cfg['head_dim'] |
| Q_DIM = Hq * Dh |
| KV_DIM = Hkv * Dh |
| eps = cfg['rms_norm_eps'] |
| theta = cfg['rope_theta'] |
|
|
| |
| idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json'))) |
| wm = idx['weight_map'] |
|
|
| needed = [ |
| 'model.embed_tokens.weight', |
| 'model.layers.0.input_layernorm.weight', |
| 'model.layers.0.self_attn.q_proj.weight', |
| 'model.layers.0.self_attn.k_proj.weight', |
| 'model.layers.0.self_attn.v_proj.weight', |
| 'model.layers.0.self_attn.o_proj.weight', |
| 'model.layers.0.self_attn.q_norm.weight', |
| 'model.layers.0.self_attn.k_norm.weight', |
| ] |
| shards = sorted({wm[n] for n in needed}) |
| print("Need to load shards:", shards) |
|
|
| weights = {} |
| for sh in shards: |
| t = load_file(os.path.join(MODEL_DIR, sh)) |
| for n in needed: |
| if n in t: |
| weights[n] = t[n].to('npu') |
| print("loaded:", list(weights.keys())) |
|
|
| |
| |
| token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu() |
| S = token_ids.shape[0] |
| print(f"S = {S}") |
|
|
| |
| x = weights['model.embed_tokens.weight'][token_ids] |
| x = x.unsqueeze(0) |
| print("embed x:", x.shape, x.dtype) |
|
|
| |
| residual = x |
|
|
| |
| ln = weights['model.layers.0.input_layernorm.weight'] |
| xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps) |
| print("after_input_norm xn:", xn.shape) |
|
|
| |
| Wq = weights['model.layers.0.self_attn.q_proj.weight'] |
| Wk = weights['model.layers.0.self_attn.k_proj.weight'] |
| Wv = weights['model.layers.0.self_attn.v_proj.weight'] |
| q = torch.matmul(xn, Wq.t()) |
| k = torch.matmul(xn, Wk.t()) |
| v = torch.matmul(xn, Wv.t()) |
|
|
| |
| q = q.view(1, S, Hq, Dh) |
| k = k.view(1, S, Hkv, Dh) |
| v = v.view(1, S, Hkv, Dh) |
|
|
| |
| qn_w = weights['model.layers.0.self_attn.q_norm.weight'] |
| kn_w = weights['model.layers.0.self_attn.k_norm.weight'] |
| q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps) |
| k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps) |
|
|
| |
| position_ids = torch.arange(S, device='npu').unsqueeze(0) |
| inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh)) |
| freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) |
| |
| emb = torch.cat([freqs, freqs], dim=-1) |
| cos = emb.cos().to(torch.bfloat16) |
| sin = emb.sin().to(torch.bfloat16) |
|
|
| |
| |
| cos_b = cos.unsqueeze(2) |
| sin_b = sin.unsqueeze(2) |
| q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b) |
|
|
| |
| q_bsh = q_roped.reshape(1, S, Q_DIM) |
| k_bsh = k_roped.reshape(1, S, KV_DIM) |
| v_bsh = v.reshape(1, S, KV_DIM) |
|
|
| |
| scale = 1.0 / math.sqrt(Dh) |
| |
| MASK_SIZE = 2048 |
| mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1) |
| mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE) |
| attn_out, _ = torch_npu.npu_fused_infer_attention_score( |
| q_bsh, k_bsh, v_bsh, |
| num_heads=Hq, |
| num_key_value_heads=Hkv, |
| scale=scale, |
| input_layout="BSH", |
| sparse_mode=3, |
| atten_mask=mask, |
| actual_seq_lengths=[S], |
| actual_seq_lengths_kv=[S], |
| ) |
| print("attn_out:", attn_out.shape) |
|
|
| |
| Wo = weights['model.layers.0.self_attn.o_proj.weight'] |
| o = torch.matmul(attn_out, Wo.t()) |
|
|
| |
| out = residual + o |
| print("out:", out.shape, out[0, 0, :4].float().tolist()) |
|
|
| |
| def dump(name, t): |
| p = os.path.join(OUT_DIR, name + '.bin') |
| a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') |
| open(p, 'wb').write(a.tobytes()) |
|
|
| |
| with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f: |
| f.write(struct.pack('<i', S)) |
| for tid in token_ids.cpu().tolist(): |
| f.write(struct.pack('<i', tid)) |
|
|
| |
| dump('x_input', x) |
| dump('x_normed', xn) |
| dump('q_normed', q_normed) |
| dump('k_normed', k_normed) |
| dump('q_roped', q_roped) |
| dump('k_roped', k_roped) |
| dump('cos', cos) |
| dump('sin', sin) |
| dump('attn_out', attn_out) |
| dump('final_out', out) |
| |
| for name, path_name in [ |
| ('model.layers.0.input_layernorm.weight', 'w_input_norm'), |
| ('model.layers.0.self_attn.q_proj.weight', 'w_q_proj'), |
| ('model.layers.0.self_attn.k_proj.weight', 'w_k_proj'), |
| ('model.layers.0.self_attn.v_proj.weight', 'w_v_proj'), |
| ('model.layers.0.self_attn.o_proj.weight', 'w_o_proj'), |
| ('model.layers.0.self_attn.q_norm.weight', 'w_q_norm'), |
| ('model.layers.0.self_attn.k_norm.weight', 'w_k_norm'), |
| ]: |
| dump(path_name, weights[name]) |
|
|
| with open(os.path.join(OUT_DIR, 'shape.txt'), 'w') as f: |
| f.write(f"S={S}\nD={D}\nHq={Hq}\nHkv={Hkv}\nDh={Dh}\nQ_DIM={Q_DIM}\nKV_DIM={KV_DIM}\neps={eps}\ntheta={theta}\n") |
|
|
| print("\nAll dumps in:", OUT_DIR) |
| print("Final output first 4:", out[0, 0, :4].float().cpu().tolist()) |
|
|