File size: 3,437 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// test_rms_norm.cpp — verify aclnnRmsNorm against Python reference.
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16;
    float f; std::memcpy(&f, &u, 4);
    return f;
}

static std::vector<uint8_t> read_file(const std::string& path) {
    std::ifstream f(path, std::ios::binary | std::ios::ate);
    if (!f) { fprintf(stderr, "open %s failed\n", path.c_str()); std::abort(); }
    size_t sz = f.tellg();
    f.seekg(0);
    std::vector<uint8_t> v(sz);
    f.read((char*)v.data(), sz);
    return v;
}

int main() {
    const std::string data = "tests/rms_norm_data";

    // Parse shape.txt
    int64_t N = 0, D = 0;
    double eps = 1e-6;
    {
        std::ifstream f(data + "/shape.txt");
        std::string line;
        while (std::getline(f, line)) {
            auto eq = line.find('=');
            if (eq == std::string::npos) continue;
            auto k = line.substr(0, eq);
            auto v = line.substr(eq + 1);
            if (k == "N") N = std::atoll(v.c_str());
            else if (k == "D") D = std::atoll(v.c_str());
            else if (k == "eps") eps = std::atof(v.c_str());
        }
    }
    printf("Shape: N=%ld D=%ld eps=%g\n", N, D, eps);

    AclRuntime rt;
    rt.init(0);

    auto x_host     = read_file(data + "/x.bin");
    auto gamma_host = read_file(data + "/gamma.bin");
    auto y_ref_host = read_file(data + "/y_ref.bin");

    DeviceBuffer x_dev(N * D * 2);
    DeviceBuffer gamma_dev(D * 2);
    DeviceBuffer y_dev(N * D * 2);
    DeviceBuffer rstd_dev(N * 4);  // F32

    ACL_CHECK(aclrtMemcpy(x_dev.get(),     x_host.size(),     x_host.data(),     x_host.size(),     ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(gamma_dev.get(), gamma_host.size(), gamma_host.data(), gamma_host.size(), ACL_MEMCPY_HOST_TO_DEVICE));

    auto t_x     = make_contig_tensor(x_dev.get(),     ACL_BF16,  {N, D});
    auto t_gamma = make_contig_tensor(gamma_dev.get(), ACL_BF16,  {D});
    auto t_y     = make_contig_tensor(y_dev.get(),     ACL_BF16,  {N, D});
    auto t_rstd  = make_contig_tensor(rstd_dev.get(),  ACL_FLOAT, {N});

    rms_norm(rt.stream(), t_x.get(), t_gamma.get(), eps, t_y.get(), t_rstd.get());
    rt.sync();

    std::vector<uint16_t> y_cxx(N * D);
    ACL_CHECK(aclrtMemcpy(y_cxx.data(), N * D * 2, y_dev.get(), N * D * 2, ACL_MEMCPY_DEVICE_TO_HOST));

    auto* y_ref = (const uint16_t*)y_ref_host.data();

    // Compare
    double l2_d = 0, l2_r = 0, max_abs = 0;
    for (int i = 0; i < N * D; i++) {
        float a = bf16_to_float(y_cxx[i]);
        float b = bf16_to_float(y_ref[i]);
        l2_d += (a-b)*(a-b);
        l2_r += b*b;
        if (std::abs(a-b) > max_abs) max_abs = std::abs(a-b);
    }
    double rel = std::sqrt(l2_d) / (std::sqrt(l2_r) + 1e-10);
    printf("L2 diff = %.6f, L2 ref = %.6f, relative = %.6e, max abs = %.6f\n",
           std::sqrt(l2_d), std::sqrt(l2_r), rel, max_abs);

    printf("y_cxx[0, :8]: ");
    for (int i = 0; i < 8; i++) printf("%.4f ", bf16_to_float(y_cxx[i]));
    printf("\ny_ref[0, :8]: ");
    for (int i = 0; i < 8; i++) printf("%.4f ", bf16_to_float(y_ref[i]));
    printf("\n");

    bool ok = rel < 1e-2;
    printf("\n%s\n", ok ? "=== test_rms_norm PASS ===" : "=== test_rms_norm FAIL ===");
    return ok ? 0 : 1;
}