File size: 10,740 Bytes
4b9fefd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
#include "device_weights.h"
#include "aclnn_ops.h"
#include <cstdio>
#include <cstring>
#include <vector>

bool DeviceWeightsLoader::load_tensor_full_(const std::string& name, DeviceBuffer& buf) {
    const auto* m = st_.get(name);
    if (!m) { fprintf(stderr, "load_tensor_full_: missing %s\n", name.c_str()); return false; }
    const void* host = st_.data_ptr(*m);
    if (!host) { fprintf(stderr, "load_tensor_full_: null host ptr %s\n", name.c_str()); return false; }
    buf.alloc(m->nbytes);
    ACL_CHECK(aclrtMemcpy(buf.get(), m->nbytes, host, m->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
    return true;
}

bool DeviceWeightsLoader::load_tensor_row_slice_(const std::string& name,
                                                 int64_t row_lo, int64_t row_hi,
                                                 DeviceBuffer& buf) {
    const auto* m = st_.get(name);
    if (!m) { fprintf(stderr, "load_tensor_row_slice_: missing %s\n", name.c_str()); return false; }
    if (m->shape.empty()) { fprintf(stderr, "%s: empty shape\n", name.c_str()); return false; }
    int64_t D0 = m->shape[0];
    if (row_hi > D0 || row_lo < 0 || row_hi <= row_lo) {
        fprintf(stderr, "load_tensor_row_slice_: %s bad range [%ld,%ld) vs D0=%ld\n",
                name.c_str(), row_lo, row_hi, D0);
        return false;
    }
    size_t elem = sdtype_size(m->dtype);
    size_t inner = 1;
    for (size_t i = 1; i < m->shape.size(); i++) inner *= m->shape[i];
    size_t row_bytes = inner * elem;
    size_t slice_bytes = (row_hi - row_lo) * row_bytes;

    const auto* host = (const char*)st_.data_ptr(*m);
    buf.alloc(slice_bytes);
    ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes,
                          host + row_lo * row_bytes, slice_bytes,
                          ACL_MEMCPY_HOST_TO_DEVICE));
    return true;
}

bool DeviceWeightsLoader::load_tensor_col_slice_(const std::string& name,
                                                 int64_t col_lo, int64_t col_hi,
                                                 DeviceBuffer& buf) {
    const auto* m = st_.get(name);
    if (!m || m->shape.size() < 2) {
        fprintf(stderr, "load_tensor_col_slice_: bad shape %s\n", name.c_str()); return false;
    }
    int64_t D0 = m->shape[0];
    int64_t D1 = m->shape[1];
    if (col_hi > D1 || col_lo < 0 || col_hi <= col_lo) {
        fprintf(stderr, "load_tensor_col_slice_: bad range %ld-%ld D1=%ld\n",
                col_lo, col_hi, D1); return false;
    }
    size_t elem = sdtype_size(m->dtype);
    int64_t new_cols = col_hi - col_lo;
    size_t slice_bytes = D0 * new_cols * elem;
    buf.alloc(slice_bytes);

    // Need to copy row-by-row since source has stride D1 but dest has stride new_cols.
    const auto* host = (const char*)st_.data_ptr(*m);
    std::vector<char> staging(slice_bytes);
    size_t src_row = D1 * elem;
    size_t dst_row = new_cols * elem;
    size_t col_off = col_lo * elem;
    for (int64_t r = 0; r < D0; r++) {
        std::memcpy(staging.data() + r * dst_row, host + r * src_row + col_off, dst_row);
    }
    ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes, staging.data(), slice_bytes,
                          ACL_MEMCPY_HOST_TO_DEVICE));
    return true;
}

bool DeviceWeightsLoader::load_shared(SharedWeights& out) {
    if (!load_tensor_full_("model.embed_tokens.weight", out.embed_tokens)) return false;
    if (!load_tensor_full_("lm_head.weight",            out.lm_head))       return false;
    if (!load_tensor_full_("model.norm.weight",         out.final_norm))    return false;
    return true;
}

bool DeviceWeightsLoader::load_moe(int L, aclrtStream stream, LayerMoEWeights& out) {
    const int64_t E = cfg_.num_experts;
    const int64_t D = cfg_.hidden_size;
    const int64_t I_full = cfg_.moe_intermediate_size;
    const int64_t I_rank = cfg_.i_per_rank;
    const size_t elem = 2;  // BF16

    auto base = "model.layers." + std::to_string(L);

    // 1. Router [E, D] — small, fully replicated
    if (!load_tensor_full_(base + ".mlp.gate.weight", out.router)) return false;

    // 2. MoE expert weights: need to stack 128 experts + TP slice + permute
    // HF gate/up: each expert [I_full, D] → TP slice rows to [I_rank, D]
    // HF down:    each expert [D, I_full] → TP slice cols to [D, I_rank]

    auto load_and_stack = [&](const std::string& proj_name,
                              bool is_down, DeviceBuffer& final_buf) -> bool {
        // HF shape for gate/up: [I_full, D]; for down: [D, I_full]
        // After TP slice: gate/up rows [I_rank, D]; down cols [D, I_rank]
        // Stacked:
        //   gate/up: [E, I_rank, D] → permute to [E, D, I_rank]
        //   down:    [E, D, I_rank] → permute to [E, I_rank, D]

        int64_t K_in, N_out;
        bool row_slice;
        if (!is_down) {
            K_in = I_rank;            // HF first dim after row-slice
            N_out = D;
            row_slice = true;
        } else {
            K_in = D;
            N_out = I_rank;
            row_slice = false;        // col slice
        }

        // Stage: stacked HF-layout [E, K_in, N_out] on device (before permute)
        size_t elem_stack = K_in * N_out * elem;
        DeviceBuffer stacked_hf(E * elem_stack);

        // For each expert, load + TP slice + memcpy to stacked_hf[e]
        // We use the existing row_slice/col_slice helpers on a per-expert basis.
        DeviceBuffer tmp;
        for (int e = 0; e < E; e++) {
            std::string name = base + ".mlp.experts." + std::to_string(e) + "." + proj_name + ".weight";
            if (row_slice) {
                int64_t lo = cfg_.tp_rank * I_rank;
                int64_t hi = lo + I_rank;
                if (!load_tensor_row_slice_(name, lo, hi, tmp)) return false;
            } else {
                int64_t lo = cfg_.tp_rank * I_rank;
                int64_t hi = lo + I_rank;
                if (!load_tensor_col_slice_(name, lo, hi, tmp)) return false;
            }
            if (tmp.size != elem_stack) {
                fprintf(stderr, "load_moe: expert %d %s slice size %zu != expected %zu\n",
                        e, name.c_str(), tmp.size, elem_stack);
                return false;
            }
            // Synchronous D2D: tmp is about to be reallocated in the next iteration,
            // so we cannot enqueue an async copy that would still reference it.
            ACL_CHECK(aclrtMemcpy(
                (char*)stacked_hf.get() + e * elem_stack, elem_stack,
                tmp.get(), elem_stack,
                ACL_MEMCPY_DEVICE_TO_DEVICE));
        }

        // Now permute stacked_hf [E, K_in, N_out] → final [E, N_out, K_in] row-major
        // (swap last two dims)
        final_buf.alloc(E * elem_stack);
        const int64_t d0 = E, d1 = K_in, d2 = N_out;
        // View stacked_hf with permuted strides pointing to same data:
        // logical shape [E, N_out, K_in], strides [K_in*N_out, 1, N_out]
        // (since physical is [E, K_in, N_out] row-major with strides [K_in*N_out, N_out, 1])
        auto t_src = make_acl_tensor(stacked_hf.get(), ACL_BF16,
                                     {d0, d2, d1},  // [E, N_out, K_in]
                                     {d1 * d2, 1, d2});
        auto t_dst = make_contig_tensor(final_buf.get(), ACL_BF16, {d0, d2, d1});
        inplace_copy(stream, t_dst.get(), t_src.get());
        // Must sync before stacked_hf goes out of scope — the inplace_copy is async and
        // reads from stacked_hf's memory. If we return without syncing, DeviceBuffer's
        // destructor frees stacked_hf while the permute kernel is still running, producing
        // garbage in final_buf.
        ACL_CHECK(aclrtSynchronizeStream(stream));
        return true;
    };

    if (!load_and_stack("gate_proj", false, out.gate_exps)) return false;
    if (!load_and_stack("up_proj",   false, out.up_exps))   return false;
    if (!load_and_stack("down_proj", true,  out.down_exps)) return false;

    return true;
}

bool DeviceWeightsLoader::load_attention(int L, LayerAttnWeights& out) {
    auto base = "model.layers." + std::to_string(L);

    if (!load_tensor_full_(base + ".input_layernorm.weight",          out.input_layernorm))          return false;
    if (!load_tensor_full_(base + ".post_attention_layernorm.weight", out.post_attention_layernorm)) return false;
    if (!load_tensor_full_(base + ".self_attn.q_norm.weight", out.q_norm)) return false;
    if (!load_tensor_full_(base + ".self_attn.k_norm.weight", out.k_norm)) return false;

    const int64_t head_dim = cfg_.head_dim;
    const int64_t q_full   = cfg_.num_attention_heads * head_dim;    // 64 * 128 = 8192

    // q_proj: [q_full, D], shard rows by head. Each rank gets n_heads_per_rank heads.
    int64_t q_rows_per_rank = cfg_.n_heads_per_rank * head_dim;
    int64_t q_row_lo = cfg_.tp_rank * q_rows_per_rank;
    int64_t q_row_hi = q_row_lo + q_rows_per_rank;
    if (!load_tensor_row_slice_(base + ".self_attn.q_proj.weight",
                                 q_row_lo, q_row_hi, out.q_proj)) return false;

    // k_proj, v_proj: HF shape [num_kv * head_dim, D].
    //   Case A (tp <= n_kv): split rows across ranks, each rank gets n_kv/tp KV heads.
    //   Case B (tp > n_kv):  each rank gets exactly ONE KV head; group of (tp/n_kv) ranks share it.
    //       kv_head_idx = tp_rank / (tp_size / n_kv)
    if (cfg_.tp_size <= cfg_.num_key_value_heads) {
        int64_t kv_rows_per_rank = cfg_.n_kv_heads_per_rank * head_dim;
        int64_t kv_row_lo = cfg_.tp_rank * kv_rows_per_rank;
        int64_t kv_row_hi = kv_row_lo + kv_rows_per_rank;
        if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
        if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
    } else {
        // GQA replicated-group mode: 1 KV head per rank, selected by group.
        int64_t ranks_per_kv = cfg_.tp_size / cfg_.num_key_value_heads;
        int64_t kv_head_idx = cfg_.tp_rank / ranks_per_kv;
        int64_t kv_row_lo = kv_head_idx * head_dim;
        int64_t kv_row_hi = kv_row_lo + head_dim;
        if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
        if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
    }

    // o_proj: [D, q_full], row-parallel → shard cols (input dim) by head.
    int64_t o_col_lo = q_row_lo;  // same slicing as q rows
    int64_t o_col_hi = q_row_hi;
    if (!load_tensor_col_slice_(base + ".self_attn.o_proj.weight",
                                 o_col_lo, o_col_hi, out.o_proj)) return false;

    return true;
}