File size: 4,936 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
// test_rope.cpp — verify aclnnApplyRotaryPosEmb works on 910 initial for Qwen3 shapes.
#include "acl_common.h"
#include "acl_runtime.h"

#include <aclnnop/aclnn_apply_rotary_pos_emb.h>
#include <aclnnop/aclnn_apply_rotary_pos_emb_v2.h>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}

static std::vector<uint8_t> read_file(const std::string& p) {
    std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
    f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}

int main() {
    const std::string data = "tests/attn_data";

    AclRuntime rt;
    rt.init(0);

    // Load q_normed, k_normed (before RoPE) and cos, sin, plus q_roped, k_roped (reference after RoPE)
    auto qn_h = read_file(data + "/q_normed.bin");
    auto kn_h = read_file(data + "/k_normed.bin");
    auto cos_h = read_file(data + "/cos.bin");
    auto sin_h = read_file(data + "/sin.bin");
    auto qr_h = read_file(data + "/q_roped.bin");
    auto kr_h = read_file(data + "/k_roped.bin");

    // Shapes: q=[1, S, Hq, Dh], k=[1, S, Hkv, Dh], cos/sin=[1, S, Dh]
    const int64_t S = 5, Hq = 64, Hkv = 4, Dh = 128;

    DeviceBuffer q_d(qn_h.size()), k_d(kn_h.size()), cos_d(cos_h.size()), sin_d(sin_h.size());
    ACL_CHECK(aclrtMemcpy(q_d.get(),   qn_h.size(), qn_h.data(),   qn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(k_d.get(),   kn_h.size(), kn_h.data(),   kn_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(cos_d.get(), cos_h.size(), cos_h.data(), cos_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(sin_d.get(), sin_h.size(), sin_h.data(), sin_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));

    // Python dumped with cos/sin shape [1, S, Dh] (unsqueeze done inline in npu_apply_rotary_pos_emb call w/ [1,S,1,Dh]).
    // For aclnnApplyRotaryPosEmb layout=1 (BSND): q [B,S,N,Dh], cos/sin [B,S,1,Dh].
    // Our dump is cos [1, S, Dh] — add a broadcast-1 dim by using view shape [1, S, 1, Dh].
    auto t_q    = make_contig_tensor(q_d.get(),   ACL_BF16, {1, S, Hq, Dh});
    auto t_k    = make_contig_tensor(k_d.get(),   ACL_BF16, {1, S, Hkv, Dh});
    auto t_cos  = make_contig_tensor(cos_d.get(), ACL_BF16, {1, S, 1, Dh});
    auto t_sin  = make_contig_tensor(sin_d.get(), ACL_BF16, {1, S, 1, Dh});

    int layout = 1;
    const char* env_layout = std::getenv("LAYOUT");
    if (env_layout) layout = std::atoi(env_layout);
    std::string mode = "half";
    const char* env_mode = std::getenv("MODE");
    if (env_mode) mode = env_mode;
    bool use_v2 = (std::getenv("V2") != nullptr);
    printf("layout=%d mode=%s v2=%d\n", layout, mode.c_str(), (int)use_v2);

    uint64_t ws = 0;
    aclOpExecutor* exec = nullptr;
    if (use_v2) {
        // v2 accepts rotaryMode string: "half" (HF/Qwen) or "interleave" (GPT-NeoX)
        aclnnStatus st = aclnnApplyRotaryPosEmbV2GetWorkspaceSize(
            t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
            layout, (char*)mode.c_str(),
            &ws, &exec);
        if (st != 0) {
            fprintf(stderr, "V2 GetWS status=%d %s\n", (int)st, aclGetRecentErrMsg());
            return 1;
        }
        DeviceBuffer ws_buf;
        if (ws > 0) ws_buf.alloc(ws);
        ACLNN_CHECK(aclnnApplyRotaryPosEmbV2(ws_buf.get(), ws, exec, rt.stream()));
    } else {
        aclnnStatus st = aclnnApplyRotaryPosEmbGetWorkspaceSize(
            t_q.get(), t_k.get(), t_cos.get(), t_sin.get(),
            layout,
            &ws, &exec);
        if (st != 0) {
            fprintf(stderr, "V1 GetWS status=%d %s\n", (int)st, aclGetRecentErrMsg());
            return 1;
        }
        DeviceBuffer ws_buf;
        if (ws > 0) ws_buf.alloc(ws);
        ACLNN_CHECK(aclnnApplyRotaryPosEmb(ws_buf.get(), ws, exec, rt.stream()));
    }
    rt.sync();

    // Compare q with q_roped reference
    std::vector<uint16_t> q_out(S * Hq * Dh);
    ACL_CHECK(aclrtMemcpy(q_out.data(), qn_h.size(), q_d.get(), qn_h.size(), ACL_MEMCPY_DEVICE_TO_HOST));
    auto* q_ref = (const uint16_t*)qr_h.data();

    double l2d = 0, l2r = 0, maxd = 0;
    for (int i = 0; i < (int)(S*Hq*Dh); i++) {
        float a = bf16_to_float(q_out[i]), b = bf16_to_float(q_ref[i]);
        l2d += (a-b)*(a-b); l2r += b*b;
        if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
    }
    double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
    printf("Q rope compare: rel=%.4e max_abs=%.4f\n", rel, maxd);
    printf("  cxx q[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(q_out[i]));
    printf("\n  ref q[0,0,:4]: "); for (int i = 0; i < 4; i++) printf("%.4f ", bf16_to_float(q_ref[i])); printf("\n");

    bool ok = rel < 1e-2;
    printf("\n%s\n", ok ? "=== test_rope PASS ===" : "=== test_rope FAIL ===");
    return ok ? 0 : 1;
}