// runner.h — multi-layer transformer Runner for Qwen3-235B-A22B. // // Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers. // Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device. // // Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer // inference requires TP=16 where per-rank MoE fits ~28GB. #pragma once #include "acl_common.h" #include "acl_runtime.h" #include "aclnn_ops.h" #include "device_weights.h" #include "engine.h" #include "hccl_comm.h" #include "model_config.h" #include "safetensors_loader.h" #include class Runner { public: Runner() = default; ~Runner() = default; Runner(const Runner&) = delete; Runner& operator=(const Runner&) = delete; // Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure // MoE + attention sharding. num_layers is how many transformer blocks to load (1..94). // max_seq is the maximum sequence length (for KV cache allocation). bool init(const std::string& model_dir, int tp_size, int tp_rank, int num_layers_to_load, int64_t max_seq, int device_id = 0); // Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV // cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D]. bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out); // Decode: take 1 new token, produce logits [vocab] from the new position. bool decode(int32_t token, DeviceBuffer& logits_out); // Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S), // produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]). // Foundation for speculative decoding / PLD. // tokens: [S] int32 // S: 1 .. 16 // all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V] // Updates past_len by +S on success. bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out); // Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors, // warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency // by ~1 s (especially noticeable on short generations or REPL cold start). // Call after init(); safe to call multiple times. Does NOT affect past_len. void warmup(int iterations = 3); // Accessors const ModelConfig& cfg() const { return cfg_; } aclrtStream stream() { return rt_.stream(); } int64_t past_len() const { return past_len_; } void reset_cache() { past_len_ = 0; } // Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens' // KV cache entries (they'll be overwritten by subsequent writes). void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; } HcclCtx& hccl_ctx() { return hccl_ctx_; } // Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates // per-phase wall-clock ms into the timer accumulators below. bool profile_enabled = false; double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0; int64_t profile_calls = 0; void print_profile_summary() const; private: // One-layer forward: x_in [S, D] → x_out [S, D] via attention + residual + MoE + residual. // Uses this layer's KV cache starting at past_len; caller updates past_len after each call. // batch_decode_mode: true for S>1 at past_len>0 (spec decoding) — uses custom causal mask // with past instead of the 2048×2048 prefill mask. void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode = false); // Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len. // Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i. void build_batch_decode_mask_(int64_t S); // Final: final_norm + lm_head on last position → logits [vocab]. void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out); // Batched final: final_norm + lm_head on [S, D] → logits [S, V]. void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out); AclRuntime rt_; SafetensorsLoader st_; ModelConfig cfg_; HcclCtx hccl_ctx_; int num_layers_ = 0; int64_t max_seq_ = 0; SharedWeights shared_; std::vector attn_; std::vector moe_; // Per-layer KV cache std::vector k_cache_; std::vector v_cache_; // Scratch (reallocated per-call sized by current S) DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_; DeviceBuffer moe_xn_, moe_rstd_, moe_logits_; DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_; DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_; DeviceBuffer moe_fwd_; DeviceBuffer moe_gate_, moe_up_, moe_down_; DeviceBuffer moe_packed_, moe_weighted_, moe_out_; DeviceBuffer moe_norm_sum_; // BF16 [S, 1] for on-device topk_w normalize DeviceBuffer x_buf_a_, x_buf_b_; // ping-pong for residual chain // Causal mask for prefill (2048 x 2048 bool); decode uses nullptr DeviceBuffer prefill_mask_dev_; // Batch decode mask: S_MAX × KV_MAX bool, where mask[i, j] = 1 (masked out) if // j > past_len + i. Built on-demand per-call (past_len changes). DeviceBuffer batch_mask_dev_; // Pre-computed RoPE cos/sin table (sized for max_seq_) RopeCache rope_cache_; int64_t past_len_ = 0; int64_t cur_S_capacity_ = 0; // scratch sized for this many tokens };