| #include "runner.h" |
|
|
| #include <chrono> |
| #include <cstdio> |
| #include <cstring> |
|
|
| |
| HcclCtx* runner_hccl_ctx_shim(Runner& r) { return &r.hccl_ctx(); } |
|
|
| bool Runner::init(const std::string& model_dir, int tp_size, int tp_rank, |
| int num_layers_to_load, int64_t max_seq, int device_id) { |
| if (!cfg_.load_from_json(model_dir + "/config.json")) return false; |
| cfg_.compute_derived(tp_size, tp_rank); |
| if (num_layers_to_load < 1 || num_layers_to_load > (int)cfg_.num_hidden_layers) { |
| fprintf(stderr, "runner: invalid num_layers %d (max %ld)\n", |
| num_layers_to_load, cfg_.num_hidden_layers); |
| return false; |
| } |
| num_layers_ = num_layers_to_load; |
| max_seq_ = max_seq; |
|
|
| if (!st_.open(model_dir)) return false; |
| rt_.init(device_id); |
|
|
| |
| if (!hccl_init(hccl_ctx_, tp_size, tp_rank)) { |
| fprintf(stderr, "runner: HCCL init failed\n"); |
| return false; |
| } |
|
|
| DeviceWeightsLoader dw(st_, cfg_); |
| printf("runner: loading shared weights (embed, lm_head, final_norm)...\n"); |
| if (!dw.load_shared(shared_)) return false; |
|
|
| attn_.resize(num_layers_); |
| moe_.resize(num_layers_); |
| k_cache_.resize(num_layers_); |
| v_cache_.resize(num_layers_); |
|
|
| const int64_t KV_DIM = cfg_.n_kv_heads_per_rank * cfg_.head_dim; |
| for (int L = 0; L < num_layers_; L++) { |
| printf("runner: loading layer %d/%d...\n", L + 1, num_layers_); |
| if (!dw.load_attention(L, attn_[L])) return false; |
| if (!dw.load_moe(L, rt_.stream(), moe_[L])) return false; |
| k_cache_[L].alloc(max_seq_ * KV_DIM * 2); |
| v_cache_[L].alloc(max_seq_ * KV_DIM * 2); |
| } |
| rt_.sync(); |
|
|
| |
| const int64_t MASK = 2048; |
| std::vector<uint8_t> mh(MASK * MASK, 0); |
| for (int i = 0; i < MASK; i++) |
| for (int j = i+1; j < MASK; j++) mh[i*MASK + j] = 1; |
| prefill_mask_dev_.alloc(MASK * MASK); |
| ACL_CHECK(aclrtMemcpy(prefill_mask_dev_.get(), MASK*MASK, mh.data(), MASK*MASK, ACL_MEMCPY_HOST_TO_DEVICE)); |
|
|
| |
| rope_cache_build(rope_cache_, max_seq_, cfg_.head_dim, cfg_.rope_theta); |
|
|
| past_len_ = 0; |
| cur_S_capacity_ = 0; |
| return true; |
| } |
|
|
| static void ensure_sc_(DeviceBuffer& buf, size_t needed) { |
| if (buf.size < needed) buf.alloc(needed); |
| } |
|
|
| static void ensure_all_scratch_(Runner* self, int64_t S, const ModelConfig& cfg, |
| DeviceBuffer& q_sc, DeviceBuffer& k_sc, DeviceBuffer& v_sc, |
| DeviceBuffer& xn_sc, DeviceBuffer& rstd_sc, DeviceBuffer& rope_sc, |
| DeviceBuffer& attn_fias_sc, DeviceBuffer& attn_out_sc, |
| DeviceBuffer& moe_xn, DeviceBuffer& moe_rstd, DeviceBuffer& moe_logits, |
| DeviceBuffer& moe_topk_w, DeviceBuffer& moe_topk_idx, DeviceBuffer& moe_row_idx, |
| DeviceBuffer& moe_ex_x, DeviceBuffer& moe_ex_ri, DeviceBuffer& moe_tpe, |
| DeviceBuffer& moe_fwd, |
| DeviceBuffer& moe_gate, DeviceBuffer& moe_up, DeviceBuffer& moe_down, |
| DeviceBuffer& moe_packed, DeviceBuffer& moe_weighted, DeviceBuffer& moe_out, |
| DeviceBuffer& moe_norm_sum, |
| DeviceBuffer& x_buf_a, DeviceBuffer& x_buf_b) { |
| (void)self; |
| const int64_t D = cfg.hidden_size; |
| const int64_t Hq = cfg.n_heads_per_rank, Hkv = cfg.n_kv_heads_per_rank; |
| const int64_t Dh = cfg.head_dim; |
| const int64_t Q_DIM = Hq * Dh; |
| const int64_t KV_DIM = Hkv * Dh; |
| const int64_t I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok; |
| const int64_t TOTAL = S * K; |
|
|
| ensure_sc_(q_sc, S * Q_DIM * 2); |
| ensure_sc_(k_sc, S * KV_DIM * 2); |
| ensure_sc_(v_sc, S * KV_DIM * 2); |
| ensure_sc_(xn_sc, S * D * 2); |
| ensure_sc_(rstd_sc, S * std::max(Hq, Hkv) * 4); |
| ensure_sc_(rope_sc, 1 * S * Hq * Dh * 2); |
| ensure_sc_(attn_fias_sc, S * Q_DIM * 2); |
| ensure_sc_(attn_out_sc, S * D * 2); |
|
|
| ensure_sc_(moe_xn, S * D * 2); |
| ensure_sc_(moe_rstd, S * 4); |
| ensure_sc_(moe_logits, S * E * 2); |
| ensure_sc_(moe_topk_w, S * K * 2); |
| ensure_sc_(moe_topk_idx, S * K * 4); |
| ensure_sc_(moe_row_idx, S * K * 4); |
| ensure_sc_(moe_ex_x, TOTAL * D * 2); |
| ensure_sc_(moe_ex_ri, TOTAL * 4); |
| ensure_sc_(moe_tpe, E * 8); |
| ensure_sc_(moe_fwd, TOTAL * 8); |
| ensure_sc_(moe_gate, TOTAL * I * 2); |
| ensure_sc_(moe_up, TOTAL * I * 2); |
| ensure_sc_(moe_down, TOTAL * D * 2); |
| ensure_sc_(moe_packed, TOTAL * D * 2); |
| ensure_sc_(moe_weighted, S * K * D * 2); |
| ensure_sc_(moe_out, S * D * 2); |
| ensure_sc_(moe_norm_sum, S * 2); |
|
|
| ensure_sc_(x_buf_a, S * D * 2); |
| ensure_sc_(x_buf_b, S * D * 2); |
| } |
|
|
| void Runner::layer_forward_(int layer_idx, int64_t S, void* x_in, void* x_out, bool batch_decode_mode) { |
| const int64_t D = cfg_.hidden_size; |
|
|
| |
| |
| |
| |
| aclTensor* mask = nullptr; |
| int64_t sparse_mode = -1; |
| AclTensorPtr t_mask_ptr; |
| if (batch_decode_mode) { |
| build_batch_decode_mask_(S); |
| int64_t kv_len = past_len_ + S; |
| t_mask_ptr = make_contig_tensor(batch_mask_dev_.get(), ACL_BOOL, {1, 1, S, kv_len}); |
| mask = t_mask_ptr.get(); |
| sparse_mode = 0; |
| } else if (S > 1) { |
| |
| t_mask_ptr = make_contig_tensor(prefill_mask_dev_.get(), ACL_BOOL, {1, 1, 2048, 2048}); |
| mask = t_mask_ptr.get(); |
| sparse_mode = 3; |
| } |
| |
|
|
| attention_forward( |
| rt_.stream(), cfg_, attn_[layer_idx], |
| x_in, S, past_len_, |
| k_cache_[layer_idx].get(), v_cache_[layer_idx].get(), max_seq_, |
| mask, |
| q_sc_.get(), k_sc_.get(), v_sc_.get(), |
| xn_sc_.get(), rstd_sc_.get(), rope_sc_.get(), |
| attn_fias_sc_.get(), |
| attn_out_sc_.get(), |
| (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr, |
| &rope_cache_, |
| sparse_mode); |
|
|
| |
| auto t_x_in = make_contig_tensor(x_in, ACL_BF16, {S, D}); |
| auto t_attn_out= make_contig_tensor(attn_out_sc_.get(), ACL_BF16, {S, D}); |
| auto t_x1 = make_contig_tensor(x_buf_a_.get(), ACL_BF16, {S, D}); |
| { |
| float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT); |
| uint64_t ws = 0; aclOpExecutor* e = nullptr; |
| ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x_in.get(), t_attn_out.get(), al, t_x1.get(), &ws, &e)); |
| DeviceBuffer wb; if (ws > 0) wb.alloc(ws); |
| ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream())); |
| aclDestroyScalar(al); |
| } |
|
|
| |
| moe_forward( |
| rt_.stream(), cfg_, attn_[layer_idx], moe_[layer_idx], |
| x_buf_a_.get(), S, |
| moe_xn_.get(), moe_rstd_.get(), |
| moe_logits_.get(), |
| moe_topk_w_.get(), moe_topk_idx_.get(), moe_row_idx_.get(), |
| moe_ex_x_.get(), moe_ex_ri_.get(), moe_tpe_.get(), |
| moe_fwd_.get(), |
| moe_gate_.get(), moe_up_.get(), moe_down_.get(), |
| moe_packed_.get(), moe_weighted_.get(), |
| moe_out_.get(), |
| (hccl_ctx_.tp_size > 1) ? &hccl_ctx_ : nullptr, |
| moe_norm_sum_.get()); |
|
|
| |
| auto t_moe_out = make_contig_tensor(moe_out_.get(), ACL_BF16, {S, D}); |
| auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); |
| { |
| float a = 1.0f; aclScalar* al = aclCreateScalar(&a, ACL_FLOAT); |
| uint64_t ws = 0; aclOpExecutor* e = nullptr; |
| ACLNN_CHECK(aclnnAddGetWorkspaceSize(t_x1.get(), t_moe_out.get(), al, t_out.get(), &ws, &e)); |
| DeviceBuffer wb; if (ws > 0) wb.alloc(ws); |
| ACLNN_CHECK(aclnnAdd(wb.get(), ws, e, rt_.stream())); |
| aclDestroyScalar(al); |
| } |
| } |
|
|
| void Runner::final_logits_(void* hidden_last, DeviceBuffer& logits_out) { |
| |
| final_logits_batch_(hidden_last, 1, logits_out); |
| } |
|
|
| void Runner::final_logits_batch_(void* hidden, int64_t S, DeviceBuffer& logits_out) { |
| const int64_t D = cfg_.hidden_size; |
| const int64_t V = cfg_.vocab_size; |
|
|
| DeviceBuffer hn(S * D * 2), rstd(S * 4); |
| auto t_h = make_contig_tensor(hidden, ACL_BF16, {S, D}); |
| auto t_hn = make_contig_tensor(hn.get(), ACL_BF16, {S, D}); |
| auto t_lnw = make_contig_tensor(shared_.final_norm.get(), ACL_BF16, {D}); |
| auto t_rstd = make_contig_tensor(rstd.get(), ACL_FLOAT, {S}); |
| rms_norm(rt_.stream(), t_h.get(), t_lnw.get(), cfg_.rms_norm_eps, t_hn.get(), t_rstd.get()); |
|
|
| logits_out.alloc(S * V * 2); |
| auto t_logits = make_contig_tensor(logits_out.get(), ACL_BF16, {S, V}); |
| linear_hf(rt_.stream(), t_hn.get(), shared_.lm_head.get(), ACL_BF16, V, D, t_logits.get()); |
| } |
|
|
| bool Runner::decode_batch(const int32_t* tokens, int64_t S, DeviceBuffer& all_logits_out) { |
| if (S < 1) return false; |
| if (past_len_ + S > max_seq_) { |
| fprintf(stderr, "runner: decode_batch exceeds max_seq (%ld + %ld > %ld)\n", |
| past_len_, S, max_seq_); |
| return false; |
| } |
| const int64_t D = cfg_.hidden_size; |
|
|
| ensure_all_scratch_(this, S, cfg_, |
| q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, |
| moe_xn_, moe_rstd_, moe_logits_, |
| moe_topk_w_, moe_topk_idx_, moe_row_idx_, |
| moe_ex_x_, moe_ex_ri_, moe_tpe_, |
| moe_fwd_, |
| moe_gate_, moe_up_, moe_down_, |
| moe_packed_, moe_weighted_, moe_out_, |
| moe_norm_sum_, |
| x_buf_a_, x_buf_b_); |
|
|
| |
| DeviceBuffer tok_dev(S * 4); |
| ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S}); |
| auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); |
|
|
| DeviceBuffer x0(S * D * 2); |
| auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D}); |
| index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); |
|
|
| DeviceBuffer xping(S * D * 2), xpong(S * D * 2); |
| ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); |
| void* cur_in = xping.get(); |
| void* cur_out = xpong.get(); |
| |
| for (int L = 0; L < num_layers_; L++) { |
| layer_forward_(L, S, cur_in, cur_out, past_len_ > 0); |
| std::swap(cur_in, cur_out); |
| } |
| rt_.sync(); |
|
|
| |
| final_logits_batch_(cur_in, S, all_logits_out); |
| rt_.sync(); |
|
|
| past_len_ += S; |
| return true; |
| } |
|
|
| bool Runner::prefill(const int32_t* tokens, int64_t S, DeviceBuffer& logits_out) { |
| if (S < 1) return false; |
| if (past_len_ + S > max_seq_) { |
| fprintf(stderr, "runner: prefill exceeds max_seq (%ld + %ld > %ld)\n", |
| past_len_, S, max_seq_); |
| return false; |
| } |
|
|
| const int64_t D = cfg_.hidden_size; |
| ensure_all_scratch_(this, S, cfg_, |
| q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, |
| moe_xn_, moe_rstd_, moe_logits_, |
| moe_topk_w_, moe_topk_idx_, moe_row_idx_, |
| moe_ex_x_, moe_ex_ri_, moe_tpe_, |
| moe_fwd_, |
| moe_gate_, moe_up_, moe_down_, |
| moe_packed_, moe_weighted_, moe_out_, |
| moe_norm_sum_, |
| x_buf_a_, x_buf_b_); |
|
|
| |
| DeviceBuffer tok_dev(S * 4); |
| ACL_CHECK(aclrtMemcpy(tok_dev.get(), S*4, tokens, S*4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {S}); |
| auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); |
|
|
| DeviceBuffer x0(S * D * 2); |
| auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {S, D}); |
| index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); |
|
|
| |
| DeviceBuffer xping(S * D * 2), xpong(S * D * 2); |
| ACL_CHECK(aclrtMemcpyAsync(xping.get(), S*D*2, x0.get(), S*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); |
|
|
| void* cur_in = xping.get(); |
| void* cur_out = xpong.get(); |
| for (int L = 0; L < num_layers_; L++) { |
| layer_forward_(L, S, cur_in, cur_out); |
| std::swap(cur_in, cur_out); |
| } |
| rt_.sync(); |
|
|
| |
| DeviceBuffer last(1 * D * 2); |
| ACL_CHECK(aclrtMemcpy(last.get(), 1*D*2, |
| (char*)cur_in + (S - 1) * D * 2, 1*D*2, |
| ACL_MEMCPY_DEVICE_TO_DEVICE)); |
| final_logits_(last.get(), logits_out); |
| rt_.sync(); |
|
|
| past_len_ += S; |
| return true; |
| } |
|
|
| bool Runner::decode(int32_t token, DeviceBuffer& logits_out) { |
| const int64_t D = cfg_.hidden_size; |
| if (past_len_ + 1 > max_seq_) { |
| fprintf(stderr, "runner: decode exceeds max_seq\n"); |
| return false; |
| } |
|
|
| const int64_t S = 1; |
| ensure_all_scratch_(this, S, cfg_, |
| q_sc_, k_sc_, v_sc_, xn_sc_, rstd_sc_, rope_sc_, attn_fias_sc_, attn_out_sc_, |
| moe_xn_, moe_rstd_, moe_logits_, |
| moe_topk_w_, moe_topk_idx_, moe_row_idx_, |
| moe_ex_x_, moe_ex_ri_, moe_tpe_, |
| moe_fwd_, |
| moe_gate_, moe_up_, moe_down_, |
| moe_packed_, moe_weighted_, moe_out_, |
| moe_norm_sum_, |
| x_buf_a_, x_buf_b_); |
|
|
| DeviceBuffer tok_dev(1 * 4); |
| ACL_CHECK(aclrtMemcpy(tok_dev.get(), 4, &token, 4, ACL_MEMCPY_HOST_TO_DEVICE)); |
| auto t_tok = make_contig_tensor(tok_dev.get(), ACL_INT32, {1}); |
| auto t_embed_w = make_contig_tensor(shared_.embed_tokens.get(), ACL_BF16, {cfg_.vocab_size, D}); |
|
|
| auto t0 = std::chrono::steady_clock::now(); |
|
|
| DeviceBuffer x0(1 * D * 2); |
| auto t_x0 = make_contig_tensor(x0.get(), ACL_BF16, {1, D}); |
| index_select(rt_.stream(), t_embed_w.get(), 0, t_tok.get(), t_x0.get()); |
|
|
| DeviceBuffer xping(1 * D * 2), xpong(1 * D * 2); |
| ACL_CHECK(aclrtMemcpyAsync(xping.get(), 1*D*2, x0.get(), 1*D*2, ACL_MEMCPY_DEVICE_TO_DEVICE, rt_.stream())); |
| if (profile_enabled) { ACL_CHECK(aclrtSynchronizeStream(rt_.stream())); } |
| auto t1 = std::chrono::steady_clock::now(); |
|
|
| void* cur_in = xping.get(); |
| void* cur_out = xpong.get(); |
| for (int L = 0; L < num_layers_; L++) { |
| layer_forward_(L, 1, cur_in, cur_out); |
| std::swap(cur_in, cur_out); |
| } |
| rt_.sync(); |
| auto t2 = std::chrono::steady_clock::now(); |
|
|
| final_logits_(cur_in, logits_out); |
| rt_.sync(); |
| auto t3 = std::chrono::steady_clock::now(); |
|
|
| if (profile_enabled) { |
| using ms = std::chrono::duration<double, std::milli>; |
| t_embed_ms += ms(t1 - t0).count(); |
| t_layers_ms += ms(t2 - t1).count(); |
| t_final_ms += ms(t3 - t2).count(); |
| profile_calls++; |
| } |
|
|
| past_len_ += 1; |
| return true; |
| } |
|
|
| void Runner::build_batch_decode_mask_(int64_t S) { |
| int64_t kv_len = past_len_ + S; |
| size_t bytes = (size_t)S * kv_len; |
| if (batch_mask_dev_.size < bytes) batch_mask_dev_.alloc(bytes); |
| std::vector<uint8_t> h_mask(bytes, 0); |
| for (int64_t i = 0; i < S; i++) { |
| |
| for (int64_t j = past_len_ + i + 1; j < kv_len; j++) { |
| h_mask[i * kv_len + j] = 1; |
| } |
| } |
| ACL_CHECK(aclrtMemcpy(batch_mask_dev_.get(), bytes, h_mask.data(), bytes, |
| ACL_MEMCPY_HOST_TO_DEVICE)); |
| } |
|
|
| void Runner::warmup(int iterations) { |
| if (num_layers_ == 0) return; |
| int64_t saved_past = past_len_; |
| past_len_ = 0; |
| int32_t dummy_tok = 0; |
| DeviceBuffer dummy_logits; |
| for (int i = 0; i < iterations; i++) { |
| past_len_ = 0; |
| if (!decode(dummy_tok, dummy_logits)) break; |
| } |
| past_len_ = saved_past; |
| fprintf(stderr, "[runner] warmup: %d iterations done\n", iterations); |
| } |
|
|
| void Runner::print_profile_summary() const { |
| if (!profile_enabled || profile_calls == 0) return; |
| double total = t_embed_ms + t_layers_ms + t_final_ms; |
| fprintf(stderr, "\n=== Runner profile (%ld decode calls) ===\n", profile_calls); |
| fprintf(stderr, " phase total_ms avg_ms/call pct\n"); |
| fprintf(stderr, " embed %8.1f %10.3f %5.1f%%\n", |
| t_embed_ms, t_embed_ms / profile_calls, 100.0 * t_embed_ms / total); |
| fprintf(stderr, " layers (x%d) %8.1f %10.3f %5.1f%% → %.3f ms/layer/call\n", |
| num_layers_, t_layers_ms, t_layers_ms / profile_calls, |
| 100.0 * t_layers_ms / total, |
| t_layers_ms / profile_calls / num_layers_); |
| fprintf(stderr, " final+lm_hd %8.1f %10.3f %5.1f%%\n", |
| t_final_ms, t_final_ms / profile_calls, 100.0 * t_final_ms / total); |
| fprintf(stderr, " total %8.1f %10.3f 100.0%%\n", |
| total, total / profile_calls); |
| } |
|
|