llm_mutil_npu / tests /test_moe_layer.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 Γ— 16 NPU
4b9fefd
// test_moe_layer.cpp β€” Full MoE layer forward (Qwen3-235B layer 0), TP=1.
//
// Pipeline:
// 1. Post-attention RmsNorm (input from attn_data/final_out.bin)
// 2. Router: xn @ W_router.T β†’ logits [S, E]
// 3. TopK softmax β†’ weights [S, K], expert_ids [S, K]
// 4. Host-normalize top_k weights (Qwen3 norm_topk_prob)
// 5. MoeInitRoutingV3 β†’ expanded_x [S*K, D], expanded_row_idx, tokens_per_expert
// 6. GMM gate: expanded_x Γ— gate_exps β†’ [S*K, I]
// 7. GMM up: same β†’ [S*K, I]
// 8. silu(gate) * up β†’ [S*K, I]
// 9. GMM down: act Γ— down_exps β†’ [S*K, D]
// 10. MoeFinalizeRouting (weighted sum) β†’ [S, D]
// 11. + residual
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "model_config.h"
#include "safetensors_loader.h"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <fstream>
#include <tuple>
#include <vector>
static float bf16_to_float(uint16_t x) {
uint32_t u = (uint32_t)x << 16; float f; std::memcpy(&f, &u, 4); return f;
}
static uint16_t float_to_bf16(float x) {
uint32_t u; std::memcpy(&u, &x, 4);
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}
static std::vector<uint8_t> read_file(const std::string& p) {
std::ifstream f(p, std::ios::binary | std::ios::ate); size_t s = f.tellg();
f.seekg(0); std::vector<uint8_t> v(s); f.read((char*)v.data(), s); return v;
}
int main() {
const std::string model_dir = "/path/to/Qwen3-235B-A22B-Instruct-2507-BF16";
const std::string data_dir = "tests/moe_data";
ModelConfig cfg;
if (!cfg.load_from_json(model_dir + "/config.json")) return 1;
cfg.compute_derived(1, 0); // TP=1
const int64_t D = cfg.hidden_size;
const int64_t I = cfg.moe_intermediate_size;
const int64_t E = cfg.num_experts;
const int64_t K = cfg.num_experts_per_tok;
const double eps = cfg.rms_norm_eps;
AclRuntime rt;
rt.init(0);
printf("[dbg] rt init ok\n"); fflush(stdout);
SafetensorsLoader st;
if (!st.open(model_dir)) return 1;
// ---- Load weights ----
printf("Loading layer 0 attention weights (for post_attention_layernorm)...\n");
DeviceWeightsLoader dw(st, cfg);
LayerAttnWeights attn;
if (!dw.load_attention(0, attn)) return 1;
printf("Loading layer 0 MoE weights (128 experts Γ— 3 projections, stacking + permute)...\n"); fflush(stdout);
LayerMoEWeights moe;
if (!dw.load_moe(0, rt.stream(), moe)) return 1;
rt.sync();
printf("[dbg] moe load ok\n"); fflush(stdout);
printf(" router %.1f MB gate_exps %.0f MB up_exps %.0f MB down_exps %.0f MB\n",
moe.router.size / 1e6, moe.gate_exps.size / 1e6, moe.up_exps.size / 1e6, moe.down_exps.size / 1e6);
// ---- Load input & Python reference ----
int S = 5;
auto x_in_host = read_file(data_dir + "/x_in.bin");
auto ref_out_host = read_file(data_dir + "/final_out.bin");
DeviceBuffer x_dev(S * D * 2);
ACL_CHECK(aclrtMemcpy(x_dev.get(), x_in_host.size(), x_in_host.data(), x_in_host.size(), ACL_MEMCPY_HOST_TO_DEVICE));
// Residual snapshot
DeviceBuffer residual_dev(S * D * 2);
ACL_CHECK(aclrtMemcpy(residual_dev.get(), S*D*2, x_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
printf("[dbg] loaded data and residual ok, TOTAL=%ld\n", S * K); fflush(stdout);
// ---- Step 1: Post-attention RmsNorm ----
DeviceBuffer xn_dev(S * D * 2);
DeviceBuffer rstd_dev(S * 4);
auto t_x = make_contig_tensor(x_dev.get(), ACL_BF16, {S, D});
auto t_xn = make_contig_tensor(xn_dev.get(), ACL_BF16, {S, D});
auto t_ln = make_contig_tensor(attn.post_attention_layernorm.get(), ACL_BF16, {D});
auto t_rstd = make_contig_tensor(rstd_dev.get(), ACL_FLOAT, {S});
rms_norm(rt.stream(), t_x.get(), t_ln.get(), eps, t_xn.get(), t_rstd.get());
rt.sync();
printf("[dbg] rms_norm ok\n"); fflush(stdout);
// ---- Step 2: Router (gate matmul) ----
DeviceBuffer logits_dev(S * E * 2);
auto t_logits = make_contig_tensor(logits_dev.get(), ACL_BF16, {S, E});
// router is [E, D] (HF). logits = xn @ router.T
linear_hf(rt.stream(), t_xn.get(), moe.router.get(), ACL_BF16, E, D, t_logits.get());
rt.sync();
printf("[dbg] router linear ok\n"); fflush(stdout);
// ---- Step 3: TopK softmax ----
DeviceBuffer topk_w_dev(S * K * 2); // BF16
DeviceBuffer topk_idx_dev(S * K * 4); // int32
DeviceBuffer row_idx_dev(S * K * 4); // int32 (from gating op, unused for our routing)
auto t_topk_w = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K});
auto t_topk_idx = make_contig_tensor(topk_idx_dev.get(), ACL_INT32, {S, K});
auto t_row_idx = make_contig_tensor(row_idx_dev.get(), ACL_INT32, {S, K});
moe_gating_topk_softmax(rt.stream(), t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
rt.sync();
printf("[dbg] topk_softmax ok\n"); fflush(stdout);
// ---- Step 4: Host-normalize top_k weights (norm_topk_prob=true) ----
std::vector<uint16_t> tw_bf(S * K);
ACL_CHECK(aclrtMemcpy(tw_bf.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
for (int s = 0; s < S; s++) {
float sum = 0.0f;
for (int k = 0; k < K; k++) sum += bf16_to_float(tw_bf[s*K + k]);
sum += 1e-20f;
for (int k = 0; k < K; k++) {
float v = bf16_to_float(tw_bf[s*K + k]) / sum;
tw_bf[s*K + k] = float_to_bf16(v);
}
}
ACL_CHECK(aclrtMemcpy(topk_w_dev.get(), S*K*2, tw_bf.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
// ---- Step 5: MoE init routing ----
int64_t TOTAL = S * K;
DeviceBuffer expanded_x_dev(TOTAL * D * 2);
DeviceBuffer expanded_row_idx_dev(TOTAL * 4);
DeviceBuffer tokens_per_expert_dev(E * 8);
auto t_ex_x = make_contig_tensor(expanded_x_dev.get(), ACL_BF16, {TOTAL, D});
auto t_ex_ri = make_contig_tensor(expanded_row_idx_dev.get(), ACL_INT32, {TOTAL});
auto t_tpe = make_contig_tensor(tokens_per_expert_dev.get(), ACL_INT64, {E});
moe_init_routing_v3(rt.stream(),
t_xn.get(), t_topk_idx.get(),
E, TOTAL,
t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
rt.sync();
printf("[dbg] moe_init_routing ok\n"); fflush(stdout);
// Convert tokens_per_expert from counts to cumsum (on host) for GMM groupListType=0.
DeviceBuffer tpe_cumsum_dev(E * 8);
{
std::vector<int64_t> h_counts(E), h_cum(E);
ACL_CHECK(aclrtMemcpy(h_counts.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
int64_t acc = 0;
for (int i = 0; i < E; i++) { acc += h_counts[i]; h_cum[i] = acc; }
ACL_CHECK(aclrtMemcpy(tpe_cumsum_dev.get(), E*8, h_cum.data(), E*8, ACL_MEMCPY_HOST_TO_DEVICE));
}
auto t_tpe_cum = make_contig_tensor(tpe_cumsum_dev.get(), ACL_INT64, {E});
// ---- Step 6/7: GMM gate and up ----
DeviceBuffer gate_out_dev(TOTAL * I * 2);
DeviceBuffer up_out_dev(TOTAL * I * 2);
auto t_gate_out = make_contig_tensor(gate_out_dev.get(), ACL_BF16, {TOTAL, I});
auto t_up_out = make_contig_tensor(up_out_dev.get(), ACL_BF16, {TOTAL, I});
// gate/up_exps loaded as [E, D, I] row-major
auto t_w_gate = make_contig_tensor(moe.gate_exps.get(), ACL_BF16, {E, D, I});
auto t_w_up = make_contig_tensor(moe.up_exps.get(), ACL_BF16, {E, D, I});
// Use cumsum group_list (groupListType=0): empirically more reliable with many zero-count experts.
grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_gate.get(), t_tpe_cum.get(), t_gate_out.get(), 0);
rt.sync();
printf("[dbg] gmm gate ok\n"); fflush(stdout);
grouped_matmul_v4(rt.stream(), t_ex_x.get(), t_w_up.get(), t_tpe_cum.get(), t_up_out.get(), 0);
rt.sync();
printf("[dbg] gmm up ok\n"); fflush(stdout);
// ---- Step 8: SwiGLU ----
// act = silu(gate) * up (inplace on gate_out)
silu(rt.stream(), t_gate_out.get(), t_gate_out.get());
rt.sync(); printf("[dbg] silu ok\n"); fflush(stdout);
mul(rt.stream(), t_gate_out.get(), t_up_out.get(), t_gate_out.get());
rt.sync(); printf("[dbg] mul ok\n"); fflush(stdout);
// now gate_out_dev contains the activated intermediate
// ---- Step 9: GMM down ----
DeviceBuffer down_out_dev(TOTAL * D * 2);
auto t_down_out = make_contig_tensor(down_out_dev.get(), ACL_BF16, {TOTAL, D});
auto t_w_down = make_contig_tensor(moe.down_exps.get(), ACL_BF16, {E, I, D});
grouped_matmul_v4(rt.stream(), t_gate_out.get(), t_w_down.get(), t_tpe_cum.get(), t_down_out.get(), 0);
rt.sync();
printf("[dbg] gmm down ok\n"); fflush(stdout);
// ---- Step 10: Device-side manual finalize (replacement for buggy MoeFinalizeRoutingV2) ----
// Compute forward permutation fwd[n*K + k] = p where token n's k-th expert's output is at
// expanded position p. We use tokens_per_expert (cumsum) + topk_idx to resolve this correctly,
// regardless of the exact rowIdxType semantics returned by MoeInitRoutingV3.
DeviceBuffer fwd_dev(TOTAL * 8);
{
std::vector<int64_t> h_tpe2(E);
std::vector<int32_t> h_tidx3(S * K);
ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tidx3.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
// Sort (n, k) pairs by expert ascending (stable). For each expert in order, tokens
// appear in ascending token index (since MoeInitRoutingV3 is stable by s).
// Specifically: expanded positions 0..tpe[0]-1 are for expert 0 (tokens picking e=0, in n-ascending order),
// next tpe[1] are for expert 1, etc.
//
// To build fwd: for each (n, k), expert e = topk_idx[n, k]. Position p is the base of expert e's
// block plus the rank of n within tokens picking e.
std::vector<int64_t> expert_base(E + 1, 0);
for (int e = 0; e < E; e++) expert_base[e + 1] = expert_base[e] + h_tpe2[e];
std::vector<int> expert_slot(E, 0); // next available slot per expert
std::vector<int64_t> fwd(TOTAL);
// Iterate in token-ascending, k-ascending order β€” match MoeInitRoutingV3's stable sort convention.
// For each (n, k) sorted by (expert[n,k], n), assign p.
// Simpler: pre-collect (e, n, k) triples, sort by (e, n), then p is the rank.
std::vector<std::tuple<int, int, int>> triples;
triples.reserve(TOTAL);
for (int n = 0; n < S; n++) for (int k = 0; k < K; k++) {
triples.emplace_back(h_tidx3[n * K + k], n, k);
}
std::sort(triples.begin(), triples.end(), [](const auto& a, const auto& b){
if (std::get<0>(a) != std::get<0>(b)) return std::get<0>(a) < std::get<0>(b);
return std::get<1>(a) < std::get<1>(b);
});
for (int64_t p = 0; p < TOTAL; p++) {
auto [e, n, k] = triples[p];
fwd[n * K + k] = p;
}
ACL_CHECK(aclrtMemcpy(fwd_dev.get(), TOTAL*8, fwd.data(), TOTAL*8, ACL_MEMCPY_HOST_TO_DEVICE));
}
auto t_fwd = make_contig_tensor(fwd_dev.get(), ACL_INT64, {TOTAL});
// Gather: packed [S*K, D] = down_out[fwd, :]
DeviceBuffer packed_dev(TOTAL * D * 2);
auto t_packed = make_contig_tensor(packed_dev.get(), ACL_BF16, {TOTAL, D});
index_select(rt.stream(), t_down_out.get(), 0, t_fwd.get(), t_packed.get());
rt.sync();
// Broadcast-multiply by topk_w: view packed as [S, K, D], topk_w as [S, K, 1].
auto t_packed_3d = make_contig_tensor(packed_dev.get(), ACL_BF16, {S, K, D});
auto t_topk_w_3d = make_contig_tensor(topk_w_dev.get(), ACL_BF16, {S, K, 1});
DeviceBuffer weighted_dev(S * K * D * 2);
auto t_weighted = make_contig_tensor(weighted_dev.get(), ACL_BF16, {S, K, D});
mul(rt.stream(), t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
rt.sync();
// Verify broadcast mul + sum by dumping all k entries and summing on host.
{
std::vector<uint16_t> h_pk_all(S * K * D);
std::vector<uint16_t> h_wt_all(S * K * D);
std::vector<uint16_t> h_tw_all(S * K);
ACL_CHECK(aclrtMemcpy(h_pk_all.data(), S*K*D*2, packed_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_wt_all.data(), S*K*D*2, weighted_dev.get(), S*K*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tw_all.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
printf(" verify weighted[0, k, 0] = packed[0, k, 0] * topk_w[0, k] for all k:\n");
float host_sum = 0;
for (int k = 0; k < K; k++) {
float p = bf16_to_float(h_pk_all[k * D]); // packed[0, k, 0] = offset s*K*D + k*D + 0 = k*D (for s=0)
float w = bf16_to_float(h_tw_all[k]); // topk_w[0, k]
float wt = bf16_to_float(h_wt_all[k * D]); // weighted[0, k, 0]
host_sum += p * w;
printf(" k=%d: packed=%.5f * topk_w=%.5f = expect=%.5f dev=%.5f\n",
k, p, w, p*w, wt);
}
printf(" host_sum_of_weighted[0, :, 0] = %.5f (expected moe_out[0,0] = -0.02466)\n", host_sum);
}
// ReduceSum over K axis β†’ [S, D]
DeviceBuffer moe_out_dev(S * D * 2);
auto t_moe_out = make_contig_tensor(moe_out_dev.get(), ACL_BF16, {S, D});
reduce_sum(rt.stream(), t_weighted.get(), {1}, /*keep_dims=*/false, ACL_BF16, t_moe_out.get());
rt.sync();
printf("[dbg] device-side finalize (gather+mul+reduce) ok\n"); fflush(stdout);
// Residual add to produce final_out
float alpha_v = 1.0f; aclScalar* alpha = aclCreateScalar(&alpha_v, ACL_FLOAT);
DeviceBuffer final_dev(S * D * 2);
auto t_final = make_contig_tensor(final_dev.get(), ACL_BF16, {S, D});
auto t_res = make_contig_tensor(residual_dev.get(), ACL_BF16, {S, D});
{
uint64_t ws = 0; aclOpExecutor* e = nullptr;
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_res.get(), t_moe_out.get(), alpha, t_final.get(), &ws, &e));
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt.stream()));
}
aclDestroyScalar(alpha);
rt.sync();
// ---- Compare (intermediate + final) ----
auto compare_bf16 = [&](const char* label, void* dev_ptr, int64_t nelem,
const std::string& ref_file) {
std::vector<uint16_t> cxx(nelem);
ACL_CHECK(aclrtMemcpy(cxx.data(), nelem*2, dev_ptr, nelem*2, ACL_MEMCPY_DEVICE_TO_HOST));
auto refbuf = read_file(data_dir + "/" + ref_file);
auto* ref = (const uint16_t*)refbuf.data();
double l2d = 0, l2r = 0, maxd = 0;
for (int64_t i = 0; i < nelem; i++) {
float a = bf16_to_float(cxx[i]), b = bf16_to_float(ref[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf(" [cmp] %-12s rel=%.4e max_abs=%.4f cxx[:4]=%.5f %.5f %.5f %.5f ref[:4]=%.5f %.5f %.5f %.5f\n",
label, rel, maxd,
bf16_to_float(cxx[0]), bf16_to_float(cxx[1]), bf16_to_float(cxx[2]), bf16_to_float(cxx[3]),
bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
return rel;
};
printf("\n=== Intermediate diagnostics ===\n");
compare_bf16("xn", xn_dev.get(), S * D, "xn.bin");
compare_bf16("topk_w", topk_w_dev.get(), S * K, "topk_w.bin");
// Dump topk_idx (int32) to compare
{
std::vector<int32_t> cxx_idx(S*K);
ACL_CHECK(aclrtMemcpy(cxx_idx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
auto refbuf = read_file(data_dir + "/topk_idx.bin");
auto* ref = (const int32_t*)refbuf.data();
int mismatches = 0;
for (int i = 0; i < S*K; i++) if (cxx_idx[i] != ref[i]) mismatches++;
printf(" [cmp] topk_idx mismatches=%d/%d cxx[0,:4]=%d %d %d %d ref[0,:4]=%d %d %d %d\n",
mismatches, S*K,
cxx_idx[0], cxx_idx[1], cxx_idx[2], cxx_idx[3],
ref[0], ref[1], ref[2], ref[3]);
}
printf("\n=== MoE-only (before residual) ===\n");
compare_bf16("moe_out", moe_out_dev.get(), S * D, "out_flat.bin");
// Manual host-side finalize: verify what down_out + expanded_row_idx + topk_w produce.
{
std::vector<uint16_t> h_down(TOTAL * D);
std::vector<int32_t> h_ri(TOTAL);
std::vector<uint16_t> h_tw(S * K);
ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_ri.data(), TOTAL*4, expanded_row_idx_dev.get(), TOTAL*4, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
printf(" expanded_row_idx (all %ld):\n ", TOTAL);
for (int i = 0; i < TOTAL; i++) {
printf("%d ", h_ri[i]);
if ((i+1) % 10 == 0) printf("\n ");
}
printf("\n");
// count unique and check bijection
std::vector<int> count(TOTAL, 0);
int out_of_range = 0;
for (int i = 0; i < TOTAL; i++) {
int v = h_ri[i];
if (v >= 0 && v < TOTAL) count[v]++;
else out_of_range++;
}
int bijection_ok = (out_of_range == 0);
for (int i = 0; i < TOTAL && bijection_ok; i++) if (count[i] != 1) bijection_ok = 0;
printf(" bijection=%s out_of_range=%d\n", bijection_ok ? "YES" : "NO", out_of_range);
// Also dump tokens_per_expert (int64) β€” should sum to TOTAL
std::vector<int64_t> h_tpe(E);
ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
int64_t tpe_sum = 0, nonzero = 0;
int64_t tpe_max = 0;
for (int i = 0; i < E; i++) { tpe_sum += h_tpe[i]; if (h_tpe[i]>0) nonzero++; if (h_tpe[i]>tpe_max) tpe_max=h_tpe[i]; }
printf(" tokens_per_expert: sum=%ld nonzero=%ld max=%ld (expected sum=%ld if counts, or last=%ld if cumsum)\n",
tpe_sum, nonzero, tpe_max, TOTAL, TOTAL);
printf(" tpe[last 4]: %ld %ld %ld %ld\n", h_tpe[E-4], h_tpe[E-3], h_tpe[E-2], h_tpe[E-1]);
std::vector<float> manual(S * D, 0.0f);
for (int64_t p = 0; p < TOTAL; p++) {
int32_t src = h_ri[p];
int s = src / K;
int k = src % K;
if (s < 0 || s >= S || k < 0 || k >= K) { printf(" bad idx p=%ld src=%d\n", p, src); continue; }
float w = bf16_to_float(h_tw[s * K + k]);
for (int d = 0; d < D; d++) {
manual[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
}
}
// Convert to bf16 and compare to Python out_flat
auto refbuf = read_file(data_dir + "/out_flat.bin");
auto* ref = (const uint16_t*)refbuf.data();
double l2d=0, l2r=0, maxd=0;
for (int64_t i = 0; i < S*D; i++) {
float a = manual[i], b = bf16_to_float(ref[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel_manual = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf(" [cmp] MANUAL(row_idx=src→flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f r[:4]=%.5f %.5f %.5f %.5f\n",
rel_manual, maxd,
manual[0], manual[1], manual[2], manual[3],
bf16_to_float(ref[0]), bf16_to_float(ref[1]), bf16_to_float(ref[2]), bf16_to_float(ref[3]));
// Alternative semantic: row_idx[p] = destination position
// In that case: p=src_row, dst=h_ri[p]
std::vector<float> manual2(S * D, 0.0f);
for (int64_t p = 0; p < TOTAL; p++) {
int32_t dst = h_ri[p];
int s = dst / K;
int k = dst % K;
if (s < 0 || s >= S || k < 0 || k >= K) continue;
float w = bf16_to_float(h_tw[s * K + k]);
for (int d = 0; d < D; d++) {
manual2[s * D + d] += w * bf16_to_float(h_down[p * D + d]);
}
}
double l2d2=0, l2r2=0, maxd2=0;
for (int64_t i = 0; i < S*D; i++) {
float a = manual2[i], b = bf16_to_float(ref[i]);
l2d2 += (a-b)*(a-b); l2r2 += b*b;
if (std::abs(a-b) > maxd2) maxd2 = std::abs(a-b);
}
double rel_manual2 = std::sqrt(l2d2) / (std::sqrt(l2r2) + 1e-10);
printf(" [cmp] MANUAL(row_idx=p→dst_flat) rel=%.4e max_abs=%.4f m[:4]=%.5f %.5f %.5f %.5f\n",
rel_manual2, maxd2,
manual2[0], manual2[1], manual2[2], manual2[3]);
}
// Manual finalize using cumsum (semantics-independent):
// For each (n, k), find p such that actual_s(p)=n AND expert(p)=topk_idx[n,k], then
// out[n] += topk_w[n,k] * down_out[p].
{
std::vector<uint16_t> h_down(TOTAL * D);
std::vector<int64_t> h_tpe(E);
std::vector<int32_t> h_tidx(S * K);
std::vector<uint16_t> h_tw(S * K);
std::vector<uint16_t> h_xn_all(S * D);
std::vector<uint16_t> h_ex_all(TOTAL * D);
ACL_CHECK(aclrtMemcpy(h_down.data(), TOTAL*D*2, down_out_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tpe.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tidx.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_dev.get(), S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
// Build p β†’ (actual_s, actual_expert).
// actual_s: find s with xn[s,0] == expanded_x[p,0]
// actual_expert: find e such that cumsum_tpe[e-1] <= p < cumsum_tpe[e]
std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
int64_t cum = 0;
int cursor_e = 0;
for (int64_t p = 0; p < TOTAL; p++) {
while (cursor_e < E && p >= cum + h_tpe[cursor_e]) { cum += h_tpe[cursor_e]; cursor_e++; }
p_to_e[p] = cursor_e;
float ev = bf16_to_float(h_ex_all[p * D]);
int best = -1; float bd = 1e30f;
for (int s = 0; s < S; s++) {
float df = std::abs(bf16_to_float(h_xn_all[s * D]) - ev);
if (df < bd) { bd = df; best = s; }
}
p_to_s[p] = best;
}
// Build (n, k) β†’ p lookup via (n, expert) β†’ p
std::vector<float> manual_cum(S * D, 0.0f);
int found_count = 0;
for (int n = 0; n < S; n++) {
for (int k = 0; k < K; k++) {
int e = h_tidx[n * K + k];
float w = bf16_to_float(h_tw[n * K + k]);
// search p with p_to_s[p]==n and p_to_e[p]==e
int found_p = -1;
for (int64_t p = 0; p < TOTAL; p++) {
if (p_to_s[p] == n && p_to_e[p] == e) { found_p = p; break; }
}
if (found_p < 0) {
printf(" [!!!] not found: n=%d k=%d expert=%d\n", n, k, e);
continue;
}
found_count++;
for (int d = 0; d < D; d++)
manual_cum[n * D + d] += w * bf16_to_float(h_down[found_p * D + d]);
}
}
auto refbuf = read_file(data_dir + "/out_flat.bin");
auto* ref = (const uint16_t*)refbuf.data();
double l2d=0, l2r=0, maxd=0;
for (int64_t i = 0; i < S*D; i++) {
float a = manual_cum[i], b = bf16_to_float(ref[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel_cum = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf(" [cmp] MANUAL_CUMSUM (p via expert cumsum) rel=%.4e max=%.4f found=%d/40 m[:4]=%.5f %.5f %.5f %.5f\n",
rel_cum, maxd, found_count, manual_cum[0], manual_cum[1], manual_cum[2], manual_cum[3]);
}
// Dump all expanded_x[p, 0] and all xn[s, 0] to determine the mapping.
{
std::vector<uint16_t> h_xn_all(S * D);
ACL_CHECK(aclrtMemcpy(h_xn_all.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
std::vector<uint16_t> h_ex_all(TOTAL * D);
ACL_CHECK(aclrtMemcpy(h_ex_all.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
printf(" xn[s, 0]: ");
for (int s = 0; s < S; s++) printf("%.5f ", bf16_to_float(h_xn_all[s * D]));
printf("\n expanded_x[p, 0]: ");
for (int p = 0; p < TOTAL; p++) printf("%.5f ", bf16_to_float(h_ex_all[p * D]));
printf("\n mapping p→s (by matching expanded_x[p,0] to xn[s,0]): ");
for (int p = 0; p < TOTAL; p++) {
float e = bf16_to_float(h_ex_all[p * D]);
int match = -1; float best = 1e30f;
for (int s = 0; s < S; s++) {
float df = std::abs(bf16_to_float(h_xn_all[s * D]) - e);
if (df < best) { best = df; match = s; }
}
printf("%d ", match);
}
printf("\n");
}
// Dump gate_out[p=4, :8] β€” gate activation of xn[0] via expert 10
{
std::vector<uint16_t> h_gate(I);
// NOTE: gate_out_dev was overwritten by silu+mul. So we need to reload from scratch.
// Instead just show down_out[4, :4].
std::vector<uint16_t> h_d(D);
ACL_CHECK(aclrtMemcpy(h_d.data(), D*2, (char*)down_out_dev.get() + 4*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
printf(" down_out[p=4, :4] (s=0, k=0, expert=10): %.5f %.5f %.5f %.5f\n",
bf16_to_float(h_d[0]), bf16_to_float(h_d[1]), bf16_to_float(h_d[2]), bf16_to_float(h_d[3]));
// If GMM is correct, down_out[4] ~ ref[0] / topk_w[0,0]. ref[0,:4]=[-0.025, -0.007, 0.005, -0.008] / 0.224 ~ [-0.113, -0.031, 0.024, -0.036].
// But it's just ONE contribution so hard to compare directly.
}
// Single-expert verification using linear_hf: compute gate/up/down for (xn[0], expert=10)
// and compare with GMM's down_out at the corresponding position.
// linear_hf expects HF-layout weight [out_features, in_features]; our stacked gate_exps/up_exps
// are [E, D, I] β€” meaning per-expert shape is [D, I] (K, N) NOT HF [I, D]. So we can NOT directly
// linear_hf from gate_exps. Instead, load the expert-10 weight fresh and use linear_hf.
{
std::vector<int32_t> h_tidx_local(S * K);
ACL_CHECK(aclrtMemcpy(h_tidx_local.data(), S*K*4, topk_idx_dev.get(), S*K*4, ACL_MEMCPY_DEVICE_TO_HOST));
int target_expert = h_tidx_local[0 * K + 0]; // topk_idx[0, 0] should be 10 from Python ref
printf("\n === Single-expert linear_hf vs GMM sanity (token 0, expert %d) ===\n", target_expert);
// Recompute p_to_s and p_to_e from host data (scoped locally).
std::vector<int64_t> h_tpe2(E);
std::vector<uint16_t> h_xn_all2(S * D);
std::vector<uint16_t> h_ex_all2(TOTAL * D);
ACL_CHECK(aclrtMemcpy(h_tpe2.data(), E*8, tokens_per_expert_dev.get(), E*8, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_xn_all2.data(), S*D*2, xn_dev.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
ACL_CHECK(aclrtMemcpy(h_ex_all2.data(), TOTAL*D*2, expanded_x_dev.get(), TOTAL*D*2, ACL_MEMCPY_DEVICE_TO_HOST));
std::vector<int> p_to_s(TOTAL), p_to_e(TOTAL);
{
int64_t cum = 0; int ce = 0;
for (int64_t p = 0; p < TOTAL; p++) {
while (ce < E && p >= cum + h_tpe2[ce]) { cum += h_tpe2[ce]; ce++; }
p_to_e[p] = ce;
float ev = bf16_to_float(h_ex_all2[p * D]);
int best = -1; float bd = 1e30f;
for (int s = 0; s < S; s++) {
float df = std::abs(bf16_to_float(h_xn_all2[s * D]) - ev);
if (df < bd) { bd = df; best = s; }
}
p_to_s[p] = best;
}
}
DeviceBuffer g_w, u_w, d_w;
char ename[256];
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", target_expert);
if (!dw.st().get(ename)) { printf(" missing %s\n", ename); goto after_sanity; }
// Load full per-expert weight using public helpers (indirectly via loader).
// Easiest: use load_tensor_full_ via friend access... Instead, use st_ directly.
{
auto* m_gate = dw.st().get(ename);
DeviceBuffer gw_buf(m_gate->nbytes);
ACL_CHECK(aclrtMemcpy(gw_buf.get(), m_gate->nbytes, dw.st().data_ptr(*m_gate), m_gate->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
g_w = std::move(gw_buf);
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.up_proj.weight", target_expert);
auto* m_up = dw.st().get(ename);
DeviceBuffer uw_buf(m_up->nbytes);
ACL_CHECK(aclrtMemcpy(uw_buf.get(), m_up->nbytes, dw.st().data_ptr(*m_up), m_up->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
u_w = std::move(uw_buf);
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.down_proj.weight", target_expert);
auto* m_down = dw.st().get(ename);
DeviceBuffer dw_buf(m_down->nbytes);
ACL_CHECK(aclrtMemcpy(dw_buf.get(), m_down->nbytes, dw.st().data_ptr(*m_down), m_down->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
d_w = std::move(dw_buf);
}
// Compute gate = xn[0] @ gate_w.T β†’ [I]; up = xn[0] @ up_w.T β†’ [I]; act; down = act @ down_w.T β†’ [D]
DeviceBuffer xn0_dev(D * 2);
ACL_CHECK(aclrtMemcpy(xn0_dev.get(), D*2, xn_dev.get(), D*2, ACL_MEMCPY_DEVICE_TO_DEVICE));
DeviceBuffer gate_v(I * 2), up_v(I * 2), act_v(I * 2), down_v(D * 2);
auto t_xn0 = make_contig_tensor(xn0_dev.get(), ACL_BF16, {1, D});
auto t_gate = make_contig_tensor(gate_v.get(), ACL_BF16, {1, I});
auto t_up = make_contig_tensor(up_v.get(), ACL_BF16, {1, I});
auto t_act = make_contig_tensor(act_v.get(), ACL_BF16, {1, I});
auto t_down = make_contig_tensor(down_v.get(), ACL_BF16, {1, D});
linear_hf(rt.stream(), t_xn0.get(), g_w.get(), ACL_BF16, I, D, t_gate.get()); // gate_proj HF [I, D]
linear_hf(rt.stream(), t_xn0.get(), u_w.get(), ACL_BF16, I, D, t_up.get());
rt.sync();
silu(rt.stream(), t_gate.get(), t_act.get());
mul(rt.stream(), t_act.get(), t_up.get(), t_act.get());
rt.sync();
linear_hf(rt.stream(), t_act.get(), d_w.get(), ACL_BF16, D, I, t_down.get()); // down_proj HF [D, I]
rt.sync();
std::vector<uint16_t> h_down_lin(D);
ACL_CHECK(aclrtMemcpy(h_down_lin.data(), D*2, down_v.get(), D*2, ACL_MEMCPY_DEVICE_TO_HOST));
// Find the p in GMM output that corresponds to (s=0, expert=target_expert)
int found_p = -1;
for (int64_t p = 0; p < TOTAL; p++) {
if (p_to_s[p] == 0 && p_to_e[p] == target_expert) { found_p = p; break; }
}
if (found_p >= 0) {
std::vector<uint16_t> h_down_gmm(D);
ACL_CHECK(aclrtMemcpy(h_down_gmm.data(), D*2, (char*)down_out_dev.get() + found_p*D*2, D*2, ACL_MEMCPY_DEVICE_TO_HOST));
double l2d=0, l2r=0, maxd=0;
for (int i = 0; i < D; i++) {
float a = bf16_to_float(h_down_gmm[i]), b = bf16_to_float(h_down_lin[i]);
l2d += (a-b)*(a-b); l2r += b*b;
if (std::abs(a-b) > maxd) maxd = std::abs(a-b);
}
double rel = std::sqrt(l2d) / (std::sqrt(l2r) + 1e-10);
printf(" GMM down_out[p=%d] vs linear_hf down: rel=%.4e max=%.4f\n", found_p, rel, maxd);
printf(" GMM[:4]: %.5f %.5f %.5f %.5f\n",
bf16_to_float(h_down_gmm[0]), bf16_to_float(h_down_gmm[1]), bf16_to_float(h_down_gmm[2]), bf16_to_float(h_down_gmm[3]));
printf(" linear[:4]: %.5f %.5f %.5f %.5f\n",
bf16_to_float(h_down_lin[0]), bf16_to_float(h_down_lin[1]), bf16_to_float(h_down_lin[2]), bf16_to_float(h_down_lin[3]));
} else {
printf(" not found p for (s=0, expert=%d)\n", target_expert);
}
}
after_sanity:;
// Direct verification: gate_exps[expert_10, :4, :4] vs HF gate_proj_10 (transposed).
{
int expert_id = 10;
std::vector<uint16_t> h_stacked(4 * 4);
// gate_exps shape [E, D, I]. Expert 10 starts at offset expert_id * D * I * 2.
// Read the first 4 rows (d=0..3), first 4 cols (i=0..3). Row stride = I * 2 bytes.
for (int d = 0; d < 4; d++) {
ACL_CHECK(aclrtMemcpy(h_stacked.data() + d*4, 8,
(char*)moe.gate_exps.get() + (expert_id * D * I + d * I) * 2, 8,
ACL_MEMCPY_DEVICE_TO_HOST));
}
char ename[256];
snprintf(ename, sizeof(ename), "model.layers.0.mlp.experts.%d.gate_proj.weight", expert_id);
auto* m = dw.st().get(ename);
// HF gate_proj [I, D] row-major. Element at (i, d) is at offset (i*D + d)*2.
// Expected gate_exps[10, d, i] == HF_gate_proj[10][i, d].
// So for d in 0..3, i in 0..3: expected is HF[i, d].
std::vector<uint16_t> h_expected(4 * 4);
auto* hf = (const uint16_t*)dw.st().data_ptr(*m);
for (int d = 0; d < 4; d++) {
for (int i = 0; i < 4; i++) {
h_expected[d*4 + i] = hf[i * D + d]; // HF[i, d]
}
}
printf("\n === gate_exps[10, :4, :4] layout check ===\n");
printf(" stacked: ");
for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_stacked[i]));
printf("\n expected: ");
for (int i = 0; i < 16; i++) printf("%.5f ", bf16_to_float(h_expected[i]));
printf("\n");
int mism = 0;
for (int i = 0; i < 16; i++) if (h_stacked[i] != h_expected[i]) mism++;
printf(" mismatches: %d / 16\n", mism);
}
printf("\n=== Final (with residual) ===\n");
double rel = compare_bf16("final_out", final_dev.get(), S * D, "final_out.bin");
bool pass = rel < 5e-2;
printf("\n%s\n", pass ? "=== test_moe_layer PASS ===" : "=== test_moe_layer FAIL ===");
return pass ? 0 : 1;
}