// test_rope_fused.cpp — test aclnnApplyRotaryPosEmbV2 vs our manual 8-op HF RoPE. // If rotaryMode="half" matches HF, we can replace apply_rope_manual with 1 op → 7× reduction // of per-layer op count for RoPE phase. #include "acl_common.h" #include "acl_runtime.h" #include "aclnn_ops.h" #include "rope.h" #include "engine.h" // for fill_cos_sin_hf + RopeCache #include #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } int main() { AclRuntime rt; rt.init(0); // Test shape: 1 batch, 5 seq, 4 heads, head_dim=128 (Qwen3-like) const int64_t B = 1, S = 5, Hq = 4, Hkv = 4, Dh = 128; const float theta = 5e6f; // Qwen3 theta // Random q, k (deterministic from seed) std::vector h_q(B * S * Hq * Dh), h_k(B * S * Hkv * Dh); uint32_t seed = 42; auto rnd = [&seed]() { seed = seed * 1103515245 + 12345; return f_to_bf16(((seed >> 16) / 32768.0f - 1.0f) * 0.1f); }; for (auto& x : h_q) x = rnd(); for (auto& x : h_k) x = rnd(); // cos/sin cache (positions 0..S-1) std::vector cos_h, sin_h; fill_cos_sin_hf(cos_h, sin_h, 0, S, Dh, theta); DeviceBuffer q1(h_q.size() * 2), k1(h_k.size() * 2); DeviceBuffer q2(h_q.size() * 2), k2(h_k.size() * 2); DeviceBuffer cos_dev(cos_h.size() * 2), sin_dev(sin_h.size() * 2); DeviceBuffer scratch(B * S * Hq * Dh * 2); ACL_CHECK(aclrtMemcpy(q1.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(q2.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(k1.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(k2.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(cos_dev.get(), cos_h.size()*2, cos_h.data(), cos_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(sin_dev.get(), sin_h.size()*2, sin_h.data(), sin_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE)); // --- Path 1: our manual HF RoPE --- apply_rope_manual(rt.stream(), q1.get(), B, S, Hq, Dh, k1.get(), Hkv, cos_dev.get(), sin_dev.get(), scratch.get()); rt.sync(); std::vector q1_out(h_q.size()), k1_out(h_k.size()); ACL_CHECK(aclrtMemcpy(q1_out.data(), h_q.size()*2, q1.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(k1_out.data(), h_k.size()*2, k1.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); // --- Path 2: aclnnApplyRotaryPosEmbV2 with rotaryMode="half" --- // Layout: see docs. Common: 0=BSND, 1=SBND, 2=BNSD. q/k shape [B, S, N, Dh]. // cos/sin shape: typically [1, S, 1, Dh] or [S, Dh]. // Try multiple combinations until one works struct Try { int64_t layout; const char* mode; std::vector qshape; std::vector cshape; }; std::vector tries = { {0, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}}, {1, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}}, {2, "half", {B, Hq, S, Dh}, {1, 1, S, Dh}}, {0, "half", {B, S, Hq, Dh}, {S, Dh}}, {0, "interleaved", {B, S, Hq, Dh}, {1, S, 1, Dh}}, {0, "half", {S, Hq, Dh}, {S, 1, Dh}}, }; uint64_t ws = 0; aclOpExecutor* exec = nullptr; aclnnStatus s1 = -1; Try chosen{}; for (auto& t : tries) { auto t_q = make_contig_tensor(q2.get(), ACL_BF16, t.qshape); std::vector kshape = t.qshape; if (kshape.size() >= 3) kshape[kshape.size()-2] = Hkv; auto t_k = make_contig_tensor(k2.get(), ACL_BF16, kshape); auto t_cos = make_contig_tensor(cos_dev.get(), ACL_BF16, t.cshape); auto t_sin = make_contig_tensor(sin_dev.get(), ACL_BF16, t.cshape); char buf[32]; strncpy(buf, t.mode, sizeof(buf)); s1 = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q.get(), t_k.get(), t_cos.get(), t_sin.get(), t.layout, buf, &ws, &exec); printf("[ropev2] layout=%ld mode=%-12s qshape=%zu cshape=%zu → status=%d\n", t.layout, t.mode, t.qshape.size(), t.cshape.size(), (int)s1); if (s1 == 0) { chosen = t; break; } } if (s1 != 0) { fprintf(stderr, "All combos failed\n"); return 1; } printf("→ winning: layout=%ld mode=%s\n", chosen.layout, chosen.mode); DeviceBuffer wb; if (ws > 0) wb.alloc(ws); s1 = aclnnApplyRotaryPosEmbV2(wb.get(), ws, exec, rt.stream()); printf("[ropev2] exec: status=%d\n", (int)s1); if (s1 != 0) return 1; rt.sync(); std::vector q2_out(h_q.size()), k2_out(h_k.size()); ACL_CHECK(aclrtMemcpy(q2_out.data(), h_q.size()*2, q2.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(k2_out.data(), h_k.size()*2, k2.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST)); // Compare double q_l2d = 0, q_l2r = 0, q_max = 0; for (size_t i = 0; i < h_q.size(); i++) { float a = bf16_to_float(q1_out[i]), b = bf16_to_float(q2_out[i]); q_l2d += (a-b)*(a-b); q_l2r += a*a; if (std::abs(a-b) > q_max) q_max = std::abs(a-b); } double q_rel = std::sqrt(q_l2d) / (std::sqrt(q_l2r) + 1e-10); double k_l2d = 0, k_l2r = 0, k_max = 0; for (size_t i = 0; i < h_k.size(); i++) { float a = bf16_to_float(k1_out[i]), b = bf16_to_float(k2_out[i]); k_l2d += (a-b)*(a-b); k_l2r += a*a; if (std::abs(a-b) > k_max) k_max = std::abs(a-b); } double k_rel = std::sqrt(k_l2d) / (std::sqrt(k_l2r) + 1e-10); printf("\nManual-HF vs aclnnApplyRotaryPosEmbV2(layout=0, mode=half):\n"); printf(" Q: rel=%.4e max=%.4f\n", q_rel, q_max); printf(" K: rel=%.4e max=%.4f\n", k_rel, k_max); printf(" Q[0,:4] manual: %.5f %.5f %.5f %.5f\n", bf16_to_float(q1_out[0]), bf16_to_float(q1_out[1]), bf16_to_float(q1_out[2]), bf16_to_float(q1_out[3])); printf(" Q[0,:4] ropev2: %.5f %.5f %.5f %.5f\n", bf16_to_float(q2_out[0]), bf16_to_float(q2_out[1]), bf16_to_float(q2_out[2]), bf16_to_float(q2_out[3])); bool pass = q_rel < 1e-2 && k_rel < 1e-2; printf("\n%s\n", pass ? "=== RoPE V2 matches manual HF ===" : "=== MISMATCH — need different mode/layout ==="); return pass ? 0 : 1; }