| | |
| |
|
| | import argparse |
| | import os |
| | import torch |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_path", |
| | type=str, |
| | required=True, |
| | help="Path to the fine-tuned checkpoint directory (e.g., ./checkpoints/checkpoint-16000)", |
| | ) |
| | parser.add_argument( |
| | "--custom_model_path", |
| | type=str, |
| | required=False, |
| | help="(Optional) Path to the model implementation source if needed", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | print(f"Loading config from: {args.model_path}") |
| | config = AutoConfig.from_pretrained(args.model_path) |
| |
|
| | if hasattr(config, "num_small_experts"): |
| | num_small_experts = config.num_small_experts |
| | else: |
| | raise ValueError("The model config does not contain 'num_small_experts'.") |
| |
|
| | print(f"Number of small experts: {num_small_experts}") |
| |
|
| | print("Loading model...") |
| | model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype=torch.bfloat16) |
| | model.eval() |
| |
|
| | print("Inspecting small expert weights...") |
| | total_params = 0 |
| | matched_params = 0 |
| | for name, param in model.named_parameters(): |
| | total_params += 1 |
| | if f"small_experts." in name: |
| | matched_params += 1 |
| | print(f"[Matched] {name} - shape: {tuple(param.shape)}") |
| | print(f"\nMatched {matched_params}/{total_params} parameters containing 'small_experts.'") |
| |
|
| | print("Done.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|