// test_attention_decode.cpp — validates single-layer attention with KV cache. // // Strategy: compare two paths yielding the same pos-5 attention output: // Path A (reference): prefill 6 tokens in one shot → attn_out[5] // Path B (decode): prefill 5 tokens → K/V cache; decode 6th token via cache → attn_out_decode[0] // // The two outputs should match within BF16 precision. #include "acl_common.h" #include "acl_runtime.h" #include "aclnn_ops.h" #include "device_weights.h" #include "model_config.h" #include "rope.h" #include "safetensors_loader.h" #include #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } static uint16_t float_to_bf16(float x) { uint32_t u; std::memcpy(&u, &x, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } static std::vector read_file(const std::string& p) { std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg(); f.seekg(0); std::vector v(s); f.read((char*)v.data(), s); return v; } // Fill cos/sin tables for a range of positions [p0, p0+L). HF layout: half-half. static void fill_cos_sin(std::vector& cos_h, std::vector& sin_h, int64_t p0, int64_t L, int64_t Dh, float theta) { cos_h.resize(L * Dh); sin_h.resize(L * Dh); int64_t half = Dh / 2; for (int64_t s = 0; s < L; s++) { for (int64_t d = 0; d < Dh; d++) { int64_t pair = (d < half) ? d : (d - half); float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh); float angle = (float)(p0 + s) * theta_pair; cos_h[s * Dh + d] = float_to_bf16(std::cos(angle)); sin_h[s * Dh + d] = float_to_bf16(std::sin(angle)); } } } int main() { const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; const std::string data_dir = "tests/attn_data"; ModelConfig cfg; if (!cfg.load_from_json(model_dir + "/config.json")) return 1; cfg.compute_derived(1, 0); const int64_t D = cfg.hidden_size; const int64_t Hq = cfg.num_attention_heads; const int64_t Hkv = cfg.num_key_value_heads; const int64_t Dh = cfg.head_dim; const int64_t Q_DIM = Hq * Dh; const int64_t KV_DIM = Hkv * Dh; const double scale = 1.0 / std::sqrt((double)Dh); const double eps = cfg.rms_norm_eps; const float theta = cfg.rope_theta; SafetensorsLoader st; if (!st.open(model_dir)) return 1; AclRuntime rt; rt.init(0); DeviceWeightsLoader dw(st, cfg); SharedWeights shared; LayerAttnWeights attn; printf("Loading weights...\n"); if (!dw.load_shared(shared)) return 1; if (!dw.load_attention(0, attn)) return 1; // ---- Load 5 prefill tokens + use token[5]=random as "6th" decoded token ---- auto tok_raw = read_file(data_dir + "/token_ids.bin"); int32_t S_prefill = *(int32_t*)tok_raw.data(); if (S_prefill < 5) { fprintf(stderr, "need >=5 tokens\n"); return 1; } std::vector tokens(S_prefill); std::memcpy(tokens.data(), tok_raw.data() + 4, S_prefill * 4); // Build 6-token sequence (reuse first 5; pick a 6th token id — use token 0 as a simple choice) const int64_t S6 = 6; const int64_t S5 = 5; std::vector tok6(S6); for (int i = 0; i < S5; i++) tok6[i] = tokens[i]; tok6[5] = tokens[0]; // any token works for cross-consistency test printf("tokens6=["); for (auto t : tok6) printf("%d,", t); printf("]\n"); // ---- Causal mask (2048x2048, sparse_mode=3) shared across both paths ---- const int64_t MASK = 2048; DeviceBuffer mask_dev(MASK * MASK); std::vector mask_host(MASK * MASK, 0); for (int i = 0; i < MASK; i++) for (int j = i+1; j < MASK; j++) mask_host[i*MASK + j] = 1; ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK}); // ========================================================================= // PATH A: 6-token prefill (reference) // ========================================================================= printf("\n[Path A] 6-token prefill reference\n"); DeviceBuffer tokA_dev(S6 * 4); ACL_CHECK(aclrtMemcpy(tokA_dev.get(), S6*4, tok6.data(), S6*4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tokA = make_contig_tensor(tokA_dev.get(), ACL_INT32, {S6}); auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D}); DeviceBuffer xA_dev(S6 * D * 2); auto t_xA = make_contig_tensor(xA_dev.get(), ACL_BF16, {S6, D}); index_select(rt.stream(), t_embed_w.get(), 0, t_tokA.get(), t_xA.get()); rt.sync(); DeviceBuffer xnA_dev(S6 * D * 2); DeviceBuffer rstdA_dev(S6 * 4); auto t_xnA = make_contig_tensor(xnA_dev.get(), ACL_BF16, {S6, D}); auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D}); auto t_rstdA = make_contig_tensor(rstdA_dev.get(), ACL_FLOAT, {S6}); rms_norm(rt.stream(), t_xA.get(), t_ln_w.get(), eps, t_xnA.get(), t_rstdA.get()); DeviceBuffer qA_dev(S6 * Q_DIM * 2); DeviceBuffer kA_dev(S6 * KV_DIM * 2); DeviceBuffer vA_dev(S6 * KV_DIM * 2); auto t_qA = make_contig_tensor(qA_dev.get(), ACL_BF16, {S6, Q_DIM}); auto t_kA = make_contig_tensor(kA_dev.get(), ACL_BF16, {S6, KV_DIM}); auto t_vA = make_contig_tensor(vA_dev.get(), ACL_BF16, {S6, KV_DIM}); linear_hf(rt.stream(), t_xnA.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qA.get()); linear_hf(rt.stream(), t_xnA.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kA.get()); linear_hf(rt.stream(), t_xnA.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vA.get()); // Per-head norm auto t_qA_4d = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Hq, Dh}); auto t_kA_4d = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, Hkv, Dh}); auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh}); auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh}); DeviceBuffer rstd_qA(S6 * Hq * 4), rstd_kA(S6 * Hkv * 4); auto t_rstd_qA = make_contig_tensor(rstd_qA.get(), ACL_FLOAT, {1, S6, Hq}); auto t_rstd_kA = make_contig_tensor(rstd_kA.get(), ACL_FLOAT, {1, S6, Hkv}); rms_norm(rt.stream(), t_qA_4d.get(), t_qn_w.get(), eps, t_qA_4d.get(), t_rstd_qA.get()); rms_norm(rt.stream(), t_kA_4d.get(), t_kn_w.get(), eps, t_kA_4d.get(), t_rstd_kA.get()); // RoPE for positions 0..5 std::vector cosA_h, sinA_h; fill_cos_sin(cosA_h, sinA_h, 0, S6, Dh, theta); DeviceBuffer cosA_dev(S6 * Dh * 2), sinA_dev(S6 * Dh * 2); ACL_CHECK(aclrtMemcpy(cosA_dev.get(), S6*Dh*2, cosA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(sinA_dev.get(), S6*Dh*2, sinA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); DeviceBuffer ropeA_scratch(1 * S6 * Hq * Dh * 2); apply_rope_manual(rt.stream(), qA_dev.get(), 1, S6, Hq, Dh, kA_dev.get(), Hkv, cosA_dev.get(), sinA_dev.get(), ropeA_scratch.get()); auto t_qA_bsh = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Q_DIM}); auto t_kA_bsh = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, KV_DIM}); auto t_vA_bsh = make_contig_tensor(vA_dev.get(), ACL_BF16, {1, S6, KV_DIM}); DeviceBuffer attnA_out(1 * S6 * Q_DIM * 2); auto t_attnA_out = make_contig_tensor(attnA_out.get(), ACL_BF16, {1, S6, Q_DIM}); fused_infer_attention_score( rt.stream(), t_qA_bsh.get(), t_kA_bsh.get(), t_vA_bsh.get(), t_mask.get(), {S6}, {S6}, Hq, Hkv, scale, 3, t_attnA_out.get()); rt.sync(); // Extract attnA_out[pos=5] into [1, 1, Q_DIM] for comparison std::vector refA_host(Q_DIM); ACL_CHECK(aclrtMemcpy(refA_host.data(), Q_DIM*2, (char*)attnA_out.get() + 5 * Q_DIM * 2, Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST)); printf(" attnA_out[5, :4] = %.5f %.5f %.5f %.5f\n", bf16_to_float(refA_host[0]), bf16_to_float(refA_host[1]), bf16_to_float(refA_host[2]), bf16_to_float(refA_host[3])); // ========================================================================= // PATH B: 5-token prefill + KV cache → 1-token decode // ========================================================================= printf("\n[Path B] 5-prefill + 1-decode via KV cache\n"); const int64_t MAX_LEN = 128; // small cache for test DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2); DeviceBuffer v_cache(MAX_LEN * KV_DIM * 2); // Zero-init unused slots (not strictly needed, FIAS uses actual_seq_lens). // ---- Prefill 5 tokens ---- DeviceBuffer tokB_dev(S5 * 4); ACL_CHECK(aclrtMemcpy(tokB_dev.get(), S5*4, tok6.data(), S5*4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tokB = make_contig_tensor(tokB_dev.get(), ACL_INT32, {S5}); DeviceBuffer xB_dev(S5 * D * 2); auto t_xB = make_contig_tensor(xB_dev.get(), ACL_BF16, {S5, D}); index_select(rt.stream(), t_embed_w.get(), 0, t_tokB.get(), t_xB.get()); rt.sync(); DeviceBuffer xnB_dev(S5 * D * 2); DeviceBuffer rstdB_dev(S5 * 4); auto t_xnB = make_contig_tensor(xnB_dev.get(), ACL_BF16, {S5, D}); auto t_rstdB = make_contig_tensor(rstdB_dev.get(), ACL_FLOAT, {S5}); rms_norm(rt.stream(), t_xB.get(), t_ln_w.get(), eps, t_xnB.get(), t_rstdB.get()); DeviceBuffer qB_dev(S5 * Q_DIM * 2); DeviceBuffer kB_dev(S5 * KV_DIM * 2); DeviceBuffer vB_dev(S5 * KV_DIM * 2); auto t_qB = make_contig_tensor(qB_dev.get(), ACL_BF16, {S5, Q_DIM}); auto t_kB = make_contig_tensor(kB_dev.get(), ACL_BF16, {S5, KV_DIM}); auto t_vB = make_contig_tensor(vB_dev.get(), ACL_BF16, {S5, KV_DIM}); linear_hf(rt.stream(), t_xnB.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qB.get()); linear_hf(rt.stream(), t_xnB.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kB.get()); linear_hf(rt.stream(), t_xnB.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vB.get()); auto t_qB_4d = make_contig_tensor(qB_dev.get(), ACL_BF16, {1, S5, Hq, Dh}); auto t_kB_4d = make_contig_tensor(kB_dev.get(), ACL_BF16, {1, S5, Hkv, Dh}); DeviceBuffer rstd_qB(S5 * Hq * 4), rstd_kB(S5 * Hkv * 4); auto t_rstd_qB = make_contig_tensor(rstd_qB.get(), ACL_FLOAT, {1, S5, Hq}); auto t_rstd_kB = make_contig_tensor(rstd_kB.get(), ACL_FLOAT, {1, S5, Hkv}); rms_norm(rt.stream(), t_qB_4d.get(), t_qn_w.get(), eps, t_qB_4d.get(), t_rstd_qB.get()); rms_norm(rt.stream(), t_kB_4d.get(), t_kn_w.get(), eps, t_kB_4d.get(), t_rstd_kB.get()); std::vector cosB_h, sinB_h; fill_cos_sin(cosB_h, sinB_h, 0, S5, Dh, theta); DeviceBuffer cosB_dev(S5 * Dh * 2), sinB_dev(S5 * Dh * 2); ACL_CHECK(aclrtMemcpy(cosB_dev.get(), S5*Dh*2, cosB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(sinB_dev.get(), S5*Dh*2, sinB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); DeviceBuffer ropeB_scratch(1 * S5 * Hq * Dh * 2); apply_rope_manual(rt.stream(), qB_dev.get(), 1, S5, Hq, Dh, kB_dev.get(), Hkv, cosB_dev.get(), sinB_dev.get(), ropeB_scratch.get()); rt.sync(); // Append K, V to cache at positions 0..4. ACL_CHECK(aclrtMemcpy(k_cache.get(), S5 * KV_DIM * 2, kB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(v_cache.get(), S5 * KV_DIM * 2, vB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); printf(" cached K/V at positions 0..%ld\n", S5 - 1); // ---- Decode 1 token (position = 5) ---- DeviceBuffer tokD_dev(1 * 4); int32_t tok_dec = tok6[5]; ACL_CHECK(aclrtMemcpy(tokD_dev.get(), 4, &tok_dec, 4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tokD = make_contig_tensor(tokD_dev.get(), ACL_INT32, {1}); DeviceBuffer xD_dev(1 * D * 2); auto t_xD = make_contig_tensor(xD_dev.get(), ACL_BF16, {1, D}); index_select(rt.stream(), t_embed_w.get(), 0, t_tokD.get(), t_xD.get()); DeviceBuffer xnD_dev(1 * D * 2), rstdD_dev(1 * 4); auto t_xnD = make_contig_tensor(xnD_dev.get(), ACL_BF16, {1, D}); auto t_rstdD = make_contig_tensor(rstdD_dev.get(), ACL_FLOAT, {1}); rms_norm(rt.stream(), t_xD.get(), t_ln_w.get(), eps, t_xnD.get(), t_rstdD.get()); DeviceBuffer qD_dev(1 * Q_DIM * 2), kD_dev(1 * KV_DIM * 2), vD_dev(1 * KV_DIM * 2); auto t_qD = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, Q_DIM}); auto t_kD = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, KV_DIM}); auto t_vD = make_contig_tensor(vD_dev.get(), ACL_BF16, {1, KV_DIM}); linear_hf(rt.stream(), t_xnD.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qD.get()); linear_hf(rt.stream(), t_xnD.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kD.get()); linear_hf(rt.stream(), t_xnD.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vD.get()); auto t_qD_4d = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Hq, Dh}); auto t_kD_4d = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, 1, Hkv, Dh}); DeviceBuffer rstd_qD(1 * Hq * 4), rstd_kD(1 * Hkv * 4); auto t_rstd_qD = make_contig_tensor(rstd_qD.get(), ACL_FLOAT, {1, 1, Hq}); auto t_rstd_kD = make_contig_tensor(rstd_kD.get(), ACL_FLOAT, {1, 1, Hkv}); rms_norm(rt.stream(), t_qD_4d.get(), t_qn_w.get(), eps, t_qD_4d.get(), t_rstd_qD.get()); rms_norm(rt.stream(), t_kD_4d.get(), t_kn_w.get(), eps, t_kD_4d.get(), t_rstd_kD.get()); // RoPE for position 5 only std::vector cosD_h, sinD_h; fill_cos_sin(cosD_h, sinD_h, /*p0=*/5, /*L=*/1, Dh, theta); DeviceBuffer cosD_dev(1 * Dh * 2), sinD_dev(1 * Dh * 2); ACL_CHECK(aclrtMemcpy(cosD_dev.get(), Dh*2, cosD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(sinD_dev.get(), Dh*2, sinD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); DeviceBuffer ropeD_scratch(1 * 1 * Hq * Dh * 2); apply_rope_manual(rt.stream(), qD_dev.get(), 1, 1, Hq, Dh, kD_dev.get(), Hkv, cosD_dev.get(), sinD_dev.get(), ropeD_scratch.get()); rt.sync(); // Append K, V to cache at position 5. ACL_CHECK(aclrtMemcpy((char*)k_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2, kD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); ACL_CHECK(aclrtMemcpy((char*)v_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2, vD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); // FIAS decode: q [1, 1, Q_DIM], k/v [1, 6, KV_DIM] from cache. auto t_qD_bsh = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Q_DIM}); auto t_kC_bsh = make_contig_tensor(k_cache.get(), ACL_BF16, {1, S6, KV_DIM}); auto t_vC_bsh = make_contig_tensor(v_cache.get(), ACL_BF16, {1, S6, KV_DIM}); DeviceBuffer attnD_out(1 * 1 * Q_DIM * 2); auto t_attnD_out = make_contig_tensor(attnD_out.get(), ACL_BF16, {1, 1, Q_DIM}); // Decode: q has 1 token, k/v have 6 tokens. Use sparse_mode=0 with no mask — the single q // at the end can attend to all cached positions; there's no causal constraint on it. fused_infer_attention_score( rt.stream(), t_qD_bsh.get(), t_kC_bsh.get(), t_vC_bsh.get(), nullptr, {1}, {S6}, Hq, Hkv, scale, 0, t_attnD_out.get()); rt.sync(); std::vector decB_host(Q_DIM); ACL_CHECK(aclrtMemcpy(decB_host.data(), Q_DIM*2, attnD_out.get(), Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST)); // ---- Compare Path A vs Path B ---- printf("\n attnB_decode[:4] = %.5f %.5f %.5f %.5f\n", bf16_to_float(decB_host[0]), bf16_to_float(decB_host[1]), bf16_to_float(decB_host[2]), bf16_to_float(decB_host[3])); double l2d = 0, l2r = 0, maxd = 0; for (int i = 0; i < Q_DIM; i++) { float a = bf16_to_float(decB_host[i]), b = bf16_to_float(refA_host[i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf("\nDecode vs 6-prefill comparison: rel=%.4e max_abs=%.4f\n", rel, maxd); bool pass = rel < 5e-2; printf("\n%s\n", pass ? "=== test_attention_decode PASS ===" : "=== test_attention_decode FAIL ==="); return pass ? 0 : 1; }