File size: 10,740 Bytes
4b9fefd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | #include "device_weights.h"
#include "aclnn_ops.h"
#include <cstdio>
#include <cstring>
#include <vector>
bool DeviceWeightsLoader::load_tensor_full_(const std::string& name, DeviceBuffer& buf) {
const auto* m = st_.get(name);
if (!m) { fprintf(stderr, "load_tensor_full_: missing %s\n", name.c_str()); return false; }
const void* host = st_.data_ptr(*m);
if (!host) { fprintf(stderr, "load_tensor_full_: null host ptr %s\n", name.c_str()); return false; }
buf.alloc(m->nbytes);
ACL_CHECK(aclrtMemcpy(buf.get(), m->nbytes, host, m->nbytes, ACL_MEMCPY_HOST_TO_DEVICE));
return true;
}
bool DeviceWeightsLoader::load_tensor_row_slice_(const std::string& name,
int64_t row_lo, int64_t row_hi,
DeviceBuffer& buf) {
const auto* m = st_.get(name);
if (!m) { fprintf(stderr, "load_tensor_row_slice_: missing %s\n", name.c_str()); return false; }
if (m->shape.empty()) { fprintf(stderr, "%s: empty shape\n", name.c_str()); return false; }
int64_t D0 = m->shape[0];
if (row_hi > D0 || row_lo < 0 || row_hi <= row_lo) {
fprintf(stderr, "load_tensor_row_slice_: %s bad range [%ld,%ld) vs D0=%ld\n",
name.c_str(), row_lo, row_hi, D0);
return false;
}
size_t elem = sdtype_size(m->dtype);
size_t inner = 1;
for (size_t i = 1; i < m->shape.size(); i++) inner *= m->shape[i];
size_t row_bytes = inner * elem;
size_t slice_bytes = (row_hi - row_lo) * row_bytes;
const auto* host = (const char*)st_.data_ptr(*m);
buf.alloc(slice_bytes);
ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes,
host + row_lo * row_bytes, slice_bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
return true;
}
bool DeviceWeightsLoader::load_tensor_col_slice_(const std::string& name,
int64_t col_lo, int64_t col_hi,
DeviceBuffer& buf) {
const auto* m = st_.get(name);
if (!m || m->shape.size() < 2) {
fprintf(stderr, "load_tensor_col_slice_: bad shape %s\n", name.c_str()); return false;
}
int64_t D0 = m->shape[0];
int64_t D1 = m->shape[1];
if (col_hi > D1 || col_lo < 0 || col_hi <= col_lo) {
fprintf(stderr, "load_tensor_col_slice_: bad range %ld-%ld D1=%ld\n",
col_lo, col_hi, D1); return false;
}
size_t elem = sdtype_size(m->dtype);
int64_t new_cols = col_hi - col_lo;
size_t slice_bytes = D0 * new_cols * elem;
buf.alloc(slice_bytes);
// Need to copy row-by-row since source has stride D1 but dest has stride new_cols.
const auto* host = (const char*)st_.data_ptr(*m);
std::vector<char> staging(slice_bytes);
size_t src_row = D1 * elem;
size_t dst_row = new_cols * elem;
size_t col_off = col_lo * elem;
for (int64_t r = 0; r < D0; r++) {
std::memcpy(staging.data() + r * dst_row, host + r * src_row + col_off, dst_row);
}
ACL_CHECK(aclrtMemcpy(buf.get(), slice_bytes, staging.data(), slice_bytes,
ACL_MEMCPY_HOST_TO_DEVICE));
return true;
}
bool DeviceWeightsLoader::load_shared(SharedWeights& out) {
if (!load_tensor_full_("model.embed_tokens.weight", out.embed_tokens)) return false;
if (!load_tensor_full_("lm_head.weight", out.lm_head)) return false;
if (!load_tensor_full_("model.norm.weight", out.final_norm)) return false;
return true;
}
bool DeviceWeightsLoader::load_moe(int L, aclrtStream stream, LayerMoEWeights& out) {
const int64_t E = cfg_.num_experts;
const int64_t D = cfg_.hidden_size;
const int64_t I_full = cfg_.moe_intermediate_size;
const int64_t I_rank = cfg_.i_per_rank;
const size_t elem = 2; // BF16
auto base = "model.layers." + std::to_string(L);
// 1. Router [E, D] — small, fully replicated
if (!load_tensor_full_(base + ".mlp.gate.weight", out.router)) return false;
// 2. MoE expert weights: need to stack 128 experts + TP slice + permute
// HF gate/up: each expert [I_full, D] → TP slice rows to [I_rank, D]
// HF down: each expert [D, I_full] → TP slice cols to [D, I_rank]
auto load_and_stack = [&](const std::string& proj_name,
bool is_down, DeviceBuffer& final_buf) -> bool {
// HF shape for gate/up: [I_full, D]; for down: [D, I_full]
// After TP slice: gate/up rows [I_rank, D]; down cols [D, I_rank]
// Stacked:
// gate/up: [E, I_rank, D] → permute to [E, D, I_rank]
// down: [E, D, I_rank] → permute to [E, I_rank, D]
int64_t K_in, N_out;
bool row_slice;
if (!is_down) {
K_in = I_rank; // HF first dim after row-slice
N_out = D;
row_slice = true;
} else {
K_in = D;
N_out = I_rank;
row_slice = false; // col slice
}
// Stage: stacked HF-layout [E, K_in, N_out] on device (before permute)
size_t elem_stack = K_in * N_out * elem;
DeviceBuffer stacked_hf(E * elem_stack);
// For each expert, load + TP slice + memcpy to stacked_hf[e]
// We use the existing row_slice/col_slice helpers on a per-expert basis.
DeviceBuffer tmp;
for (int e = 0; e < E; e++) {
std::string name = base + ".mlp.experts." + std::to_string(e) + "." + proj_name + ".weight";
if (row_slice) {
int64_t lo = cfg_.tp_rank * I_rank;
int64_t hi = lo + I_rank;
if (!load_tensor_row_slice_(name, lo, hi, tmp)) return false;
} else {
int64_t lo = cfg_.tp_rank * I_rank;
int64_t hi = lo + I_rank;
if (!load_tensor_col_slice_(name, lo, hi, tmp)) return false;
}
if (tmp.size != elem_stack) {
fprintf(stderr, "load_moe: expert %d %s slice size %zu != expected %zu\n",
e, name.c_str(), tmp.size, elem_stack);
return false;
}
// Synchronous D2D: tmp is about to be reallocated in the next iteration,
// so we cannot enqueue an async copy that would still reference it.
ACL_CHECK(aclrtMemcpy(
(char*)stacked_hf.get() + e * elem_stack, elem_stack,
tmp.get(), elem_stack,
ACL_MEMCPY_DEVICE_TO_DEVICE));
}
// Now permute stacked_hf [E, K_in, N_out] → final [E, N_out, K_in] row-major
// (swap last two dims)
final_buf.alloc(E * elem_stack);
const int64_t d0 = E, d1 = K_in, d2 = N_out;
// View stacked_hf with permuted strides pointing to same data:
// logical shape [E, N_out, K_in], strides [K_in*N_out, 1, N_out]
// (since physical is [E, K_in, N_out] row-major with strides [K_in*N_out, N_out, 1])
auto t_src = make_acl_tensor(stacked_hf.get(), ACL_BF16,
{d0, d2, d1}, // [E, N_out, K_in]
{d1 * d2, 1, d2});
auto t_dst = make_contig_tensor(final_buf.get(), ACL_BF16, {d0, d2, d1});
inplace_copy(stream, t_dst.get(), t_src.get());
// Must sync before stacked_hf goes out of scope — the inplace_copy is async and
// reads from stacked_hf's memory. If we return without syncing, DeviceBuffer's
// destructor frees stacked_hf while the permute kernel is still running, producing
// garbage in final_buf.
ACL_CHECK(aclrtSynchronizeStream(stream));
return true;
};
if (!load_and_stack("gate_proj", false, out.gate_exps)) return false;
if (!load_and_stack("up_proj", false, out.up_exps)) return false;
if (!load_and_stack("down_proj", true, out.down_exps)) return false;
return true;
}
bool DeviceWeightsLoader::load_attention(int L, LayerAttnWeights& out) {
auto base = "model.layers." + std::to_string(L);
if (!load_tensor_full_(base + ".input_layernorm.weight", out.input_layernorm)) return false;
if (!load_tensor_full_(base + ".post_attention_layernorm.weight", out.post_attention_layernorm)) return false;
if (!load_tensor_full_(base + ".self_attn.q_norm.weight", out.q_norm)) return false;
if (!load_tensor_full_(base + ".self_attn.k_norm.weight", out.k_norm)) return false;
const int64_t head_dim = cfg_.head_dim;
const int64_t q_full = cfg_.num_attention_heads * head_dim; // 64 * 128 = 8192
// q_proj: [q_full, D], shard rows by head. Each rank gets n_heads_per_rank heads.
int64_t q_rows_per_rank = cfg_.n_heads_per_rank * head_dim;
int64_t q_row_lo = cfg_.tp_rank * q_rows_per_rank;
int64_t q_row_hi = q_row_lo + q_rows_per_rank;
if (!load_tensor_row_slice_(base + ".self_attn.q_proj.weight",
q_row_lo, q_row_hi, out.q_proj)) return false;
// k_proj, v_proj: HF shape [num_kv * head_dim, D].
// Case A (tp <= n_kv): split rows across ranks, each rank gets n_kv/tp KV heads.
// Case B (tp > n_kv): each rank gets exactly ONE KV head; group of (tp/n_kv) ranks share it.
// kv_head_idx = tp_rank / (tp_size / n_kv)
if (cfg_.tp_size <= cfg_.num_key_value_heads) {
int64_t kv_rows_per_rank = cfg_.n_kv_heads_per_rank * head_dim;
int64_t kv_row_lo = cfg_.tp_rank * kv_rows_per_rank;
int64_t kv_row_hi = kv_row_lo + kv_rows_per_rank;
if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
} else {
// GQA replicated-group mode: 1 KV head per rank, selected by group.
int64_t ranks_per_kv = cfg_.tp_size / cfg_.num_key_value_heads;
int64_t kv_head_idx = cfg_.tp_rank / ranks_per_kv;
int64_t kv_row_lo = kv_head_idx * head_dim;
int64_t kv_row_hi = kv_row_lo + head_dim;
if (!load_tensor_row_slice_(base + ".self_attn.k_proj.weight", kv_row_lo, kv_row_hi, out.k_proj)) return false;
if (!load_tensor_row_slice_(base + ".self_attn.v_proj.weight", kv_row_lo, kv_row_hi, out.v_proj)) return false;
}
// o_proj: [D, q_full], row-parallel → shard cols (input dim) by head.
int64_t o_col_lo = q_row_lo; // same slicing as q rows
int64_t o_col_hi = q_row_hi;
if (!load_tensor_col_slice_(base + ".self_attn.o_proj.weight",
o_col_lo, o_col_hi, out.o_proj)) return false;
return true;
}
|