File size: 2,916 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// test_linear_hf.cpp — verify linear_hf (y = x @ W.T with HF [out, in] layout).
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <vector>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}

int main() {
    const std::string data = "tests/mm_data";
    int64_t N = 0, D = 0, OUT = 0;
    {
        std::ifstream f(data + "/shape.txt"); std::string line;
        while (std::getline(f, line)) {
            auto eq = line.find('='); if (eq == std::string::npos) continue;
            auto k = line.substr(0, eq); auto v = std::atoll(line.c_str() + eq + 1);
            if (k == "N") N = v; else if (k == "D") D = v; else if (k == "OUT") OUT = v;
        }
    }
    printf("N=%ld D=%ld OUT=%ld\n", N, D, OUT);

    AclRuntime rt;
    rt.init(0);

    auto read_all = [&](const std::string& p) {
        std::ifstream f(p, std::ios::binary | std::ios::ate); size_t sz = f.tellg();
        f.seekg(0); std::vector<uint8_t> v(sz); f.read((char*)v.data(), sz); return v;
    };
    auto x_h  = read_all(data + "/x.bin");
    auto W_h  = read_all(data + "/W.bin");
    auto yr_h = read_all(data + "/y_ref.bin");

    DeviceBuffer x_d(N * D * 2);
    DeviceBuffer W_d(OUT * D * 2);
    DeviceBuffer y_d(N * OUT * 2);
    ACL_CHECK(aclrtMemcpy(x_d.get(), x_h.size(), x_h.data(), x_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(W_d.get(), W_h.size(), W_h.data(), W_h.size(), ACL_MEMCPY_HOST_TO_DEVICE));

    auto t_x = make_contig_tensor(x_d.get(), ACL_BF16, {N, D});
    auto t_y = make_contig_tensor(y_d.get(), ACL_BF16, {N, OUT});

    linear_hf(rt.stream(), t_x.get(), W_d.get(), ACL_BF16, OUT, D, t_y.get());
    rt.sync();

    std::vector<uint16_t> y_cxx(N * OUT);
    ACL_CHECK(aclrtMemcpy(y_cxx.data(), N * OUT * 2, y_d.get(), N * OUT * 2, ACL_MEMCPY_DEVICE_TO_HOST));
    auto* y_ref = (const uint16_t*)yr_h.data();

    double l2d = 0, l2r = 0, maxd = 0;
    for (int i = 0; i < N * OUT; i++) {
        float a = bf16_to_float(y_cxx[i]);
        float b = bf16_to_float(y_ref[i]);
        l2d += (a-b)*(a-b); l2r += b*b;
        if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
    }
    double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
    printf("L2 diff=%.4f ref=%.4f relative=%.4e max_abs=%.4f\n",
           std::sqrt(l2d), std::sqrt(l2r), rel, maxd);
    printf("y_cxx[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_cxx[i])); printf("\n");
    printf("y_ref[0..3]: "); for (int i = 0; i < 4; i++) printf("%.3f ", bf16_to_float(y_ref[i])); printf("\n");

    // BF16 matmul has more precision loss than RmsNorm. Allow 1% relative error.
    bool ok = rel < 1e-2;
    printf("\n%s\n", ok ? "=== test_linear_hf PASS ===" : "=== test_linear_hf FAIL ===");
    return ok ? 0 : 1;
}