llm_mutil_npu / tests /test_attention_layer.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
// test_attention_layer.cpp — full single-layer attention forward (Qwen3-235B layer 0), TP=1.
// Validates C++ output against Python HF-style reference (attn_data/final_out.bin).
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "model_config.h"
#include "rope.h"
#include "safetensors_loader.h"
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
static uint16_t float_to_bf16(float x) {
uint32_t u; std::memcpy(&u, &x, 4);
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}
static std::vector<uint8_t> read_file(const std::string& p) {
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}
int main() {
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
const std::string data_dir = "tests/attn_data";
ModelConfig cfg;
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
cfg.compute_derived(/*tp_size=*/1, /*tp_rank=*/0); // single rank for correctness test
const int64_t D = cfg.hidden_size;
const int64_t Hq = cfg.num_attention_heads;
const int64_t Hkv = cfg.num_key_value_heads;
const int64_t Dh = cfg.head_dim;
const int64_t Q_DIM = Hq * Dh;
const int64_t KV_DIM = Hkv * Dh;
const double scale = 1.0 / std::sqrt((double)Dh);
const double eps = cfg.rms_norm_eps;
const float theta = cfg.rope_theta;
SafetensorsLoader st;
if (!st.open(model_dir)) return 1;
AclRuntime rt;
rt.init(0);
// ---- Load weights (layer 0 attention + embed) ----
DeviceWeightsLoader dw(st, cfg);
SharedWeights shared;
LayerAttnWeights attn;
printf("Loading weights...\n");
if (!dw.load_shared(shared)) return 1;
if (!dw.load_attention(0, attn)) return 1;
printf(" shared.embed %.0fMB, attn total ~140MB\n", shared.embed_tokens.size / 1e6);
// ---- Load token ids (5 tokens: "The capital of France is") ----
auto tok_raw = read_file(data_dir + "/token_ids.bin");
int32_t S = *(int32_t*)tok_raw.data();
std::vector<int32_t> tokens(S);
std::memcpy(tokens.data(), tok_raw.data() + 4, S * 4);
printf("S=%d tokens=[", S); for (auto t : tokens) printf("%d,", t); printf("]\n");
// ---- Embed lookup: [S, D] ----
DeviceBuffer tok_dev(S * 4);
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S * 4, tokens.data(), S * 4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
// embed weight shape [vocab, D]
auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D});
DeviceBuffer x_dev(S * D * 2);
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
index_select(rt.stream(), t_embed_w.get(), 0, t_tok.get(), t_x.get());
rt.sync();
// ---- Residual snapshot (copy x) ----
DeviceBuffer residual_dev(S * D * 2);
ACL_CHECK(aclrtMemcpyAsync(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt.stream()));
// ---- Input layernorm ----
DeviceBuffer xn_dev(S * D * 2);
DeviceBuffer rstd_dev(S * 4);
auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D});
auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
rms_norm(rt.stream(), t_x.get(), t_ln_w.get(), eps, t_xn.get(), t_rstd.get());
// ---- Q/K/V projections (linear_hf: y = x @ W.T, W stored as [out, in]) ----
DeviceBuffer q_dev(S * Q_DIM * 2);
DeviceBuffer k_dev(S * KV_DIM * 2);
DeviceBuffer v_dev(S * KV_DIM * 2);
auto t_q = make_contig_tensor(q_dev.get(), ACL_BF16, {S, Q_DIM});
auto t_k = make_contig_tensor(k_dev.get(), ACL_BF16, {S, KV_DIM});
auto t_v = make_contig_tensor(v_dev.get(), ACL_BF16, {S, KV_DIM});
linear_hf(rt.stream(), t_xn.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
linear_hf(rt.stream(), t_xn.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
linear_hf(rt.stream(), t_xn.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
// ---- Reshape Q, K as [B=1, S, N, Dh] for q_norm/k_norm + RoPE ----
// Same memory; just new views.
// q_dev has S * Q_DIM = S * Hq * Dh BF16
auto t_q_4d = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Hq, Dh});
auto t_k_4d = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, Hkv, Dh});
// Per-head RmsNorm on last dim (gamma shape [Dh])
auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh});
auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh});
DeviceBuffer rstd_q_dev(S * Hq * 4); // rstd shape = q's all-but-last dims
DeviceBuffer rstd_k_dev(S * Hkv * 4);
auto t_rstd_q = make_contig_tensor(rstd_q_dev.get(), ACL_FLOAT, {1, S, Hq});
auto t_rstd_k = make_contig_tensor(rstd_k_dev.get(), ACL_FLOAT, {1, S, Hkv});
// RmsNorm in place on q/k
rms_norm(rt.stream(), t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
rms_norm(rt.stream(), t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
// ---- Compute cos/sin on device ----
// cos/sin shape [1, S, Dh] BF16
std::vector<uint16_t> cos_host(S * Dh), sin_host(S * Dh);
for (int s = 0; s < S; s++) {
for (int64_t d = 0; d < Dh; d++) {
// freq index: for half-half layout, index d corresponds to pair index (d % (Dh/2))
int64_t half = Dh / 2;
int64_t pair = (d < half) ? d : (d - half);
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
float angle = (float)s * theta_pair;
cos_host[s * Dh + d] = float_to_bf16(std::cos(angle));
sin_host[s * Dh + d] = float_to_bf16(std::sin(angle));
}
}
DeviceBuffer cos_dev(S * Dh * 2);
DeviceBuffer sin_dev(S * Dh * 2);
ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_host.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
// ---- RoPE ----
DeviceBuffer rope_scratch(1 * S * Hq * Dh * 2);
apply_rope_manual(rt.stream(),
q_dev.get(), 1, S, Hq, Dh,
k_dev.get(), Hkv,
cos_dev.get(), sin_dev.get(),
rope_scratch.get());
// ---- FIAS ----
// q/k/v are reshaped back to BSH [1, S, Hq*Dh or Hkv*Dh]
auto t_q_bsh = make_contig_tensor(q_dev.get(), ACL_BF16, {1, S, Q_DIM});
auto t_k_bsh = make_contig_tensor(k_dev.get(), ACL_BF16, {1, S, KV_DIM});
auto t_v_bsh = make_contig_tensor(v_dev.get(), ACL_BF16, {1, S, KV_DIM});
// Causal mask 2048x2048 (sparse_mode=3 requires fixed size)
const int64_t MASK = 2048;
DeviceBuffer mask_dev(MASK * MASK); // bool = 1 byte
std::vector<uint8_t> mask_host(MASK * MASK, 0);
for (int i = 0; i < MASK; i++)
for (int j = i+1; j < MASK; j++)
mask_host[i*MASK + j] = 1; // upper triangular = True
ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK});
DeviceBuffer attn_out_dev(1 * S * Q_DIM * 2);
auto t_attn_out = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {1, S, Q_DIM});
fused_infer_attention_score(
rt.stream(),
t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
t_mask.get(),
{S}, {S},
Hq, Hkv,
scale,
3, // sparse_mode = causal
t_attn_out.get());
// ---- O projection ----
auto t_attn_out_2d = make_contig_tensor(attn_out_dev.get(), ACL_BF16, {S, Q_DIM});
DeviceBuffer o_dev(S * D * 2);
auto t_o = make_contig_tensor(o_dev.get(), ACL_BF16, {S, D});
linear_hf(rt.stream(), t_attn_out_2d.get(), attn.o_proj.get(), ACL_BF16, D, Q_DIM, t_o.get());
// ---- Residual add: out = residual + o ----
auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
float alpha_v = 1.0f;
aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
DeviceBuffer out_dev(S * D * 2);
auto t_out = make_contig_tensor(out_dev.get(), ACL_BF16, {S, D});
{
uint64_t ws = 0; aclOpExecutor* e = nullptr;
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_o.get(), alpha, t_out.get(), &ws, &e));
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
}
aclDestroyScalar(alpha);
rt.sync();
// ---- Compare with Python reference ----
auto ref_h = read_file(data_dir + "/final_out.bin");
std::vector<uint16_t> cxx(S * D);
ACL_CHECK(aclrtMemcpy(cxx.data(), S*D*2, out_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
auto* ref = (const uint16_t*)ref_h.data();
double l2d = 0, l2r = 0, maxd = 0;
for (int i = 0; i < S * D; i++) {
float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf("\nAttention layer output compare: rel=%.4e max_abs=%.4f\n", rel, maxd);
printf(" cxx[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(cxx[i]));
printf("\n ref[0, :4]: "); for (int i = 0; i < 4; i++) printf("%.6f ", bf16_to_float(ref[i])); printf("\n");
bool pass = rel < 5e-2; // BF16 accumulation across 5+ ops loses ~1-2% per step
printf("\n%s\n", pass ? "=== test_attention_layer PASS ===" : "=== test_attention_layer FAIL ===");
return pass ? 0 : 1;
}