File size: 6,641 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | // test_rope_fused.cpp — test aclnnApplyRotaryPosEmbV2 vs our manual 8-op HF RoPE.
// If rotaryMode="half" matches HF, we can replace apply_rope_manual with 1 op → 7× reduction
// of per-layer op count for RoPE phase.
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "rope.h"
#include "engine.h" // for fill_cos_sin_hf + RopeCache
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <vector>
static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; }
static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); }
int main() {
AclRuntime rt;
rt.init(0);
// Test shape: 1 batch, 5 seq, 4 heads, head_dim=128 (Qwen3-like)
const int64_t B = 1, S = 5, Hq = 4, Hkv = 4, Dh = 128;
const float theta = 5e6f; // Qwen3 theta
// Random q, k (deterministic from seed)
std::vector<uint16_t> h_q(B * S * Hq * Dh), h_k(B * S * Hkv * Dh);
uint32_t seed = 42;
auto rnd = [&seed]() {
seed = seed * 1103515245 + 12345;
return f_to_bf16(((seed >> 16) / 32768.0f - 1.0f) * 0.1f);
};
for (auto& x : h_q) x = rnd();
for (auto& x : h_k) x = rnd();
// cos/sin cache (positions 0..S-1)
std::vector<uint16_t> cos_h, sin_h;
fill_cos_sin_hf(cos_h, sin_h, 0, S, Dh, theta);
DeviceBuffer q1(h_q.size() * 2), k1(h_k.size() * 2);
DeviceBuffer q2(h_q.size() * 2), k2(h_k.size() * 2);
DeviceBuffer cos_dev(cos_h.size() * 2), sin_dev(sin_h.size() * 2);
DeviceBuffer scratch(B * S * Hq * Dh * 2);
ACL_CHECK(aclrtMemcpy(q1.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(q2.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(k1.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(k2.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(cos_dev.get(), cos_h.size()*2, cos_h.data(), cos_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sin_dev.get(), sin_h.size()*2, sin_h.data(), sin_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
// --- Path 1: our manual HF RoPE ---
apply_rope_manual(rt.stream(), q1.get(), B, S, Hq, Dh, k1.get(), Hkv,
cos_dev.get(), sin_dev.get(), scratch.get());
rt.sync();
std::vector<uint16_t> q1_out(h_q.size()), k1_out(h_k.size());
ACL_CHECK(aclrtMemcpy(q1_out.data(), h_q.size()*2, q1.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(k1_out.data(), h_k.size()*2, k1.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
// --- Path 2: aclnnApplyRotaryPosEmbV2 with rotaryMode="half" ---
// Layout: see docs. Common: 0=BSND, 1=SBND, 2=BNSD. q/k shape [B, S, N, Dh].
// cos/sin shape: typically [1, S, 1, Dh] or [S, Dh].
// Try multiple combinations until one works
struct Try { int64_t layout; const char* mode; std::vector<int64_t> qshape; std::vector<int64_t> cshape; };
std::vector<Try> tries = {
{0, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{1, "half", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{2, "half", {B, Hq, S, Dh}, {1, 1, S, Dh}},
{0, "half", {B, S, Hq, Dh}, {S, Dh}},
{0, "interleaved", {B, S, Hq, Dh}, {1, S, 1, Dh}},
{0, "half", {S, Hq, Dh}, {S, 1, Dh}},
};
uint64_t ws = 0; aclOpExecutor* exec = nullptr;
aclnnStatus s1 = -1;
Try chosen{};
for (auto& t : tries) {
auto t_q = make_contig_tensor(q2.get(), ACL_BF16, t.qshape);
std::vector<int64_t> kshape = t.qshape; if (kshape.size() >= 3) kshape[kshape.size()-2] = Hkv;
auto t_k = make_contig_tensor(k2.get(), ACL_BF16, kshape);
auto t_cos = make_contig_tensor(cos_dev.get(), ACL_BF16, t.cshape);
auto t_sin = make_contig_tensor(sin_dev.get(), ACL_BF16, t.cshape);
char buf[32]; strncpy(buf, t.mode, sizeof(buf));
s1 = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
t.layout, buf, &ws, &exec);
printf("[ropev2] layout=%ld mode=%-12s qshape=%zu cshape=%zu → status=%d\n",
t.layout, t.mode, t.qshape.size(), t.cshape.size(), (int)s1);
if (s1 == 0) { chosen = t; break; }
}
if (s1 != 0) { fprintf(stderr, "All combos failed\n"); return 1; }
printf("→ winning: layout=%ld mode=%s\n", chosen.layout, chosen.mode);
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
s1 = aclnnApplyRotaryPosEmbV2(wb.get(), ws, exec, rt.stream());
printf("[ropev2] exec: status=%d\n", (int)s1);
if (s1 != 0) return 1;
rt.sync();
std::vector<uint16_t> q2_out(h_q.size()), k2_out(h_k.size());
ACL_CHECK(aclrtMemcpy(q2_out.data(), h_q.size()*2, q2.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(k2_out.data(), h_k.size()*2, k2.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
// Compare
double q_l2d = 0, q_l2r = 0, q_max = 0;
for (size_t i = 0; i < h_q.size(); i++) {
float a = bf16_to_float(q1_out[i]), b = bf16_to_float(q2_out[i]);
q_l2d += (a-b)*(a-b); q_l2r += a*a;
if (std::abs(a-b) > q_max) q_max = std::abs(a-b);
}
double q_rel = std::sqrt(q_l2d) / (std::sqrt(q_l2r) + 1e-10);
double k_l2d = 0, k_l2r = 0, k_max = 0;
for (size_t i = 0; i < h_k.size(); i++) {
float a = bf16_to_float(k1_out[i]), b = bf16_to_float(k2_out[i]);
k_l2d += (a-b)*(a-b); k_l2r += a*a;
if (std::abs(a-b) > k_max) k_max = std::abs(a-b);
}
double k_rel = std::sqrt(k_l2d) / (std::sqrt(k_l2r) + 1e-10);
printf("\nManual-HF vs aclnnApplyRotaryPosEmbV2(layout=0, mode=half):\n");
printf(" Q: rel=%.4e max=%.4f\n", q_rel, q_max);
printf(" K: rel=%.4e max=%.4f\n", k_rel, k_max);
printf(" Q[0,:4] manual: %.5f %.5f %.5f %.5f\n",
bf16_to_float(q1_out[0]), bf16_to_float(q1_out[1]),
bf16_to_float(q1_out[2]), bf16_to_float(q1_out[3]));
printf(" Q[0,:4] ropev2: %.5f %.5f %.5f %.5f\n",
bf16_to_float(q2_out[0]), bf16_to_float(q2_out[1]),
bf16_to_float(q2_out[2]), bf16_to_float(q2_out[3]));
bool pass = q_rel < 1e-2 && k_rel < 1e-2;
printf("\n%s\n", pass ? "=== RoPE V2 matches manual HF ===" : "=== MISMATCH — need different mode/layout ===");
return pass ? 0 : 1;
}
|