| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #include "acl_common.h" |
| #include "acl_runtime.h" |
| #include "aclnn_ops.h" |
| #include "device_weights.h" |
| #include "model_config.h" |
| #include "safetensors_loader.h" |
|
|
| #include <algorithm> |
| #include <cmath> |
| #include <cstdio> |
| #include <cstring> |
| #include <fstream> |
| #include <tuple> |
| #include <vector> |
|
|
| static float bf16_to_float(uint16_t x) { |
| uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; |
| } |
| static uint16_t float_to_bf16(float x) { |
| uint32_t u; std::memcpy(&u, &x, 4); |
| return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); |
| } |
| static std::vector<uint8_t> read_file(const std::string& p) { |
| std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg(); |
| f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v; |
| } |
|
|
| int main() { |
| const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; |
| const std::string data_dir = "tests/moe_data"; |
|
|
| ModelConfig cfg; |
| if (!cfg.load_from_json(model_dir + "/config.json")) return 1; |
| cfg.compute_derived(1, 0); |
| const int64_t D = cfg.hidden_size; |
| const int64_t I = cfg.moe_intermediate_size; |
| const int64_t E = cfg.num_experts; |
| const int64_t K = cfg.num_experts_per_tok; |
| const double eps = cfg.rms_norm_eps; |
|
|
| AclRuntime rt; |
| rt.init(0); |
| printf("[dbg] rt init ok\n"); fflush(stdout); |
|
|
| SafetensorsLoader st; |
| if (!st.open(model_dir)) return 1; |
|
|
| |
| printf("Loading layer 0 attention weights (for post_attention_layernorm)...\n"); |
| DeviceWeightsLoader dw(st, cfg); |
| LayerAttnWeights attn; |
| if (!dw.load_attention(0, attn)) return 1; |
|
|
| printf("Loading layer 0 MoE weights (128 experts Γ 3 projections, stacking + permute)...\n"); fflush(stdout); |
| LayerMoEWeights moe; |
| if (!dw.load_moe(0, rt.stream(), moe)) return 1; |
| rt.sync(); |
| printf("[dbg] moe load ok\n"); fflush(stdout); |
| printf(" router %.1f MB gate_exps %.0f MB up_exps %.0f MB down_exps %.0f MB\n", |
| moe.router.size / 1e6, moe.gate_exps.size / 1e6, moe.up_exps.size / 1e6, moe.down_exps.size / 1e6); |
|
|
| |
| int S = 5; |
| auto x_in_host = read_file(data_dir + "/x_in.bin"); |
| auto ref_out_host = read_file(data_dir + "/final_out.bin"); |
| DeviceBuffer x_dev(S * D * 2); |
| ACL_CHECK(aclrtMemcpy(x_dev.get(), x_in_host.size(), x_in_host.data(), x_in_host.size(), ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| DeviceBuffer residual_dev(S * D * 2); |
| ACL_CHECK(aclrtMemcpy(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
|
|
| printf("[dbg] loaded data and residual ok, TOTAL=%ld\n", S * K); fflush(stdout); |
|
|
| |
| DeviceBuffer xn_dev(S * D * 2); |
| DeviceBuffer rstd_dev(S * 4); |
| auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D}); |
| auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D}); |
| auto t_ln = make_contig_tensor(attn.post_attention_layernorm.get(), ACL_BF16, {D}); |
| auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S}); |
| rms_norm(rt.stream(), t_x.get(), t_ln.get(), eps, t_xn.get(), t_rstd.get()); |
| rt.sync(); |
| printf("[dbg] rms_norm ok\n"); fflush(stdout); |
|
|
| |
| DeviceBuffer logits_dev(S * E * 2); |
| auto t_logits = make_contig_tensor(logits_dev.get(), ACL_BF16, {S, E}); |
| |
| linear_hf(rt.stream(), t_xn.get(), moe.router.get(), ACL_BF16, E, D, t_logits.get()); |
| rt.sync(); |
| printf("[dbg] router linear ok\n"); fflush(stdout); |
|
|
| |
| DeviceBuffer topk_w_dev(S * K * 2); |
| DeviceBuffer topk_idx_dev(S * K * 4); |
| DeviceBuffer row_idx_dev(S * K * 4); |
| auto t_topk_w = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K}); |
| auto t_topk_idx = make_contig_tensor(topk_idx_dev.get(), ACL_INT32, {S, K}); |
| auto t_row_idx = make_contig_tensor(row_idx_dev.get(), ACL_INT32, {S, K}); |
| moe_gating_topk_softmax(rt.stream(), t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get()); |
| rt.sync(); |
| printf("[dbg] topk_softmax ok\n"); fflush(stdout); |
|
|
| |
| std::vector<uint16_t> tw_bf(S * K); |
| ACL_CHECK(aclrtMemcpy(tw_bf.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| for (int s = 0; s < S; s++) { |
| float sum = 0.0f; |
| for (int k = 0; k < K; k++) sum += bf16_to_float(tw_bf[s*K + k]); |
| sum += 1e-20f; |
| for (int k = 0; k < K; k++) { |
| float v = bf16_to_float(tw_bf[s*K + k]) / sum; |
| tw_bf[s*K + k] = float_to_bf16(v); |
| } |
| } |
| ACL_CHECK(aclrtMemcpy(topk_w_dev.get(), S*K*2, tw_bf.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| int64_t TOTAL = S * K; |
| DeviceBuffer expanded_x_dev(TOTAL * D * 2); |
| DeviceBuffer expanded_row_idx_dev(TOTAL * 4); |
| DeviceBuffer tokens_per_expert_dev(E * 8); |
|
|
| auto t_ex_x = make_contig_tensor(expanded_x_dev.get(), ACL_BF16, {TOTAL, D}); |
| auto t_ex_ri = make_contig_tensor(expanded_row_idx_dev.get(), ACL_INT32, {TOTAL}); |
| auto t_tpe = make_contig_tensor(tokens_per_expert_dev.get(), ACL_INT64, {E}); |
|
|
| moe_init_routing_v3(rt.stream(), |
| t_xn.get(), t_topk_idx.get(), |
| E, TOTAL, |
| t_ex_x.get(), t_ex_ri.get(), t_tpe.get()); |
| rt.sync(); |
| printf("[dbg] moe_init_routing ok\n"); fflush(stdout); |
|
|
| |
| DeviceBuffer tpe_cumsum_dev(E * 8); |
| { |
| std::vector<int64_t> h_counts(E), h_cum(E); |
| ACL_CHECK(aclrtMemcpy(h_counts.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); |
| int64_t acc = 0; |
| for (int i = 0; i < E; i++) { acc += h_counts[i]; h_cum[i] = acc; } |
| ACL_CHECK(aclrtMemcpy(tpe_cumsum_dev.get(), E*8, h_cum.data(), E*8, ACL_MEMCPY_HOST_TO_DEVICE)); |
| } |
| auto t_tpe_cum = make_contig_tensor(tpe_cumsum_dev.get(), ACL_INT64, {E}); |
|
|
| |
| DeviceBuffer gate_out_dev(TOTAL * I * 2); |
| DeviceBuffer up_out_dev(TOTAL * I * 2); |
| auto t_gate_out = make_contig_tensor(gate_out_dev.get(), ACL_BF16, {TOTAL, I}); |
| auto t_up_out = make_contig_tensor(up_out_dev.get(), ACL_BF16, {TOTAL, I}); |
| |
| auto t_w_gate = make_contig_tensor(moe.gate_exps.get(), ACL_BF16, {E, D, I}); |
| auto t_w_up = make_contig_tensor(moe.up_exps.get(), ACL_BF16, {E, D, I}); |
| |
| grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_gate.get(), t_tpe_cum.get(), t_gate_out.get(), 0); |
| rt.sync(); |
| printf("[dbg] gmm gate ok\n"); fflush(stdout); |
| grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_up.get(), t_tpe_cum.get(), t_up_out.get(), 0); |
| rt.sync(); |
| printf("[dbg] gmm up ok\n"); fflush(stdout); |
|
|
| |
| |
| silu(rt.stream(), t_gate_out.get(), t_gate_out.get()); |
| rt.sync(); printf("[dbg] silu ok\n"); fflush(stdout); |
| mul(rt.stream(), t_gate_out.get(), t_up_out.get(), t_gate_out.get()); |
| rt.sync(); printf("[dbg] mul ok\n"); fflush(stdout); |
| |
|
|
| |
| DeviceBuffer down_out_dev(TOTAL * D * 2); |
| auto t_down_out = make_contig_tensor(down_out_dev.get(), ACL_BF16, {TOTAL, D}); |
| auto t_w_down = make_contig_tensor(moe.down_exps.get(), ACL_BF16, {E, I, D}); |
| grouped_matmul_v4(rt.stream(), t_gate_out.get(), t_w_down.get(), t_tpe_cum.get(), t_down_out.get(), 0); |
| rt.sync(); |
| printf("[dbg] gmm down ok\n"); fflush(stdout); |
|
|
| |
| |
| |
| |
| DeviceBuffer fwd_dev(TOTAL * 8); |
| { |
| std::vector<int64_t> h_tpe2(E); |
| std::vector<int32_t> h_tidx3(S * K); |
| ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tidx3.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| |
| |
| |
| |
| |
| |
| std::vector<int64_t> expert_base(E + 1, 0); |
| for (int e = 0; e < E; e++) expert_base[e + 1] = expert_base[e] + h_tpe2[e]; |
|
|
| std::vector<int> expert_slot(E, 0); |
| std::vector<int64_t> fwd(TOTAL); |
| |
| |
| |
| std::vector<std::tuple<int, int, int>> triples; |
| triples.reserve(TOTAL); |
| for (int n = 0; n < S; n++) for (int k = 0; k < K; k++) { |
| triples.emplace_back(h_tidx3[n * K + k], n, k); |
| } |
| std::sort(triples.begin(), triples.end(), [](const auto& a, const auto& b){ |
| if (std::get<0>(a) != std::get<0>(b)) return std::get<0>(a) < std::get<0>(b); |
| return std::get<1>(a) < std::get<1>(b); |
| }); |
| for (int64_t p = 0; p < TOTAL; p++) { |
| auto [e, n, k] = triples[p]; |
| fwd[n * K + k] = p; |
| } |
| ACL_CHECK(aclrtMemcpy(fwd_dev.get(), TOTAL*8, fwd.data(), TOTAL*8, ACL_MEMCPY_HOST_TO_DEVICE)); |
| } |
| auto t_fwd = make_contig_tensor(fwd_dev.get(), ACL_INT64, {TOTAL}); |
|
|
| |
| DeviceBuffer packed_dev(TOTAL * D * 2); |
| auto t_packed = make_contig_tensor(packed_dev.get(), ACL_BF16, {TOTAL, D}); |
| index_select(rt.stream(), t_down_out.get(), 0, t_fwd.get(), t_packed.get()); |
| rt.sync(); |
|
|
| |
| auto t_packed_3d = make_contig_tensor(packed_dev.get(), ACL_BF16, {S, K, D}); |
| auto t_topk_w_3d = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K, 1}); |
| DeviceBuffer weighted_dev(S * K * D * 2); |
| auto t_weighted = make_contig_tensor(weighted_dev.get(), ACL_BF16, {S, K, D}); |
| mul(rt.stream(), t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get()); |
| rt.sync(); |
|
|
| |
| { |
| std::vector<uint16_t> h_pk_all(S * K * D); |
| std::vector<uint16_t> h_wt_all(S * K * D); |
| std::vector<uint16_t> h_tw_all(S * K); |
| ACL_CHECK(aclrtMemcpy(h_pk_all.data(), S*K*D*2, packed_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_wt_all.data(), S*K*D*2, weighted_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tw_all.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| printf(" verify weighted[0, k, 0] = packed[0, k, 0] * topk_w[0, k] for all k:\n"); |
| float host_sum = 0; |
| for (int k = 0; k < K; k++) { |
| float p = bf16_to_float(h_pk_all[k * D]); |
| float w = bf16_to_float(h_tw_all[k]); |
| float wt = bf16_to_float(h_wt_all[k * D]); |
| host_sum += p * w; |
| printf(" k=%d: packed=%.5f * topk_w=%.5f = expect=%.5f dev=%.5f\n", |
| k, p, w, p*w, wt); |
| } |
| printf(" host_sum_of_weighted[0, :, 0] = %.5f (expected moe_out[0,0] = -0.02466)\n", host_sum); |
| } |
|
|
| |
| DeviceBuffer moe_out_dev(S * D * 2); |
| auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D}); |
| reduce_sum(rt.stream(), t_weighted.get(), {1}, false, ACL_BF16, t_moe_out.get()); |
| rt.sync(); |
| printf("[dbg] device-side finalize (gather+mul+reduce) ok\n"); fflush(stdout); |
|
|
| |
| float alpha_v = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT); |
| DeviceBuffer final_dev(S * D * 2); |
| auto t_final = make_contig_tensor(final_dev.get(), ACL_BF16, {S, D}); |
| auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D}); |
| { |
| uint64_t ws = 0; aclOpExecutor* e = nullptr; |
| ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_moe_out.get(), alpha, t_final.get(), &ws, &e)); |
| DeviceBuffer wb; if (ws > 0) wb.alloc(ws); |
| ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream())); |
| } |
| aclDestroyScalar(alpha); |
| rt.sync(); |
|
|
| |
| auto compare_bf16 = [&](const char* label, void* dev_ptr, int64_t nelem, |
| const std::string& ref_file) { |
| std::vector<uint16_t> cxx(nelem); |
| ACL_CHECK(aclrtMemcpy(cxx.data(), nelem*2, dev_ptr, nelem*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| auto refbuf = read_file(data_dir + "/" + ref_file); |
| auto* ref = (const uint16_t*)refbuf.data(); |
| double l2d = 0, l2r = 0, maxd = 0; |
| for (int64_t i = 0; i < nelem; i++) { |
| float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf(" [cmp] %-12s rel=%.4e max_abs=%.4f cxx[:4]=%.5f %.5f %.5f %.5f ref[:4]=%.5f %.5f %.5f %.5f\n", |
| label, rel, maxd, |
| bf16_to_float(cxx[0]), bf16_to_float(cxx[1]), bf16_to_float(cxx[2]), bf16_to_float(cxx[3]), |
| bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3])); |
| return rel; |
| }; |
|
|
| printf("\n=== Intermediate diagnostics ===\n"); |
| compare_bf16("xn", xn_dev.get(), S * D, "xn.bin"); |
| compare_bf16("topk_w", topk_w_dev.get(), S * K, "topk_w.bin"); |
|
|
| |
| { |
| std::vector<int32_t> cxx_idx(S*K); |
| ACL_CHECK(aclrtMemcpy(cxx_idx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| auto refbuf = read_file(data_dir + "/topk_idx.bin"); |
| auto* ref = (const int32_t*)refbuf.data(); |
| int mismatches = 0; |
| for (int i = 0; i < S*K; i++) if (cxx_idx[i] != ref[i]) mismatches++; |
| printf(" [cmp] topk_idx mismatches=%d/%d cxx[0,:4]=%d %d %d %d ref[0,:4]=%d %d %d %d\n", |
| mismatches, S*K, |
| cxx_idx[0], cxx_idx[1], cxx_idx[2], cxx_idx[3], |
| ref[0], ref[1], ref[2], ref[3]); |
| } |
|
|
| printf("\n=== MoE-only (before residual) ===\n"); |
| compare_bf16("moe_out", moe_out_dev.get(), S * D, "out_flat.bin"); |
|
|
| |
| { |
| std::vector<uint16_t> h_down(TOTAL * D); |
| std::vector<int32_t> h_ri(TOTAL); |
| std::vector<uint16_t> h_tw(S * K); |
| ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_ri.data(), TOTAL*4, expanded_row_idx_dev.get(), TOTAL*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| printf(" expanded_row_idx (all %ld):\n ", TOTAL); |
| for (int i = 0; i < TOTAL; i++) { |
| printf("%d ", h_ri[i]); |
| if ((i+1) % 10 == 0) printf("\n "); |
| } |
| printf("\n"); |
| |
| std::vector<int> count(TOTAL, 0); |
| int out_of_range = 0; |
| for (int i = 0; i < TOTAL; i++) { |
| int v = h_ri[i]; |
| if (v >= 0 && v < TOTAL) count[v]++; |
| else out_of_range++; |
| } |
| int bijection_ok = (out_of_range == 0); |
| for (int i = 0; i < TOTAL && bijection_ok; i++) if (count[i] != 1) bijection_ok = 0; |
| printf(" bijection=%s out_of_range=%d\n", bijection_ok ? "YES" : "NO", out_of_range); |
|
|
| |
| std::vector<int64_t> h_tpe(E); |
| ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); |
| int64_t tpe_sum = 0, nonzero = 0; |
| int64_t tpe_max = 0; |
| for (int i = 0; i < E; i++) { tpe_sum += h_tpe[i]; if (h_tpe[i]>0) nonzero++; if (h_tpe[i]>tpe_max) tpe_max=h_tpe[i]; } |
| printf(" tokens_per_expert: sum=%ld nonzero=%ld max=%ld (expected sum=%ld if counts, or last=%ld if cumsum)\n", |
| tpe_sum, nonzero, tpe_max, TOTAL, TOTAL); |
| printf(" tpe[last 4]: %ld %ld %ld %ld\n", h_tpe[E-4], h_tpe[E-3], h_tpe[E-2], h_tpe[E-1]); |
|
|
| std::vector<float> manual(S * D, 0.0f); |
| for (int64_t p = 0; p < TOTAL; p++) { |
| int32_t src = h_ri[p]; |
| int s = src / K; |
| int k = src % K; |
| if (s < 0 || s >= S || k < 0 || k >= K) { printf(" bad idx p=%ld src=%d\n", p, src); continue; } |
| float w = bf16_to_float(h_tw[s * K + k]); |
| for (int d = 0; d < D; d++) { |
| manual[s * D + d] += w * bf16_to_float(h_down[p * D + d]); |
| } |
| } |
| |
| auto refbuf = read_file(data_dir + "/out_flat.bin"); |
| auto* ref = (const uint16_t*)refbuf.data(); |
| double l2d=0, l2r=0, maxd=0; |
| for (int64_t i = 0; i < S*D; i++) { |
| float a = manual[i], b = bf16_to_float(ref[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel_manual = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf(" [cmp] MANUAL(row_idx=srcβflat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f r[:4]=%.5f %.5f %.5f %.5f\n", |
| rel_manual, maxd, |
| manual[0], manual[1], manual[2], manual[3], |
| bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3])); |
|
|
| |
| |
| std::vector<float> manual2(S * D, 0.0f); |
| for (int64_t p = 0; p < TOTAL; p++) { |
| int32_t dst = h_ri[p]; |
| int s = dst / K; |
| int k = dst % K; |
| if (s < 0 || s >= S || k < 0 || k >= K) continue; |
| float w = bf16_to_float(h_tw[s * K + k]); |
| for (int d = 0; d < D; d++) { |
| manual2[s * D + d] += w * bf16_to_float(h_down[p * D + d]); |
| } |
| } |
| double l2d2=0, l2r2=0, maxd2=0; |
| for (int64_t i = 0; i < S*D; i++) { |
| float a = manual2[i], b = bf16_to_float(ref[i]); |
| l2d2 += (a-b)*(a-b); l2r2 += b*b; |
| if (std::abs(a-b) > maxd2) maxd2 = std::abs(a-b); |
| } |
| double rel_manual2 = std::sqrt(l2d2) / (std::sqrt(l2r2) + 1e-10); |
| printf(" [cmp] MANUAL(row_idx=pβdst_flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f\n", |
| rel_manual2, maxd2, |
| manual2[0], manual2[1], manual2[2], manual2[3]); |
| } |
|
|
| |
| |
| |
| { |
| std::vector<uint16_t> h_down(TOTAL * D); |
| std::vector<int64_t> h_tpe(E); |
| std::vector<int32_t> h_tidx(S * K); |
| std::vector<uint16_t> h_tw(S * K); |
| std::vector<uint16_t> h_xn_all(S * D); |
| std::vector<uint16_t> h_ex_all(TOTAL * D); |
| ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tidx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| |
| |
| std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL); |
| int64_t cum = 0; |
| int cursor_e = 0; |
| for (int64_t p = 0; p < TOTAL; p++) { |
| while (cursor_e < E && p >= cum + h_tpe[cursor_e]) { cum += h_tpe[cursor_e]; cursor_e++; } |
| p_to_e[p] = cursor_e; |
| float ev = bf16_to_float(h_ex_all[p * D]); |
| int best = -1; float bd = 1e30f; |
| for (int s = 0; s < S; s++) { |
| float df = std::abs(bf16_to_float(h_xn_all[s * D]) - ev); |
| if (df < bd) { bd = df; best = s; } |
| } |
| p_to_s[p] = best; |
| } |
|
|
| |
| std::vector<float> manual_cum(S * D, 0.0f); |
| int found_count = 0; |
| for (int n = 0; n < S; n++) { |
| for (int k = 0; k < K; k++) { |
| int e = h_tidx[n * K + k]; |
| float w = bf16_to_float(h_tw[n * K + k]); |
| |
| int found_p = -1; |
| for (int64_t p = 0; p < TOTAL; p++) { |
| if (p_to_s[p] == n && p_to_e[p] == e) { found_p = p; break; } |
| } |
| if (found_p < 0) { |
| printf(" [!!!] not found: n=%d k=%d expert=%d\n", n, k, e); |
| continue; |
| } |
| found_count++; |
| for (int d = 0; d < D; d++) |
| manual_cum[n * D + d] += w * bf16_to_float(h_down[found_p * D + d]); |
| } |
| } |
| auto refbuf = read_file(data_dir + "/out_flat.bin"); |
| auto* ref = (const uint16_t*)refbuf.data(); |
| double l2d=0, l2r=0, maxd=0; |
| for (int64_t i = 0; i < S*D; i++) { |
| float a = manual_cum[i], b = bf16_to_float(ref[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel_cum = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf(" [cmp] MANUAL_CUMSUM (p via expert cumsum) rel=%.4e max=%.4f found=%d/40 m[:4]=%.5f %.5f %.5f %.5f\n", |
| rel_cum, maxd, found_count, manual_cum[0], manual_cum[1], manual_cum[2], manual_cum[3]); |
| } |
|
|
| |
| { |
| std::vector<uint16_t> h_xn_all(S * D); |
| ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| std::vector<uint16_t> h_ex_all(TOTAL * D); |
| ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| printf(" xn[s, 0]: "); |
| for (int s = 0; s < S; s++) printf("%.5f ", bf16_to_float(h_xn_all[s * D])); |
| printf("\n expanded_x[p, 0]: "); |
| for (int p = 0; p < TOTAL; p++) printf("%.5f ", bf16_to_float(h_ex_all[p * D])); |
| printf("\n mapping pβs (by matching expanded_x[p,0] to xn[s,0]): "); |
| for (int p = 0; p < TOTAL; p++) { |
| float e = bf16_to_float(h_ex_all[p * D]); |
| int match = -1; float best = 1e30f; |
| for (int s = 0; s < S; s++) { |
| float df = std::abs(bf16_to_float(h_xn_all[s * D]) - e); |
| if (df < best) { best = df; match = s; } |
| } |
| printf("%d ", match); |
| } |
| printf("\n"); |
| } |
|
|
| |
| { |
| std::vector<uint16_t> h_gate(I); |
| |
| |
| std::vector<uint16_t> h_d(D); |
| ACL_CHECK(aclrtMemcpy(h_d.data(), D*2, (char*)down_out_dev.get() + 4*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| printf(" down_out[p=4, :4] (s=0, k=0, expert=10): %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(h_d[0]), bf16_to_float(h_d[1]), bf16_to_float(h_d[2]), bf16_to_float(h_d[3])); |
| |
| |
| } |
|
|
| |
| |
| |
| |
| |
| { |
| std::vector<int32_t> h_tidx_local(S * K); |
| ACL_CHECK(aclrtMemcpy(h_tidx_local.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); |
| int target_expert = h_tidx_local[0 * K + 0]; |
| printf("\n === Single-expert linear_hf vs GMM sanity (token 0, expert %d) ===\n", target_expert); |
|
|
| |
| std::vector<int64_t> h_tpe2(E); |
| std::vector<uint16_t> h_xn_all2(S * D); |
| std::vector<uint16_t> h_ex_all2(TOTAL * D); |
| ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_xn_all2.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| ACL_CHECK(aclrtMemcpy(h_ex_all2.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL); |
| { |
| int64_t cum = 0; int ce = 0; |
| for (int64_t p = 0; p < TOTAL; p++) { |
| while (ce < E && p >= cum + h_tpe2[ce]) { cum += h_tpe2[ce]; ce++; } |
| p_to_e[p] = ce; |
| float ev = bf16_to_float(h_ex_all2[p * D]); |
| int best = -1; float bd = 1e30f; |
| for (int s = 0; s < S; s++) { |
| float df = std::abs(bf16_to_float(h_xn_all2[s * D]) - ev); |
| if (df < bd) { bd = df; best = s; } |
| } |
| p_to_s[p] = best; |
| } |
| } |
|
|
| DeviceBuffer g_w, u_w, d_w; |
| char ename[256]; |
| snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", target_expert); |
| if (!dw.st().get(ename)) { printf(" missing %s\n", ename); goto after_sanity; } |
|
|
| |
| |
| { |
| auto* m_gate = dw.st().get(ename); |
| DeviceBuffer gw_buf(m_gate->nbytes); |
| ACL_CHECK(aclrtMemcpy(gw_buf.get(), m_gate->nbytes, dw.st().data_ptr(*m_gate), m_gate->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
| g_w = std::move(gw_buf); |
|
|
| snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.up_proj.weight", target_expert); |
| auto* m_up = dw.st().get(ename); |
| DeviceBuffer uw_buf(m_up->nbytes); |
| ACL_CHECK(aclrtMemcpy(uw_buf.get(), m_up->nbytes, dw.st().data_ptr(*m_up), m_up->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
| u_w = std::move(uw_buf); |
|
|
| snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.down_proj.weight", target_expert); |
| auto* m_down = dw.st().get(ename); |
| DeviceBuffer dw_buf(m_down->nbytes); |
| ACL_CHECK(aclrtMemcpy(dw_buf.get(), m_down->nbytes, dw.st().data_ptr(*m_down), m_down->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); |
| d_w = std::move(dw_buf); |
| } |
|
|
| |
| DeviceBuffer xn0_dev(D * 2); |
| ACL_CHECK(aclrtMemcpy(xn0_dev.get(), D*2, xn_dev.get(), D*2, ACL_MEMCPY_DEVICE_TO_DEVICE)); |
|
|
| DeviceBuffer gate_v(I * 2), up_v(I * 2), act_v(I * 2), down_v(D * 2); |
| auto t_xn0 = make_contig_tensor(xn0_dev.get(), ACL_BF16, {1, D}); |
| auto t_gate = make_contig_tensor(gate_v.get(), ACL_BF16, {1, I}); |
| auto t_up = make_contig_tensor(up_v.get(), ACL_BF16, {1, I}); |
| auto t_act = make_contig_tensor(act_v.get(), ACL_BF16, {1, I}); |
| auto t_down = make_contig_tensor(down_v.get(), ACL_BF16, {1, D}); |
| linear_hf(rt.stream(), t_xn0.get(), g_w.get(), ACL_BF16, I, D, t_gate.get()); |
| linear_hf(rt.stream(), t_xn0.get(), u_w.get(), ACL_BF16, I, D, t_up.get()); |
| rt.sync(); |
| silu(rt.stream(), t_gate.get(), t_act.get()); |
| mul(rt.stream(), t_act.get(), t_up.get(), t_act.get()); |
| rt.sync(); |
| linear_hf(rt.stream(), t_act.get(), d_w.get(), ACL_BF16, D, I, t_down.get()); |
| rt.sync(); |
|
|
| std::vector<uint16_t> h_down_lin(D); |
| ACL_CHECK(aclrtMemcpy(h_down_lin.data(), D*2, down_v.get(), D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
|
|
| |
| int found_p = -1; |
| for (int64_t p = 0; p < TOTAL; p++) { |
| if (p_to_s[p] == 0 && p_to_e[p] == target_expert) { found_p = p; break; } |
| } |
| if (found_p >= 0) { |
| std::vector<uint16_t> h_down_gmm(D); |
| ACL_CHECK(aclrtMemcpy(h_down_gmm.data(), D*2, (char*)down_out_dev.get() + found_p*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| double l2d=0, l2r=0, maxd=0; |
| for (int i = 0; i < D; i++) { |
| float a = bf16_to_float(h_down_gmm[i]), b = bf16_to_float(h_down_lin[i]); |
| l2d += (a-b)*(a-b); l2r += b*b; |
| if (std::abs(a-b) > maxd) maxd = std::abs(a-b); |
| } |
| double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); |
| printf(" GMM down_out[p=%d] vs linear_hf down: rel=%.4e max=%.4f\n", found_p, rel, maxd); |
| printf(" GMM[:4]: %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(h_down_gmm[0]), bf16_to_float(h_down_gmm[1]), bf16_to_float(h_down_gmm[2]), bf16_to_float(h_down_gmm[3])); |
| printf(" linear[:4]: %.5f %.5f %.5f %.5f\n", |
| bf16_to_float(h_down_lin[0]), bf16_to_float(h_down_lin[1]), bf16_to_float(h_down_lin[2]), bf16_to_float(h_down_lin[3])); |
| } else { |
| printf(" not found p for (s=0, expert=%d)\n", target_expert); |
| } |
| } |
| after_sanity:; |
|
|
| |
| { |
| int expert_id = 10; |
| std::vector<uint16_t> h_stacked(4 * 4); |
| |
| |
| for (int d = 0; d < 4; d++) { |
| ACL_CHECK(aclrtMemcpy(h_stacked.data() + d*4, 8, |
| (char*)moe.gate_exps.get() + (expert_id * D * I + d * I) * 2, 8, |
| ACL_MEMCPY_DEVICE_TO_HOST)); |
| } |
| char ename[256]; |
| snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", expert_id); |
| auto* m = dw.st().get(ename); |
| |
| |
| |
| std::vector<uint16_t> h_expected(4 * 4); |
| auto* hf = (const uint16_t*)dw.st().data_ptr(*m); |
| for (int d = 0; d < 4; d++) { |
| for (int i = 0; i < 4; i++) { |
| h_expected[d*4 + i] = hf[i * D + d]; |
| } |
| } |
| printf("\n === gate_exps[10, :4, :4] layout check ===\n"); |
| printf(" stacked: "); |
| for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_stacked[i])); |
| printf("\n expected: "); |
| for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_expected[i])); |
| printf("\n"); |
| int mism = 0; |
| for (int i = 0; i < 16; i++) if (h_stacked[i] != h_expected[i]) mism++; |
| printf(" mismatches: %d / 16\n", mism); |
| } |
|
|
| printf("\n=== Final (with residual) ===\n"); |
| double rel = compare_bf16("final_out", final_dev.get(), S * D, "final_out.bin"); |
| bool pass = rel < 5e-2; |
| printf("\n%s\n", pass ? "=== test_moe_layer PASS ===" : "=== test_moe_layer FAIL ==="); |
| return pass ? 0 : 1; |
| } |
|
|