| |
| |
| |
| |
| |
| |
| |
| #include "acl_common.h" |
| #include "acl_runtime.h" |
| #include "aclnn_ops.h" |
| #include "device_weights.h" |
| #include "model_config.h" |
| #include "rope.h" |
| #include "safetensors_loader.h" |
|
|
| #include <cmath> |
| #include <cstdio> |
| #include <cstring> |
| #include <fstream> |
| #include <vector> |
|
|
| static float bf16_to_float(uint16_t x) { |
| uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; |
| } |
| static uint16_t float_to_bf16(float x) { |
| uint32_t u; std::memcpy(&u, &x, 4); |
| return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); |
| } |
| static std::vector<uint8_t> read_file(const std::string& p) { |
| std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg(); |
| f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v; |
| } |
|
|
| |
| static void fill_cos_sin(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h, |
| int64_t p0, int64_t L, int64_t Dh, float theta) { |
| cos_h.resize(L * Dh); sin_h.resize(L * Dh); |
| int64_t half = Dh / 2; |
| for (int64_t s = 0; s < L; s++) { |
| for (int64_t d = 0; d < Dh; d++) { |
| int64_t pair = (d < half) ? d : (d - half); |
| float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh); |
| float angle = (float)(p0 + s) * theta_pair; |
| cos_h[s * Dh + d] = float_to_bf16(std::cos(angle)); |
| sin_h[s * Dh + d] = float_to_bf16(std::sin(angle)); |
| } |
| } |
| } |
|
|
| int main() { |
| const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; |
| const std::string data_dir = "tests/attn_data"; |
|
|
| ModelConfig cfg; |
| if (!cfg.load_from_json(model_dir + "/config.json")) return 1; |
| cfg.compute_derived(1, 0); |
| const int64_t D = cfg.hidden_size; |
| const int64_t Hq = cfg.num_attention_heads; |
| const int64_t Hkv = cfg.num_key_value_heads; |
| const int64_t Dh = cfg.head_dim; |
| const int64_t Q_DIM = Hq * Dh; |
| const int64_t KV_DIM = Hkv * Dh; |
| const double scale = 1.0 / std::sqrt((double)Dh); |
| const double eps = cfg.rms_norm_eps; |
| const float theta = cfg.rope_theta; |
|
|
| SafetensorsLoader st; |
| if (!st.open(model_dir)) return 1; |
| AclRuntime rt; |
| rt.init(0); |
|
|
| DeviceWeightsLoader dw(st, cfg); |
| SharedWeights shared; |
| LayerAttnWeights attn; |
| printf("Loading weights...\n"); |
| if (!dw.load_shared(shared)) return 1; |
| if (!dw.load_attention(0, attn)) return 1; |
|
|
| |
| auto tok_raw = read_file(data_dir + "/token_ids.bin"); |
| int32_t S_prefill = *(int32_t*)tok_raw.data(); |
| if (S_prefill < 5) { fprintf(stderr, "need >=5 tokens\n"); return 1; } |
| std::vector<int32_t> tokens(S_prefill); |
| std::memcpy(tokens.data(), tok_raw.data() + 4, S_prefill * 4); |
|
|
| |
| const int64_t S6 = 6; |
| const int64_t S5 = 5; |
| std::vector<int32_t> tok6(S6); |
| for (int i = 0; i < S5; i++) tok6[i] = tokens[i]; |
| tok6[5] = tokens[0]; |
| printf("tokens6=["); for (auto t : tok6) printf("%d,", t); printf("]\n"); |
|
|
| |
| const int64_t MASK = 2048; |
| DeviceBuffer mask_dev(MASK * MASK); |
| std::vector<uint8_t> mask_host(MASK * MASK, 0); |
| for (int i = 0; i < MASK; i++) |
| for (int j = i+1; j < MASK; j++) |
| mask_host[i*MASK + j] = 1; |
| ACL_CHECK(aclrtMemcpy(mask_dev.get(), MASK*MASK, mask_host.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_mask = make_contig_tensor(mask_dev.get(), ACL_BOOL, {1, 1, MASK, MASK}); |
|
|
| |
| |
| |
| printf("\n[Path A] 6-token prefill reference\n"); |
|
|
| DeviceBuffer tokA_dev(S6 * 4); |
| ACL_CHECK(aclrtMemcpy(tokA_dev.get(), S6*4, tok6.data(), S6*4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tokA = make_contig_tensor(tokA_dev.get(), ACL_INT32, {S6}); |
| auto t_embed_w = make_contig_tensor(shared.embed_tokens.get(), ACL_BF16, {cfg.vocab_size, D}); |
|
|
| DeviceBuffer xA_dev(S6 * D * 2); |
| auto t_xA = make_contig_tensor(xA_dev.get(), ACL_BF16, {S6, D}); |
| index_select(rt.stream(), t_embed_w.get(), 0, t_tokA.get(), t_xA.get()); |
| rt.sync(); |
|
|
| DeviceBuffer xnA_dev(S6 * D * 2); |
| DeviceBuffer rstdA_dev(S6 * 4); |
| auto t_xnA = make_contig_tensor(xnA_dev.get(), ACL_BF16, {S6, D}); |
| auto t_ln_w = make_contig_tensor(attn.input_layernorm.get(), ACL_BF16, {D}); |
| auto t_rstdA = make_contig_tensor(rstdA_dev.get(), ACL_FLOAT, {S6}); |
| rms_norm(rt.stream(), t_xA.get(), t_ln_w.get(), eps, t_xnA.get(), t_rstdA.get()); |
|
|
| DeviceBuffer qA_dev(S6 * Q_DIM * 2); |
| DeviceBuffer kA_dev(S6 * KV_DIM * 2); |
| DeviceBuffer vA_dev(S6 * KV_DIM * 2); |
| auto t_qA = make_contig_tensor(qA_dev.get(), ACL_BF16, {S6, Q_DIM}); |
| auto t_kA = make_contig_tensor(kA_dev.get(), ACL_BF16, {S6, KV_DIM}); |
| auto t_vA = make_contig_tensor(vA_dev.get(), ACL_BF16, {S6, KV_DIM}); |
| linear_hf(rt.stream(), t_xnA.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qA.get()); |
| linear_hf(rt.stream(), t_xnA.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kA.get()); |
| linear_hf(rt.stream(), t_xnA.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vA.get()); |
|
|
| |
| auto t_qA_4d = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Hq, Dh}); |
| auto t_kA_4d = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, Hkv, Dh}); |
| auto t_qn_w = make_contig_tensor(attn.q_norm.get(), ACL_BF16, {Dh}); |
| auto t_kn_w = make_contig_tensor(attn.k_norm.get(), ACL_BF16, {Dh}); |
| DeviceBuffer rstd_qA(S6 * Hq * 4), rstd_kA(S6 * Hkv * 4); |
| auto t_rstd_qA = make_contig_tensor(rstd_qA.get(), ACL_FLOAT, {1, S6, Hq}); |
| auto t_rstd_kA = make_contig_tensor(rstd_kA.get(), ACL_FLOAT, {1, S6, Hkv}); |
| rms_norm(rt.stream(), t_qA_4d.get(), t_qn_w.get(), eps, t_qA_4d.get(), t_rstd_qA.get()); |
| rms_norm(rt.stream(), t_kA_4d.get(), t_kn_w.get(), eps, t_kA_4d.get(), t_rstd_kA.get()); |
|
|
| |
| std::vector<uint16_t> cosA_h, sinA_h; |
| fill_cos_sin(cosA_h, sinA_h, 0, S6, Dh, theta); |
| DeviceBuffer cosA_dev(S6 * Dh * 2), sinA_dev(S6 * Dh * 2); |
| ACL_CHECK(aclrtMemcpy(cosA_dev.get(), S6*Dh*2, cosA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sinA_dev.get(), S6*Dh*2, sinA_h.data(), S6*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| DeviceBuffer ropeA_scratch(1 * S6 * Hq * Dh * 2); |
| apply_rope_manual(rt.stream(), qA_dev.get(), 1, S6, Hq, Dh, kA_dev.get(), Hkv, |
| cosA_dev.get(), sinA_dev.get(), ropeA_scratch.get()); |
|
|
| auto t_qA_bsh = make_contig_tensor(qA_dev.get(), ACL_BF16, {1, S6, Q_DIM}); |
| auto t_kA_bsh = make_contig_tensor(kA_dev.get(), ACL_BF16, {1, S6, KV_DIM}); |
| auto t_vA_bsh = make_contig_tensor(vA_dev.get(), ACL_BF16, {1, S6, KV_DIM}); |
|
|
| DeviceBuffer attnA_out(1 * S6 * Q_DIM * 2); |
| auto t_attnA_out = make_contig_tensor(attnA_out.get(), ACL_BF16, {1, S6, Q_DIM}); |
| fused_infer_attention_score( |
| rt.stream(), t_qA_bsh.get(), t_kA_bsh.get(), t_vA_bsh.get(), |
| t_mask.get(), {S6}, {S6}, Hq, Hkv, scale, 3, t_attnA_out.get()); |
| rt.sync(); |
|
|
| |
| std::vector<uint16_t> refA_host(Q_DIM); |
| ACL_CHECK(aclrtMemcpy(refA_host.data(), Q_DIM*2, |
| (char*)attnA_out.get() + 5 * Q_DIM * 2, Q_DIM*2, |
| ACL_MEMCPY_DEVICE_TO_HOST)); |
| printf(" attnA_out[5, :4] = %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(refA_host[0]), bf16_to_float(refA_host[1]), |
| bf16_to_float(refA_host[2]), bf16_to_float(refA_host[3])); |
|
|
| |
| |
| |
| printf("\n[Path B] 5-prefill + 1-decode via KV cache\n"); |
|
|
| const int64_t MAX_LEN = 128; |
| DeviceBuffer k_cache(MAX_LEN * KV_DIM * 2); |
| DeviceBuffer v_cache(MAX_LEN * KV_DIM * 2); |
| |
|
|
| |
| DeviceBuffer tokB_dev(S5 * 4); |
| ACL_CHECK(aclrtMemcpy(tokB_dev.get(), S5*4, tok6.data(), S5*4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tokB = make_contig_tensor(tokB_dev.get(), ACL_INT32, {S5}); |
| DeviceBuffer xB_dev(S5 * D * 2); |
| auto t_xB = make_contig_tensor(xB_dev.get(), ACL_BF16, {S5, D}); |
| index_select(rt.stream(), t_embed_w.get(), 0, t_tokB.get(), t_xB.get()); |
| rt.sync(); |
|
|
| DeviceBuffer xnB_dev(S5 * D * 2); |
| DeviceBuffer rstdB_dev(S5 * 4); |
| auto t_xnB = make_contig_tensor(xnB_dev.get(), ACL_BF16, {S5, D}); |
| auto t_rstdB = make_contig_tensor(rstdB_dev.get(), ACL_FLOAT, {S5}); |
| rms_norm(rt.stream(), t_xB.get(), t_ln_w.get(), eps, t_xnB.get(), t_rstdB.get()); |
|
|
| DeviceBuffer qB_dev(S5 * Q_DIM * 2); |
| DeviceBuffer kB_dev(S5 * KV_DIM * 2); |
| DeviceBuffer vB_dev(S5 * KV_DIM * 2); |
| auto t_qB = make_contig_tensor(qB_dev.get(), ACL_BF16, {S5, Q_DIM}); |
| auto t_kB = make_contig_tensor(kB_dev.get(), ACL_BF16, {S5, KV_DIM}); |
| auto t_vB = make_contig_tensor(vB_dev.get(), ACL_BF16, {S5, KV_DIM}); |
| linear_hf(rt.stream(), t_xnB.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qB.get()); |
| linear_hf(rt.stream(), t_xnB.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kB.get()); |
| linear_hf(rt.stream(), t_xnB.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vB.get()); |
|
|
| auto t_qB_4d = make_contig_tensor(qB_dev.get(), ACL_BF16, {1, S5, Hq, Dh}); |
| auto t_kB_4d = make_contig_tensor(kB_dev.get(), ACL_BF16, {1, S5, Hkv, Dh}); |
| DeviceBuffer rstd_qB(S5 * Hq * 4), rstd_kB(S5 * Hkv * 4); |
| auto t_rstd_qB = make_contig_tensor(rstd_qB.get(), ACL_FLOAT, {1, S5, Hq}); |
| auto t_rstd_kB = make_contig_tensor(rstd_kB.get(), ACL_FLOAT, {1, S5, Hkv}); |
| rms_norm(rt.stream(), t_qB_4d.get(), t_qn_w.get(), eps, t_qB_4d.get(), t_rstd_qB.get()); |
| rms_norm(rt.stream(), t_kB_4d.get(), t_kn_w.get(), eps, t_kB_4d.get(), t_rstd_kB.get()); |
|
|
| std::vector<uint16_t> cosB_h, sinB_h; |
| fill_cos_sin(cosB_h, sinB_h, 0, S5, Dh, theta); |
| DeviceBuffer cosB_dev(S5 * Dh * 2), sinB_dev(S5 * Dh * 2); |
| ACL_CHECK(aclrtMemcpy(cosB_dev.get(), S5*Dh*2, cosB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sinB_dev.get(), S5*Dh*2, sinB_h.data(), S5*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| DeviceBuffer ropeB_scratch(1 * S5 * Hq * Dh * 2); |
| apply_rope_manual(rt.stream(), qB_dev.get(), 1, S5, Hq, Dh, kB_dev.get(), Hkv, |
| cosB_dev.get(), sinB_dev.get(), ropeB_scratch.get()); |
| rt.sync(); |
|
|
| |
| ACL_CHECK(aclrtMemcpy(k_cache.get(), S5 * KV_DIM * 2, |
| kB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(v_cache.get(), S5 * KV_DIM * 2, |
| vB_dev.get(), S5 * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
| printf(" cached K/V at positions 0..%ld\n", S5 - 1); |
|
|
| |
| DeviceBuffer tokD_dev(1 * 4); |
| int32_t tok_dec = tok6[5]; |
| ACL_CHECK(aclrtMemcpy(tokD_dev.get(), 4, &tok_dec, 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tokD = make_contig_tensor(tokD_dev.get(), ACL_INT32, {1}); |
| DeviceBuffer xD_dev(1 * D * 2); |
| auto t_xD = make_contig_tensor(xD_dev.get(), ACL_BF16, {1, D}); |
| index_select(rt.stream(), t_embed_w.get(), 0, t_tokD.get(), t_xD.get()); |
|
|
| DeviceBuffer xnD_dev(1 * D * 2), rstdD_dev(1 * 4); |
| auto t_xnD = make_contig_tensor(xnD_dev.get(), ACL_BF16, {1, D}); |
| auto t_rstdD = make_contig_tensor(rstdD_dev.get(), ACL_FLOAT, {1}); |
| rms_norm(rt.stream(), t_xD.get(), t_ln_w.get(), eps, t_xnD.get(), t_rstdD.get()); |
|
|
| DeviceBuffer qD_dev(1 * Q_DIM * 2), kD_dev(1 * KV_DIM * 2), vD_dev(1 * KV_DIM * 2); |
| auto t_qD = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, Q_DIM}); |
| auto t_kD = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, KV_DIM}); |
| auto t_vD = make_contig_tensor(vD_dev.get(), ACL_BF16, {1, KV_DIM}); |
| linear_hf(rt.stream(), t_xnD.get(), attn.q_proj.get(), ACL_BF16, Q_DIM, D, t_qD.get()); |
| linear_hf(rt.stream(), t_xnD.get(), attn.k_proj.get(), ACL_BF16, KV_DIM, D, t_kD.get()); |
| linear_hf(rt.stream(), t_xnD.get(), attn.v_proj.get(), ACL_BF16, KV_DIM, D, t_vD.get()); |
|
|
| auto t_qD_4d = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Hq, Dh}); |
| auto t_kD_4d = make_contig_tensor(kD_dev.get(), ACL_BF16, {1, 1, Hkv, Dh}); |
| DeviceBuffer rstd_qD(1 * Hq * 4), rstd_kD(1 * Hkv * 4); |
| auto t_rstd_qD = make_contig_tensor(rstd_qD.get(), ACL_FLOAT, {1, 1, Hq}); |
| auto t_rstd_kD = make_contig_tensor(rstd_kD.get(), ACL_FLOAT, {1, 1, Hkv}); |
| rms_norm(rt.stream(), t_qD_4d.get(), t_qn_w.get(), eps, t_qD_4d.get(), t_rstd_qD.get()); |
| rms_norm(rt.stream(), t_kD_4d.get(), t_kn_w.get(), eps, t_kD_4d.get(), t_rstd_kD.get()); |
|
|
| |
| std::vector<uint16_t> cosD_h, sinD_h; |
| fill_cos_sin(cosD_h, sinD_h, 5, 1, Dh, theta); |
| DeviceBuffer cosD_dev(1 * Dh * 2), sinD_dev(1 * Dh * 2); |
| ACL_CHECK(aclrtMemcpy(cosD_dev.get(), Dh*2, cosD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sinD_dev.get(), Dh*2, sinD_h.data(), Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| DeviceBuffer ropeD_scratch(1 * 1 * Hq * Dh * 2); |
| apply_rope_manual(rt.stream(), qD_dev.get(), 1, 1, Hq, Dh, kD_dev.get(), Hkv, |
| cosD_dev.get(), sinD_dev.get(), ropeD_scratch.get()); |
| rt.sync(); |
|
|
| |
| ACL_CHECK(aclrtMemcpy((char*)k_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2, |
| kD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy((char*)v_cache.get() + S5 * KV_DIM * 2, KV_DIM * 2, |
| vD_dev.get(), KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
|
|
| |
| auto t_qD_bsh = make_contig_tensor(qD_dev.get(), ACL_BF16, {1, 1, Q_DIM}); |
| auto t_kC_bsh = make_contig_tensor(k_cache.get(), ACL_BF16, {1, S6, KV_DIM}); |
| auto t_vC_bsh = make_contig_tensor(v_cache.get(), ACL_BF16, {1, S6, KV_DIM}); |
|
|
| DeviceBuffer attnD_out(1 * 1 * Q_DIM * 2); |
| auto t_attnD_out = make_contig_tensor(attnD_out.get(), ACL_BF16, {1, 1, Q_DIM}); |
| |
| |
| fused_infer_attention_score( |
| rt.stream(), t_qD_bsh.get(), t_kC_bsh.get(), t_vC_bsh.get(), |
| nullptr, {1}, {S6}, |
| Hq, Hkv, scale, 0, t_attnD_out.get()); |
| rt.sync(); |
|
|
| std::vector<uint16_t> decB_host(Q_DIM); |
| ACL_CHECK(aclrtMemcpy(decB_host.data(), Q_DIM*2, attnD_out.get(), Q_DIM*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| printf("\n attnB_decode[:4] = %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(decB_host[0]), bf16_to_float(decB_host[1]), |
| bf16_to_float(decB_host[2]), bf16_to_float(decB_host[3])); |
|
|
| double l2d = 0, l2r = 0, maxd = 0; |
| for (int i = 0; i < Q_DIM; i++) { |
| float a = bf16_to_float(decB_host[i]), b = bf16_to_float(refA_host[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf("\nDecode vs 6-prefill comparison: rel=%.4e max_abs=%.4f\n", rel, maxd); |
|
|
| bool pass = rel < 5e-2; |
| printf("\n%s\n", pass ? "=== test_attention_decode PASS ===" : "=== test_attention_decode FAIL ==="); |
| return pass ? 0 : 1; |
| } |
|
|