| |
| |
| |
| |
| |
| |
| |
| |
| |
| #pragma once |
| #include "acl_common.h" |
| #include "aclnn_ops.h" |
| #include "device_weights.h" |
| #include "hccl_comm.h" |
| #include "model_config.h" |
| #include "rope.h" |
|
|
| #include <algorithm> |
| #include <cmath> |
| #include <cstring> |
| #include <tuple> |
| #include <vector> |
|
|
| |
| static inline uint16_t _engine_f2bf16(float x) { |
| uint32_t u; std::memcpy(&u, &x, 4); |
| return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); |
| } |
|
|
| |
| |
| inline void fill_cos_sin_hf(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h, |
| int64_t p0, int64_t L, int64_t Dh, float theta) { |
| cos_h.resize(L * Dh); |
| sin_h.resize(L * Dh); |
| int64_t half = Dh / 2; |
| for (int64_t s = 0; s < L; s++) { |
| for (int64_t d = 0; d < Dh; d++) { |
| int64_t pair = (d < half) ? d : (d - half); |
| float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh); |
| float angle = (float)(p0 + s) * theta_pair; |
| cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle)); |
| sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle)); |
| } |
| } |
| } |
|
|
| |
| struct RopeCache { |
| DeviceBuffer cos; |
| DeviceBuffer sin; |
| int64_t max_seq = 0; |
| int64_t head_dim = 0; |
| float theta = 0.0f; |
| }; |
|
|
| inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) { |
| std::vector<uint16_t> cos_h, sin_h; |
| fill_cos_sin_hf(cos_h, sin_h, 0, max_seq, head_dim, theta); |
| rc.cos.alloc(max_seq * head_dim * 2); |
| rc.sin.alloc(max_seq * head_dim * 2); |
| ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta; |
| return true; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| inline void attention_forward( |
| aclrtStream stream, |
| const ModelConfig& cfg, |
| LayerAttnWeights& w, |
| void* x_in, |
| int64_t S, |
| int64_t past_len, |
| void* k_cache, void* v_cache, int64_t max_len, |
| aclTensor* mask_tensor, |
| void* q_scratch, void* k_scratch, void* v_scratch, |
| void* xn_scratch, void* rstd_scratch, void* rope_scratch, |
| void* attn_out_scratch, |
| void* x_out, |
| HcclCtx* hccl_ctx = nullptr, |
| const RopeCache* rope_cache = nullptr, |
| int64_t sparse_mode = -1 |
| ) { |
| const int64_t D = cfg.hidden_size; |
| const int64_t Hq = cfg.n_heads_per_rank; |
| const int64_t Hkv = cfg.n_kv_heads_per_rank; |
| const int64_t Dh = cfg.head_dim; |
| const int64_t Q_DIM = Hq * Dh; |
| const int64_t KV_DIM = Hkv * Dh; |
| const double scale = 1.0 / std::sqrt((double)Dh); |
| const double eps = cfg.rms_norm_eps; |
| const float theta = cfg.rope_theta; |
|
|
| |
| auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D}); |
| auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D}); |
| auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D}); |
| auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S}); |
| rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get()); |
|
|
| |
| auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM}); |
| auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM}); |
| auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM}); |
| linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM, D, t_q.get()); |
| linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get()); |
| linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get()); |
|
|
| |
| auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq, Dh}); |
| auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh}); |
| auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh}); |
| auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh}); |
| |
| |
| |
| auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq}); |
| auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv}); |
| rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get()); |
| rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get()); |
|
|
| |
| |
| if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) { |
| void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2; |
| void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2; |
| apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr); |
| } else { |
| std::vector<uint16_t> cos_h, sin_h; |
| fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta); |
| DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2); |
| ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, |
| cos_dev.get(), sin_dev.get(), rope_scratch); |
| |
| ACL_CHECK(aclrtSynchronizeStream(stream)); |
| } |
|
|
| |
| ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2, |
| k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream)); |
| ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2, |
| v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream)); |
|
|
| |
| int64_t kv_len = past_len + S; |
| auto t_q_bsh = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Q_DIM}); |
| auto t_k_bsh = make_contig_tensor(k_cache, ACL_BF16, {1, kv_len, KV_DIM}); |
| auto t_v_bsh = make_contig_tensor(v_cache, ACL_BF16, {1, kv_len, KV_DIM}); |
| |
| auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM}); |
| |
| |
| |
| |
| int64_t sparse = sparse_mode; |
| if (sparse < 0) { |
| sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0; |
| } |
| fused_infer_attention_score( |
| stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(), |
| mask_tensor, {S}, {kv_len}, |
| Hq, Hkv, scale, sparse, t_attn_out_bsh.get()); |
|
|
| |
| auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM}); |
| auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); |
| linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get()); |
|
|
| |
| if (hccl_ctx && hccl_ctx->tp_size > 1) { |
| hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream); |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| inline void moe_forward( |
| aclrtStream stream, |
| const ModelConfig& cfg, |
| LayerAttnWeights& attn_w, |
| LayerMoEWeights& w, |
| void* x_in, int64_t S, |
| void* xn_scratch, void* rstd_scratch, |
| void* logits_scratch, |
| void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch, |
| void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch, |
| void* fwd_scratch, |
| void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch, |
| void* packed_scratch, void* weighted_scratch, |
| void* x_out, |
| HcclCtx* hccl_ctx = nullptr, |
| void* norm_sum_scratch = nullptr |
| ) { |
| const int64_t D = cfg.hidden_size; |
| const int64_t I = cfg.i_per_rank; |
| const int64_t E = cfg.num_experts; |
| const int64_t K = cfg.num_experts_per_tok; |
| const double eps = cfg.rms_norm_eps; |
| const int64_t TOTAL = S * K; |
|
|
| |
| auto t_x = make_contig_tensor(x_in, ACL_BF16, {S, D}); |
| auto t_xn = make_contig_tensor(xn_scratch, ACL_BF16, {S, D}); |
| auto t_lnw = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D}); |
| auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S}); |
| rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get()); |
|
|
| |
| auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E}); |
| linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get()); |
|
|
| |
| auto t_topk_w = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K}); |
| auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K}); |
| auto t_row_idx = make_contig_tensor(row_idx_scratch, ACL_INT32, {S, K}); |
| moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get()); |
|
|
| |
| |
| |
| |
| |
| |
| if (norm_sum_scratch) { |
| auto t_sum = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S, 1}); |
| auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch, ACL_BF16, {S, 1}); |
| reduce_sum(stream, t_topk_w.get(), {-1}, true, ACL_FLOAT, t_sum.get()); |
| inplace_adds(stream, t_sum.get(), 1e-20); |
| cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get()); |
| div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get()); |
| } else { |
| |
| ACL_CHECK(aclrtSynchronizeStream(stream)); |
| std::vector<uint16_t> h_tw(S * K); |
| ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST)); |
| for (int s = 0; s < S; s++) { |
| float sum = 0; |
| for (int k = 0; k < K; k++) { |
| uint32_t u = (uint32_t)h_tw[s*K + k] << 16; |
| float v; std::memcpy(&v, &u, 4); |
| sum += v; |
| } |
| sum += 1e-20f; |
| for (int k = 0; k < K; k++) { |
| uint32_t u = (uint32_t)h_tw[s*K + k] << 16; |
| float v; std::memcpy(&v, &u, 4); |
| v /= sum; |
| std::memcpy(&u, &v, 4); |
| h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16); |
| } |
| } |
| ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE)); |
| } |
|
|
| |
| auto t_ex_x = make_contig_tensor(expanded_x_scratch, ACL_BF16, {TOTAL, D}); |
| auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL}); |
| auto t_tpe = make_contig_tensor(tpe_scratch, ACL_INT64, {E}); |
| moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(), |
| E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get()); |
|
|
| |
| auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I}); |
| auto t_up_out = make_contig_tensor(up_out_scratch, ACL_BF16, {TOTAL, I}); |
| auto t_w_gate = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I}); |
| auto t_w_up = make_contig_tensor(w.up_exps.get(), ACL_BF16, {E, D, I}); |
| grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1); |
| grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(), t_tpe.get(), t_up_out.get(), 1); |
|
|
| |
| silu(stream, t_gate_out.get(), t_gate_out.get()); |
| mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get()); |
|
|
| |
| auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D}); |
| auto t_w_down = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D}); |
| grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| { |
| auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch, ACL_INT32, {TOTAL}); |
| auto t_inv_fwd = make_contig_tensor(weighted_scratch, ACL_INT64, {TOTAL}); |
| auto t_fwd_64 = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL}); |
| argsort(stream, t_topk_idx_flat.get(), 0, false, t_inv_fwd.get()); |
| argsort(stream, t_inv_fwd.get(), 0, false, t_fwd_64.get()); |
| } |
| auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL}); |
| auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D}); |
| index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get()); |
|
|
| auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D}); |
| auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1}); |
| auto t_weighted = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D}); |
| mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get()); |
|
|
| auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D}); |
| reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get()); |
|
|
| |
| if (hccl_ctx && hccl_ctx->tp_size > 1) { |
| hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream); |
| } |
| } |
|
|