| #include "model_config.h" |
|
|
| #include <cstdio> |
| #include <fstream> |
| #include <sstream> |
|
|
| #include "json.hpp" |
| using json = nlohmann::json; |
|
|
| bool ModelConfig::load_from_json(const std::string& path) { |
| std::ifstream f(path); |
| if (!f) { |
| fprintf(stderr, "ModelConfig: cannot open %s\n", path.c_str()); |
| return false; |
| } |
| json j; |
| try { f >> j; } catch (std::exception& e) { |
| fprintf(stderr, "ModelConfig: bad json: %s\n", e.what()); |
| return false; |
| } |
|
|
| auto get = [&](const char* k, auto def) { |
| if (j.contains(k) && !j[k].is_null()) return j[k].get<decltype(def)>(); |
| return def; |
| }; |
|
|
| vocab_size = get("vocab_size", (int64_t)0); |
| hidden_size = get("hidden_size", (int64_t)0); |
| intermediate_size = get("intermediate_size", (int64_t)0); |
| moe_intermediate_size = get("moe_intermediate_size", (int64_t)0); |
| num_hidden_layers = get("num_hidden_layers", (int64_t)0); |
| num_attention_heads = get("num_attention_heads", (int64_t)0); |
| num_key_value_heads = get("num_key_value_heads", (int64_t)0); |
| head_dim = get("head_dim", (int64_t)0); |
| num_experts = get("num_experts", (int64_t)0); |
| num_experts_per_tok = get("num_experts_per_tok", (int64_t)0); |
| max_position_embeddings = get("max_position_embeddings", (int64_t)0); |
| rope_theta = (float)get("rope_theta", (double)10000.0); |
| rms_norm_eps = (float)get("rms_norm_eps", (double)1e-6); |
| norm_topk_prob = get("norm_topk_prob", true); |
| tie_word_embeddings = get("tie_word_embeddings", false); |
| bos_token_id = get("bos_token_id", (int64_t)0); |
| eos_token_id = get("eos_token_id", (int64_t)0); |
|
|
| |
| if (num_attention_heads == 0 || head_dim == 0 || hidden_size == 0) { |
| fprintf(stderr, "ModelConfig: required fields missing\n"); |
| return false; |
| } |
| return true; |
| } |
|
|
| void ModelConfig::compute_derived(int tps, int tpr) { |
| tp_size = tps; |
| tp_rank = tpr; |
|
|
| |
| if (num_attention_heads % tp_size != 0) { |
| fprintf(stderr, "WARN: num_attention_heads=%ld not divisible by tp_size=%d\n", |
| num_attention_heads, tp_size); |
| } |
| n_heads_per_rank = num_attention_heads / tp_size; |
| q_dim_per_rank = n_heads_per_rank * head_dim; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if (tp_size <= num_key_value_heads && num_key_value_heads % tp_size == 0) { |
| n_kv_heads_per_rank = num_key_value_heads / tp_size; |
| } else if (tp_size % num_key_value_heads == 0) { |
| n_kv_heads_per_rank = 1; |
| } else { |
| fprintf(stderr, "WARN: non-standard TP/KV head ratio: tp=%d kv=%ld — falling back to replicate-all\n", |
| tp_size, num_key_value_heads); |
| n_kv_heads_per_rank = num_key_value_heads; |
| } |
| kv_dim_per_rank = n_kv_heads_per_rank * head_dim; |
|
|
| |
| if (moe_intermediate_size % tp_size != 0) { |
| fprintf(stderr, "WARN: moe_intermediate_size=%ld not divisible by tp_size=%d\n", |
| moe_intermediate_size, tp_size); |
| } |
| i_per_rank = moe_intermediate_size / tp_size; |
| } |
|
|
| std::string ModelConfig::describe() const { |
| std::ostringstream os; |
| os << "Qwen3MoE config:\n" |
| << " vocab_size = " << vocab_size << "\n" |
| << " hidden_size = " << hidden_size << "\n" |
| << " num_hidden_layers = " << num_hidden_layers << "\n" |
| << " num_attention_heads = " << num_attention_heads << "\n" |
| << " num_key_value_heads = " << num_key_value_heads << "\n" |
| << " head_dim = " << head_dim << "\n" |
| << " num_experts = " << num_experts << "\n" |
| << " num_experts_per_tok = " << num_experts_per_tok << "\n" |
| << " moe_intermediate_size = " << moe_intermediate_size << "\n" |
| << " rope_theta = " << rope_theta << "\n" |
| << " rms_norm_eps = " << rms_norm_eps << "\n" |
| << " max_pos_embeddings = " << max_position_embeddings << "\n" |
| << "TP rank " << tp_rank << " / " << tp_size << " derived:\n" |
| << " n_heads_per_rank = " << n_heads_per_rank << "\n" |
| << " q_dim_per_rank = " << q_dim_per_rank << "\n" |
| << " n_kv_heads_per_rank = " << n_kv_heads_per_rank << "\n" |
| << " kv_dim_per_rank = " << kv_dim_per_rank << "\n" |
| << " i_per_rank = " << i_per_rank << "\n"; |
| return os.str(); |
| } |
|
|