llm_mutil_npu / src /runner.cpp
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 × 16 NPU
4b9fefd
#include "runner.h"
#include <chrono>
#include <cstdio>
#include <cstring>
// Expose HCCL context for the CLI broadcast helper.
HcclCtx* runner_hccl_ctx_shim(Runner& r) { return &r.hccl_ctx(); }
bool Runner::init(const std::string& model_dir, int tp_size, int tp_rank,
int num_layers_to_load, int64_t max_seq, int device_id) {
if (!cfg_.load_from_json(model_dir + "/config.json")) return false;
cfg_.compute_derived(tp_size, tp_rank);
if (num_layers_to_load < 1 || num_layers_to_load > (int)cfg_.num_hidden_layers) {
fprintf(stderr, "runner: invalid num_layers %d (max %ld)\n",
num_layers_to_load, cfg_.num_hidden_layers);
return false;
}
num_layers_ = num_layers_to_load;
max_seq_ = max_seq;
if (!st_.open(model_dir)) return false;
rt_.init(device_id);
// HCCL init (no-op if tp_size == 1)
if (!hccl_init(hccl_ctx_, tp_size, tp_rank)) {
fprintf(stderr, "runner: HCCL init failed\n");
return false;
}
DeviceWeightsLoader dw(st_, cfg_);
printf("runner: loading shared weights (embed, lm_head, final_norm)...\n");
if (!dw.load_shared(shared_)) return false;
attn_.resize(num_layers_);
moe_.resize(num_layers_);
k_cache_.resize(num_layers_);
v_cache_.resize(num_layers_);
const int64_t KV_DIM = cfg_.n_kv_heads_per_rank * cfg_.head_dim;
for (int L = 0; L < num_layers_; L++) {
printf("runner: loading layer %d/%d...\n", L + 1, num_layers_);
if (!dw.load_attention(L, attn_[L])) return false;
if (!dw.load_moe(L, rt_.stream(), moe_[L])) return false;
k_cache_[L].alloc(max_seq_ * KV_DIM * 2);
v_cache_[L].alloc(max_seq_ * KV_DIM * 2);
}
rt_.sync();
// Prefill mask (2048x2048 bool causal)
const int64_t MASK = 2048;
std::vector<uint8_t> mh(MASK * MASK, 0);
for (int i = 0; i < MASK; i++)
for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1;
prefill_mask_dev_.alloc(MASK * MASK);
ACL_CHECK(aclrtMemcpy(prefill_mask_dev_.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE));
// Pre-compute RoPE cos/sin table once (covers all positions up to max_seq_)
rope_cache_build(rope_cache_, max_seq_, cfg_.head_dim, cfg_.rope_theta);
past_len_ = 0;
cur_S_capacity_ = 0;
return true;
}
static void ensure_sc_(DeviceBuffer& buf, size_t needed) {
if (buf.size < needed) buf.alloc(needed);
}
static void ensure_all_scratch_(Runner* self, int64_t S, const ModelConfig& cfg,
DeviceBuffer& q_sc, DeviceBuffer& k_sc, DeviceBuffer& v_sc,
DeviceBuffer& xn_sc, DeviceBuffer& rstd_sc, DeviceBuffer& rope_sc,
DeviceBuffer& attn_fias_sc, DeviceBuffer& attn_out_sc,
DeviceBuffer& moe_xn, DeviceBuffer& moe_rstd, DeviceBuffer& moe_logits,
DeviceBuffer& moe_topk_w, DeviceBuffer& moe_topk_idx, DeviceBuffer& moe_row_idx,
DeviceBuffer& moe_ex_x, DeviceBuffer& moe_ex_ri, DeviceBuffer& moe_tpe,
DeviceBuffer& moe_fwd,
DeviceBuffer& moe_gate, DeviceBuffer& moe_up, DeviceBuffer& moe_down,
DeviceBuffer& moe_packed, DeviceBuffer& moe_weighted, DeviceBuffer& moe_out,
DeviceBuffer& moe_norm_sum,
DeviceBuffer& x_buf_a, DeviceBuffer& x_buf_b) {
(void)self;
const int64_t D = cfg.hidden_size;
const int64_t Hq = cfg.n_heads_per_rank, Hkv = cfg.n_kv_heads_per_rank;
const int64_t Dh = cfg.head_dim;
const int64_t Q_DIM = Hq * Dh;
const int64_t KV_DIM = Hkv * Dh;
const int64_t I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok;
const int64_t TOTAL = S * K;
ensure_sc_(q_sc, S * Q_DIM * 2);
ensure_sc_(k_sc, S * KV_DIM * 2);
ensure_sc_(v_sc, S * KV_DIM * 2);
ensure_sc_(xn_sc, S * D * 2);
ensure_sc_(rstd_sc, S * std::max(Hq, Hkv) * 4);
ensure_sc_(rope_sc, 1 * S * Hq * Dh * 2);
ensure_sc_(attn_fias_sc, S * Q_DIM * 2);
ensure_sc_(attn_out_sc, S * D * 2);
ensure_sc_(moe_xn, S * D * 2);
ensure_sc_(moe_rstd, S * 4);
ensure_sc_(moe_logits, S * E * 2);
ensure_sc_(moe_topk_w, S * K * 2);
ensure_sc_(moe_topk_idx, S * K * 4);
ensure_sc_(moe_row_idx, S * K * 4);
ensure_sc_(moe_ex_x, TOTAL * D * 2);
ensure_sc_(moe_ex_ri, TOTAL * 4);
ensure_sc_(moe_tpe, E * 8);
ensure_sc_(moe_fwd, TOTAL * 8);
ensure_sc_(moe_gate, TOTAL * I * 2);
ensure_sc_(moe_up, TOTAL * I * 2);
ensure_sc_(moe_down, TOTAL * D * 2);
ensure_sc_(moe_packed, TOTAL * D * 2);
ensure_sc_(moe_weighted, S * K * D * 2);
ensure_sc_(moe_out, S * D * 2);
ensure_sc_(moe_norm_sum, S * 2);
ensure_sc_(x_buf_a, S * D * 2);
ensure_sc_(x_buf_b, S * D * 2);
}
void Runner::layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode) {
const int64_t D = cfg_.hidden_size;
// Attention mask selection:
// prefill (S>1, past=0): 2048×2048 upper-tri + sparse_mode=3 (FIAS internal causal)
// decode (S==1): mask=nullptr + sparse_mode=0 (single query sees all cache)
// batch decode (S>1, past>0): S × (past+S) causal-with-past + sparse_mode=0
aclTensor* mask = nullptr;
int64_t sparse_mode = -1; // auto
AclTensorPtr t_mask_ptr;
if (batch_decode_mode) {
build_batch_decode_mask_(S);
int64_t kv_len = past_len_ + S;
t_mask_ptr = make_contig_tensor(batch_mask_dev_.get(), ACL_BOOL, {1, 1, S, kv_len});
mask = t_mask_ptr.get();
sparse_mode = 0;
} else if (S > 1) {
// Pure prefill from past=0
t_mask_ptr = make_contig_tensor(prefill_mask_dev_.get(), ACL_BOOL, {1, 1, 2048, 2048});
mask = t_mask_ptr.get();
sparse_mode = 3;
}
// else: S=1 decode, mask=nullptr, sparse_mode=0 (auto)
attention_forward(
rt_.stream(), cfg_, attn_[layer_idx],
x_in, S, past_len_,
k_cache_[layer_idx].get(), v_cache_[layer_idx].get(), max_seq_,
mask,
q_sc_.get(), k_sc_.get(), v_sc_.get(),
xn_sc_.get(), rstd_sc_.get(), rope_sc_.get(),
attn_fias_sc_.get(),
attn_out_sc_.get(),
(hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
&rope_cache_,
sparse_mode);
// x1 = x_in + attn_out (residual)
auto t_x_in = make_contig_tensor(x_in, ACL_BF16, {S, D});
auto t_attn_out= make_contig_tensor(attn_out_sc_.get(), ACL_BF16, {S, D});
auto t_x1 = make_contig_tensor(x_buf_a_.get(), ACL_BF16, {S, D});
{
float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
uint64_t ws = 0; aclOpExecutor* e = nullptr;
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x_in.get(), t_attn_out.get(), al, t_x1.get(), &ws, &e));
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
aclDestroyScalar(al);
}
// MoE
moe_forward(
rt_.stream(), cfg_, attn_[layer_idx], moe_[layer_idx],
x_buf_a_.get(), S,
moe_xn_.get(), moe_rstd_.get(),
moe_logits_.get(),
moe_topk_w_.get(), moe_topk_idx_.get(), moe_row_idx_.get(),
moe_ex_x_.get(), moe_ex_ri_.get(), moe_tpe_.get(),
moe_fwd_.get(),
moe_gate_.get(), moe_up_.get(), moe_down_.get(),
moe_packed_.get(), moe_weighted_.get(),
moe_out_.get(),
(hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr,
moe_norm_sum_.get());
// x_out = x1 + moe_out (residual)
auto t_moe_out = make_contig_tensor(moe_out_.get(), ACL_BF16, {S, D});
auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
{
float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT);
uint64_t ws = 0; aclOpExecutor* e = nullptr;
ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x1.get(), t_moe_out.get(), al, t_out.get(), &ws, &e));
DeviceBuffer wb; if (ws > 0) wb.alloc(ws);
ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream()));
aclDestroyScalar(al);
}
}
void Runner::final_logits_(void* hidden_last, DeviceBuffer& logits_out) {
// Single-position variant: hidden_last is [1, D], output [1, V].
final_logits_batch_(hidden_last, 1, logits_out);
}
void Runner::final_logits_batch_(void* hidden, int64_t S, DeviceBuffer& logits_out) {
const int64_t D = cfg_.hidden_size;
const int64_t V = cfg_.vocab_size;
DeviceBuffer hn(S * D * 2), rstd(S * 4);
auto t_h = make_contig_tensor(hidden, ACL_BF16, {S, D});
auto t_hn = make_contig_tensor(hn.get(), ACL_BF16, {S, D});
auto t_lnw = make_contig_tensor(shared_.final_norm.get(), ACL_BF16, {D});
auto t_rstd = make_contig_tensor(rstd.get(), ACL_FLOAT, {S});
rms_norm(rt_.stream(), t_h.get(), t_lnw.get(), cfg_.rms_norm_eps, t_hn.get(), t_rstd.get());
logits_out.alloc(S * V * 2);
auto t_logits = make_contig_tensor(logits_out.get(), ACL_BF16, {S, V});
linear_hf(rt_.stream(), t_hn.get(), shared_.lm_head.get(), ACL_BF16, V, D, t_logits.get());
}
bool Runner::decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out) {
if (S < 1) return false;
if (past_len_ + S > max_seq_) {
fprintf(stderr, "runner: decode_batch exceeds max_seq (%ld + %ld > %ld)\n",
past_len_, S, max_seq_);
return false;
}
const int64_t D = cfg_.hidden_size;
ensure_all_scratch_(this, S, cfg_,
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
moe_xn_, moe_rstd_, moe_logits_,
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
moe_ex_x_, moe_ex_ri_, moe_tpe_,
moe_fwd_,
moe_gate_, moe_up_, moe_down_,
moe_packed_, moe_weighted_, moe_out_,
moe_norm_sum_,
x_buf_a_, x_buf_b_);
// Embed S tokens
DeviceBuffer tok_dev(S * 4);
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
DeviceBuffer x0(S * D * 2);
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
void* cur_in = xping.get();
void* cur_out = xpong.get();
// batch_decode_mode=true uses proper causal-with-past mask (S × past+S, sparse_mode=0).
for (int L = 0; L < num_layers_; L++) {
layer_forward_(L, S, cur_in, cur_out, /*batch_decode_mode=*/past_len_ > 0);
std::swap(cur_in, cur_out);
}
rt_.sync();
// Get logits for ALL S positions (not just last)
final_logits_batch_(cur_in, S, all_logits_out);
rt_.sync();
past_len_ += S;
return true;
}
bool Runner::prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out) {
if (S < 1) return false;
if (past_len_ + S > max_seq_) {
fprintf(stderr, "runner: prefill exceeds max_seq (%ld + %ld > %ld)\n",
past_len_, S, max_seq_);
return false;
}
const int64_t D = cfg_.hidden_size;
ensure_all_scratch_(this, S, cfg_,
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
moe_xn_, moe_rstd_, moe_logits_,
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
moe_ex_x_, moe_ex_ri_, moe_tpe_,
moe_fwd_,
moe_gate_, moe_up_, moe_down_,
moe_packed_, moe_weighted_, moe_out_,
moe_norm_sum_,
x_buf_a_, x_buf_b_);
// Embed
DeviceBuffer tok_dev(S * 4);
ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S});
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
DeviceBuffer x0(S * D * 2);
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D});
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
// Layer chain: ping-pong between two buffers
DeviceBuffer xping(S * D * 2), xpong(S * D * 2);
ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
void* cur_in = xping.get();
void* cur_out = xpong.get();
for (int L = 0; L < num_layers_; L++) {
layer_forward_(L, S, cur_in, cur_out);
std::swap(cur_in, cur_out);
}
rt_.sync();
// Take last position's hidden → final_logits
DeviceBuffer last(1 * D * 2);
ACL_CHECK(aclrtMemcpy(last.get(), 1*D*2,
(char*)cur_in + (S - 1) * D * 2, 1*D*2,
ACL_MEMCPY_DEVICE_TO_DEVICE));
final_logits_(last.get(), logits_out);
rt_.sync();
past_len_ += S;
return true;
}
bool Runner::decode(int32_t token, DeviceBuffer& logits_out) {
const int64_t D = cfg_.hidden_size;
if (past_len_ + 1 > max_seq_) {
fprintf(stderr, "runner: decode exceeds max_seq\n");
return false;
}
const int64_t S = 1;
ensure_all_scratch_(this, S, cfg_,
q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_,
moe_xn_, moe_rstd_, moe_logits_,
moe_topk_w_, moe_topk_idx_, moe_row_idx_,
moe_ex_x_, moe_ex_ri_, moe_tpe_,
moe_fwd_,
moe_gate_, moe_up_, moe_down_,
moe_packed_, moe_weighted_, moe_out_,
moe_norm_sum_,
x_buf_a_, x_buf_b_);
DeviceBuffer tok_dev(1 * 4);
ACL_CHECK(aclrtMemcpy(tok_dev.get(), 4, &token, 4, ACL_MEMCPY_HOST_TO_DEVICE));
auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {1});
auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D});
auto t0 = std::chrono::steady_clock::now();
DeviceBuffer x0(1 * D * 2);
auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {1, D});
index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get());
DeviceBuffer xping(1 * D * 2), xpong(1 * D * 2);
ACL_CHECK(aclrtMemcpyAsync(xping.get(), 1*D*2, x0.get(), 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream()));
if (profile_enabled) { ACL_CHECK(aclrtSynchronizeStream(rt_.stream())); }
auto t1 = std::chrono::steady_clock::now();
void* cur_in = xping.get();
void* cur_out = xpong.get();
for (int L = 0; L < num_layers_; L++) {
layer_forward_(L, 1, cur_in, cur_out);
std::swap(cur_in, cur_out);
}
rt_.sync();
auto t2 = std::chrono::steady_clock::now();
final_logits_(cur_in, logits_out);
rt_.sync();
auto t3 = std::chrono::steady_clock::now();
if (profile_enabled) {
using ms = std::chrono::duration<double, std::milli>;
t_embed_ms += ms(t1 - t0).count();
t_layers_ms += ms(t2 - t1).count();
t_final_ms += ms(t3 - t2).count();
profile_calls++;
}
past_len_ += 1;
return true;
}
void Runner::build_batch_decode_mask_(int64_t S) {
int64_t kv_len = past_len_ + S;
size_t bytes = (size_t)S * kv_len; // bool = 1 byte
if (batch_mask_dev_.size < bytes) batch_mask_dev_.alloc(bytes);
std::vector<uint8_t> h_mask(bytes, 0);
for (int64_t i = 0; i < S; i++) {
// Row i: positions j ≤ past_len_+i are visible (0), j > past_len_+i are masked (1).
for (int64_t j = past_len_ + i + 1; j < kv_len; j++) {
h_mask[i * kv_len + j] = 1;
}
}
ACL_CHECK(aclrtMemcpy(batch_mask_dev_.get(), bytes, h_mask.data(), bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
}
void Runner::warmup(int iterations) {
if (num_layers_ == 0) return;
int64_t saved_past = past_len_;
past_len_ = 0;
int32_t dummy_tok = 0; // token id 0, valid for Qwen3 (bos)
DeviceBuffer dummy_logits;
for (int i = 0; i < iterations; i++) {
past_len_ = 0;
if (!decode(dummy_tok, dummy_logits)) break;
}
past_len_ = saved_past;
fprintf(stderr, "[runner] warmup: %d iterations done\n", iterations);
}
void Runner::print_profile_summary() const {
if (!profile_enabled || profile_calls == 0) return;
double total = t_embed_ms + t_layers_ms + t_final_ms;
fprintf(stderr, "\n=== Runner profile (%ld decode calls) ===\n", profile_calls);
fprintf(stderr, " phase total_ms avg_ms/call pct\n");
fprintf(stderr, " embed %8.1f %10.3f %5.1f%%\n",
t_embed_ms, t_embed_ms / profile_calls, 100.0 * t_embed_ms / total);
fprintf(stderr, " layers (x%d) %8.1f %10.3f %5.1f%% → %.3f ms/layer/call\n",
num_layers_, t_layers_ms, t_layers_ms / profile_calls,
100.0 * t_layers_ms / total,
t_layers_ms / profile_calls / num_layers_);
fprintf(stderr, " final+lm_hd %8.1f %10.3f %5.1f%%\n",
t_final_ms, t_final_ms / profile_calls, 100.0 * t_final_ms / total);
fprintf(stderr, " total %8.1f %10.3f 100.0%%\n",
total, total / profile_calls);
}