#include "runner.h" #include #include #include // Expose HCCL context for the CLI broadcast helper. HcclCtx* runner_hccl_ctx_shim(Runner& r) { return &r.hccl_ctx(); } bool Runner::init(const std::string& model_dir, int tp_size, int tp_rank, int num_layers_to_load, int64_t max_seq, int device_id) { if (!cfg_.load_from_json(model_dir + "/config.json")) return false; cfg_.compute_derived(tp_size, tp_rank); if (num_layers_to_load < 1 || num_layers_to_load > (int)cfg_.num_hidden_layers) { fprintf(stderr, "runner: invalid num_layers %d (max %ld)\n", num_layers_to_load, cfg_.num_hidden_layers); return false; } num_layers_ = num_layers_to_load; max_seq_ = max_seq; if (!st_.open(model_dir)) return false; rt_.init(device_id); // HCCL init (no-op if tp_size == 1) if (!hccl_init(hccl_ctx_, tp_size, tp_rank)) { fprintf(stderr, "runner: HCCL init failed\n"); return false; } DeviceWeightsLoader dw(st_, cfg_); printf("runner: loading shared weights (embed, lm_head, final_norm)...\n"); if (!dw.load_shared(shared_)) return false; attn_.resize(num_layers_); moe_.resize(num_layers_); k_cache_.resize(num_layers_); v_cache_.resize(num_layers_); const int64_t KV_DIM = cfg_.n_kv_heads_per_rank * cfg_.head_dim; for (int L = 0; L < num_layers_; L++) { printf("runner: loading layer %d/%d...\n", L + 1, num_layers_); if (!dw.load_attention(L, attn_[L])) return false; if (!dw.load_moe(L, rt_.stream(), moe_[L])) return false; k_cache_[L].alloc(max_seq_ * KV_DIM * 2); v_cache_[L].alloc(max_seq_ * KV_DIM * 2); } rt_.sync(); // Prefill mask (2048x2048 bool causal) const int64_t MASK = 2048; std::vector mh(MASK * MASK, 0); for (int i = 0; i < MASK; i++) for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1; prefill_mask_dev_.alloc(MASK * MASK); ACL_CHECK(aclrtMemcpy(prefill_mask_dev_.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE)); // Pre-compute RoPE cos/sin table once (covers all positions up to max_seq_) rope_cache_build(rope_cache_, max_seq_, cfg_.head_dim, cfg_.rope_theta); past_len_ = 0; cur_S_capacity_ = 0; return true; } static void ensure_sc_(DeviceBuffer& buf, size_t needed) { if (buf.size < needed) buf.alloc(needed); } static void ensure_all_scratch_(Runner* self, int64_t S, const ModelConfig& cfg, DeviceBuffer& q_sc, DeviceBuffer& k_sc, DeviceBuffer& v_sc, DeviceBuffer& xn_sc, DeviceBuffer& rstd_sc, DeviceBuffer& rope_sc, DeviceBuffer& attn_fias_sc, DeviceBuffer& attn_out_sc, DeviceBuffer& moe_xn, DeviceBuffer& moe_rstd, DeviceBuffer& moe_logits, DeviceBuffer& moe_topk_w, DeviceBuffer& moe_topk_idx, DeviceBuffer& moe_row_idx, DeviceBuffer& moe_ex_x, DeviceBuffer& moe_ex_ri, DeviceBuffer& moe_tpe, DeviceBuffer& moe_fwd, DeviceBuffer& moe_gate, DeviceBuffer& moe_up, DeviceBuffer& moe_down, DeviceBuffer& moe_packed, DeviceBuffer& moe_weighted, DeviceBuffer& moe_out, DeviceBuffer& moe_norm_sum, DeviceBuffer& x_buf_a, DeviceBuffer& x_buf_b) { (void)self; const int64_t D = cfg.hidden_size; const int64_t Hq = cfg.n_heads_per_rank, Hkv = cfg.n_kv_heads_per_rank; const int64_t Dh = cfg.head_dim; const int64_t Q_DIM = Hq * Dh; const int64_t KV_DIM = Hkv * Dh; const int64_t I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok; const int64_t TOTAL = S * K; ensure_sc_(q_sc, S * Q_DIM * 2); ensure_sc_(k_sc, S * KV_DIM * 2); ensure_sc_(v_sc, S * KV_DIM * 2); ensure_sc_(xn_sc, S * D * 2); ensure_sc_(rstd_sc, S * std::max(Hq, Hkv) * 4); ensure_sc_(rope_sc, 1 * S * Hq * Dh * 2); ensure_sc_(attn_fias_sc, S * Q_DIM * 2); ensure_sc_(attn_out_sc, S * D * 2); ensure_sc_(moe_xn, S * D * 2); ensure_sc_(moe_rstd, S * 4); ensure_sc_(moe_logits, S * E * 2); ensure_sc_(moe_topk_w, S * K * 2); ensure_sc_(moe_topk_idx, S * K * 4); ensure_sc_(moe_row_idx, S * K * 4); ensure_sc_(moe_ex_x, TOTAL * D * 2); ensure_sc_(moe_ex_ri, TOTAL * 4); ensure_sc_(moe_tpe, E * 8); ensure_sc_(moe_fwd, TOTAL * 8); ensure_sc_(moe_gate, TOTAL * I * 2); ensure_sc_(moe_up, TOTAL * I * 2); ensure_sc_(moe_down, TOTAL * D * 2); ensure_sc_(moe_packed, TOTAL * D * 2); ensure_sc_(moe_weighted, S * K * D * 2); ensure_sc_(moe_out, S * D * 2); ensure_sc_(moe_norm_sum, S * 2); ensure_sc_(x_buf_a, S * D * 2); ensure_sc_(x_buf_b, S * D * 2); } void Runner::layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode) { const int64_t D = cfg_.hidden_size; // Attention mask selection: // prefill (S>1, past=0): 2048×2048 upper-tri + sparse_mode=3 (FIAS internal causal) // decode (S==1): mask=nullptr + sparse_mode=0 (single query sees all cache) // batch decode (S>1, past>0): S × (past+S) causal-with-past + sparse_mode=0 aclTensor* mask = nullptr; int64_t sparse_mode = -1; // auto AclTensorPtr t_mask_ptr; if (batch_decode_mode) { build_batch_decode_mask_(S); int64_t kv_len = past_len_ + S; t_mask_ptr = make_contig_tensor(batch_mask_dev_.get(), ACL_BOOL, {1, 1, S, kv_len}); mask = t_mask_ptr.get(); sparse_mode = 0; } else if (S > 1) { // Pure prefill from past=0 t_mask_ptr = make_contig_tensor(prefill_mask_dev_.get(), ACL_BOOL, {1, 1, 2048, 2048}); mask = t_mask_ptr.get(); sparse_mode = 3; } // else: S=1 decode, mask=nullptr, sparse_mode=0 (auto) attention_forward( rt_.stream(), cfg_, attn_[layer_idx], x_in, S, past_len_, k_cache_[layer_idx].get(), v_cache_[layer_idx].get(), max_seq_, mask, q_sc_.get(), k_sc_.get(), v_sc_.get(), xn_sc_.get(), rstd_sc_.get(), rope_sc_.get(), attn_fias_sc_.get(), attn_out_sc_.get(), (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr, &rope_cache_, sparse_mode); // x1 = x_in + attn_out (residual) auto t_x_in = make_contig_tensor(x_in, ACL_BF16, {S, D}); auto t_attn_out= make_contig_tensor(attn_out_sc_.get(), ACL_BF16, {S, D}); auto t_x1 = make_contig_tensor(x_buf_a_.get(), ACL_BF16, {S, D}); { float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT); uint64_t ws = 0; aclOpExecutor* e = nullptr; ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x_in.get(), t_attn_out.get(), al, t_x1.get(), &ws, &e)); DeviceBuffer wb; if (ws > 0) wb.alloc(ws); ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream())); aclDestroyScalar(al); } // MoE moe_forward( rt_.stream(), cfg_, attn_[layer_idx], moe_[layer_idx], x_buf_a_.get(), S, moe_xn_.get(), moe_rstd_.get(), moe_logits_.get(), moe_topk_w_.get(), moe_topk_idx_.get(), moe_row_idx_.get(), moe_ex_x_.get(), moe_ex_ri_.get(), moe_tpe_.get(), moe_fwd_.get(), moe_gate_.get(), moe_up_.get(), moe_down_.get(), moe_packed_.get(), moe_weighted_.get(), moe_out_.get(), (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr, moe_norm_sum_.get()); // x_out = x1 + moe_out (residual) auto t_moe_out = make_contig_tensor(moe_out_.get(), ACL_BF16, {S, D}); auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); { float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT); uint64_t ws = 0; aclOpExecutor* e = nullptr; ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x1.get(), t_moe_out.get(), al, t_out.get(), &ws, &e)); DeviceBuffer wb; if (ws > 0) wb.alloc(ws); ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream())); aclDestroyScalar(al); } } void Runner::final_logits_(void* hidden_last, DeviceBuffer& logits_out) { // Single-position variant: hidden_last is [1, D], output [1, V]. final_logits_batch_(hidden_last, 1, logits_out); } void Runner::final_logits_batch_(void* hidden, int64_t S, DeviceBuffer& logits_out) { const int64_t D = cfg_.hidden_size; const int64_t V = cfg_.vocab_size; DeviceBuffer hn(S * D * 2), rstd(S * 4); auto t_h = make_contig_tensor(hidden, ACL_BF16, {S, D}); auto t_hn = make_contig_tensor(hn.get(), ACL_BF16, {S, D}); auto t_lnw = make_contig_tensor(shared_.final_norm.get(), ACL_BF16, {D}); auto t_rstd = make_contig_tensor(rstd.get(), ACL_FLOAT, {S}); rms_norm(rt_.stream(), t_h.get(), t_lnw.get(), cfg_.rms_norm_eps, t_hn.get(), t_rstd.get()); logits_out.alloc(S * V * 2); auto t_logits = make_contig_tensor(logits_out.get(), ACL_BF16, {S, V}); linear_hf(rt_.stream(), t_hn.get(), shared_.lm_head.get(), ACL_BF16, V, D, t_logits.get()); } bool Runner::decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out) { if (S < 1) return false; if (past_len_ + S > max_seq_) { fprintf(stderr, "runner: decode_batch exceeds max_seq (%ld + %ld > %ld)\n", past_len_, S, max_seq_); return false; } const int64_t D = cfg_.hidden_size; ensure_all_scratch_(this, S, cfg_, q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, moe_xn_, moe_rstd_, moe_logits_, moe_topk_w_, moe_topk_idx_, moe_row_idx_, moe_ex_x_, moe_ex_ri_, moe_tpe_, moe_fwd_, moe_gate_, moe_up_, moe_down_, moe_packed_, moe_weighted_, moe_out_, moe_norm_sum_, x_buf_a_, x_buf_b_); // Embed S tokens DeviceBuffer tok_dev(S * 4); ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S}); auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); DeviceBuffer x0(S * D * 2); auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D}); index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); DeviceBuffer xping(S * D * 2), xpong(S * D * 2); ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); void* cur_in = xping.get(); void* cur_out = xpong.get(); // batch_decode_mode=true uses proper causal-with-past mask (S × past+S, sparse_mode=0). for (int L = 0; L < num_layers_; L++) { layer_forward_(L, S, cur_in, cur_out, /*batch_decode_mode=*/past_len_ > 0); std::swap(cur_in, cur_out); } rt_.sync(); // Get logits for ALL S positions (not just last) final_logits_batch_(cur_in, S, all_logits_out); rt_.sync(); past_len_ += S; return true; } bool Runner::prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out) { if (S < 1) return false; if (past_len_ + S > max_seq_) { fprintf(stderr, "runner: prefill exceeds max_seq (%ld + %ld > %ld)\n", past_len_, S, max_seq_); return false; } const int64_t D = cfg_.hidden_size; ensure_all_scratch_(this, S, cfg_, q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, moe_xn_, moe_rstd_, moe_logits_, moe_topk_w_, moe_topk_idx_, moe_row_idx_, moe_ex_x_, moe_ex_ri_, moe_tpe_, moe_fwd_, moe_gate_, moe_up_, moe_down_, moe_packed_, moe_weighted_, moe_out_, moe_norm_sum_, x_buf_a_, x_buf_b_); // Embed DeviceBuffer tok_dev(S * 4); ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S}); auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); DeviceBuffer x0(S * D * 2); auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D}); index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); // Layer chain: ping-pong between two buffers DeviceBuffer xping(S * D * 2), xpong(S * D * 2); ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); void* cur_in = xping.get(); void* cur_out = xpong.get(); for (int L = 0; L < num_layers_; L++) { layer_forward_(L, S, cur_in, cur_out); std::swap(cur_in, cur_out); } rt_.sync(); // Take last position's hidden → final_logits DeviceBuffer last(1 * D * 2); ACL_CHECK(aclrtMemcpy(last.get(), 1*D*2, (char*)cur_in + (S - 1) * D * 2, 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE)); final_logits_(last.get(), logits_out); rt_.sync(); past_len_ += S; return true; } bool Runner::decode(int32_t token, DeviceBuffer& logits_out) { const int64_t D = cfg_.hidden_size; if (past_len_ + 1 > max_seq_) { fprintf(stderr, "runner: decode exceeds max_seq\n"); return false; } const int64_t S = 1; ensure_all_scratch_(this, S, cfg_, q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, moe_xn_, moe_rstd_, moe_logits_, moe_topk_w_, moe_topk_idx_, moe_row_idx_, moe_ex_x_, moe_ex_ri_, moe_tpe_, moe_fwd_, moe_gate_, moe_up_, moe_down_, moe_packed_, moe_weighted_, moe_out_, moe_norm_sum_, x_buf_a_, x_buf_b_); DeviceBuffer tok_dev(1 * 4); ACL_CHECK(aclrtMemcpy(tok_dev.get(), 4, &token, 4, ACL_MEMCPY_HOST_TO_DEVICE)); auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {1}); auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); auto t0 = std::chrono::steady_clock::now(); DeviceBuffer x0(1 * D * 2); auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {1, D}); index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); DeviceBuffer xping(1 * D * 2), xpong(1 * D * 2); ACL_CHECK(aclrtMemcpyAsync(xping.get(), 1*D*2, x0.get(), 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); if (profile_enabled) { ACL_CHECK(aclrtSynchronizeStream(rt_.stream())); } auto t1 = std::chrono::steady_clock::now(); void* cur_in = xping.get(); void* cur_out = xpong.get(); for (int L = 0; L < num_layers_; L++) { layer_forward_(L, 1, cur_in, cur_out); std::swap(cur_in, cur_out); } rt_.sync(); auto t2 = std::chrono::steady_clock::now(); final_logits_(cur_in, logits_out); rt_.sync(); auto t3 = std::chrono::steady_clock::now(); if (profile_enabled) { using ms = std::chrono::duration; t_embed_ms += ms(t1 - t0).count(); t_layers_ms += ms(t2 - t1).count(); t_final_ms += ms(t3 - t2).count(); profile_calls++; } past_len_ += 1; return true; } void Runner::build_batch_decode_mask_(int64_t S) { int64_t kv_len = past_len_ + S; size_t bytes = (size_t)S * kv_len; // bool = 1 byte if (batch_mask_dev_.size < bytes) batch_mask_dev_.alloc(bytes); std::vector h_mask(bytes, 0); for (int64_t i = 0; i < S; i++) { // Row i: positions j ≤ past_len_+i are visible (0), j > past_len_+i are masked (1). for (int64_t j = past_len_ + i + 1; j < kv_len; j++) { h_mask[i * kv_len + j] = 1; } } ACL_CHECK(aclrtMemcpy(batch_mask_dev_.get(), bytes, h_mask.data(), bytes, ACL_MEMCPY_HOST_TO_DEVICE)); } void Runner::warmup(int iterations) { if (num_layers_ == 0) return; int64_t saved_past = past_len_; past_len_ = 0; int32_t dummy_tok = 0; // token id 0, valid for Qwen3 (bos) DeviceBuffer dummy_logits; for (int i = 0; i < iterations; i++) { past_len_ = 0; if (!decode(dummy_tok, dummy_logits)) break; } past_len_ = saved_past; fprintf(stderr, "[runner] warmup: %d iterations done\n", iterations); } void Runner::print_profile_summary() const { if (!profile_enabled || profile_calls == 0) return; double total = t_embed_ms + t_layers_ms + t_final_ms; fprintf(stderr, "\n=== Runner profile (%ld decode calls) ===\n", profile_calls); fprintf(stderr, " phase total_ms avg_ms/call pct\n"); fprintf(stderr, " embed %8.1f %10.3f %5.1f%%\n", t_embed_ms, t_embed_ms / profile_calls, 100.0 * t_embed_ms / total); fprintf(stderr, " layers (x%d) %8.1f %10.3f %5.1f%% → %.3f ms/layer/call\n", num_layers_, t_layers_ms, t_layers_ms / profile_calls, 100.0 * t_layers_ms / total, t_layers_ms / profile_calls / num_layers_); fprintf(stderr, " final+lm_hd %8.1f %10.3f %5.1f%%\n", t_final_ms, t_final_ms / profile_calls, 100.0 * t_final_ms / total); fprintf(stderr, " total %8.1f %10.3f 100.0%%\n", total, total / profile_calls); }