File size: 17,899 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
// engine.h β€” single-layer forward functions for attention and MoE.
//
// Both functions operate on device tensors. The caller owns all buffers (input, output, weights,
// KV cache slots, scratch). They take RoPE cos/sin tables and act as pure forward kernels.
//
// Design goals:
//   - Zero allocations per call (all scratch is passed in)
//   - Same signature works for prefill (S>=1) and decode (S=1); caller picks sparse_mode.
//   - Residual connection is NOT included (caller decides when to add residual).
#pragma once
#include "acl_common.h"
#include "aclnn_ops.h"
#include "device_weights.h"
#include "hccl_comm.h"
#include "model_config.h"
#include "rope.h"

#include <algorithm>
#include <cmath>
#include <cstring>
#include <tuple>
#include <vector>

// Bf16 conversion helpers used by fill_cos_sin.
static inline uint16_t _engine_f2bf16(float x) {
    uint32_t u; std::memcpy(&u, &x, 4);
    return (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
}

// Fill cos/sin tables for positions [p0, p0+L) with HF half-half layout. Returns
// contiguous [L*Dh] BF16 in provided host vectors (caller uploads to device).
inline void fill_cos_sin_hf(std::vector<uint16_t>& cos_h, std::vector<uint16_t>& sin_h,
                            int64_t p0, int64_t L, int64_t Dh, float theta) {
    cos_h.resize(L * Dh);
    sin_h.resize(L * Dh);
    int64_t half = Dh / 2;
    for (int64_t s = 0; s < L; s++) {
        for (int64_t d = 0; d < Dh; d++) {
            int64_t pair = (d < half) ? d : (d - half);
            float theta_pair = 1.0f / std::pow(theta, (2.0f * pair) / Dh);
            float angle = (float)(p0 + s) * theta_pair;
            cos_h[s * Dh + d] = _engine_f2bf16(std::cos(angle));
            sin_h[s * Dh + d] = _engine_f2bf16(std::sin(angle));
        }
    }
}

// Precomputed RoPE cos/sin table: BF16 [max_seq, Dh]. One-time cost per runtime.
struct RopeCache {
    DeviceBuffer cos;   // [max_seq, Dh] BF16
    DeviceBuffer sin;   // [max_seq, Dh] BF16
    int64_t max_seq = 0;
    int64_t head_dim = 0;
    float theta = 0.0f;
};

inline bool rope_cache_build(RopeCache& rc, int64_t max_seq, int64_t head_dim, float theta) {
    std::vector<uint16_t> cos_h, sin_h;
    fill_cos_sin_hf(cos_h, sin_h, /*p0=*/0, max_seq, head_dim, theta);
    rc.cos.alloc(max_seq * head_dim * 2);
    rc.sin.alloc(max_seq * head_dim * 2);
    ACL_CHECK(aclrtMemcpy(rc.cos.get(), cos_h.size() * 2, cos_h.data(), cos_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
    ACL_CHECK(aclrtMemcpy(rc.sin.get(), sin_h.size() * 2, sin_h.data(), sin_h.size() * 2, ACL_MEMCPY_HOST_TO_DEVICE));
    rc.max_seq = max_seq; rc.head_dim = head_dim; rc.theta = theta;
    return true;
}

// Attention forward for a single layer.
//
//   x_in  [S, D]    (hidden state, pre input_layernorm)
//   x_out [S, D]    (attention output β€” NOT residual-added)
//
// K cache / V cache are contiguous [MAX_LEN, KV_DIM] BF16 buffers. This call writes new
// positions at [past_len, past_len+S) and then runs FIAS over [0, past_len+S).
//
// Scratch requirements:
//   q_scratch:    S * Q_DIM * 2 bytes
//   k_scratch:    S * KV_DIM * 2 bytes
//   v_scratch:    S * KV_DIM * 2 bytes
//   xn_scratch:   S * D * 2 bytes
//   rstd_scratch: S * 4 bytes (RmsNorm rstd output)
//   rope_scratch: S * Hq * Dh * 2 bytes
//
// mask: [1, 1, 2048, 2048] bool for prefill; ignored (pass nullptr) for decode.
inline void attention_forward(
    aclrtStream stream,
    const ModelConfig& cfg,
    LayerAttnWeights& w,
    void* x_in,                 // [S, D] BF16
    int64_t S,
    int64_t past_len,           // prior KV positions
    void* k_cache, void* v_cache, int64_t max_len,
    aclTensor* mask_tensor,     // may be nullptr for decode
    void* q_scratch, void* k_scratch, void* v_scratch,
    void* xn_scratch, void* rstd_scratch, void* rope_scratch,
    void* attn_out_scratch,     // S * Q_DIM * 2 bytes (FIAS output before o_proj)
    void* x_out,                // [S, D] BF16
    HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce x_out after o_proj
    const RopeCache* rope_cache = nullptr,  // if provided, use cached cos/sin table; avoids per-call H2D
    int64_t sparse_mode = -1    // -1=auto (3 for prefill, 0 for decode); explicit 0/3 overrides
) {
    const int64_t D = cfg.hidden_size;
    const int64_t Hq = cfg.n_heads_per_rank;
    const int64_t Hkv = cfg.n_kv_heads_per_rank;
    const int64_t Dh = cfg.head_dim;
    const int64_t Q_DIM = Hq * Dh;
    const int64_t KV_DIM = Hkv * Dh;
    const double scale = 1.0 / std::sqrt((double)Dh);
    const double eps = cfg.rms_norm_eps;
    const float theta = cfg.rope_theta;

    // 1. Input layernorm: xn = rmsnorm(x_in, input_layernorm_weight)
    auto t_x   = make_contig_tensor(x_in,         ACL_BF16, {S, D});
    auto t_xn  = make_contig_tensor(xn_scratch,   ACL_BF16, {S, D});
    auto t_lnw = make_contig_tensor(w.input_layernorm.get(), ACL_BF16, {D});
    auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
    rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());

    // 2. Q/K/V projection
    auto t_q = make_contig_tensor(q_scratch, ACL_BF16, {S, Q_DIM});
    auto t_k = make_contig_tensor(k_scratch, ACL_BF16, {S, KV_DIM});
    auto t_v = make_contig_tensor(v_scratch, ACL_BF16, {S, KV_DIM});
    linear_hf(stream, t_xn.get(), w.q_proj.get(), ACL_BF16, Q_DIM,  D, t_q.get());
    linear_hf(stream, t_xn.get(), w.k_proj.get(), ACL_BF16, KV_DIM, D, t_k.get());
    linear_hf(stream, t_xn.get(), w.v_proj.get(), ACL_BF16, KV_DIM, D, t_v.get());

    // 3. Per-head q_norm, k_norm
    auto t_q_4d = make_contig_tensor(q_scratch, ACL_BF16, {1, S, Hq,  Dh});
    auto t_k_4d = make_contig_tensor(k_scratch, ACL_BF16, {1, S, Hkv, Dh});
    auto t_qn_w = make_contig_tensor(w.q_norm.get(), ACL_BF16, {Dh});
    auto t_kn_w = make_contig_tensor(w.k_norm.get(), ACL_BF16, {Dh});
    // reuse rstd_scratch split or allocate? reuse xn_scratch's first S*Hq*4 bytes.
    // Simpler: require rstd_scratch to have max(S, S*max(Hq,Hkv)) * 4 bytes.
    // For single-rank attention tests we pass enough.
    auto t_rstd_q = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hq});
    auto t_rstd_k = make_contig_tensor(rstd_scratch, ACL_FLOAT, {1, S, Hkv});
    rms_norm(stream, t_q_4d.get(), t_qn_w.get(), eps, t_q_4d.get(), t_rstd_q.get());
    rms_norm(stream, t_k_4d.get(), t_kn_w.get(), eps, t_k_4d.get(), t_rstd_k.get());

    // 4. RoPE: positions [past_len, past_len + S). Fused aclnnApplyRotaryPosEmbV2 is 1 op
    // vs 8-op manual version β€” saves ~7 kernel launches/layer Γ— 94 layers = 658/token.
    if (rope_cache && rope_cache->cos.get() && past_len + S <= rope_cache->max_seq) {
        void* cos_ptr = (char*)rope_cache->cos.get() + past_len * Dh * 2;
        void* sin_ptr = (char*)rope_cache->sin.get() + past_len * Dh * 2;
        apply_rope_fused(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv, cos_ptr, sin_ptr);
    } else {
        std::vector<uint16_t> cos_h, sin_h;
        fill_cos_sin_hf(cos_h, sin_h, past_len, S, Dh, theta);
        DeviceBuffer cos_dev(S * Dh * 2), sin_dev(S * Dh * 2);
        ACL_CHECK(aclrtMemcpy(cos_dev.get(), S*Dh*2, cos_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
        ACL_CHECK(aclrtMemcpy(sin_dev.get(), S*Dh*2, sin_h.data(), S*Dh*2, ACL_MEMCPY_HOST_TO_DEVICE));
        apply_rope_manual(stream, q_scratch, 1, S, Hq, Dh, k_scratch, Hkv,
                          cos_dev.get(), sin_dev.get(), rope_scratch);
        // Local DeviceBuffers would be freed on return while async kernels still read them.
        ACL_CHECK(aclrtSynchronizeStream(stream));
    }

    // 5. Append K, V to cache at [past_len, past_len + S)
    ACL_CHECK(aclrtMemcpyAsync((char*)k_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
                               k_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));
    ACL_CHECK(aclrtMemcpyAsync((char*)v_cache + past_len * KV_DIM * 2, S * KV_DIM * 2,
                               v_scratch, S * KV_DIM * 2, ACL_MEMCPY_DEVICE_TO_DEVICE, stream));

    // 6. FIAS: q [1, S, Q_DIM], k/v [1, kv_len, KV_DIM] from cache
    int64_t kv_len = past_len + S;
    auto t_q_bsh = make_contig_tensor(q_scratch,  ACL_BF16, {1, S,      Q_DIM});
    auto t_k_bsh = make_contig_tensor(k_cache,    ACL_BF16, {1, kv_len, KV_DIM});
    auto t_v_bsh = make_contig_tensor(v_cache,    ACL_BF16, {1, kv_len, KV_DIM});
    // FIAS writes to a separate buffer (attn_out_scratch) — aliasing q→out is unsafe.
    auto t_attn_out_bsh = make_contig_tensor(attn_out_scratch, ACL_BF16, {1, S, Q_DIM});
    // sparse_mode selection:
    //   3 = left-top causal (prefill, q.S == kv.S with 2048 mask)
    //   0 = user mask (decode with cache, batch verify)
    //  -1 (sentinel) = auto: 3 if mask given & past_len==0 & S>1 (prefill), else 0
    int64_t sparse = sparse_mode;
    if (sparse < 0) {
        sparse = (mask_tensor != nullptr && past_len == 0 && S > 1) ? 3 : 0;
    }
    fused_infer_attention_score(
        stream, t_q_bsh.get(), t_k_bsh.get(), t_v_bsh.get(),
        mask_tensor, {S}, {kv_len},
        Hq, Hkv, scale, sparse, t_attn_out_bsh.get());

    // 7. O projection: y = attn_out @ o_proj.T β†’ [S, D]
    auto t_attn_2d = make_contig_tensor(attn_out_scratch, ACL_BF16, {S, Q_DIM});
    auto t_out     = make_contig_tensor(x_out,     ACL_BF16, {S, D});
    linear_hf(stream, t_attn_2d.get(), w.o_proj.get(), ACL_BF16, D, Q_DIM, t_out.get());

    // 8. TP AllReduce on x_out (row-parallel o_proj β†’ SUM across ranks)
    if (hccl_ctx && hccl_ctx->tp_size > 1) {
        hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
    }
}

// MoE forward for a single layer. Residual NOT applied here.
//
//   x_in  [S, D]    (hidden state, pre post_attention_layernorm)
//   x_out [S, D]    (MoE output)
//
// Scratch:
//   xn_scratch:         S * D * 2
//   rstd_scratch:       S * 4
//   logits_scratch:     S * E * 2
//   topk_w_scratch:     S * K * 2
//   topk_idx_scratch:   S * K * 4
//   row_idx_scratch:    S * K * 4  (gating output unused)
//   expanded_x_scratch: TOTAL * D * 2
//   expanded_ri_scratch:TOTAL * 4
//   tpe_scratch:        E * 8
//   fwd_dev:            TOTAL * 8
//   gate_out_scratch:   TOTAL * I * 2
//   up_out_scratch:     TOTAL * I * 2
//   down_out_scratch:   TOTAL * D * 2
//   packed_scratch:     TOTAL * D * 2
//   weighted_scratch:   S * K * D * 2
//
// where TOTAL = S * K, I = cfg.i_per_rank, E = cfg.num_experts, K = cfg.num_experts_per_tok.
//
// IMPORTANT: post_attention_layernorm weight in `attn_w` (not in LayerMoEWeights).
inline void moe_forward(
    aclrtStream stream,
    const ModelConfig& cfg,
    LayerAttnWeights& attn_w,            // for post_attention_layernorm
    LayerMoEWeights& w,
    void* x_in, int64_t S,
    void* xn_scratch, void* rstd_scratch,
    void* logits_scratch,
    void* topk_w_scratch, void* topk_idx_scratch, void* row_idx_scratch,
    void* expanded_x_scratch, void* expanded_ri_scratch, void* tpe_scratch,
    void* fwd_scratch,
    void* gate_out_scratch, void* up_out_scratch, void* down_out_scratch,
    void* packed_scratch, void* weighted_scratch,
    void* x_out,
    HcclCtx* hccl_ctx = nullptr, // if tp_size > 1, AllReduce after reduce_sum
    void* norm_sum_scratch = nullptr  // S * 2 bytes β€” persistent buffer for topk_w normalize
) {
    const int64_t D = cfg.hidden_size;
    const int64_t I = cfg.i_per_rank;
    const int64_t E = cfg.num_experts;
    const int64_t K = cfg.num_experts_per_tok;
    const double eps = cfg.rms_norm_eps;
    const int64_t TOTAL = S * K;

    // 1. post_attention_layernorm
    auto t_x    = make_contig_tensor(x_in,         ACL_BF16, {S, D});
    auto t_xn   = make_contig_tensor(xn_scratch,   ACL_BF16, {S, D});
    auto t_lnw  = make_contig_tensor(attn_w.post_attention_layernorm.get(), ACL_BF16, {D});
    auto t_rstd = make_contig_tensor(rstd_scratch, ACL_FLOAT, {S});
    rms_norm(stream, t_x.get(), t_lnw.get(), eps, t_xn.get(), t_rstd.get());

    // 2. Router linear: logits = xn @ router.T β†’ [S, E]
    auto t_logits = make_contig_tensor(logits_scratch, ACL_BF16, {S, E});
    linear_hf(stream, t_xn.get(), w.router.get(), ACL_BF16, E, D, t_logits.get());

    // 3. TopK softmax
    auto t_topk_w   = make_contig_tensor(topk_w_scratch,   ACL_BF16,  {S, K});
    auto t_topk_idx = make_contig_tensor(topk_idx_scratch, ACL_INT32, {S, K});
    auto t_row_idx  = make_contig_tensor(row_idx_scratch,  ACL_INT32, {S, K});
    moe_gating_topk_softmax(stream, t_logits.get(), K, t_topk_w.get(), t_topk_idx.get(), t_row_idx.get());

    // 4. Device-side normalize topk weights (Qwen3 norm_topk_prob=true).
    //   sum = reduce_sum(topk_w, dim=-1, keepdim=true)   # [S, 1]  F32 in rstd_scratch
    //   sum += 1e-20
    //   sum_bf16 = cast(sum, BF16)                        # [S, 1]  in norm_sum_scratch (caller-owned)
    //   topk_w /= sum_bf16                                # broadcast divide
    // No per-layer syncs β€” all scratch buffers persist across layers.
    if (norm_sum_scratch) {
        auto t_sum      = make_contig_tensor(rstd_scratch,      ACL_FLOAT, {S, 1});
        auto t_sum_bf16 = make_contig_tensor(norm_sum_scratch,  ACL_BF16,  {S, 1});
        reduce_sum(stream, t_topk_w.get(), {-1}, /*keep_dims=*/true, ACL_FLOAT, t_sum.get());
        inplace_adds(stream, t_sum.get(), 1e-20);
        cast(stream, t_sum.get(), ACL_BF16, t_sum_bf16.get());
        div_tensor(stream, t_topk_w.get(), t_sum_bf16.get(), t_topk_w.get());
    } else {
        // Fallback: host-side normalize (for callers that didn't provide scratch).
        ACL_CHECK(aclrtSynchronizeStream(stream));
        std::vector<uint16_t> h_tw(S * K);
        ACL_CHECK(aclrtMemcpy(h_tw.data(), S*K*2, topk_w_scratch, S*K*2, ACL_MEMCPY_DEVICE_TO_HOST));
        for (int s = 0; s < S; s++) {
            float sum = 0;
            for (int k = 0; k < K; k++) {
                uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
                float v; std::memcpy(&v, &u, 4);
                sum += v;
            }
            sum += 1e-20f;
            for (int k = 0; k < K; k++) {
                uint32_t u = (uint32_t)h_tw[s*K + k] << 16;
                float v; std::memcpy(&v, &u, 4);
                v /= sum;
                std::memcpy(&u, &v, 4);
                h_tw[s*K + k] = (uint16_t)((u + 0x7FFF + ((u >> 16) & 1)) >> 16);
            }
        }
        ACL_CHECK(aclrtMemcpy(topk_w_scratch, S*K*2, h_tw.data(), S*K*2, ACL_MEMCPY_HOST_TO_DEVICE));
    }

    // 5. MoE init routing
    auto t_ex_x  = make_contig_tensor(expanded_x_scratch,  ACL_BF16,  {TOTAL, D});
    auto t_ex_ri = make_contig_tensor(expanded_ri_scratch, ACL_INT32, {TOTAL});
    auto t_tpe   = make_contig_tensor(tpe_scratch,         ACL_INT64, {E});
    moe_init_routing_v3(stream, t_xn.get(), t_topk_idx.get(),
                        E, TOTAL, t_ex_x.get(), t_ex_ri.get(), t_tpe.get());

    // 6. GMM gate + up
    auto t_gate_out = make_contig_tensor(gate_out_scratch, ACL_BF16, {TOTAL, I});
    auto t_up_out   = make_contig_tensor(up_out_scratch,   ACL_BF16, {TOTAL, I});
    auto t_w_gate   = make_contig_tensor(w.gate_exps.get(), ACL_BF16, {E, D, I});
    auto t_w_up     = make_contig_tensor(w.up_exps.get(),   ACL_BF16, {E, D, I});
    grouped_matmul_v4(stream, t_ex_x.get(), t_w_gate.get(), t_tpe.get(), t_gate_out.get(), 1);
    grouped_matmul_v4(stream, t_ex_x.get(), t_w_up.get(),   t_tpe.get(), t_up_out.get(),   1);

    // 7. SwiGLU: gate_out = silu(gate_out) * up_out
    silu(stream, t_gate_out.get(), t_gate_out.get());
    mul(stream, t_gate_out.get(), t_up_out.get(), t_gate_out.get());

    // 8. GMM down
    auto t_down_out = make_contig_tensor(down_out_scratch, ACL_BF16, {TOTAL, D});
    auto t_w_down   = make_contig_tensor(w.down_exps.get(), ACL_BF16, {E, I, D});
    grouped_matmul_v4(stream, t_gate_out.get(), t_w_down.get(), t_tpe.get(), t_down_out.get(), 1);

    // 9. Device-side finalize: build forward perm via two consecutive argsorts on topk_idx.
    // No host sync β€” safe for graph capture.
    //   inv_fwd = argsort(topk_idx.flat)   // each (n,k) β†’ sorted position (primary key: expert_id)
    //   fwd     = argsort(inv_fwd)          // inverse perm β€” what IndexSelect needs
    // Stability: aclnnArgsort preserves input order for equal keys; flat index = n*K + k orders
    // ties by n-then-k, matching our previous manual sort convention.
    //
    // Scratch for inv_fwd: reuse first TOTAL*8 bytes of weighted_scratch (gets overwritten
    // by the subsequent mul op, so aliasing is safe).
    {
        auto t_topk_idx_flat = make_contig_tensor(topk_idx_scratch,  ACL_INT32, {TOTAL});
        auto t_inv_fwd       = make_contig_tensor(weighted_scratch,  ACL_INT64, {TOTAL});
        auto t_fwd_64        = make_contig_tensor(fwd_scratch,       ACL_INT64, {TOTAL});
        argsort(stream, t_topk_idx_flat.get(), /*dim=*/0, /*descending=*/false, t_inv_fwd.get());
        argsort(stream, t_inv_fwd.get(),       /*dim=*/0, /*descending=*/false, t_fwd_64.get());
    }
    auto t_fwd = make_contig_tensor(fwd_scratch, ACL_INT64, {TOTAL});
    auto t_packed = make_contig_tensor(packed_scratch, ACL_BF16, {TOTAL, D});
    index_select(stream, t_down_out.get(), 0, t_fwd.get(), t_packed.get());

    auto t_packed_3d = make_contig_tensor(packed_scratch, ACL_BF16, {S, K, D});
    auto t_topk_w_3d = make_contig_tensor(topk_w_scratch, ACL_BF16, {S, K, 1});
    auto t_weighted  = make_contig_tensor(weighted_scratch, ACL_BF16, {S, K, D});
    mul(stream, t_packed_3d.get(), t_topk_w_3d.get(), t_weighted.get());

    auto t_out = make_contig_tensor(x_out, ACL_BF16, {S, D});
    reduce_sum(stream, t_weighted.get(), {1}, false, ACL_BF16, t_out.get());

    // TP AllReduce on MoE output (column-parallel experts β†’ SUM partial intermediate outputs)
    if (hccl_ctx && hccl_ctx->tp_size > 1) {
        hccl_allreduce_bf16(*hccl_ctx, x_out, S * D, stream);
    }
}