// test_moe_layer.cpp — Full MoE layer forward (Qwen3-235B layer 0), TP=1. // // Pipeline: // 1. Post-attention RmsNorm (input from attn_data/final_out.bin) // 2. Router: xn @ W_router.T → logits [S, E] // 3. TopK softmax → weights [S, K], expert_ids [S, K] // 4. Host-normalize top_k weights (Qwen3 norm_topk_prob) // 5. MoeInitRoutingV3 → expanded_x [S*K, D], expanded_row_idx, tokens_per_expert // 6. GMM gate: expanded_x × gate_exps → [S*K, I] // 7. GMM up: same → [S*K, I] // 8. silu(gate) * up → [S*K, I] // 9. GMM down: act × down_exps → [S*K, D] // 10. MoeFinalizeRouting (weighted sum) → [S, D] // 11. + residual #include "acl_common.h" #include "acl_runtime.h" #include "aclnn_ops.h" #include "device_weights.h" #include "model_config.h" #include "safetensors_loader.h" #include #include #include #include #include #include #include static float bf16_to_float(uint16_t x) { uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f; } static uint16_t float_to_bf16(float x) { uint32_t u; std::memcpy(&u, &x, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } static std::vector read_file(const std::string& p) { std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg(); f.seekg(0); std::vector v(s); f.read((char*)v.data(), s); return v; } int main() { const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16"; const std::string data_dir = "tests/moe_data"; ModelConfig cfg; if (!cfg.load_from_json(model_dir + "/config.json")) return 1; cfg.compute_derived(1, 0); // TP=1 const int64_t D = cfg.hidden_size; const int64_t I = cfg.moe_intermediate_size; const int64_t E = cfg.num_experts; const int64_t K = cfg.num_experts_per_tok; const double eps = cfg.rms_norm_eps; AclRuntime rt; rt.init(0); printf("[dbg] rt init ok\n"); fflush(stdout); SafetensorsLoader st; if (!st.open(model_dir)) return 1; // ---- Load weights ---- printf("Loading layer 0 attention weights (for post_attention_layernorm)...\n"); DeviceWeightsLoader dw(st, cfg); LayerAttnWeights attn; if (!dw.load_attention(0, attn)) return 1; printf("Loading layer 0 MoE weights (128 experts × 3 projections, stacking + permute)...\n"); fflush(stdout); LayerMoEWeights moe; if (!dw.load_moe(0, rt.stream(), moe)) return 1; rt.sync(); printf("[dbg] moe load ok\n"); fflush(stdout); printf(" router %.1f MB gate_exps %.0f MB up_exps %.0f MB down_exps %.0f MB\n", moe.router.size / 1e6, moe.gate_exps.size / 1e6, moe.up_exps.size / 1e6, moe.down_exps.size / 1e6); // ---- Load input & Python reference ---- int S = 5; auto x_in_host = read_file(data_dir + "/x_in.bin"); auto ref_out_host = read_file(data_dir + "/final_out.bin"); DeviceBuffer x_dev(S * D * 2); ACL_CHECK(aclrtMemcpy(x_dev.get(), x_in_host.size(), x_in_host.data(), x_in_host.size(), ACL_MEMCPY_HOST_TO_DEVICE)); // Residual snapshot DeviceBuffer residual_dev(S * D * 2); ACL_CHECK(aclrtMemcpy(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE)); printf("[dbg] loaded data and residual ok, TOTAL=%ld\n", S * K); fflush(stdout); // ---- Step 1: Post-attention RmsNorm ---- DeviceBuffer xn_dev(S * D * 2); DeviceBuffer rstd_dev(S * 4); auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D}); auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D}); auto t_ln = make_contig_tensor(attn.post_attention_layernorm.get(), ACL_BF16, {D}); auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S}); rms_norm(rt.stream(), t_x.get(), t_ln.get(), eps, t_xn.get(), t_rstd.get()); rt.sync(); printf("[dbg] rms_norm ok\n"); fflush(stdout); // ---- Step 2: Router (gate matmul) ---- DeviceBuffer logits_dev(S * E * 2); auto t_logits = make_contig_tensor(logits_dev.get(), ACL_BF16, {S, E}); // router is [E, D] (HF). logits = xn @ router.T linear_hf(rt.stream(), t_xn.get(), moe.router.get(), ACL_BF16, E, D, t_logits.get()); rt.sync(); printf("[dbg] router linear ok\n"); fflush(stdout); // ---- Step 3: TopK softmax ---- DeviceBuffer topk_w_dev(S * K * 2); // BF16 DeviceBuffer topk_idx_dev(S * K * 4); // int32 DeviceBuffer row_idx_dev(S * K * 4); // int32 (from gating op, unused for our routing) auto t_topk_w = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K}); auto t_topk_idx = make_contig_tensor(topk_idx_dev.get(), ACL_INT32, {S, K}); auto t_row_idx = make_contig_tensor(row_idx_dev.get(), ACL_INT32, {S, K}); moe_gating_topk_softmax(rt.stream(), t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get()); rt.sync(); printf("[dbg] topk_softmax ok\n"); fflush(stdout); // ---- Step 4: Host-normalize top_k weights (norm_topk_prob=true) ---- std::vector tw_bf(S * K); ACL_CHECK(aclrtMemcpy(tw_bf.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); for (int s = 0; s < S; s++) { float sum = 0.0f; for (int k = 0; k < K; k++) sum += bf16_to_float(tw_bf[s*K + k]); sum += 1e-20f; for (int k = 0; k < K; k++) { float v = bf16_to_float(tw_bf[s*K + k]) / sum; tw_bf[s*K + k] = float_to_bf16(v); } } ACL_CHECK(aclrtMemcpy(topk_w_dev.get(), S*K*2, tw_bf.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE)); // ---- Step 5: MoE init routing ---- int64_t TOTAL = S * K; DeviceBuffer expanded_x_dev(TOTAL * D * 2); DeviceBuffer expanded_row_idx_dev(TOTAL * 4); DeviceBuffer tokens_per_expert_dev(E * 8); auto t_ex_x = make_contig_tensor(expanded_x_dev.get(), ACL_BF16, {TOTAL, D}); auto t_ex_ri = make_contig_tensor(expanded_row_idx_dev.get(), ACL_INT32, {TOTAL}); auto t_tpe = make_contig_tensor(tokens_per_expert_dev.get(), ACL_INT64, {E}); moe_init_routing_v3(rt.stream(), t_xn.get(), t_topk_idx.get(), E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get()); rt.sync(); printf("[dbg] moe_init_routing ok\n"); fflush(stdout); // Convert tokens_per_expert from counts to cumsum (on host) for GMM groupListType=0. DeviceBuffer tpe_cumsum_dev(E * 8); { std::vector h_counts(E), h_cum(E); ACL_CHECK(aclrtMemcpy(h_counts.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); int64_t acc = 0; for (int i = 0; i < E; i++) { acc += h_counts[i]; h_cum[i] = acc; } ACL_CHECK(aclrtMemcpy(tpe_cumsum_dev.get(), E*8, h_cum.data(), E*8, ACL_MEMCPY_HOST_TO_DEVICE)); } auto t_tpe_cum = make_contig_tensor(tpe_cumsum_dev.get(), ACL_INT64, {E}); // ---- Step 6/7: GMM gate and up ---- DeviceBuffer gate_out_dev(TOTAL * I * 2); DeviceBuffer up_out_dev(TOTAL * I * 2); auto t_gate_out = make_contig_tensor(gate_out_dev.get(), ACL_BF16, {TOTAL, I}); auto t_up_out = make_contig_tensor(up_out_dev.get(), ACL_BF16, {TOTAL, I}); // gate/up_exps loaded as [E, D, I] row-major auto t_w_gate = make_contig_tensor(moe.gate_exps.get(), ACL_BF16, {E, D, I}); auto t_w_up = make_contig_tensor(moe.up_exps.get(), ACL_BF16, {E, D, I}); // Use cumsum group_list (groupListType=0): empirically more reliable with many zero-count experts. grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_gate.get(), t_tpe_cum.get(), t_gate_out.get(), 0); rt.sync(); printf("[dbg] gmm gate ok\n"); fflush(stdout); grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_up.get(), t_tpe_cum.get(), t_up_out.get(), 0); rt.sync(); printf("[dbg] gmm up ok\n"); fflush(stdout); // ---- Step 8: SwiGLU ---- // act = silu(gate) * up (inplace on gate_out) silu(rt.stream(), t_gate_out.get(), t_gate_out.get()); rt.sync(); printf("[dbg] silu ok\n"); fflush(stdout); mul(rt.stream(), t_gate_out.get(), t_up_out.get(), t_gate_out.get()); rt.sync(); printf("[dbg] mul ok\n"); fflush(stdout); // now gate_out_dev contains the activated intermediate // ---- Step 9: GMM down ---- DeviceBuffer down_out_dev(TOTAL * D * 2); auto t_down_out = make_contig_tensor(down_out_dev.get(), ACL_BF16, {TOTAL, D}); auto t_w_down = make_contig_tensor(moe.down_exps.get(), ACL_BF16, {E, I, D}); grouped_matmul_v4(rt.stream(), t_gate_out.get(), t_w_down.get(), t_tpe_cum.get(), t_down_out.get(), 0); rt.sync(); printf("[dbg] gmm down ok\n"); fflush(stdout); // ---- Step 10: Device-side manual finalize (replacement for buggy MoeFinalizeRoutingV2) ---- // Compute forward permutation fwd[n*K + k] = p where token n's k-th expert's output is at // expanded position p. We use tokens_per_expert (cumsum) + topk_idx to resolve this correctly, // regardless of the exact rowIdxType semantics returned by MoeInitRoutingV3. DeviceBuffer fwd_dev(TOTAL * 8); { std::vector h_tpe2(E); std::vector h_tidx3(S * K); ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tidx3.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); // Sort (n, k) pairs by expert ascending (stable). For each expert in order, tokens // appear in ascending token index (since MoeInitRoutingV3 is stable by s). // Specifically: expanded positions 0..tpe[0]-1 are for expert 0 (tokens picking e=0, in n-ascending order), // next tpe[1] are for expert 1, etc. // // To build fwd: for each (n, k), expert e = topk_idx[n, k]. Position p is the base of expert e's // block plus the rank of n within tokens picking e. std::vector expert_base(E + 1, 0); for (int e = 0; e < E; e++) expert_base[e + 1] = expert_base[e] + h_tpe2[e]; std::vector expert_slot(E, 0); // next available slot per expert std::vector fwd(TOTAL); // Iterate in token-ascending, k-ascending order — match MoeInitRoutingV3's stable sort convention. // For each (n, k) sorted by (expert[n,k], n), assign p. // Simpler: pre-collect (e, n, k) triples, sort by (e, n), then p is the rank. std::vector> triples; triples.reserve(TOTAL); for (int n = 0; n < S; n++) for (int k = 0; k < K; k++) { triples.emplace_back(h_tidx3[n * K + k], n, k); } std::sort(triples.begin(), triples.end(), [](const auto& a, const auto& b){ if (std::get<0>(a) != std::get<0>(b)) return std::get<0>(a) < std::get<0>(b); return std::get<1>(a) < std::get<1>(b); }); for (int64_t p = 0; p < TOTAL; p++) { auto [e, n, k] = triples[p]; fwd[n * K + k] = p; } ACL_CHECK(aclrtMemcpy(fwd_dev.get(), TOTAL*8, fwd.data(), TOTAL*8, ACL_MEMCPY_HOST_TO_DEVICE)); } auto t_fwd = make_contig_tensor(fwd_dev.get(), ACL_INT64, {TOTAL}); // Gather: packed [S*K, D] = down_out[fwd, :] DeviceBuffer packed_dev(TOTAL * D * 2); auto t_packed = make_contig_tensor(packed_dev.get(), ACL_BF16, {TOTAL, D}); index_select(rt.stream(), t_down_out.get(), 0, t_fwd.get(), t_packed.get()); rt.sync(); // Broadcast-multiply by topk_w: view packed as [S, K, D], topk_w as [S, K, 1]. auto t_packed_3d = make_contig_tensor(packed_dev.get(), ACL_BF16, {S, K, D}); auto t_topk_w_3d = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K, 1}); DeviceBuffer weighted_dev(S * K * D * 2); auto t_weighted = make_contig_tensor(weighted_dev.get(), ACL_BF16, {S, K, D}); mul(rt.stream(), t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get()); rt.sync(); // Verify broadcast mul + sum by dumping all k entries and summing on host. { std::vector h_pk_all(S * K * D); std::vector h_wt_all(S * K * D); std::vector h_tw_all(S * K); ACL_CHECK(aclrtMemcpy(h_pk_all.data(), S*K*D*2, packed_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_wt_all.data(), S*K*D*2, weighted_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tw_all.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); printf(" verify weighted[0, k, 0] = packed[0, k, 0] * topk_w[0, k] for all k:\n"); float host_sum = 0; for (int k = 0; k < K; k++) { float p = bf16_to_float(h_pk_all[k * D]); // packed[0, k, 0] = offset s*K*D + k*D + 0 = k*D (for s=0) float w = bf16_to_float(h_tw_all[k]); // topk_w[0, k] float wt = bf16_to_float(h_wt_all[k * D]); // weighted[0, k, 0] host_sum += p * w; printf(" k=%d: packed=%.5f * topk_w=%.5f = expect=%.5f dev=%.5f\n", k, p, w, p*w, wt); } printf(" host_sum_of_weighted[0, :, 0] = %.5f (expected moe_out[0,0] = -0.02466)\n", host_sum); } // ReduceSum over K axis → [S, D] DeviceBuffer moe_out_dev(S * D * 2); auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D}); reduce_sum(rt.stream(), t_weighted.get(), {1}, /*keep_dims=*/false, ACL_BF16, t_moe_out.get()); rt.sync(); printf("[dbg] device-side finalize (gather+mul+reduce) ok\n"); fflush(stdout); // Residual add to produce final_out float alpha_v = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT); DeviceBuffer final_dev(S * D * 2); auto t_final = make_contig_tensor(final_dev.get(), ACL_BF16, {S, D}); auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D}); { uint64_t ws = 0; aclOpExecutor* e = nullptr; ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_moe_out.get(), alpha, t_final.get(), &ws, &e)); DeviceBuffer wb; if (ws > 0) wb.alloc(ws); ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream())); } aclDestroyScalar(alpha); rt.sync(); // ---- Compare (intermediate + final) ---- auto compare_bf16 = [&](const char* label, void* dev_ptr, int64_t nelem, const std::string& ref_file) { std::vector cxx(nelem); ACL_CHECK(aclrtMemcpy(cxx.data(), nelem*2, dev_ptr, nelem*2, ACL_MEMCPY_DEVICE_TO_HOST)); auto refbuf = read_file(data_dir + "/" + ref_file); auto* ref = (const uint16_t*)refbuf.data(); double l2d = 0, l2r = 0, maxd = 0; for (int64_t i = 0; i < nelem; i++) { float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf(" [cmp] %-12s rel=%.4e max_abs=%.4f cxx[:4]=%.5f %.5f %.5f %.5f ref[:4]=%.5f %.5f %.5f %.5f\n", label, rel, maxd, bf16_to_float(cxx[0]), bf16_to_float(cxx[1]), bf16_to_float(cxx[2]), bf16_to_float(cxx[3]), bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3])); return rel; }; printf("\n=== Intermediate diagnostics ===\n"); compare_bf16("xn", xn_dev.get(), S * D, "xn.bin"); compare_bf16("topk_w", topk_w_dev.get(), S * K, "topk_w.bin"); // Dump topk_idx (int32) to compare { std::vector cxx_idx(S*K); ACL_CHECK(aclrtMemcpy(cxx_idx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); auto refbuf = read_file(data_dir + "/topk_idx.bin"); auto* ref = (const int32_t*)refbuf.data(); int mismatches = 0; for (int i = 0; i < S*K; i++) if (cxx_idx[i] != ref[i]) mismatches++; printf(" [cmp] topk_idx mismatches=%d/%d cxx[0,:4]=%d %d %d %d ref[0,:4]=%d %d %d %d\n", mismatches, S*K, cxx_idx[0], cxx_idx[1], cxx_idx[2], cxx_idx[3], ref[0], ref[1], ref[2], ref[3]); } printf("\n=== MoE-only (before residual) ===\n"); compare_bf16("moe_out", moe_out_dev.get(), S * D, "out_flat.bin"); // Manual host-side finalize: verify what down_out + expanded_row_idx + topk_w produce. { std::vector h_down(TOTAL * D); std::vector h_ri(TOTAL); std::vector h_tw(S * K); ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_ri.data(), TOTAL*4, expanded_row_idx_dev.get(), TOTAL*4, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); printf(" expanded_row_idx (all %ld):\n ", TOTAL); for (int i = 0; i < TOTAL; i++) { printf("%d ", h_ri[i]); if ((i+1) % 10 == 0) printf("\n "); } printf("\n"); // count unique and check bijection std::vector count(TOTAL, 0); int out_of_range = 0; for (int i = 0; i < TOTAL; i++) { int v = h_ri[i]; if (v >= 0 && v < TOTAL) count[v]++; else out_of_range++; } int bijection_ok = (out_of_range == 0); for (int i = 0; i < TOTAL && bijection_ok; i++) if (count[i] != 1) bijection_ok = 0; printf(" bijection=%s out_of_range=%d\n", bijection_ok ? "YES" : "NO", out_of_range); // Also dump tokens_per_expert (int64) — should sum to TOTAL std::vector h_tpe(E); ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); int64_t tpe_sum = 0, nonzero = 0; int64_t tpe_max = 0; for (int i = 0; i < E; i++) { tpe_sum += h_tpe[i]; if (h_tpe[i]>0) nonzero++; if (h_tpe[i]>tpe_max) tpe_max=h_tpe[i]; } printf(" tokens_per_expert: sum=%ld nonzero=%ld max=%ld (expected sum=%ld if counts, or last=%ld if cumsum)\n", tpe_sum, nonzero, tpe_max, TOTAL, TOTAL); printf(" tpe[last 4]: %ld %ld %ld %ld\n", h_tpe[E-4], h_tpe[E-3], h_tpe[E-2], h_tpe[E-1]); std::vector manual(S * D, 0.0f); for (int64_t p = 0; p < TOTAL; p++) { int32_t src = h_ri[p]; int s = src / K; int k = src % K; if (s < 0 || s >= S || k < 0 || k >= K) { printf(" bad idx p=%ld src=%d\n", p, src); continue; } float w = bf16_to_float(h_tw[s * K + k]); for (int d = 0; d < D; d++) { manual[s * D + d] += w * bf16_to_float(h_down[p * D + d]); } } // Convert to bf16 and compare to Python out_flat auto refbuf = read_file(data_dir + "/out_flat.bin"); auto* ref = (const uint16_t*)refbuf.data(); double l2d=0, l2r=0, maxd=0; for (int64_t i = 0; i < S*D; i++) { float a = manual[i], b = bf16_to_float(ref[i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel_manual = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf(" [cmp] MANUAL(row_idx=src→flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f r[:4]=%.5f %.5f %.5f %.5f\n", rel_manual, maxd, manual[0], manual[1], manual[2], manual[3], bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3])); // Alternative semantic: row_idx[p] = destination position // In that case: p=src_row, dst=h_ri[p] std::vector manual2(S * D, 0.0f); for (int64_t p = 0; p < TOTAL; p++) { int32_t dst = h_ri[p]; int s = dst / K; int k = dst % K; if (s < 0 || s >= S || k < 0 || k >= K) continue; float w = bf16_to_float(h_tw[s * K + k]); for (int d = 0; d < D; d++) { manual2[s * D + d] += w * bf16_to_float(h_down[p * D + d]); } } double l2d2=0, l2r2=0, maxd2=0; for (int64_t i = 0; i < S*D; i++) { float a = manual2[i], b = bf16_to_float(ref[i]); l2d2 += (a-b)*(a-b); l2r2 += b*b; if (std::abs(a-b) > maxd2) maxd2 = std::abs(a-b); } double rel_manual2 = std::sqrt(l2d2) / (std::sqrt(l2r2) + 1e-10); printf(" [cmp] MANUAL(row_idx=p→dst_flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f\n", rel_manual2, maxd2, manual2[0], manual2[1], manual2[2], manual2[3]); } // Manual finalize using cumsum (semantics-independent): // For each (n, k), find p such that actual_s(p)=n AND expert(p)=topk_idx[n,k], then // out[n] += topk_w[n,k] * down_out[p]. { std::vector h_down(TOTAL * D); std::vector h_tpe(E); std::vector h_tidx(S * K); std::vector h_tw(S * K); std::vector h_xn_all(S * D); std::vector h_ex_all(TOTAL * D); ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tidx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); // Build p → (actual_s, actual_expert). // actual_s: find s with xn[s,0] == expanded_x[p,0] // actual_expert: find e such that cumsum_tpe[e-1] <= p < cumsum_tpe[e] std::vector p_to_s(TOTAL), p_to_e(TOTAL); int64_t cum = 0; int cursor_e = 0; for (int64_t p = 0; p < TOTAL; p++) { while (cursor_e < E && p >= cum + h_tpe[cursor_e]) { cum += h_tpe[cursor_e]; cursor_e++; } p_to_e[p] = cursor_e; float ev = bf16_to_float(h_ex_all[p * D]); int best = -1; float bd = 1e30f; for (int s = 0; s < S; s++) { float df = std::abs(bf16_to_float(h_xn_all[s * D]) - ev); if (df < bd) { bd = df; best = s; } } p_to_s[p] = best; } // Build (n, k) → p lookup via (n, expert) → p std::vector manual_cum(S * D, 0.0f); int found_count = 0; for (int n = 0; n < S; n++) { for (int k = 0; k < K; k++) { int e = h_tidx[n * K + k]; float w = bf16_to_float(h_tw[n * K + k]); // search p with p_to_s[p]==n and p_to_e[p]==e int found_p = -1; for (int64_t p = 0; p < TOTAL; p++) { if (p_to_s[p] == n && p_to_e[p] == e) { found_p = p; break; } } if (found_p < 0) { printf(" [!!!] not found: n=%d k=%d expert=%d\n", n, k, e); continue; } found_count++; for (int d = 0; d < D; d++) manual_cum[n * D + d] += w * bf16_to_float(h_down[found_p * D + d]); } } auto refbuf = read_file(data_dir + "/out_flat.bin"); auto* ref = (const uint16_t*)refbuf.data(); double l2d=0, l2r=0, maxd=0; for (int64_t i = 0; i < S*D; i++) { float a = manual_cum[i], b = bf16_to_float(ref[i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel_cum = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf(" [cmp] MANUAL_CUMSUM (p via expert cumsum) rel=%.4e max=%.4f found=%d/40 m[:4]=%.5f %.5f %.5f %.5f\n", rel_cum, maxd, found_count, manual_cum[0], manual_cum[1], manual_cum[2], manual_cum[3]); } // Dump all expanded_x[p, 0] and all xn[s, 0] to determine the mapping. { std::vector h_xn_all(S * D); ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); std::vector h_ex_all(TOTAL * D); ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); printf(" xn[s, 0]: "); for (int s = 0; s < S; s++) printf("%.5f ", bf16_to_float(h_xn_all[s * D])); printf("\n expanded_x[p, 0]: "); for (int p = 0; p < TOTAL; p++) printf("%.5f ", bf16_to_float(h_ex_all[p * D])); printf("\n mapping p→s (by matching expanded_x[p,0] to xn[s,0]): "); for (int p = 0; p < TOTAL; p++) { float e = bf16_to_float(h_ex_all[p * D]); int match = -1; float best = 1e30f; for (int s = 0; s < S; s++) { float df = std::abs(bf16_to_float(h_xn_all[s * D]) - e); if (df < best) { best = df; match = s; } } printf("%d ", match); } printf("\n"); } // Dump gate_out[p=4, :8] — gate activation of xn[0] via expert 10 { std::vector h_gate(I); // NOTE: gate_out_dev was overwritten by silu+mul. So we need to reload from scratch. // Instead just show down_out[4, :4]. std::vector h_d(D); ACL_CHECK(aclrtMemcpy(h_d.data(), D*2, (char*)down_out_dev.get() + 4*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST)); printf(" down_out[p=4, :4] (s=0, k=0, expert=10): %.5f %.5f %.5f %.5f\n", bf16_to_float(h_d[0]), bf16_to_float(h_d[1]), bf16_to_float(h_d[2]), bf16_to_float(h_d[3])); // If GMM is correct, down_out[4] ~ ref[0] / topk_w[0,0]. ref[0,:4]=[-0.025, -0.007, 0.005, -0.008] / 0.224 ~ [-0.113, -0.031, 0.024, -0.036]. // But it's just ONE contribution so hard to compare directly. } // Single-expert verification using linear_hf: compute gate/up/down for (xn[0], expert=10) // and compare with GMM's down_out at the corresponding position. // linear_hf expects HF-layout weight [out_features, in_features]; our stacked gate_exps/up_exps // are [E, D, I] — meaning per-expert shape is [D, I] (K, N) NOT HF [I, D]. So we can NOT directly // linear_hf from gate_exps. Instead, load the expert-10 weight fresh and use linear_hf. { std::vector h_tidx_local(S * K); ACL_CHECK(aclrtMemcpy(h_tidx_local.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST)); int target_expert = h_tidx_local[0 * K + 0]; // topk_idx[0, 0] should be 10 from Python ref printf("\n === Single-expert linear_hf vs GMM sanity (token 0, expert %d) ===\n", target_expert); // Recompute p_to_s and p_to_e from host data (scoped locally). std::vector h_tpe2(E); std::vector h_xn_all2(S * D); std::vector h_ex_all2(TOTAL * D); ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_xn_all2.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); ACL_CHECK(aclrtMemcpy(h_ex_all2.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST)); std::vector p_to_s(TOTAL), p_to_e(TOTAL); { int64_t cum = 0; int ce = 0; for (int64_t p = 0; p < TOTAL; p++) { while (ce < E && p >= cum + h_tpe2[ce]) { cum += h_tpe2[ce]; ce++; } p_to_e[p] = ce; float ev = bf16_to_float(h_ex_all2[p * D]); int best = -1; float bd = 1e30f; for (int s = 0; s < S; s++) { float df = std::abs(bf16_to_float(h_xn_all2[s * D]) - ev); if (df < bd) { bd = df; best = s; } } p_to_s[p] = best; } } DeviceBuffer g_w, u_w, d_w; char ename[256]; snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", target_expert); if (!dw.st().get(ename)) { printf(" missing %s\n", ename); goto after_sanity; } // Load full per-expert weight using public helpers (indirectly via loader). // Easiest: use load_tensor_full_ via friend access... Instead, use st_ directly. { auto* m_gate = dw.st().get(ename); DeviceBuffer gw_buf(m_gate->nbytes); ACL_CHECK(aclrtMemcpy(gw_buf.get(), m_gate->nbytes, dw.st().data_ptr(*m_gate), m_gate->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); g_w = std::move(gw_buf); snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.up_proj.weight", target_expert); auto* m_up = dw.st().get(ename); DeviceBuffer uw_buf(m_up->nbytes); ACL_CHECK(aclrtMemcpy(uw_buf.get(), m_up->nbytes, dw.st().data_ptr(*m_up), m_up->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); u_w = std::move(uw_buf); snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.down_proj.weight", target_expert); auto* m_down = dw.st().get(ename); DeviceBuffer dw_buf(m_down->nbytes); ACL_CHECK(aclrtMemcpy(dw_buf.get(), m_down->nbytes, dw.st().data_ptr(*m_down), m_down->nbytes, ACL_MEMCPY_HOST_TO_DEVICE)); d_w = std::move(dw_buf); } // Compute gate = xn[0] @ gate_w.T → [I]; up = xn[0] @ up_w.T → [I]; act; down = act @ down_w.T → [D] DeviceBuffer xn0_dev(D * 2); ACL_CHECK(aclrtMemcpy(xn0_dev.get(), D*2, xn_dev.get(), D*2, ACL_MEMCPY_DEVICE_TO_DEVICE)); DeviceBuffer gate_v(I * 2), up_v(I * 2), act_v(I * 2), down_v(D * 2); auto t_xn0 = make_contig_tensor(xn0_dev.get(), ACL_BF16, {1, D}); auto t_gate = make_contig_tensor(gate_v.get(), ACL_BF16, {1, I}); auto t_up = make_contig_tensor(up_v.get(), ACL_BF16, {1, I}); auto t_act = make_contig_tensor(act_v.get(), ACL_BF16, {1, I}); auto t_down = make_contig_tensor(down_v.get(), ACL_BF16, {1, D}); linear_hf(rt.stream(), t_xn0.get(), g_w.get(), ACL_BF16, I, D, t_gate.get()); // gate_proj HF [I, D] linear_hf(rt.stream(), t_xn0.get(), u_w.get(), ACL_BF16, I, D, t_up.get()); rt.sync(); silu(rt.stream(), t_gate.get(), t_act.get()); mul(rt.stream(), t_act.get(), t_up.get(), t_act.get()); rt.sync(); linear_hf(rt.stream(), t_act.get(), d_w.get(), ACL_BF16, D, I, t_down.get()); // down_proj HF [D, I] rt.sync(); std::vector h_down_lin(D); ACL_CHECK(aclrtMemcpy(h_down_lin.data(), D*2, down_v.get(), D*2, ACL_MEMCPY_DEVICE_TO_HOST)); // Find the p in GMM output that corresponds to (s=0, expert=target_expert) int found_p = -1; for (int64_t p = 0; p < TOTAL; p++) { if (p_to_s[p] == 0 && p_to_e[p] == target_expert) { found_p = p; break; } } if (found_p >= 0) { std::vector h_down_gmm(D); ACL_CHECK(aclrtMemcpy(h_down_gmm.data(), D*2, (char*)down_out_dev.get() + found_p*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST)); double l2d=0, l2r=0, maxd=0; for (int i = 0; i < D; i++) { float a = bf16_to_float(h_down_gmm[i]), b = bf16_to_float(h_down_lin[i]); l2d += (a-b)*(a-b); l2r += b*b; if (std::abs(a-b) > maxd) maxd = std::abs(a-b); } double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10); printf(" GMM down_out[p=%d] vs linear_hf down: rel=%.4e max=%.4f\n", found_p, rel, maxd); printf(" GMM[:4]: %.5f %.5f %.5f %.5f\n", bf16_to_float(h_down_gmm[0]), bf16_to_float(h_down_gmm[1]), bf16_to_float(h_down_gmm[2]), bf16_to_float(h_down_gmm[3])); printf(" linear[:4]: %.5f %.5f %.5f %.5f\n", bf16_to_float(h_down_lin[0]), bf16_to_float(h_down_lin[1]), bf16_to_float(h_down_lin[2]), bf16_to_float(h_down_lin[3])); } else { printf(" not found p for (s=0, expert=%d)\n", target_expert); } } after_sanity:; // Direct verification: gate_exps[expert_10, :4, :4] vs HF gate_proj_10 (transposed). { int expert_id = 10; std::vector h_stacked(4 * 4); // gate_exps shape [E, D, I]. Expert 10 starts at offset expert_id * D * I * 2. // Read the first 4 rows (d=0..3), first 4 cols (i=0..3). Row stride = I * 2 bytes. for (int d = 0; d < 4; d++) { ACL_CHECK(aclrtMemcpy(h_stacked.data() + d*4, 8, (char*)moe.gate_exps.get() + (expert_id * D * I + d * I) * 2, 8, ACL_MEMCPY_DEVICE_TO_HOST)); } char ename[256]; snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", expert_id); auto* m = dw.st().get(ename); // HF gate_proj [I, D] row-major. Element at (i, d) is at offset (i*D + d)*2. // Expected gate_exps[10, d, i] == HF_gate_proj[10][i, d]. // So for d in 0..3, i in 0..3: expected is HF[i, d]. std::vector h_expected(4 * 4); auto* hf = (const uint16_t*)dw.st().data_ptr(*m); for (int d = 0; d < 4; d++) { for (int i = 0; i < 4; i++) { h_expected[d*4 + i] = hf[i * D + d]; // HF[i, d] } } printf("\n === gate_exps[10, :4, :4] layout check ===\n"); printf(" stacked: "); for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_stacked[i])); printf("\n expected: "); for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_expected[i])); printf("\n"); int mism = 0; for (int i = 0; i < 16; i++) if (h_stacked[i] != h_expected[i]) mism++; printf(" mismatches: %d / 16\n", mism); } printf("\n=== Final (with residual) ===\n"); double rel = compare_bf16("final_out", final_dev.get(), S * D, "final_out.bin"); bool pass = rel < 5e-2; printf("\n%s\n", pass ? "=== test_moe_layer PASS ===" : "=== test_moe_layer FAIL ==="); return pass ? 0 : 1; }