// engine.h — single-layer forward functions for attention and MoE. // // Both functions operate on device tensors. The caller owns all buffers (input, output, weights, // KV cache slots, scratch). They take RoPE cos/sin tables and act as pure forward kernels. // // Design goals: // - Zero allocations per call (all scratch is passed in) // - Same signature works for prefill (S>=1) and decode (S=1); caller picks sparse_mode. // - Residual connection is NOT included (caller decides when to add residual). #pragma once #include "acl_common.h" #include "aclnn_ops.h" #include "device_weights.h" #include "hccl_comm.h" #include "model_config.h" #include "rope.h" #include #include #include #include #include // Bf16 conversion helpers used by fill_cos_sin. static inline uint16_t _engine_f2bf16(float x) { uint32_t u; std::memcpy(&u, &x, 4); return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } // Fill cos/sin tables for positions [p0, p0+L) with HF half-half layout. Returns // contiguous [L*Dh] BF16 in provided host vectors (caller uploads to device). inline void fill_cos_sin_hf(std::vector& cos_h, std::vector& sin_h, int64_t p0, int64_t L, int64_t Dh, float theta) { cos_h.resize(L * Dh); sin_h.resize(L * Dh); int64_t half = Dh / 2; for (int64_t s = 0; s < L; s++) { for (int64_t d = 0; d < Dh; d++) { int64_t pair = (d < half) ? d : (d - half); float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh); float angle = (float)(p0 + s) * theta_pair; cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle)); sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle)); } } } // Precomputed RoPE cos/sin table: BF16 [max_seq, Dh]. One-time cost per runtime. struct RopeCache { DeviceBuffer cos; // [max_seq, Dh] BF16 DeviceBuffer sin; // [max_seq, Dh] BF16 int64_t max_seq = 0; int64_t head_dim = 0; float theta = 0.0f; }; inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) { std::vector cos_h, sin_h; fill_cos_sin_hf(cos_h, sin_h, /*p0=*/0, max_seq, head_dim, theta); rc.cos.alloc(max_seq * head_dim * 2); rc.sin.alloc(max_seq * head_dim * 2); ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE)); rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta; return true; } // Attention forward for a single layer. // // x_in [S, D] (hidden state, pre input_layernorm) // x_out [S, D] (attention output — NOT residual-added) // // K cache / V cache are contiguous [MAX_LEN, KV_DIM] BF16 buffers. This call writes new // positions at [past_len, past_len+S) and then runs FIAS over [0, past_len+S). // // Scratch requirements: // q_scratch: S * Q_DIM * 2 bytes // k_scratch: S * KV_DIM * 2 bytes // v_scratch: S * KV_DIM * 2 bytes // xn_scratch: S * D * 2 bytes // rstd_scratch: S * 4 bytes (RmsNorm rstd output) // rope_scratch: S * Hq * Dh * 2 bytes // // mask: [1, 1, 2048, 2048] bool for prefill; ignored (pass nullptr) for decode. inline void attention_forward( aclrtStream stream, const ModelConfig& cfg, LayerAttnWeights& w, void* x_in, // [S, D] BF16 int64_t S, int64_t past_len, // prior KV positions void* k_cache, void* v_cache, int64_t max_len, aclTensor* mask_tensor, // may be nullptr for decode void* q_scratch, void* k_scratch, void* v_scratch, void* xn_scratch, void* rstd_scratch, void* rope_scratch, void* attn_out_scratch, // S * Q_DIM * 2 bytes (FIAS output before o_proj) void* x_out, // [S, D] BF16 HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce x_out after o_proj const RopeCache* rope_cache = nullptr, // if provided, use cached cos/sin table; avoids per-call H2D int64_t sparse_mode = -1 // -1=auto (3 for prefill, 0 for decode); explicit 0/3 overrides ) { const int64_t D = cfg.hidden_size; const int64_t Hq = cfg.n_heads_per_rank; const int64_t Hkv = cfg.n_kv_heads_per_rank; const int64_t Dh = cfg.head_dim; const int64_t Q_DIM = Hq * Dh; const int64_t KV_DIM = Hkv * Dh; const double scale = 1.0 / std::sqrt((double)Dh); const double eps = cfg.rms_norm_eps; const float theta = cfg.rope_theta; // 1. Input layernorm: xn = rmsnorm(x_in, input_layernorm_weight) auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D}); auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D}); auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D}); auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S}); rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get()); // 2. Q/K/V projection auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM}); auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM}); auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM}); linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get()); linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get()); linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get()); // 3. Per-head q_norm, k_norm auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq, Dh}); auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh}); auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh}); auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh}); // reuse rstd_scratch split or allocate? reuse xn_scratch's first S*Hq*4 bytes. // Simpler: require rstd_scratch to have max(S, S*max(Hq,Hkv)) * 4 bytes. // For single-rank attention tests we pass enough. auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq}); auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv}); rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get()); rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get()); // 4. RoPE: positions [past_len, past_len + S). Fused aclnnApplyRotaryPosEmbV2 is 1 op // vs 8-op manual version — saves ~7 kernel launches/layer × 94 layers = 658/token. if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) { void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2; void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2; apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr); } else { std::vector cos_h, sin_h; fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta); DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2); ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_dev.get(), sin_dev.get(), rope_scratch); // Local DeviceBuffers would be freed on return while async kernels still read them. ACL_CHECK(aclrtSynchronizeStream(stream)); } // 5. Append K, V to cache at [past_len, past_len + S) ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2, k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream)); ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2, v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream)); // 6. FIAS: q [1, S, Q_DIM], k/v [1, kv_len, KV_DIM] from cache int64_t kv_len = past_len + S; auto t_q_bsh = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Q_DIM}); auto t_k_bsh = make_contig_tensor(k_cache, ACL_BF16, {1, kv_len, KV_DIM}); auto t_v_bsh = make_contig_tensor(v_cache, ACL_BF16, {1, kv_len, KV_DIM}); // FIAS writes to a separate buffer (attn_out_scratch) — aliasing q→out is unsafe. auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM}); // sparse_mode selection: // 3 = left-top causal (prefill, q.S == kv.S with 2048 mask) // 0 = user mask (decode with cache, batch verify) // -1 (sentinel) = auto: 3 if mask given & past_len==0 & S>1 (prefill), else 0 int64_t sparse = sparse_mode; if (sparse < 0) { sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0; } fused_infer_attention_score( stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(), mask_tensor, {S}, {kv_len}, Hq, Hkv, scale, sparse, t_attn_out_bsh.get()); // 7. O projection: y = attn_out @ o_proj.T → [S, D] auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM}); auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get()); // 8. TP AllReduce on x_out (row-parallel o_proj → SUM across ranks) if (hccl_ctx && hccl_ctx->tp_size > 1) { hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream); } } // MoE forward for a single layer. Residual NOT applied here. // // x_in [S, D] (hidden state, pre post_attention_layernorm) // x_out [S, D] (MoE output) // // Scratch: // xn_scratch: S * D * 2 // rstd_scratch: S * 4 // logits_scratch: S * E * 2 // topk_w_scratch: S * K * 2 // topk_idx_scratch: S * K * 4 // row_idx_scratch: S * K * 4 (gating output unused) // expanded_x_scratch: TOTAL * D * 2 // expanded_ri_scratch:TOTAL * 4 // tpe_scratch: E * 8 // fwd_dev: TOTAL * 8 // gate_out_scratch: TOTAL * I * 2 // up_out_scratch: TOTAL * I * 2 // down_out_scratch: TOTAL * D * 2 // packed_scratch: TOTAL * D * 2 // weighted_scratch: S * K * D * 2 // // where TOTAL = S * K, I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok. // // IMPORTANT: post_attention_layernorm weight in `attn_w` (not in LayerMoEWeights). inline void moe_forward( aclrtStream stream, const ModelConfig& cfg, LayerAttnWeights& attn_w, // for post_attention_layernorm LayerMoEWeights& w, void* x_in, int64_t S, void* xn_scratch, void* rstd_scratch, void* logits_scratch, void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch, void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch, void* fwd_scratch, void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch, void* packed_scratch, void* weighted_scratch, void* x_out, HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce after reduce_sum void* norm_sum_scratch = nullptr // S * 2 bytes — persistent buffer for topk_w normalize ) { const int64_t D = cfg.hidden_size; const int64_t I = cfg.i_per_rank; const int64_t E = cfg.num_experts; const int64_t K = cfg.num_experts_per_tok; const double eps = cfg.rms_norm_eps; const int64_t TOTAL = S * K; // 1. post_attention_layernorm auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D}); auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D}); auto t_lnw = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D}); auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S}); rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get()); // 2. Router linear: logits = xn @ router.T → [S, E] auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E}); linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get()); // 3. TopK softmax auto t_topk_w = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K}); auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K}); auto t_row_idx = make_contig_tensor(row_idx_scratch, ACL_INT32, {S, K}); moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get()); // 4. Device-side normalize topk weights (Qwen3 norm_topk_prob=true). // sum = reduce_sum(topk_w, dim=-1, keepdim=true) # [S, 1] F32 in rstd_scratch // sum += 1e-20 // sum_bf16 = cast(sum, BF16) # [S, 1] in norm_sum_scratch (caller-owned) // topk_w /= sum_bf16 # broadcast divide // No per-layer syncs — all scratch buffers persist across layers. if (norm_sum_scratch) { auto t_sum = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S, 1}); auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch, ACL_BF16, {S, 1}); reduce_sum(stream, t_topk_w.get(), {-1}, /*keep_dims=*/true, ACL_FLOAT, t_sum.get()); inplace_adds(stream, t_sum.get(), 1e-20); cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get()); div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get()); } else { // Fallback: host-side normalize (for callers that didn't provide scratch). ACL_CHECK(aclrtSynchronizeStream(stream)); std::vector h_tw(S * K); ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); for (int s = 0; s < S; s++) { float sum = 0; for (int k = 0; k < K; k++) { uint32_t u = (uint32_t)h_tw[s*K + k] << 16; float v; std::memcpy(&v, &u, 4); sum += v; } sum += 1e-20f; for (int k = 0; k < K; k++) { uint32_t u = (uint32_t)h_tw[s*K + k] << 16; float v; std::memcpy(&v, &u, 4); v /= sum; std::memcpy(&u, &v, 4); h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); } } ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE)); } // 5. MoE init routing auto t_ex_x = make_contig_tensor(expanded_x_scratch, ACL_BF16, {TOTAL, D}); auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL}); auto t_tpe = make_contig_tensor(tpe_scratch, ACL_INT64, {E}); moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(), E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get()); // 6. GMM gate + up auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I}); auto t_up_out = make_contig_tensor(up_out_scratch, ACL_BF16, {TOTAL, I}); auto t_w_gate = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I}); auto t_w_up = make_contig_tensor(w.up_exps.get(), ACL_BF16, {E, D, I}); grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1); grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(), t_tpe.get(), t_up_out.get(), 1); // 7. SwiGLU: gate_out = silu(gate_out) * up_out silu(stream, t_gate_out.get(), t_gate_out.get()); mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get()); // 8. GMM down auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D}); auto t_w_down = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D}); grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1); // 9. Device-side finalize: build forward perm via two consecutive argsorts on topk_idx. // No host sync — safe for graph capture. // inv_fwd = argsort(topk_idx.flat) // each (n,k) → sorted position (primary key: expert_id) // fwd = argsort(inv_fwd) // inverse perm — what IndexSelect needs // Stability: aclnnArgsort preserves input order for equal keys; flat index = n*K + k orders // ties by n-then-k, matching our previous manual sort convention. // // Scratch for inv_fwd: reuse first TOTAL*8 bytes of weighted_scratch (gets overwritten // by the subsequent mul op, so aliasing is safe). { auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch, ACL_INT32, {TOTAL}); auto t_inv_fwd = make_contig_tensor(weighted_scratch, ACL_INT64, {TOTAL}); auto t_fwd_64 = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL}); argsort(stream, t_topk_idx_flat.get(), /*dim=*/0, /*descending=*/false, t_inv_fwd.get()); argsort(stream, t_inv_fwd.get(), /*dim=*/0, /*descending=*/false, t_fwd_64.get()); } auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL}); auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D}); index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get()); auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D}); auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1}); auto t_weighted = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D}); mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get()); auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get()); // TP AllReduce on MoE output (column-parallel experts → SUM partial intermediate outputs) if (hccl_ctx && hccl_ctx->tp_size > 1) { hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream); } }