File size: 5,461 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// test_model_config.cpp — load config.json, derive TP shard sizes, verify all expected
// HF tensors exist in safetensors for Qwen3-235B.
#include "model_config.h"
#include "safetensors_loader.h"

#include <cstdio>
#include <sstream>
#include <string>

int main(int argc, char** argv) {
    std::string dir = argc > 1 ? argv[1]
                               : "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
    int tp_size = argc > 2 ? std::atoi(argv[2]) : 16;
    int tp_rank = argc > 3 ? std::atoi(argv[3]) : 0;

    ModelConfig cfg;
    if (!cfg.load_from_json(dir + "/config.json")) return 1;
    cfg.compute_derived(tp_size, tp_rank);
    printf("%s\n", cfg.describe().c_str());

    SafetensorsLoader loader;
    if (!loader.open(dir)) return 1;

    // Verify all expected tensor names & shapes match cfg.
    int missing = 0, shape_mismatch = 0;
    auto check_shape = [&](const std::string& name, const std::vector<int64_t>& expected) {
        auto* m = loader.get(name);
        if (!m) {
            printf("  MISSING: %s\n", name.c_str());
            missing++;
            return;
        }
        if (m->shape != expected) {
            printf("  SHAPE MISMATCH: %s got=[", name.c_str());
            for (size_t i = 0; i < m->shape.size(); i++) printf("%s%ld", i ? "," : "", m->shape[i]);
            printf("] want=[");
            for (size_t i = 0; i < expected.size(); i++) printf("%s%ld", i ? "," : "", expected[i]);
            printf("]\n");
            shape_mismatch++;
        }
    };

    // embed/head
    check_shape("model.embed_tokens.weight", {cfg.vocab_size, cfg.hidden_size});
    check_shape("lm_head.weight",            {cfg.vocab_size, cfg.hidden_size});
    check_shape("model.norm.weight",         {cfg.hidden_size});

    // Attention weights (HF stores as [out, in])
    int64_t q_full = cfg.num_attention_heads * cfg.head_dim;
    int64_t kv_full = cfg.num_key_value_heads * cfg.head_dim;
    for (int L = 0; L < cfg.num_hidden_layers; L++) {
        auto base = "model.layers." + std::to_string(L);
        check_shape(base + ".input_layernorm.weight",                   {cfg.hidden_size});
        check_shape(base + ".post_attention_layernorm.weight",          {cfg.hidden_size});
        check_shape(base + ".self_attn.q_proj.weight",                  {q_full,  cfg.hidden_size});
        check_shape(base + ".self_attn.k_proj.weight",                  {kv_full, cfg.hidden_size});
        check_shape(base + ".self_attn.v_proj.weight",                  {kv_full, cfg.hidden_size});
        check_shape(base + ".self_attn.o_proj.weight",                  {cfg.hidden_size, q_full});
        // Qwen3 uses q_norm / k_norm (norm per head) — check existence
        check_shape(base + ".self_attn.q_norm.weight",                  {cfg.head_dim});
        check_shape(base + ".self_attn.k_norm.weight",                  {cfg.head_dim});
        // MoE router
        check_shape(base + ".mlp.gate.weight",                          {cfg.num_experts, cfg.hidden_size});
        // Spot-check few experts (full enumeration is 94*384=36096 lines)
        for (int e : {0, 1, 63, 127}) {
            auto ebase = base + ".mlp.experts." + std::to_string(e);
            check_shape(ebase + ".gate_proj.weight", {cfg.moe_intermediate_size, cfg.hidden_size});
            check_shape(ebase + ".up_proj.weight",   {cfg.moe_intermediate_size, cfg.hidden_size});
            check_shape(ebase + ".down_proj.weight", {cfg.hidden_size, cfg.moe_intermediate_size});
        }
    }

    // Print TP memory estimate
    int64_t attn_bytes_per_rank = 0;
    attn_bytes_per_rank += cfg.q_dim_per_rank     * cfg.hidden_size   * 2;   // q_proj
    attn_bytes_per_rank += cfg.kv_dim_per_rank    * cfg.hidden_size   * 2;   // k_proj
    attn_bytes_per_rank += cfg.kv_dim_per_rank    * cfg.hidden_size   * 2;   // v_proj
    attn_bytes_per_rank += cfg.hidden_size        * cfg.q_dim_per_rank * 2;  // o_proj
    attn_bytes_per_rank *= cfg.num_hidden_layers;

    int64_t moe_bytes_per_rank = 0;
    // gate_exps + up_exps: [E, I_per_rank, D]
    moe_bytes_per_rank += 2 * cfg.num_experts * cfg.i_per_rank * cfg.hidden_size * 2;
    // down_exps: [E, D, I_per_rank]
    moe_bytes_per_rank += cfg.num_experts * cfg.hidden_size * cfg.i_per_rank * 2;
    moe_bytes_per_rank *= cfg.num_hidden_layers;

    int64_t embed_bytes = cfg.vocab_size * cfg.hidden_size * 2 * 2;  // embed + lm_head
    int64_t router_bytes = cfg.num_experts * cfg.hidden_size * 2 * cfg.num_hidden_layers;
    int64_t norm_bytes = cfg.hidden_size * 2 * (2 * cfg.num_hidden_layers + 1);
    int64_t total_per_rank = attn_bytes_per_rank + moe_bytes_per_rank + embed_bytes + router_bytes + norm_bytes;

    printf("\nPer-rank weight memory estimate (BF16, TP=%d):\n", tp_size);
    printf("  attention: %.2f GB\n", attn_bytes_per_rank / 1e9);
    printf("  MoE exps:  %.2f GB\n", moe_bytes_per_rank / 1e9);
    printf("  embed+head: %.2f GB (replicated)\n", embed_bytes / 1e9);
    printf("  router:    %.2f MB (replicated)\n", router_bytes / 1e6);
    printf("  norms:     %.2f MB (replicated)\n", norm_bytes / 1e6);
    printf("  TOTAL:     %.2f GB\n", total_per_rank / 1e9);

    int errors = missing + shape_mismatch;
    printf("\nMissing: %d, Shape mismatch: %d\n", missing, shape_mismatch);
    printf("%s\n", errors == 0 ? "=== test_model_config PASS ==="
                               : "=== test_model_config FAIL ===");
    return errors == 0 ? 0 : 1;
}