File size: 3,799 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// test_weight_load.cpp — validate attention weight loading for layer 0 and print memory use.
#include "acl_runtime.h"
#include "device_weights.h"
#include "model_config.h"
#include "safetensors_loader.h"

#include <cstdio>
#include <cstring>
#include <vector>

static float bf16_to_float(uint16_t x) {
    uint32_t u = (uint32_t)x << 16;
    float f; std::memcpy(&f, &u, 4);
    return f;
}

int main(int argc, char** argv) {
    const std::string dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
    int tp_size = argc > 1 ? std::atoi(argv[1]) : 16;
    int tp_rank = argc > 2 ? std::atoi(argv[2]) : 0;
    int layer   = argc > 3 ? std::atoi(argv[3]) : 0;

    ModelConfig cfg;
    if (!cfg.load_from_json(dir + "/config.json")) return 1;
    cfg.compute_derived(tp_size, tp_rank);

    SafetensorsLoader st;
    if (!st.open(dir)) return 1;

    AclRuntime rt;
    if (!rt.init(0)) return 1;

    DeviceWeightsLoader dw(st, cfg);

    // Load shared (large: ~2.5GB each for embed/head)
    SharedWeights shared;
    printf("Loading shared weights...\n");
    if (!dw.load_shared(shared)) return 1;
    printf("  embed_tokens: %.2f MB\n", shared.embed_tokens.size / 1e6);
    printf("  lm_head:      %.2f MB\n", shared.lm_head.size / 1e6);
    printf("  final_norm:   %.2f MB\n", shared.final_norm.size / 1e6);

    // Load layer 0 attention
    LayerAttnWeights attn;
    printf("\nLoading layer %d attention...\n", layer);
    if (!dw.load_attention(layer, attn)) return 1;
    printf("  input_layernorm: %.1f KB\n", attn.input_layernorm.size / 1e3);
    printf("  q_proj:          %.2f MB  (q_dim_per_rank=%ld)\n",
           attn.q_proj.size / 1e6, cfg.q_dim_per_rank);
    printf("  k_proj:          %.2f MB\n", attn.k_proj.size / 1e6);
    printf("  v_proj:          %.2f MB\n", attn.v_proj.size / 1e6);
    printf("  o_proj:          %.2f MB\n", attn.o_proj.size / 1e6);
    printf("  q_norm / k_norm: %zu B each\n", attn.q_norm.size);

    // Sanity check: q_proj expected bytes = q_dim_per_rank * D * 2
    int64_t expected_q = cfg.q_dim_per_rank * cfg.hidden_size * 2;
    int64_t expected_o = cfg.hidden_size * cfg.q_dim_per_rank * 2;
    bool ok_q = (attn.q_proj.size == (size_t)expected_q);
    bool ok_o = (attn.o_proj.size == (size_t)expected_o);
    printf("\nq_proj size check: %zu == %ld %s\n", attn.q_proj.size, expected_q, ok_q ? "OK" : "FAIL");
    printf("o_proj size check: %zu == %ld %s\n", attn.o_proj.size, expected_o, ok_o ? "OK" : "FAIL");

    // Spot-check: D2H read first 4 BF16 values of q_proj, compare to HF safetensors data.
    std::vector<uint16_t> q_first(4);
    ACL_CHECK(aclrtMemcpy(q_first.data(), 8, attn.q_proj.get(), 8, ACL_MEMCPY_DEVICE_TO_HOST));
    printf("q_proj first 4 BF16 raw (rank=%d, starts at head %ld): ", tp_rank, tp_rank * cfg.n_heads_per_rank);
    for (auto v : q_first) printf("0x%04x ", v);
    printf("\n");

    // Compare with host-side reference: HF q_proj full shape [8192, 4096].
    // Rank r takes rows [r * q_dim_per_rank, (r+1) * q_dim_per_rank).
    // First row of rank 0 = HF[0, 0..3].
    const auto* m = st.get("model.layers." + std::to_string(layer) + ".self_attn.q_proj.weight");
    const auto* host_q = (const uint16_t*)st.data_ptr(*m);
    int64_t row_off = tp_rank * cfg.q_dim_per_rank;
    const uint16_t* host_first = host_q + row_off * cfg.hidden_size;
    printf("host_q reference (row %ld first 4): ", row_off);
    for (int i = 0; i < 4; i++) printf("0x%04x ", host_first[i]);
    printf("\n");
    bool bytes_match = (std::memcmp(q_first.data(), host_first, 8) == 0);
    printf("Bytes match: %s\n", bytes_match ? "OK" : "FAIL");

    bool pass = ok_q && ok_o && bytes_match;
    printf("\n%s\n", pass ? "=== test_weight_load PASS ===" : "=== test_weight_load FAIL ===");
    return pass ? 0 : 1;
}