// device_weights.h — load safetensors weights to device memory with proper TP shard. // // For M3 (attention only): loads attention + norm weights. MoE expert weights come in M4. // #pragma once #include "acl_common.h" #include "model_config.h" #include "safetensors_loader.h" #include #include #include // Per-layer MoE weights on device (BF16). // After loading: weights are in GMM-ready layout [E, K_in, N_out] row-major contiguous. // For gate/up: K_in=D, N_out=I_per_rank // For down: K_in=I, N_out=D struct LayerMoEWeights { DeviceBuffer router; // [E, D] BF16 replicated DeviceBuffer gate_exps; // [E, D, I_per_rank] (permuted from HF [E, I, D]) DeviceBuffer up_exps; // [E, D, I_per_rank] DeviceBuffer down_exps; // [E, I_per_rank, D] (permuted from HF [E, D, I]) }; // Per-layer attention weights on device (BF16 unless noted). struct LayerAttnWeights { DeviceBuffer input_layernorm; // [D] BF16 DeviceBuffer post_attention_layernorm; // [D] BF16 // Q/K/V/O projections. HF stores as [out, in] BF16. // For M3 we keep HF layout as-is; matmul wrappers handle the transpose via aclnnMm semantics. DeviceBuffer q_proj; // [Q_full, D] on rank, but physical stored as [Q_rank, D] (sliced by head) DeviceBuffer k_proj; // [KV, D] (replicated if tp_size > num_kv_heads) DeviceBuffer v_proj; // [KV, D] DeviceBuffer o_proj; // [D, Q_rank] (row-parallel on Q dim) DeviceBuffer q_norm; // [head_dim] BF16 (Qwen3 per-head norm) DeviceBuffer k_norm; // [head_dim] BF16 }; // Shared model weights (replicated across ranks). struct SharedWeights { DeviceBuffer embed_tokens; // [vocab, D] DeviceBuffer lm_head; // [vocab, D] DeviceBuffer final_norm; // [D] }; class DeviceWeightsLoader { public: DeviceWeightsLoader(SafetensorsLoader& st, const ModelConfig& cfg) : st_(st), cfg_(cfg) {} // Load shared (embed, norm, lm_head). Replicated on every rank. bool load_shared(SharedWeights& out); // Load ONE attention layer's weights with TP sharding. bool load_attention(int layer_idx, LayerAttnWeights& out); // Load ONE MoE layer's weights. Stacks 128 experts and permutes to GMM-ready layout. // stream: ACL stream for the permute op (aclnnInplaceCopy). bool load_moe(int layer_idx, aclrtStream stream, LayerMoEWeights& out); // Expose underlying safetensors for direct access (diagnostic use). SafetensorsLoader& st() { return st_; } private: SafetensorsLoader& st_; const ModelConfig& cfg_; // Helper: load HF tensor (full shape) into device buffer (simple H2D). bool load_tensor_full_(const std::string& name, DeviceBuffer& buf); // Helper: load HF tensor and keep only [row_lo, row_hi) of first dim (TP shard by "out" dim). // HF format: tensor has shape [D0, D1, ...] stored row-major. We take rows [lo, hi) to form // a sharded tensor of shape [hi-lo, D1, ...]. bool load_tensor_row_slice_(const std::string& name, int64_t row_lo, int64_t row_hi, DeviceBuffer& buf); // TP shard by "in" dim (second axis for 2D, etc.) — used for o_proj (row-parallel). bool load_tensor_col_slice_(const std::string& name, int64_t col_lo, int64_t col_hi, DeviceBuffer& buf); };