#!/usr/bin/env python3 """Generate a single-layer attention forward reference for Qwen3-235B layer 0. Input: token ids (representing "The capital of France is") Output: hidden_states after layer 0 attention (residual already added). Also dumps all intermediate tensors for step-wise debugging. """ import os, json, math, struct import torch import torch_npu from safetensors.torch import load_file torch.npu.set_device(0) torch.set_grad_enabled(False) MODEL_DIR = '/path/to/Qwen3-235B-A22B-Instruct-2507-BF16' OUT_DIR = 'tests/attn_data' os.makedirs(OUT_DIR, exist_ok=True) cfg = json.load(open(os.path.join(MODEL_DIR, 'config.json'))) D = cfg['hidden_size'] # 4096 Hq = cfg['num_attention_heads'] # 64 Hkv = cfg['num_key_value_heads'] # 4 Dh = cfg['head_dim'] # 128 Q_DIM = Hq * Dh # 8192 KV_DIM = Hkv * Dh # 512 eps = cfg['rms_norm_eps'] theta = cfg['rope_theta'] # 5e6 for Qwen3-235B # ---- Find which safetensors shard contains layer 0 attention + input_layernorm ---- idx = json.load(open(os.path.join(MODEL_DIR, 'model.safetensors.index.json'))) wm = idx['weight_map'] needed = [ 'model.embed_tokens.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.self_attn.q_proj.weight', 'model.layers.0.self_attn.k_proj.weight', 'model.layers.0.self_attn.v_proj.weight', 'model.layers.0.self_attn.o_proj.weight', 'model.layers.0.self_attn.q_norm.weight', 'model.layers.0.self_attn.k_norm.weight', ] shards = sorted({wm[n] for n in needed}) print("Need to load shards:", shards) weights = {} for sh in shards: t = load_file(os.path.join(MODEL_DIR, sh)) for n in needed: if n in t: weights[n] = t[n].to('npu') print("loaded:", list(weights.keys())) # ---- Forward ---- # Input tokens (from tokenizer: "The capital of France is") token_ids = torch.tensor([785, 6722, 315, 9625, 374], dtype=torch.long).npu() S = token_ids.shape[0] print(f"S = {S}") # Embedding lookup x = weights['model.embed_tokens.weight'][token_ids] # [S, D] x = x.unsqueeze(0) # [1, S, D] print("embed x:", x.shape, x.dtype) # Residual residual = x # Input layernorm (RMSNorm) ln = weights['model.layers.0.input_layernorm.weight'] xn, _ = torch_npu.npu_rms_norm(x, ln, epsilon=eps) print("after_input_norm xn:", xn.shape) # Q/K/V projections Wq = weights['model.layers.0.self_attn.q_proj.weight'] Wk = weights['model.layers.0.self_attn.k_proj.weight'] Wv = weights['model.layers.0.self_attn.v_proj.weight'] q = torch.matmul(xn, Wq.t()) # [1, S, Q_DIM] k = torch.matmul(xn, Wk.t()) # [1, S, KV_DIM] v = torch.matmul(xn, Wv.t()) # Reshape to heads q = q.view(1, S, Hq, Dh) k = k.view(1, S, Hkv, Dh) v = v.view(1, S, Hkv, Dh) # Per-head RMSNorm on head_dim (Qwen3 specific) qn_w = weights['model.layers.0.self_attn.q_norm.weight'] # [Dh] kn_w = weights['model.layers.0.self_attn.k_norm.weight'] q_normed, _ = torch_npu.npu_rms_norm(q, qn_w, epsilon=eps) k_normed, _ = torch_npu.npu_rms_norm(k, kn_w, epsilon=eps) # RoPE: compute cos/sin for positions [0, S) position_ids = torch.arange(S, device='npu').unsqueeze(0) # [1, S] inv_freq = 1.0 / (theta ** (torch.arange(0, Dh, 2, dtype=torch.float32).npu() / Dh)) freqs = position_ids.float().unsqueeze(-1) * inv_freq.unsqueeze(0).unsqueeze(0) # [1, S, Dh/2] # Concat (half, half) to get [1, S, Dh] emb = torch.cat([freqs, freqs], dim=-1) cos = emb.cos().to(torch.bfloat16) # [1, S, Dh] sin = emb.sin().to(torch.bfloat16) # Apply RoPE — npu_apply_rotary_pos_emb expects BSND layout # cos/sin shape: [1, S, 1, Dh] for broadcast over heads cos_b = cos.unsqueeze(2) sin_b = sin.unsqueeze(2) q_roped, k_roped = torch_npu.npu_apply_rotary_pos_emb(q_normed, k_normed, cos_b, sin_b) # Flatten for FIAS (BSH layout) q_bsh = q_roped.reshape(1, S, Q_DIM) k_bsh = k_roped.reshape(1, S, KV_DIM) v_bsh = v.reshape(1, S, KV_DIM) # FIAS with causal mask for prefill scale = 1.0 / math.sqrt(Dh) # sparse_mode=3 requires fixed 2048×2048 mask MASK_SIZE = 2048 mask = torch.triu(torch.ones(MASK_SIZE, MASK_SIZE, dtype=torch.bool, device='npu'), diagonal=1) mask = mask.view(1, 1, MASK_SIZE, MASK_SIZE) attn_out, _ = torch_npu.npu_fused_infer_attention_score( q_bsh, k_bsh, v_bsh, num_heads=Hq, num_key_value_heads=Hkv, scale=scale, input_layout="BSH", sparse_mode=3, atten_mask=mask, actual_seq_lengths=[S], actual_seq_lengths_kv=[S], ) print("attn_out:", attn_out.shape) # [1, S, Q_DIM] # Output projection Wo = weights['model.layers.0.self_attn.o_proj.weight'] o = torch.matmul(attn_out, Wo.t()) # [1, S, D] # Residual add out = residual + o print("out:", out.shape, out[0, 0, :4].float().tolist()) # ---- Dump ---- def dump(name, t): p = os.path.join(OUT_DIR, name + '.bin') a = t.contiguous().cpu().view(torch.int16).numpy().astype('int16') open(p, 'wb').write(a.tobytes()) # Save token_ids with open(os.path.join(OUT_DIR, 'token_ids.bin'), 'wb') as f: f.write(struct.pack('