| // device_weights.h — load safetensors weights to device memory with proper TP shard. | |
| // | |
| // For M3 (attention only): loads attention + norm weights. MoE expert weights come in M4. | |
| // | |
| // Per-layer MoE weights on device (BF16). | |
| // After loading: weights are in GMM-ready layout [E, K_in, N_out] row-major contiguous. | |
| // For gate/up: K_in=D, N_out=I_per_rank | |
| // For down: K_in=I, N_out=D | |
| struct LayerMoEWeights { | |
| DeviceBuffer router; // [E, D] BF16 replicated | |
| DeviceBuffer gate_exps; // [E, D, I_per_rank] (permuted from HF [E, I, D]) | |
| DeviceBuffer up_exps; // [E, D, I_per_rank] | |
| DeviceBuffer down_exps; // [E, I_per_rank, D] (permuted from HF [E, D, I]) | |
| }; | |
| // Per-layer attention weights on device (BF16 unless noted). | |
| struct LayerAttnWeights { | |
| DeviceBuffer input_layernorm; // [D] BF16 | |
| DeviceBuffer post_attention_layernorm; // [D] BF16 | |
| // Q/K/V/O projections. HF stores as [out, in] BF16. | |
| // For M3 we keep HF layout as-is; matmul wrappers handle the transpose via aclnnMm semantics. | |
| DeviceBuffer q_proj; // [Q_full, D] on rank, but physical stored as [Q_rank, D] (sliced by head) | |
| DeviceBuffer k_proj; // [KV, D] (replicated if tp_size > num_kv_heads) | |
| DeviceBuffer v_proj; // [KV, D] | |
| DeviceBuffer o_proj; // [D, Q_rank] (row-parallel on Q dim) | |
| DeviceBuffer q_norm; // [head_dim] BF16 (Qwen3 per-head norm) | |
| DeviceBuffer k_norm; // [head_dim] BF16 | |
| }; | |
| // Shared model weights (replicated across ranks). | |
| struct SharedWeights { | |
| DeviceBuffer embed_tokens; // [vocab, D] | |
| DeviceBuffer lm_head; // [vocab, D] | |
| DeviceBuffer final_norm; // [D] | |
| }; | |
| class DeviceWeightsLoader { | |
| public: | |
| DeviceWeightsLoader(SafetensorsLoader& st, const ModelConfig& cfg) | |
| : st_(st), cfg_(cfg) {} | |
| // Load shared (embed, norm, lm_head). Replicated on every rank. | |
| bool load_shared(SharedWeights& out); | |
| // Load ONE attention layer's weights with TP sharding. | |
| bool load_attention(int layer_idx, LayerAttnWeights& out); | |
| // Load ONE MoE layer's weights. Stacks 128 experts and permutes to GMM-ready layout. | |
| // stream: ACL stream for the permute op (aclnnInplaceCopy). | |
| bool load_moe(int layer_idx, aclrtStream stream, LayerMoEWeights& out); | |
| // Expose underlying safetensors for direct access (diagnostic use). | |
| SafetensorsLoader& st() { return st_; } | |
| private: | |
| SafetensorsLoader& st_; | |
| const ModelConfig& cfg_; | |
| // Helper: load HF tensor (full shape) into device buffer (simple H2D). | |
| bool load_tensor_full_(const std::string& name, DeviceBuffer& buf); | |
| // Helper: load HF tensor and keep only [row_lo, row_hi) of first dim (TP shard by "out" dim). | |
| // HF format: tensor has shape [D0, D1, ...] stored row-major. We take rows [lo, hi) to form | |
| // a sharded tensor of shape [hi-lo, D1, ...]. | |
| bool load_tensor_row_slice_(const std::string& name, | |
| int64_t row_lo, int64_t row_hi, | |
| DeviceBuffer& buf); | |
| // TP shard by "in" dim (second axis for 2D, etc.) — used for o_proj (row-parallel). | |
| bool load_tensor_col_slice_(const std::string& name, | |
| int64_t col_lo, int64_t col_hi, | |
| DeviceBuffer& buf); | |
| }; | |