llm_mutil_npu / include /runner.h
xianglarry's picture
Initial C++ aclnn EAGER inference for Qwen3-235B-A22B MoE on Ascend 910 Γ— 16 NPU
4b9fefd
// runner.h β€” multi-layer transformer Runner for Qwen3-235B-A22B.
//
// Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers.
// Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device.
//
// Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer
// inference requires TP=16 where per-rank MoE fits ~28GB.
#pragma once
#include "acl_common.h"
#include "acl_runtime.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "engine.h"
#include "hccl_comm.h"
#include "model_config.h"
#include "safetensors_loader.h"
#include <vector>
class Runner {
public:
Runner() = default;
~Runner() = default;
Runner(const Runner&) = delete;
Runner& operator=(const Runner&) = delete;
// Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure
// MoE + attention sharding. num_layers is how many transformer blocks to load (1..94).
// max_seq is the maximum sequence length (for KV cache allocation).
bool init(const std::string& model_dir, int tp_size, int tp_rank,
int num_layers_to_load, int64_t max_seq, int device_id = 0);
// Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV
// cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D].
bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out);
// Decode: take 1 new token, produce logits [vocab] from the new position.
bool decode(int32_t token, DeviceBuffer& logits_out);
// Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S),
// produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]).
// Foundation for speculative decoding / PLD.
// tokens: [S] int32
// S: 1 .. 16
// all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V]
// Updates past_len by +S on success.
bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out);
// Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors,
// warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency
// by ~1 s (especially noticeable on short generations or REPL cold start).
// Call after init(); safe to call multiple times. Does NOT affect past_len.
void warmup(int iterations = 3);
// Accessors
const ModelConfig& cfg() const { return cfg_; }
aclrtStream stream() { return rt_.stream(); }
int64_t past_len() const { return past_len_; }
void reset_cache() { past_len_ = 0; }
// Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens'
// KV cache entries (they'll be overwritten by subsequent writes).
void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; }
HcclCtx& hccl_ctx() { return hccl_ctx_; }
// Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates
// per-phase wall-clock ms into the timer accumulators below.
bool profile_enabled = false;
double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0;
int64_t profile_calls = 0;
void print_profile_summary() const;
private:
// One-layer forward: x_in [S, D] β†’ x_out [S, D] via attention + residual + MoE + residual.
// Uses this layer's KV cache starting at past_len; caller updates past_len after each call.
// batch_decode_mode: true for S>1 at past_len>0 (spec decoding) β€” uses custom causal mask
// with past instead of the 2048Γ—2048 prefill mask.
void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out,
bool batch_decode_mode = false);
// Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len.
// Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i.
void build_batch_decode_mask_(int64_t S);
// Final: final_norm + lm_head on last position β†’ logits [vocab].
void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out);
// Batched final: final_norm + lm_head on [S, D] β†’ logits [S, V].
void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out);
AclRuntime rt_;
SafetensorsLoader st_;
ModelConfig cfg_;
HcclCtx hccl_ctx_;
int num_layers_ = 0;
int64_t max_seq_ = 0;
SharedWeights shared_;
std::vector<LayerAttnWeights> attn_;
std::vector<LayerMoEWeights> moe_;
// Per-layer KV cache
std::vector<DeviceBuffer> k_cache_;
std::vector<DeviceBuffer> v_cache_;
// Scratch (reallocated per-call sized by current S)
DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_;
DeviceBuffer moe_xn_, moe_rstd_, moe_logits_;
DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_;
DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_;
DeviceBuffer moe_fwd_;
DeviceBuffer moe_gate_, moe_up_, moe_down_;
DeviceBuffer moe_packed_, moe_weighted_, moe_out_;
DeviceBuffer moe_norm_sum_; // BF16 [S, 1] for on-device topk_w normalize
DeviceBuffer x_buf_a_, x_buf_b_; // ping-pong for residual chain
// Causal mask for prefill (2048 x 2048 bool); decode uses nullptr
DeviceBuffer prefill_mask_dev_;
// Batch decode mask: S_MAX Γ— KV_MAX bool, where mask[i, j] = 1 (masked out) if
// j > past_len + i. Built on-demand per-call (past_len changes).
DeviceBuffer batch_mask_dev_;
// Pre-computed RoPE cos/sin table (sized for max_seq_)
RopeCache rope_cache_;
int64_t past_len_ = 0;
int64_t cur_S_capacity_ = 0; // scratch sized for this many tokens
};