File size: 6,641 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// test_rope_fused.cpp — test aclnnApplyRotaryPosEmbV2 vs our manual 8-op HF RoPE.
// If rotaryMode="half" matches HF, we can replace apply_rope_manual with 1 op → 7× reduction
// of per-layer op count for RoPE phase.
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "rope.h"
#include "engine.h"   // for fill_cos_sin_hf + RopeCache

#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>

#include <cmath>
#include <cstdio>
#include <cstring>
#include <vector>

static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; }
static uint16_t f_to_bf16(float f) { uint32_t u; std::memcpy(&u, &f, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); }

int main() {
    AclRuntime rt;
    rt.init(0);

    // Test shape: 1 batch, 5 seq, 4 heads, head_dim=128 (Qwen3-like)
    const int64_t B = 1, S = 5, Hq = 4, Hkv = 4, Dh = 128;
    const float theta = 5e6f;   // Qwen3 theta

    // Random q, k (deterministic from seed)
    std::vector<uint16_t> h_q(B * S * Hq * Dh), h_k(B * S * Hkv * Dh);
    uint32_t seed = 42;
    auto rnd = [&seed]() {
        seed = seed * 1103515245 + 12345;
        return f_to_bf16(((seed >> 16) / 32768.0f - 1.0f) * 0.1f);
    };
    for (auto& x : h_q) x = rnd();
    for (auto& x : h_k) x = rnd();

    // cos/sin cache (positions 0..S-1)
    std::vector<uint16_t> cos_h, sin_h;
    fill_cos_sin_hf(cos_h, sin_h, 0, S, Dh, theta);

    DeviceBuffer q1(h_q.size() * 2), k1(h_k.size() * 2);
    DeviceBuffer q2(h_q.size() * 2), k2(h_k.size() * 2);
    DeviceBuffer cos_dev(cos_h.size() * 2), sin_dev(sin_h.size() * 2);
    DeviceBuffer scratch(B * S * Hq * Dh * 2);

    ACL_CHECK(aclrtMemcpy(q1.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(q2.get(), h_q.size()*2, h_q.data(), h_q.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(k1.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(k2.get(), h_k.size()*2, h_k.data(), h_k.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(cos_dev.get(), cos_h.size()*2, cos_h.data(), cos_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(sin_dev.get(), sin_h.size()*2, sin_h.data(), sin_h.size()*2, ACL_MEMCPY_HOST_TO_DEVICE));

    // --- Path 1: our manual HF RoPE ---
    apply_rope_manual(rt.stream(), q1.get(), B, S, Hq, Dh, k1.get(), Hkv,
                      cos_dev.get(), sin_dev.get(), scratch.get());
    rt.sync();

    std::vector<uint16_t> q1_out(h_q.size()), k1_out(h_k.size());
    ACL_CHECK(aclrtMemcpy(q1_out.data(), h_q.size()*2, q1.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
    ACL_CHECK(aclrtMemcpy(k1_out.data(), h_k.size()*2, k1.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));

    // --- Path 2: aclnnApplyRotaryPosEmbV2 with rotaryMode="half" ---
    // Layout: see docs. Common: 0=BSND, 1=SBND, 2=BNSD. q/k shape [B, S, N, Dh].
    // cos/sin shape: typically [1, S, 1, Dh] or [S, Dh].

    // Try multiple combinations until one works
    struct Try { int64_t layout; const char* mode; std::vector<int64_t> qshape; std::vector<int64_t> cshape; };
    std::vector<Try> tries = {
        {0, "half",         {B, S, Hq, Dh}, {1, S, 1, Dh}},
        {1, "half",         {B, S, Hq, Dh}, {1, S, 1, Dh}},
        {2, "half",         {B, Hq, S, Dh}, {1, 1, S, Dh}},
        {0, "half",         {B, S, Hq, Dh}, {S, Dh}},
        {0, "interleaved",  {B, S, Hq, Dh}, {1, S, 1, Dh}},
        {0, "half",         {S, Hq, Dh},    {S, 1, Dh}},
    };

    uint64_t ws = 0; aclOpExecutor* exec = nullptr;
    aclnnStatus s1 = -1;
    Try chosen{};
    for (auto& t : tries) {
        auto t_q = make_contig_tensor(q2.get(), ACL_BF16, t.qshape);
        std::vector<int64_t> kshape = t.qshape; if (kshape.size() >= 3) kshape[kshape.size()-2] = Hkv;
        auto t_k = make_contig_tensor(k2.get(), ACL_BF16, kshape);
        auto t_cos = make_contig_tensor(cos_dev.get(), ACL_BF16, t.cshape);
        auto t_sin = make_contig_tensor(sin_dev.get(), ACL_BF16, t.cshape);
        char buf[32]; strncpy(buf, t.mode, sizeof(buf));
        s1 = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
                                                      t.layout, buf, &ws, &exec);
        printf("[ropev2] layout=%ld mode=%-12s qshape=%zu cshape=%zu → status=%d\n",
               t.layout, t.mode, t.qshape.size(), t.cshape.size(), (int)s1);
        if (s1 == 0) { chosen = t; break; }
    }
    if (s1 != 0) { fprintf(stderr, "All combos failed\n"); return 1; }
    printf("→ winning: layout=%ld mode=%s\n", chosen.layout, chosen.mode);
    DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
    s1 = aclnnApplyRotaryPosEmbV2(wb.get(), ws, exec, rt.stream());
    printf("[ropev2] exec: status=%d\n", (int)s1);
    if (s1 != 0) return 1;
    rt.sync();

    std::vector<uint16_t> q2_out(h_q.size()), k2_out(h_k.size());
    ACL_CHECK(aclrtMemcpy(q2_out.data(), h_q.size()*2, q2.get(), h_q.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));
    ACL_CHECK(aclrtMemcpy(k2_out.data(), h_k.size()*2, k2.get(), h_k.size()*2, ACL_MEMCPY_DEVICE_TO_HOST));

    // Compare
    double q_l2d = 0, q_l2r = 0, q_max = 0;
    for (size_t i = 0; i < h_q.size(); i++) {
        float a = bf16_to_float(q1_out[i]), b = bf16_to_float(q2_out[i]);
        q_l2d += (a-b)*(a-b); q_l2r += a*a;
        if (std::abs(a-b) > q_max) q_max = std::abs(a-b);
    }
    double q_rel = std::sqrt(q_l2d) / (std::sqrt(q_l2r) + 1e-10);

    double k_l2d = 0, k_l2r = 0, k_max = 0;
    for (size_t i = 0; i < h_k.size(); i++) {
        float a = bf16_to_float(k1_out[i]), b = bf16_to_float(k2_out[i]);
        k_l2d += (a-b)*(a-b); k_l2r += a*a;
        if (std::abs(a-b) > k_max) k_max = std::abs(a-b);
    }
    double k_rel = std::sqrt(k_l2d) / (std::sqrt(k_l2r) + 1e-10);

    printf("\nManual-HF vs aclnnApplyRotaryPosEmbV2(layout=0, mode=half):\n");
    printf("  Q:  rel=%.4e  max=%.4f\n", q_rel, q_max);
    printf("  K:  rel=%.4e  max=%.4f\n", k_rel, k_max);
    printf("  Q[0,:4] manual:  %.5f %.5f %.5f %.5f\n",
           bf16_to_float(q1_out[0]), bf16_to_float(q1_out[1]),
           bf16_to_float(q1_out[2]), bf16_to_float(q1_out[3]));
    printf("  Q[0,:4] ropev2:  %.5f %.5f %.5f %.5f\n",
           bf16_to_float(q2_out[0]), bf16_to_float(q2_out[1]),
           bf16_to_float(q2_out[2]), bf16_to_float(q2_out[3]));

    bool pass = q_rel < 1e-2 && k_rel < 1e-2;
    printf("\n%s\n", pass ? "=== RoPE V2 matches manual HF ===" : "=== MISMATCH — need different mode/layout ===");
    return pass ? 0 : 1;
}