| // runner.h β multi-layer transformer Runner for Qwen3-235B-A22B. | |
| // | |
| // Owns: shared weights, per-layer attention + MoE weights, KV cache, scratch buffers. | |
| // Provides: prefill(tokens) and decode(new_token) methods returning logits [vocab] on device. | |
| // | |
| // Memory budget at TP=1 for testing a SUBSET of layers (num_layers_to_load <= 94). Full 94-layer | |
| // inference requires TP=16 where per-rank MoE fits ~28GB. | |
| class Runner { | |
| public: | |
| Runner() = default; | |
| ~Runner() = default; | |
| Runner(const Runner&) = delete; | |
| Runner& operator=(const Runner&) = delete; | |
| // Initialize runtime, open safetensors, load shared weights. tp_size/tp_rank configure | |
| // MoE + attention sharding. num_layers is how many transformer blocks to load (1..94). | |
| // max_seq is the maximum sequence length (for KV cache allocation). | |
| bool init(const std::string& model_dir, int tp_size, int tp_rank, | |
| int num_layers_to_load, int64_t max_seq, int device_id = 0); | |
| // Prefill: ingest S>=1 tokens, produces logits [vocab] for the LAST position. Populates KV | |
| // cache starting at position 0. `hidden_out` optionally returns the final hidden state [S, D]. | |
| bool prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out); | |
| // Decode: take 1 new token, produce logits [vocab] from the new position. | |
| bool decode(int32_t token, DeviceBuffer& logits_out); | |
| // Batched decode: take S tokens as "candidate verify batch" at positions [past_len..past_len+S), | |
| // produce logits [S, vocab]. Uses causal-with-past mask (token i sees past+tokens[0..i]). | |
| // Foundation for speculative decoding / PLD. | |
| // tokens: [S] int32 | |
| // S: 1 .. 16 | |
| // all_logits_out: will hold S * vocab_size * 2 bytes BF16, row-major [S, V] | |
| // Updates past_len by +S on success. | |
| bool decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out); | |
| // Warmup: run N dummy decode() calls (resetting cache) to pre-compile aclnn executors, | |
| // warm HCCL collective buffers, and stabilize NPU thermals. Improves first-N-token latency | |
| // by ~1 s (especially noticeable on short generations or REPL cold start). | |
| // Call after init(); safe to call multiple times. Does NOT affect past_len. | |
| void warmup(int iterations = 3); | |
| // Accessors | |
| const ModelConfig& cfg() const { return cfg_; } | |
| aclrtStream stream() { return rt_.stream(); } | |
| int64_t past_len() const { return past_len_; } | |
| void reset_cache() { past_len_ = 0; } | |
| // Rewind past_len by n. Used by speculative decoding to discard rejected draft tokens' | |
| // KV cache entries (they'll be overwritten by subsequent writes). | |
| void rewind_cache(int64_t n) { if (n > 0 && n <= past_len_) past_len_ -= n; } | |
| HcclCtx& hccl_ctx() { return hccl_ctx_; } | |
| // Profiling: set via LCA_PROFILE=1 env in main_cli. If enabled, decode() accumulates | |
| // per-phase wall-clock ms into the timer accumulators below. | |
| bool profile_enabled = false; | |
| double t_embed_ms = 0, t_layers_ms = 0, t_final_ms = 0; | |
| int64_t profile_calls = 0; | |
| void print_profile_summary() const; | |
| private: | |
| // One-layer forward: x_in [S, D] β x_out [S, D] via attention + residual + MoE + residual. | |
| // Uses this layer's KV cache starting at past_len; caller updates past_len after each call. | |
| // batch_decode_mode: true for S>1 at past_len>0 (spec decoding) β uses custom causal mask | |
| // with past instead of the 2048Γ2048 prefill mask. | |
| void layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, | |
| bool batch_decode_mode = false); | |
| // Build causal-with-past mask in batch_mask_dev_ for decode_batch at current past_len. | |
| // Shape [1, 1, S, past_len+S] bool, mask[i, j] = 1 iff j > past_len+i. | |
| void build_batch_decode_mask_(int64_t S); | |
| // Final: final_norm + lm_head on last position β logits [vocab]. | |
| void final_logits_(void* hidden_last /*[1, D]*/, DeviceBuffer& logits_out); | |
| // Batched final: final_norm + lm_head on [S, D] β logits [S, V]. | |
| void final_logits_batch_(void* hidden /*[S, D]*/, int64_t S, DeviceBuffer& logits_out); | |
| AclRuntime rt_; | |
| SafetensorsLoader st_; | |
| ModelConfig cfg_; | |
| HcclCtx hccl_ctx_; | |
| int num_layers_ = 0; | |
| int64_t max_seq_ = 0; | |
| SharedWeights shared_; | |
| std::vector<LayerAttnWeights> attn_; | |
| std::vector<LayerMoEWeights> moe_; | |
| // Per-layer KV cache | |
| std::vector<DeviceBuffer> k_cache_; | |
| std::vector<DeviceBuffer> v_cache_; | |
| // Scratch (reallocated per-call sized by current S) | |
| DeviceBuffer q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_; | |
| DeviceBuffer moe_xn_, moe_rstd_, moe_logits_; | |
| DeviceBuffer moe_topk_w_, moe_topk_idx_, moe_row_idx_; | |
| DeviceBuffer moe_ex_x_, moe_ex_ri_, moe_tpe_; | |
| DeviceBuffer moe_fwd_; | |
| DeviceBuffer moe_gate_, moe_up_, moe_down_; | |
| DeviceBuffer moe_packed_, moe_weighted_, moe_out_; | |
| DeviceBuffer moe_norm_sum_; // BF16 [S, 1] for on-device topk_w normalize | |
| DeviceBuffer x_buf_a_, x_buf_b_; // ping-pong for residual chain | |
| // Causal mask for prefill (2048 x 2048 bool); decode uses nullptr | |
| DeviceBuffer prefill_mask_dev_; | |
| // Batch decode mask: S_MAX Γ KV_MAX bool, where mask[i, j] = 1 (masked out) if | |
| // j > past_len + i. Built on-demand per-call (past_len changes). | |
| DeviceBuffer batch_mask_dev_; | |
| // Pre-computed RoPE cos/sin table (sized for max_seq_) | |
| RopeCache rope_cache_; | |
| int64_t past_len_ = 0; | |
| int64_t cur_S_capacity_ = 0; // scratch sized for this many tokens | |
| }; | |