llm_mutil_npu / tests /test_attention_decode.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// test_attention_decode.cpp — validates single-layer attention with KV cache.
//
// Strategy: compare two paths yielding the same pos-5 attention output:
// Path A (reference): prefill 6 tokens in one shot → attn_out[5]
// Path B (decode): prefill 5 tokens → K/V cache; decode 6th token via cache → attn_out_decode[0]
//
// The two outputs should match within BF16 precision.
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "model_config.h"
#include "rope.h"
#include "safetensors_loader.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
static uint16_t float_to_bf16(float x) {
uint32_t u; std::memcpy(&u, &x, 4);
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}
static std::vector<uint8_t> read_file(const std::string& p) {
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}
// Fill cos/sin tables for a range of positions [p0, p0+L). HF layout: half-half.
static void fill_cos_sin(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
int64_t p0, int64_t L, int64_t Dh, float theta) {
cos_h.resize(L * Dh); sin_h.resize(L * Dh);
int64_t half = Dh / 2;
for (int64_t s = 0; s < L; s++) {
for (int64_t d = 0; d < Dh; d++) {
int64_t pair = (d < half) ? d : (d - half);
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
float angle = (float)(p0 + s) * theta_pair;
cos_h[s * Dh + d] = float_to_bf16(std::cos(angle));
sin_h[s * Dh + d] = float_to_bf16(std::sin(angle));
}
}
}
int main() {
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
const std::string data_dir = "tests/attn_data";
ModelConfig cfg;
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
cfg.compute_derived(1, 0);
const int64_t D = cfg.hidden_size;
const int64_t Hq = cfg.num_attention_heads;
const int64_t Hkv = cfg.num_key_value_heads;
const int64_t Dh = cfg.head_dim;
const int64_t Q_DIM = Hq * Dh;
const int64_t KV_DIM = Hkv * Dh;
const double scale = 1.0 / std::sqrt((double)Dh);
const double eps = cfg.rms_norm_eps;
const float theta = cfg.rope_theta;
SafetensorsLoader st;
if (!st.open(model_dir)) return 1;
AclRuntime rt;
rt.init(0);
DeviceWeightsLoader dw(st, cfg);
SharedWeights shared;
LayerAttnWeights attn;
printf("Loading weights...\n");
if (!dw.load_shared(shared)) return 1;
if (!dw.load_attention(0, attn)) return 1;
// ---- Load 5 prefill tokens + use token[5]=random as "6th" decoded token ----
auto tok_raw = read_file(data_dir + "/token_ids.bin");
int32_t S_prefill = *(int32_t*)tok_raw.data();
if (S_prefill < 5) { fprintf(stderr, "need >=5 tokens\n"); return 1; }
std::vector<int32_t> tokens(S_prefill);
std::memcpy(tokens.data(), tok_raw.data() + 4, S_prefill * 4);
// Build 6-token sequence (reuse first 5; pick a 6th token id — use token 0 as a simple choice)
const int64_t S6 = 6;
const int64_t S5 = 5;
std::vector<int32_t> tok6(S6);
for (int i = 0; i < S5; i++) tok6[i] = tokens[i];
tok6[5] = tokens[0]; // any token works for cross-consistency test
printf("tokens6=["); for (auto t : tok6) printf("%d,", t); printf("]\n");
// ---- Causal mask (2048x2048, sparse_mode=3) shared across both paths ----
const int64_t MASK = 2048;
DeviceBuffer mask_dev(MASK * MASK);
std::vector<uint8_t> mask_host(MASK * MASK, 0);
for (int i = 0; i < MASK; i++)
for (int j = i+1; j < MASK; j++)
mask_host[i*MASK + j] = 1;
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
// =========================================================================
// PATH A: 6-token prefill (reference)
// =========================================================================
printf("\n[Path A] 6-token prefill reference\n");
DeviceBuffer tokA_dev(S6 * 4);
ACL_CHECK(aclrtMemcpy(tokA_dev.get(), S6*4, tok6.data(), S6*4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tokA = make_contig_tensor(tokA_dev.get(), ACL_INT32, {S6});
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
DeviceBuffer xA_dev(S6 * D * 2);
auto t_xA = make_contig_tensor(xA_dev.get(), ACL_BF16, {S6, D});
index_select(rt.stream(), t_embed_w.get(), 0, t_tokA.get(), t_xA.get());
rt.sync();
DeviceBuffer xnA_dev(S6 * D * 2);
DeviceBuffer rstdA_dev(S6 * 4);
auto t_xnA = make_contig_tensor(xnA_dev.get(), ACL_BF16, {S6, D});
auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
auto t_rstdA = make_contig_tensor(rstdA_dev.get(), ACL_FLOAT, {S6});
rms_norm(rt.stream(), t_xA.get(), t_ln_w.get(), eps, t_xnA.get(), t_rstdA.get());
DeviceBuffer qA_dev(S6 * Q_DIM * 2);
DeviceBuffer kA_dev(S6 * KV_DIM * 2);
DeviceBuffer vA_dev(S6 * KV_DIM * 2);
auto t_qA = make_contig_tensor(qA_dev.get(), ACL_BF16, {S6, Q_DIM});
auto t_kA = make_contig_tensor(kA_dev.get(), ACL_BF16, {S6, KV_DIM});
auto t_vA = make_contig_tensor(vA_dev.get(), ACL_BF16, {S6, KV_DIM});
linear_hf(rt.stream(), t_xnA.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qA.get());
linear_hf(rt.stream(), t_xnA.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kA.get());
linear_hf(rt.stream(), t_xnA.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vA.get());
// Per-head norm
auto t_qA_4d = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Hq, Dh});
auto t_kA_4d = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, Hkv, Dh});
auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
DeviceBuffer rstd_qA(S6 * Hq * 4), rstd_kA(S6 * Hkv * 4);
auto t_rstd_qA = make_contig_tensor(rstd_qA.get(), ACL_FLOAT, {1, S6, Hq});
auto t_rstd_kA = make_contig_tensor(rstd_kA.get(), ACL_FLOAT, {1, S6, Hkv});
rms_norm(rt.stream(), t_qA_4d.get(), t_qn_w.get(), eps, t_qA_4d.get(), t_rstd_qA.get());
rms_norm(rt.stream(), t_kA_4d.get(), t_kn_w.get(), eps, t_kA_4d.get(), t_rstd_kA.get());
// RoPE for positions 0..5
std::vector<uint16_t> cosA_h, sinA_h;
fill_cos_sin(cosA_h, sinA_h, 0, S6, Dh, theta);
DeviceBuffer cosA_dev(S6 * Dh * 2), sinA_dev(S6 * Dh * 2);
ACL_CHECK(aclrtMemcpy(cosA_dev.get(), S6*Dh*2, cosA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sinA_dev.get(), S6*Dh*2, sinA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
DeviceBuffer ropeA_scratch(1 * S6 * Hq * Dh * 2);
apply_rope_manual(rt.stream(), qA_dev.get(), 1, S6, Hq, Dh, kA_dev.get(), Hkv,
cosA_dev.get(), sinA_dev.get(), ropeA_scratch.get());
auto t_qA_bsh = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Q_DIM});
auto t_kA_bsh = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
auto t_vA_bsh = make_contig_tensor(vA_dev.get(), ACL_BF16, {1, S6, KV_DIM});
DeviceBuffer attnA_out(1 * S6 * Q_DIM * 2);
auto t_attnA_out = make_contig_tensor(attnA_out.get(), ACL_BF16, {1, S6, Q_DIM});
fused_infer_attention_score(
rt.stream(), t_qA_bsh.get(), t_kA_bsh.get(), t_vA_bsh.get(),
t_mask.get(), {S6}, {S6}, Hq, Hkv, scale, 3, t_attnA_out.get());
rt.sync();
// Extract attnA_out[pos=5] into [1, 1, Q_DIM] for comparison
std::vector<uint16_t> refA_host(Q_DIM);
ACL_CHECK(aclrtMemcpy(refA_host.data(), Q_DIM*2,
(char*)attnA_out.get() + 5 * Q_DIM * 2, Q_DIM*2,
ACL_MEMCPY_DEVICE_TO_HOST));
printf(" attnA_out[5, :4] = %.5f %.5f %.5f %.5f\n",
bf16_to_float(refA_host[0]), bf16_to_float(refA_host[1]),
bf16_to_float(refA_host[2]), bf16_to_float(refA_host[3]));
// =========================================================================
// PATH B: 5-token prefill + KV cache → 1-token decode
// =========================================================================
printf("\n[Path B] 5-prefill + 1-decode via KV cache\n");
const int64_t MAX_LEN = 128; // small cache for test
DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2);
DeviceBuffer v_cache(MAX_LEN * KV_DIM * 2);
// Zero-init unused slots (not strictly needed, FIAS uses actual_seq_lens).
// ---- Prefill 5 tokens ----
DeviceBuffer tokB_dev(S5 * 4);
ACL_CHECK(aclrtMemcpy(tokB_dev.get(), S5*4, tok6.data(), S5*4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tokB = make_contig_tensor(tokB_dev.get(), ACL_INT32, {S5});
DeviceBuffer xB_dev(S5 * D * 2);
auto t_xB = make_contig_tensor(xB_dev.get(), ACL_BF16, {S5, D});
index_select(rt.stream(), t_embed_w.get(), 0, t_tokB.get(), t_xB.get());
rt.sync();
DeviceBuffer xnB_dev(S5 * D * 2);
DeviceBuffer rstdB_dev(S5 * 4);
auto t_xnB = make_contig_tensor(xnB_dev.get(), ACL_BF16, {S5, D});
auto t_rstdB = make_contig_tensor(rstdB_dev.get(), ACL_FLOAT, {S5});
rms_norm(rt.stream(), t_xB.get(), t_ln_w.get(), eps, t_xnB.get(), t_rstdB.get());
DeviceBuffer qB_dev(S5 * Q_DIM * 2);
DeviceBuffer kB_dev(S5 * KV_DIM * 2);
DeviceBuffer vB_dev(S5 * KV_DIM * 2);
auto t_qB = make_contig_tensor(qB_dev.get(), ACL_BF16, {S5, Q_DIM});
auto t_kB = make_contig_tensor(kB_dev.get(), ACL_BF16, {S5, KV_DIM});
auto t_vB = make_contig_tensor(vB_dev.get(), ACL_BF16, {S5, KV_DIM});
linear_hf(rt.stream(), t_xnB.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qB.get());
linear_hf(rt.stream(), t_xnB.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kB.get());
linear_hf(rt.stream(), t_xnB.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vB.get());
auto t_qB_4d = make_contig_tensor(qB_dev.get(), ACL_BF16, {1, S5, Hq, Dh});
auto t_kB_4d = make_contig_tensor(kB_dev.get(), ACL_BF16, {1, S5, Hkv, Dh});
DeviceBuffer rstd_qB(S5 * Hq * 4), rstd_kB(S5 * Hkv * 4);
auto t_rstd_qB = make_contig_tensor(rstd_qB.get(), ACL_FLOAT, {1, S5, Hq});
auto t_rstd_kB = make_contig_tensor(rstd_kB.get(), ACL_FLOAT, {1, S5, Hkv});
rms_norm(rt.stream(), t_qB_4d.get(), t_qn_w.get(), eps, t_qB_4d.get(), t_rstd_qB.get());
rms_norm(rt.stream(), t_kB_4d.get(), t_kn_w.get(), eps, t_kB_4d.get(), t_rstd_kB.get());
std::vector<uint16_t> cosB_h, sinB_h;
fill_cos_sin(cosB_h, sinB_h, 0, S5, Dh, theta);
DeviceBuffer cosB_dev(S5 * Dh * 2), sinB_dev(S5 * Dh * 2);
ACL_CHECK(aclrtMemcpy(cosB_dev.get(), S5*Dh*2, cosB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sinB_dev.get(), S5*Dh*2, sinB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
DeviceBuffer ropeB_scratch(1 * S5 * Hq * Dh * 2);
apply_rope_manual(rt.stream(), qB_dev.get(), 1, S5, Hq, Dh, kB_dev.get(), Hkv,
cosB_dev.get(), sinB_dev.get(), ropeB_scratch.get());
rt.sync();
// Append K, V to cache at positions 0..4.
ACL_CHECK(aclrtMemcpy(k_cache.get(), S5 * KV_DIM * 2,
kB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(v_cache.get(), S5 * KV_DIM * 2,
vB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
printf(" cached K/V at positions 0..%ld\n", S5 - 1);
// ---- Decode 1 token (position = 5) ----
DeviceBuffer tokD_dev(1 * 4);
int32_t tok_dec = tok6[5];
ACL_CHECK(aclrtMemcpy(tokD_dev.get(), 4, &tok_dec, 4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tokD = make_contig_tensor(tokD_dev.get(), ACL_INT32, {1});
DeviceBuffer xD_dev(1 * D * 2);
auto t_xD = make_contig_tensor(xD_dev.get(), ACL_BF16, {1, D});
index_select(rt.stream(), t_embed_w.get(), 0, t_tokD.get(), t_xD.get());
DeviceBuffer xnD_dev(1 * D * 2), rstdD_dev(1 * 4);
auto t_xnD = make_contig_tensor(xnD_dev.get(), ACL_BF16, {1, D});
auto t_rstdD = make_contig_tensor(rstdD_dev.get(), ACL_FLOAT, {1});
rms_norm(rt.stream(), t_xD.get(), t_ln_w.get(), eps, t_xnD.get(), t_rstdD.get());
DeviceBuffer qD_dev(1 * Q_DIM * 2), kD_dev(1 * KV_DIM * 2), vD_dev(1 * KV_DIM * 2);
auto t_qD = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, Q_DIM});
auto t_kD = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, KV_DIM});
auto t_vD = make_contig_tensor(vD_dev.get(), ACL_BF16, {1, KV_DIM});
linear_hf(rt.stream(), t_xnD.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qD.get());
linear_hf(rt.stream(), t_xnD.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kD.get());
linear_hf(rt.stream(), t_xnD.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vD.get());
auto t_qD_4d = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Hq, Dh});
auto t_kD_4d = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, 1, Hkv, Dh});
DeviceBuffer rstd_qD(1 * Hq * 4), rstd_kD(1 * Hkv * 4);
auto t_rstd_qD = make_contig_tensor(rstd_qD.get(), ACL_FLOAT, {1, 1, Hq});
auto t_rstd_kD = make_contig_tensor(rstd_kD.get(), ACL_FLOAT, {1, 1, Hkv});
rms_norm(rt.stream(), t_qD_4d.get(), t_qn_w.get(), eps, t_qD_4d.get(), t_rstd_qD.get());
rms_norm(rt.stream(), t_kD_4d.get(), t_kn_w.get(), eps, t_kD_4d.get(), t_rstd_kD.get());
// RoPE for position 5 only
std::vector<uint16_t> cosD_h, sinD_h;
fill_cos_sin(cosD_h, sinD_h, /*p0=*/5, /*L=*/1, Dh, theta);
DeviceBuffer cosD_dev(1 * Dh * 2), sinD_dev(1 * Dh * 2);
ACL_CHECK(aclrtMemcpy(cosD_dev.get(), Dh*2, cosD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sinD_dev.get(), Dh*2, sinD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
DeviceBuffer ropeD_scratch(1 * 1 * Hq * Dh * 2);
apply_rope_manual(rt.stream(), qD_dev.get(), 1, 1, Hq, Dh, kD_dev.get(), Hkv,
cosD_dev.get(), sinD_dev.get(), ropeD_scratch.get());
rt.sync();
// Append K, V to cache at position 5.
ACL_CHECK(aclrtMemcpy((char*)k_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
kD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
ACL_CHECK(aclrtMemcpy((char*)v_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2,
vD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE));
// FIAS decode: q [1, 1, Q_DIM], k/v [1, 6, KV_DIM] from cache.
auto t_qD_bsh = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Q_DIM});
auto t_kC_bsh = make_contig_tensor(k_cache.get(), ACL_BF16, {1, S6, KV_DIM});
auto t_vC_bsh = make_contig_tensor(v_cache.get(), ACL_BF16, {1, S6, KV_DIM});
DeviceBuffer attnD_out(1 * 1 * Q_DIM * 2);
auto t_attnD_out = make_contig_tensor(attnD_out.get(), ACL_BF16, {1, 1, Q_DIM});
// Decode: q has 1 token, k/v have 6 tokens. Use sparse_mode=0 with no mask — the single q
// at the end can attend to all cached positions; there's no causal constraint on it.
fused_infer_attention_score(
rt.stream(), t_qD_bsh.get(), t_kC_bsh.get(), t_vC_bsh.get(),
nullptr, {1}, {S6},
Hq, Hkv, scale, 0, t_attnD_out.get());
rt.sync();
std::vector<uint16_t> decB_host(Q_DIM);
ACL_CHECK(aclrtMemcpy(decB_host.data(), Q_DIM*2, attnD_out.get(), Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST));
// ---- Compare Path A vs Path B ----
printf("\n attnB_decode[:4] = %.5f %.5f %.5f %.5f\n",
bf16_to_float(decB_host[0]), bf16_to_float(decB_host[1]),
bf16_to_float(decB_host[2]), bf16_to_float(decB_host[3]));
double l2d = 0, l2r = 0, maxd = 0;
for (int i = 0; i < Q_DIM; i++) {
float a = bf16_to_float(decB_host[i]), b = bf16_to_float(refA_host[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf("\nDecode vs 6-prefill comparison: rel=%.4e max_abs=%.4f\n", rel, maxd);
bool pass = rel < 5e-2;
printf("\n%s\n", pass ? "=== test_attention_decode PASS ===" : "=== test_attention_decode FAIL ===");
return pass ? 0 : 1;
}