// test_weight_load.cpp — validate attention weight loading for layer 0 and print memory use. #include "acl_runtime.h" #include "device_weights.h" #include "model_config.h" #include "safetensors_loader.h" #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } int main(int argc, char** argv) { const std::string dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; int tp_size = argc > 1 ? std::atoi(argv[1]) : 16; int tp_rank = argc > 2 ? std::atoi(argv[2]) : 0; int layer = argc > 3 ? std::atoi(argv[3]) : 0; ModelConfig cfg; if (!cfg.load_from_json(dir + "/config.json")) return 1; cfg.compute_derived(tp_size, tp_rank); SafetensorsLoader st; if (!st.open(dir)) return 1; AclRuntime rt; if (!rt.init(0)) return 1; DeviceWeightsLoader dw(st, cfg); // Load shared (large: ~2.5GB each for embed/head) SharedWeights shared; printf("Loading shared weights...\n"); if (!dw.load_shared(shared)) return 1; printf(" embed_tokens: %.2f MB\n", shared.embed_tokens.size / 1e6); printf(" lm_head: %.2f MB\n", shared.lm_head.size / 1e6); printf(" final_norm: %.2f MB\n", shared.final_norm.size / 1e6); // Load layer 0 attention LayerAttnWeights attn; printf("\nLoading layer %d attention...\n", layer); if (!dw.load_attention(layer, attn)) return 1; printf(" input_layernorm: %.1f KB\n", attn.input_layernorm.size / 1e3); printf(" q_proj: %.2f MB (q_dim_per_rank=%ld)\n", attn.q_proj.size / 1e6, cfg.q_dim_per_rank); printf(" k_proj: %.2f MB\n", attn.k_proj.size / 1e6); printf(" v_proj: %.2f MB\n", attn.v_proj.size / 1e6); printf(" o_proj: %.2f MB\n", attn.o_proj.size / 1e6); printf(" q_norm / k_norm: %zu B each\n", attn.q_norm.size); // Sanity check: q_proj expected bytes = q_dim_per_rank * D * 2 int64_t expected_q = cfg.q_dim_per_rank * cfg.hidden_size * 2; int64_t expected_o = cfg.hidden_size * cfg.q_dim_per_rank * 2; bool ok_q = (attn.q_proj.size == (size_t)expected_q); bool ok_o = (attn.o_proj.size == (size_t)expected_o); printf("\nq_proj size check: %zu == %ld %s\n", attn.q_proj.size, expected_q, ok_q ? "OK" : "FAIL"); printf("o_proj size check: %zu == %ld %s\n", attn.o_proj.size, expected_o, ok_o ? "OK" : "FAIL"); // Spot-check: D2H read first 4 BF16 values of q_proj, compare to HF safetensors data. std::vector q_first(4); ACL_CHECK(aclrtMemcpy(q_first.data(), 8, attn.q_proj.get(), 8, ACL_MEMCPY_DEVICE_TO_HOST)); printf("q_proj first 4 BF16 raw (rank=%d, starts at head %ld): ", tp_rank, tp_rank * cfg.n_heads_per_rank); for (auto v : q_first) printf("0x%04x ", v); printf("\n"); // Compare with host-side reference: HF q_proj full shape [8192, 4096]. // Rank r takes rows [r * q_dim_per_rank, (r+1) * q_dim_per_rank). // First row of rank 0 = HF[0, 0..3]. const auto* m = st.get("model.layers." + std::to_string(layer) + ".self_attn.q_proj.weight"); const auto* host_q = (const uint16_t*)st.data_ptr(*m); int64_t row_off = tp_rank * cfg.q_dim_per_rank; const uint16_t* host_first = host_q + row_off * cfg.hidden_size; printf("host_q reference (row %ld first 4): ", row_off); for (int i = 0; i < 4; i++) printf("0x%04x ", host_first[i]); printf("\n"); bool bytes_match = (std::memcmp(q_first.data(), host_first, 8) == 0); printf("Bytes match: %s\n", bytes_match ? "OK" : "FAIL"); bool pass = ok_q && ok_o && bytes_match; printf("\n%s\n", pass ? "=== test_weight_load PASS ===" : "=== test_weight_load FAIL ==="); return pass ? 0 : 1; }