llm_mutil_npu / include /engine.h
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 Γ— 16 NPU
4b9fefd
// engine.h β€” single-layer forward functions for attention and MoE.
//
// Both functions operate on device tensors. The caller owns all buffers (input, output, weights,
// KV cache slots, scratch). They take RoPE cos/sin tables and act as pure forward kernels.
//
// Design goals:
// - Zero allocations per call (all scratch is passed in)
// - Same signature works for prefill (S>=1) and decode (S=1); caller picks sparse_mode.
// - Residual connection is NOT included (caller decides when to add residual).
#pragma once
#include "acl_common.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "hccl_comm.h"
#include "model_config.h"
#include "rope.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <tuple>
#include <vector>
// Bf16 conversion helpers used by fill_cos_sin.
static inline uint16_t _engine_f2bf16(float x) {
uint32_t u; std::memcpy(&u, &x, 4);
return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}
// Fill cos/sin tables for positions [p0, p0+L) with HF half-half layout. Returns
// contiguous [L*Dh] BF16 in provided host vectors (caller uploads to device).
inline void fill_cos_sin_hf(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
int64_t p0, int64_t L, int64_t Dh, float theta) {
cos_h.resize(L * Dh);
sin_h.resize(L * Dh);
int64_t half = Dh / 2;
for (int64_t s = 0; s < L; s++) {
for (int64_t d = 0; d < Dh; d++) {
int64_t pair = (d < half) ? d : (d - half);
float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
float angle = (float)(p0 + s) * theta_pair;
cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle));
sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle));
}
}
}
// Precomputed RoPE cos/sin table: BF16 [max_seq, Dh]. One-time cost per runtime.
struct RopeCache {
DeviceBuffer cos; // [max_seq, Dh] BF16
DeviceBuffer sin; // [max_seq, Dh] BF16
int64_t max_seq = 0;
int64_t head_dim = 0;
float theta = 0.0f;
};
inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) {
std::vector<uint16_t> cos_h, sin_h;
fill_cos_sin_hf(cos_h, sin_h, /*p0=*/0, max_seq, head_dim, theta);
rc.cos.alloc(max_seq * head_dim * 2);
rc.sin.alloc(max_seq * head_dim * 2);
ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta;
return true;
}
// Attention forward for a single layer.
//
// x_in [S, D] (hidden state, pre input_layernorm)
// x_out [S, D] (attention output β€” NOT residual-added)
//
// K cache / V cache are contiguous [MAX_LEN, KV_DIM] BF16 buffers. This call writes new
// positions at [past_len, past_len+S) and then runs FIAS over [0, past_len+S).
//
// Scratch requirements:
// q_scratch: S * Q_DIM * 2 bytes
// k_scratch: S * KV_DIM * 2 bytes
// v_scratch: S * KV_DIM * 2 bytes
// xn_scratch: S * D * 2 bytes
// rstd_scratch: S * 4 bytes (RmsNorm rstd output)
// rope_scratch: S * Hq * Dh * 2 bytes
//
// mask: [1, 1, 2048, 2048] bool for prefill; ignored (pass nullptr) for decode.
inline void attention_forward(
aclrtStream stream,
const ModelConfig& cfg,
LayerAttnWeights& w,
void* x_in, // [S, D] BF16
int64_t S,
int64_t past_len, // prior KV positions
void* k_cache, void* v_cache, int64_t max_len,
aclTensor* mask_tensor, // may be nullptr for decode
void* q_scratch, void* k_scratch, void* v_scratch,
void* xn_scratch, void* rstd_scratch, void* rope_scratch,
void* attn_out_scratch, // S * Q_DIM * 2 bytes (FIAS output before o_proj)
void* x_out, // [S, D] BF16
HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce x_out after o_proj
const RopeCache* rope_cache = nullptr, // if provided, use cached cos/sin table; avoids per-call H2D
int64_t sparse_mode = -1 // -1=auto (3 for prefill, 0 for decode); explicit 0/3 overrides
) {
const int64_t D = cfg.hidden_size;
const int64_t Hq = cfg.n_heads_per_rank;
const int64_t Hkv = cfg.n_kv_heads_per_rank;
const int64_t Dh = cfg.head_dim;
const int64_t Q_DIM = Hq * Dh;
const int64_t KV_DIM = Hkv * Dh;
const double scale = 1.0 / std::sqrt((double)Dh);
const double eps = cfg.rms_norm_eps;
const float theta = cfg.rope_theta;
// 1. Input layernorm: xn = rmsnorm(x_in, input_layernorm_weight)
auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D});
auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
// 2. Q/K/V projection
auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM});
auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM});
auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM});
linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get());
linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());
// 3. Per-head q_norm, k_norm
auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq, Dh});
auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh});
auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh});
auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh});
// reuse rstd_scratch split or allocate? reuse xn_scratch's first S*Hq*4 bytes.
// Simpler: require rstd_scratch to have max(S, S*max(Hq,Hkv)) * 4 bytes.
// For single-rank attention tests we pass enough.
auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq});
auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv});
rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());
// 4. RoPE: positions [past_len, past_len + S). Fused aclnnApplyRotaryPosEmbV2 is 1 op
// vs 8-op manual version β€” saves ~7 kernel launches/layer Γ— 94 layers = 658/token.
if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) {
void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2;
void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2;
apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr);
} else {
std::vector<uint16_t> cos_h, sin_h;
fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta);
DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2);
ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv,
cos_dev.get(), sin_dev.get(), rope_scratch);
// Local DeviceBuffers would be freed on return while async kernels still read them.
ACL_CHECK(aclrtSynchronizeStream(stream));
}
// 5. Append K, V to cache at [past_len, past_len + S)
ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
// 6. FIAS: q [1, S, Q_DIM], k/v [1, kv_len, KV_DIM] from cache
int64_t kv_len = past_len + S;
auto t_q_bsh = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Q_DIM});
auto t_k_bsh = make_contig_tensor(k_cache, ACL_BF16, {1, kv_len, KV_DIM});
auto t_v_bsh = make_contig_tensor(v_cache, ACL_BF16, {1, kv_len, KV_DIM});
// FIAS writes to a separate buffer (attn_out_scratch) — aliasing q→out is unsafe.
auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM});
// sparse_mode selection:
// 3 = left-top causal (prefill, q.S == kv.S with 2048 mask)
// 0 = user mask (decode with cache, batch verify)
// -1 (sentinel) = auto: 3 if mask given & past_len==0 & S>1 (prefill), else 0
int64_t sparse = sparse_mode;
if (sparse < 0) {
sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0;
}
fused_infer_attention_score(
stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
mask_tensor, {S}, {kv_len},
Hq, Hkv, scale, sparse, t_attn_out_bsh.get());
// 7. O projection: y = attn_out @ o_proj.T β†’ [S, D]
auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM});
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get());
// 8. TP AllReduce on x_out (row-parallel o_proj β†’ SUM across ranks)
if (hccl_ctx && hccl_ctx->tp_size > 1) {
hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
}
}
// MoE forward for a single layer. Residual NOT applied here.
//
// x_in [S, D] (hidden state, pre post_attention_layernorm)
// x_out [S, D] (MoE output)
//
// Scratch:
// xn_scratch: S * D * 2
// rstd_scratch: S * 4
// logits_scratch: S * E * 2
// topk_w_scratch: S * K * 2
// topk_idx_scratch: S * K * 4
// row_idx_scratch: S * K * 4 (gating output unused)
// expanded_x_scratch: TOTAL * D * 2
// expanded_ri_scratch:TOTAL * 4
// tpe_scratch: E * 8
// fwd_dev: TOTAL * 8
// gate_out_scratch: TOTAL * I * 2
// up_out_scratch: TOTAL * I * 2
// down_out_scratch: TOTAL * D * 2
// packed_scratch: TOTAL * D * 2
// weighted_scratch: S * K * D * 2
//
// where TOTAL = S * K, I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok.
//
// IMPORTANT: post_attention_layernorm weight in `attn_w` (not in LayerMoEWeights).
inline void moe_forward(
aclrtStream stream,
const ModelConfig& cfg,
LayerAttnWeights& attn_w, // for post_attention_layernorm
LayerMoEWeights& w,
void* x_in, int64_t S,
void* xn_scratch, void* rstd_scratch,
void* logits_scratch,
void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch,
void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch,
void* fwd_scratch,
void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch,
void* packed_scratch, void* weighted_scratch,
void* x_out,
HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce after reduce_sum
void* norm_sum_scratch = nullptr // S * 2 bytes β€” persistent buffer for topk_w normalize
) {
const int64_t D = cfg.hidden_size;
const int64_t I = cfg.i_per_rank;
const int64_t E = cfg.num_experts;
const int64_t K = cfg.num_experts_per_tok;
const double eps = cfg.rms_norm_eps;
const int64_t TOTAL = S * K;
// 1. post_attention_layernorm
auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D});
auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D});
auto t_lnw = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D});
auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());
// 2. Router linear: logits = xn @ router.T β†’ [S, E]
auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E});
linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get());
// 3. TopK softmax
auto t_topk_w = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K});
auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K});
auto t_row_idx = make_contig_tensor(row_idx_scratch, ACL_INT32, {S, K});
moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());
// 4. Device-side normalize topk weights (Qwen3 norm_topk_prob=true).
// sum = reduce_sum(topk_w, dim=-1, keepdim=true) # [S, 1] F32 in rstd_scratch
// sum += 1e-20
// sum_bf16 = cast(sum, BF16) # [S, 1] in norm_sum_scratch (caller-owned)
// topk_w /= sum_bf16 # broadcast divide
// No per-layer syncs β€” all scratch buffers persist across layers.
if (norm_sum_scratch) {
auto t_sum = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S, 1});
auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch, ACL_BF16, {S, 1});
reduce_sum(stream, t_topk_w.get(), {-1}, /*keep_dims=*/true, ACL_FLOAT, t_sum.get());
inplace_adds(stream, t_sum.get(), 1e-20);
cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get());
div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get());
} else {
// Fallback: host-side normalize (for callers that didn't provide scratch).
ACL_CHECK(aclrtSynchronizeStream(stream));
std::vector<uint16_t> h_tw(S * K);
ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
for (int s = 0; s < S; s++) {
float sum = 0;
for (int k = 0; k < K; k++) {
uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
float v; std::memcpy(&v, &u, 4);
sum += v;
}
sum += 1e-20f;
for (int k = 0; k < K; k++) {
uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
float v; std::memcpy(&v, &u, 4);
v /= sum;
std::memcpy(&u, &v, 4);
h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}
}
ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
}
// 5. MoE init routing
auto t_ex_x = make_contig_tensor(expanded_x_scratch, ACL_BF16, {TOTAL, D});
auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL});
auto t_tpe = make_contig_tensor(tpe_scratch, ACL_INT64, {E});
moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(),
E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get());
// 6. GMM gate + up
auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I});
auto t_up_out = make_contig_tensor(up_out_scratch, ACL_BF16, {TOTAL, I});
auto t_w_gate = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I});
auto t_w_up = make_contig_tensor(w.up_exps.get(), ACL_BF16, {E, D, I});
grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1);
grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(), t_tpe.get(), t_up_out.get(), 1);
// 7. SwiGLU: gate_out = silu(gate_out) * up_out
silu(stream, t_gate_out.get(), t_gate_out.get());
mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get());
// 8. GMM down
auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D});
auto t_w_down = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D});
grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1);
// 9. Device-side finalize: build forward perm via two consecutive argsorts on topk_idx.
// No host sync β€” safe for graph capture.
// inv_fwd = argsort(topk_idx.flat) // each (n,k) β†’ sorted position (primary key: expert_id)
// fwd = argsort(inv_fwd) // inverse perm β€” what IndexSelect needs
// Stability: aclnnArgsort preserves input order for equal keys; flat index = n*K + k orders
// ties by n-then-k, matching our previous manual sort convention.
//
// Scratch for inv_fwd: reuse first TOTAL*8 bytes of weighted_scratch (gets overwritten
// by the subsequent mul op, so aliasing is safe).
{
auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch, ACL_INT32, {TOTAL});
auto t_inv_fwd = make_contig_tensor(weighted_scratch, ACL_INT64, {TOTAL});
auto t_fwd_64 = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
argsort(stream, t_topk_idx_flat.get(), /*dim=*/0, /*descending=*/false, t_inv_fwd.get());
argsort(stream, t_inv_fwd.get(), /*dim=*/0, /*descending=*/false, t_fwd_64.get());
}
auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D});
index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get());
auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D});
auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1});
auto t_weighted = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D});
mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get());
// TP AllReduce on MoE output (column-parallel experts β†’ SUM partial intermediate outputs)
if (hccl_ctx && hccl_ctx->tp_size > 1) {
hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
}
}