Instructions to use PakNin/Reuse-Trained-R3 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use PakNin/Reuse-Trained-R3 with PEFT:
from peft import PeftModel from transformers import AutoModelForCausalLM base_model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-mini-MoE-instruct") model = PeftModel.from_pretrained(base_model, "PakNin/Reuse-Trained-R3") - Transformers
How to use PakNin/Reuse-Trained-R3 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="PakNin/Reuse-Trained-R3") messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("PakNin/Reuse-Trained-R3", dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use PakNin/Reuse-Trained-R3 with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "PakNin/Reuse-Trained-R3" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "PakNin/Reuse-Trained-R3", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/PakNin/Reuse-Trained-R3
- SGLang
How to use PakNin/Reuse-Trained-R3 with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "PakNin/Reuse-Trained-R3" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "PakNin/Reuse-Trained-R3", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "PakNin/Reuse-Trained-R3" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "PakNin/Reuse-Trained-R3", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use PakNin/Reuse-Trained-R3 with Docker Model Runner:
docker model run hf.co/PakNin/Reuse-Trained-R3
| import os | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import json | |
| import bitsandbytes as bnb | |
| from peft import LoraConfig, get_peft_model | |
| import argparse | |
| import logging | |
| from datetime import datetime | |
| from torch.optim.lr_scheduler import CosineAnnealingLR | |
| from typing import Optional | |
| from get_dataset import get_dataloader | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| # ==================== 0. LOGGING SETUP ==================== | |
| def setup_logger(save_path="./logs"): | |
| """Setup logger with timestamp-based filename""" | |
| os.makedirs(save_path, exist_ok=True) | |
| # Create unique log filename: DDMM_HHMMSS | |
| timestamp = datetime.now().strftime("%d%m_%H%M%S") | |
| log_file = os.path.join(save_path, f"rexmoe_training_{timestamp}.log") | |
| # Create logger | |
| logger = logging.getLogger('ReXMoE') | |
| logger.setLevel(logging.INFO) | |
| # Remove existing handlers | |
| logger.handlers = [] | |
| # File handle | |
| file_handler = logging.FileHandler(log_file) | |
| file_handler.setLevel(logging.INFO) | |
| # Console handler | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.INFO) | |
| # Formatter | |
| formatter = logging.Formatter( | |
| '%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| file_handler.setFormatter(formatter) | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(file_handler) | |
| logger.addHandler(console_handler) | |
| logger.info(f"=" * 80) | |
| logger.info(f"ReXMoE Training Log - {timestamp}") | |
| logger.info(f"Log file: {log_file}") | |
| logger.info(f"=" * 80) | |
| return logger, log_file | |
| # Format prompt | |
| def format_alpaca_prompt(instruction: str, input_text: str = "") -> str: | |
| """Match the training prompt template used in main.py.""" | |
| if input_text: | |
| return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" | |
| return f"### Instruction:\n{instruction}\n\n### Response:\n" | |
| def build_model_input(tokenizer, instruction: str, input_text: str = "") -> str: | |
| """Prefer the model chat template if available; fall back to Alpaca prompt.""" | |
| user_msg = instruction if not input_text else f"{instruction}\n\n{input_text}" | |
| # Newer HF tokenizers expose an explicit chat template for instruct models. | |
| if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): | |
| messages = [{"role": "user", "content": user_msg}] | |
| print(f"Applying tokenizer's chat template: {tokenizer.chat_template}") | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| return format_alpaca_prompt(instruction=instruction, input_text=input_text) | |
| # Evaluate | |
| # ...existing code... | |
| def evaluate_prompt(model, tokenizer, max_new_tokens=100, do_sample=True, temperature=0.7, logger=None): | |
| """Generate completions for 3 sample prompts and print/log results.""" | |
| try: | |
| # Safely get device for tensors | |
| if hasattr(model, "device"): | |
| device = model.device | |
| else: | |
| # fallback: first parameter device | |
| device = next(model.parameters()).device | |
| msg = "\nEvaluating model with 3 sample prompts..." | |
| if logger: | |
| logger.info(msg) | |
| print(msg) | |
| # Display pruning status if IG-MET is enabled (count UNIQUE experts, not router-level copies) | |
| backend_model = get_backend_model(model) | |
| pruning_info = [] | |
| unique_experts_pruned = set() | |
| unique_experts_total = set() | |
| unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score | |
| # 1. Aggregate EMA values | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| router = layer.block_sparse_moe.router | |
| threshold = router.mask_threshold.item() | |
| if threshold >= 0: # IG-MET enabled | |
| # Reconstruct mapping logic carefully | |
| current_r = router.get_candidate_layers(step=None, total_steps=None) | |
| half = (current_r - 1) // 2 | |
| start_layer = max(0, layer_idx - half) | |
| end_layer = min(len(backend_model.layers), start_layer + current_r) | |
| start_layer = max(0, end_layer - current_r) | |
| # Build mapping for this router | |
| current_mapping = [] | |
| for layer_offset in range(current_r): | |
| l_id = start_layer + layer_offset | |
| if l_id >= len(backend_model.layers): break | |
| for e_id in range(router.num_experts_per_layer): | |
| current_mapping.append((l_id, e_id)) | |
| num_active = len(current_mapping) | |
| # Accumulate EMA | |
| for pool_pos, key in enumerate(current_mapping): | |
| if pool_pos >= len(router.ema_utilization): break | |
| unique_experts_total.add(key) | |
| ema_val = router.ema_utilization[pool_pos].item() | |
| if key not in unique_experts_sum: | |
| unique_experts_sum[key] = ema_val | |
| else: | |
| unique_experts_sum[key] += ema_val | |
| # 2. Determine pruning status based on SUMMED usage vs Threshold | |
| # Note: All routers share the same threshold value derived from summed distribution | |
| if unique_experts_sum and hasattr(backend_model.layers[0].block_sparse_moe.router, "mask_threshold"): | |
| # Get current global threshold from first router | |
| threshold = backend_model.layers[0].block_sparse_moe.router.mask_threshold.item() | |
| unique_experts_pruned = {k for k, v in unique_experts_sum.items() if v < threshold} | |
| msg = f"\n[IG-MET Pruning Status during Evaluation]:" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| total_unique_pruned = len(unique_experts_pruned) | |
| total_unique = len(unique_experts_total) | |
| pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0 | |
| msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold:.6f}" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| # Count per-layer pruning stats based on GLOBAL decision | |
| # Rerun loop just for stats | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| router = layer.block_sparse_moe.router | |
| # Mapping logic again | |
| current_r = router.get_candidate_layers(step=None, total_steps=None) | |
| half = (current_r - 1) // 2 | |
| start_layer = max(0, layer_idx - half) | |
| end_layer = min(len(backend_model.layers), start_layer + current_r) | |
| start_layer = max(0, end_layer - current_r) | |
| masked_in_layer = 0 | |
| # Only count experts from the CURRENT layer, not all reused layers | |
| current_layer_experts = [(layer_idx, e_id) for e_id in range(router.num_experts_per_layer)] | |
| for key in current_layer_experts: | |
| if key in unique_experts_pruned: | |
| masked_in_layer += 1 | |
| # num_active = total experts in current layer (always num_experts_per_layer for the layer itself) | |
| num_active = router.num_experts_per_layer | |
| pruning_info.append((layer_idx, threshold, masked_in_layer, num_active)) | |
| # Show top/bottom layers by pruning ratio | |
| pruning_by_ratio = sorted(pruning_info, key=lambda x: x[2]/x[3] if x[3] > 0 else 0, reverse=True) | |
| msg = "Top 5 most pruned layers:" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| for layer_idx, thr, masked, total in pruning_by_ratio[:5]: | |
| pct = 100 * masked / total if total > 0 else 0 | |
| msg = f" Layer {layer_idx:>2}: {masked:>2}/{total} pruned ({pct:>5.1f}%)" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| # Define 3 evaluation prompts | |
| eval_prompts = [ | |
| { | |
| "instruction": "What is the capital of France?", | |
| "input_text": None | |
| }, | |
| { | |
| "instruction": "High-pressure systems stop air from rising into the colder regions of the atmosphere where water can condense. What will most likely result if a high-pressure system remains in an area for a long period of time?\nA. fog\nB. rain\nC. drought\nD. tornado\nAnswer:", | |
| "input_text": None | |
| }, | |
| { | |
| "instruction": "Given the fact: predators eat prey\nQuestion: Predators eat\nA. lions\nB. humans\nC. bunnies\nD. grass\nAnswer:", | |
| "input_text": None | |
| } | |
| ] | |
| for prompt_idx, prompt_config in enumerate(eval_prompts, 1): | |
| print("\n" + "=" * 80) | |
| print(f"Prompt {prompt_idx}/3:") | |
| print(f" Instruction: {prompt_config['instruction']}") | |
| print(f" Input: {prompt_config['input_text']}") | |
| print("=" * 80) | |
| # Build prompt | |
| prompt = build_model_input( | |
| tokenizer, | |
| instruction=prompt_config['instruction'], | |
| input_text=prompt_config['input_text'], | |
| ) | |
| # Tokenize and move tensors to device | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| pad_token_id=getattr(tokenizer, "pad_token_id", None), | |
| eos_token_id=getattr(tokenizer, "eos_token_id", None), | |
| ) | |
| prompt_len = inputs["input_ids"].shape[-1] | |
| # outputs: tensor [batch, seq_len] | |
| generated = outputs[0] | |
| completion_ids = generated[prompt_len:] | |
| completion_text = tokenizer.decode(completion_ids, skip_special_tokens=True) | |
| full_text = tokenizer.decode(generated, skip_special_tokens=True) | |
| print("GENERATED RESPONSE:") | |
| print("-" * 80) | |
| print(completion_text) | |
| print("-" * 80) | |
| # Debugging info | |
| print(f"\n[debug] prompt_tokens={prompt_len}, new_tokens={int(completion_ids.numel())}") | |
| if completion_ids.numel() == 0: | |
| print("[debug] Model generated 0 new tokens (likely hit EOS immediately).") | |
| print("[debug] Full decoded text:\n" + full_text) | |
| elif completion_text.strip() == "": | |
| print("[debug] Model generated tokens, but they decode to empty/whitespace or special tokens.") | |
| print("[debug] completion token ids:", completion_ids.tolist()) | |
| if logger: | |
| logger.info(f"\n--- Prompt {prompt_idx}/3 ---") | |
| logger.info(f"Instruction: {prompt_config['instruction']}") | |
| logger.info(f"Input: {prompt_config['input_text']}") | |
| logger.info(f"Generated completion (len {int(completion_ids.numel())}): {completion_text}") | |
| print("=" * 80) | |
| if logger: | |
| logger.info("Evaluation of all 3 prompts complete.") | |
| except Exception as e: | |
| err = f"evaluate_prompt failed: {e}" | |
| print(err) | |
| if logger: | |
| logger.exception(err) | |
| # ==================== 1. MODEL MODIFICATION ==================== | |
| class ReXMoERouter(nn.Module): | |
| """Router logic (no parameters) for cross-layer expert reuse. | |
| Note: This module is intentionally parameterless so that trainable gate | |
| weights live directly under the MoE block as `block_sparse_moe.gate`, | |
| matching the original Phi-MoE naming. This prevents saving keys like | |
| `router.gate` and keeps checkpoint compatibility. | |
| """ | |
| def __init__(self, layer_idx, total_layers=32, num_experts_per_layer=16, | |
| reuse_scale=3, num_experts_per_tok=2, all_experts_dict=None, aux_loss_weight=0.02): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.total_layers = total_layers | |
| self.num_experts_per_layer = num_experts_per_layer | |
| self.reuse_scale = reuse_scale # R=3: layers [i-1, i, i+1] | |
| self.num_experts_per_tok = num_experts_per_tok | |
| # Store reference to all experts dict to get actual expert counts | |
| self.all_experts_dict = all_experts_dict | |
| # Max pool size = reuse_scale * num_experts_per_layer | |
| self.max_pool_size = reuse_scale * num_experts_per_layer | |
| # EMA tracking for expert utilization | |
| # Initialize with uniform probability (1/num_experts_per_tok?) | |
| # Or just zeros if we want to learn from scratch. | |
| # User says: "smoothed utilization... Ck = raw selection count" | |
| # Since we start with no history, let's init with zeros. | |
| self.register_buffer('ema_utilization', torch.zeros(self.max_pool_size)) | |
| self.register_buffer('mask_threshold', torch.tensor(-1.0)) # Default -1 means no masking | |
| self.aux_loss_weight = aux_loss_weight | |
| def get_candidate_layers(self, step, total_steps, psr_enabled=True, initial_R=2, met_warmup=None): | |
| """Progressive Scaling Routing: gradually expand reuse scale | |
| Args: | |
| step: current training step | |
| total_steps: total steps in training | |
| psr_enabled: whether PSR is enabled | |
| initial_R: initial reuse scale (default 2) | |
| met_warmup: if provided (float 0-1), PSR only runs during 0-met_warmup phase, | |
| then stays at max reuse_scale. If None, uses old schedule (0-50% of training). | |
| """ | |
| if not psr_enabled or step is None or total_steps is None: | |
| return self.reuse_scale | |
| if met_warmup is not None: | |
| # New behavior: PSR completes within the first phase (0 to met_warmup) | |
| # After met_warmup, stay at max R | |
| progress = min(step / (met_warmup * total_steps), 1.0) | |
| current_r = initial_R + int(progress * (self.reuse_scale - initial_R)) | |
| else: | |
| # Legacy behavior: Linear schedule R=2 → target_R over first 50% of training | |
| progress = min(step / (0.5 * total_steps), 1.0) | |
| current_r = initial_R + int(progress * (self.reuse_scale - initial_R)) | |
| return current_r | |
| def update_ema(self, selection_counts, alpha=0.9): | |
| """Update EMA tracking for expert utilization""" | |
| # selection_counts: tensor of shape [max_pool_size] | |
| with torch.no_grad(): | |
| self.ema_utilization = alpha * self.ema_utilization + (1 - alpha) * selection_counts | |
| def forward_with_logits(self, all_logits, hidden_states, step=None, total_steps=None, met_enabled=False, met_warmup=None, logger=None): | |
| """ | |
| Args: | |
| all_logits: [batch_size * seq_len, max_pool_size] precomputed logits from block's gate | |
| hidden_states: [batch_size, seq_len, hidden_dim] | |
| met_warmup: if provided (float 0-1), PSR runs only during this phase then stays at max R | |
| Returns: | |
| router_logits: [batch_size * seq_len, max_pool_size] | |
| aux_loss: scalar | |
| active_expert_mask: [max_pool_size] boolean mask | |
| layer_expert_mapping: list of (layer_idx, expert_idx) tuples | |
| """ | |
| batch_size, seq_len, hidden_dim = hidden_states.shape | |
| # all_logits already has shape [B*S, max_pool_size] | |
| # Get current reuse scale via PSR (pass met_warmup if available) | |
| current_r = self.get_candidate_layers(step, total_steps, met_warmup=met_warmup) | |
| # [CRITICAL FIX]: Ensure mapping aligns with STATIC physical gate size (self.max_pool_size) | |
| # The base mapping uses the FULL reuse_scale statically! | |
| base_half = (self.reuse_scale - 1) // 2 | |
| base_start = max(0, self.layer_idx - base_half) | |
| base_end = min(self.total_layers, base_start + self.reuse_scale) | |
| base_start = max(0, base_end - self.reuse_scale) | |
| # The PSR subset window defines which subset of the full mapping is CURRENTLY active | |
| psr_half = (current_r - 1) // 2 | |
| psr_start = max(0, self.layer_idx - psr_half) | |
| psr_end = min(self.total_layers, psr_start + current_r) | |
| psr_start = max(0, psr_end - current_r) | |
| # Create active_mask natively aligned to the self.gate output nodes | |
| num_active_experts = current_r * self.num_experts_per_layer | |
| active_mask = torch.zeros(self.max_pool_size, dtype=torch.bool, device=all_logits.device) | |
| # Create full layer-expert mapping for the expert selector. | |
| layer_expert_mapping = [] | |
| for layer_offset in range(self.reuse_scale): | |
| layer_id = base_start + layer_offset | |
| # Is this physical block currently enabled by PSR? | |
| is_active_psr_layer = (psr_start <= layer_id < psr_end) | |
| for expert_id in range(self.num_experts_per_layer): | |
| pool_idx = len(layer_expert_mapping) | |
| layer_expert_mapping.append((layer_id, expert_id)) | |
| # Activate in the mask if it falls within the PSR window AND is a valid layer | |
| if is_active_psr_layer and layer_id < self.total_layers: | |
| active_mask[pool_idx] = True | |
| # Mask out inactive experts by setting their logits to -inf | |
| masked_logits = all_logits.clone() | |
| masked_logits[:, ~active_mask] = float('-inf') | |
| # === DYNAMIC PRUNING MASK (OLD_TO_NEW) === | |
| # If the checkpoint is hard-pruned, the mapping omits pruned experts. | |
| # Natively mask logits here preventing explicit drops during top-k | |
| if hasattr(self, 'old_to_new') and self.old_to_new: | |
| for pool_idx, (orig_layer, orig_expert) in enumerate(layer_expert_mapping): | |
| if not active_mask[pool_idx]: | |
| continue | |
| orig_layer_int = int(orig_layer) # old_to_new is keyed by int layer | |
| if orig_layer_int in self.old_to_new: | |
| layer_map = self.old_to_new[orig_layer_int] | |
| if orig_expert not in layer_map and str(orig_expert) not in layer_map: | |
| masked_logits[:, pool_idx] = float('-inf') | |
| active_mask[pool_idx] = False | |
| # === IMPORTANCE-GUIDED MASKED EXPERT TRAINING (IG-MET) === | |
| # Apply mask based on global threshold if enabled | |
| # If router has a specifically pre-calculated pruning mask (from global analysis), use it. | |
| # Otherwise fall back to local thresholding (which may be inaccurate if aggregation is SUM). | |
| # Check for externally provided mask (from train_rexmoe global pass) | |
| global_keep_mask = getattr(self, 'global_keep_mask', None) | |
| if met_enabled: | |
| # Mode A: Precise Global Pruning (via mask pushed from training loop) | |
| if global_keep_mask is not None: | |
| # Ensure mask is on correct device | |
| if global_keep_mask.device != all_logits.device: | |
| global_keep_mask = global_keep_mask.to(all_logits.device) | |
| # Invert to get what we should prune (keep=False -> prune=True) | |
| # The global_keep_mask perfectly aligns with self.max_pool_size | |
| cur_len = min(len(global_keep_mask), self.max_pool_size) | |
| # We mask where global_keep_mask is FALSE, BUT only if it is currently active | |
| target_mask = torch.zeros_like(active_mask) | |
| target_mask[:cur_len] = ~global_keep_mask[:cur_len] | |
| target_mask = target_mask & active_mask | |
| # Safety: Don't prune everything | |
| if target_mask.all() or target_mask.sum() == active_mask.sum(): | |
| pass # Don't prune if it kills all remaining active experts | |
| else: | |
| masked_logits[:, target_mask] = float('-inf') | |
| active_mask[target_mask] = False | |
| # Mode B: Local Thresholding (Fallback / Original) | |
| elif self.mask_threshold.item() >= 0: | |
| # Mask experts with EMA utilization below threshold | |
| # Note: Only mask experts that are theoretically active (based on current_r) | |
| threshold = self.mask_threshold.to(all_logits.device) | |
| ema = self.ema_utilization.to(all_logits.device) | |
| under_utilized_mask = (ema < threshold) | |
| target_mask = under_utilized_mask & active_mask | |
| if target_mask.any(): | |
| num_active_before = active_mask.sum().item() | |
| num_to_mask = target_mask.sum().item() | |
| if num_to_mask < num_active_before: | |
| masked_logits[:, target_mask] = float('-inf') | |
| active_mask[target_mask] = False | |
| # Compute routing probabilities | |
| routing_weights = torch.softmax(masked_logits, dim=-1) # [B*S, max_pool_size] | |
| # Update EMA tracking (detached from graph) | |
| # Calculate C_k: raw selection count at step k | |
| # We use sum of routing weights as per user request: "= expert utilization counts (sum of routing weights per expert)" | |
| if self.training: | |
| current_counts = routing_weights.sum(dim=0).detach() # [max_pool_size] | |
| self.update_ema(current_counts) | |
| # Auxiliary load balancing loss (coefficient of variation) | |
| # Only compute over active experts | |
| # Ensure active_mask is on the same device as routing_weights | |
| active_mask = active_mask.to(routing_weights.device) | |
| # Safeguard: check that we have active experts | |
| num_true_active = active_mask.sum().item() | |
| if num_true_active == 0: | |
| # Fallback: mark at least the first expert as active | |
| active_mask[0] = True | |
| active_routing_weights = routing_weights[:, active_mask] | |
| expert_counts = active_routing_weights.sum(0) # [num_active_experts] | |
| # Safe CV calculation to prevent NaNs | |
| if expert_counts.numel() > 1: | |
| mean_count = expert_counts.mean() | |
| std_count = expert_counts.std() | |
| # Use larger epsilon and handle potential detached tensor | |
| cv_squared = (std_count / (mean_count + 1e-6)) ** 2 | |
| else: | |
| cv_squared = torch.tensor(0.0, device=active_routing_weights.device, dtype=active_routing_weights.dtype) | |
| # aux_loss = 0.01 * cv_squared # α=0.01 per ReXMoE | |
| aux_loss = self.aux_loss_weight * cv_squared # Higher weight on load balancing loss to encourage more even routing, especially important with PSR where early layers have fewer experts and are more likely to be overloaded. | |
| # Keep last mapping for introspection/debugging | |
| self.last_layer_expert_mapping = layer_expert_mapping | |
| return masked_logits, aux_loss, active_mask, layer_expert_mapping | |
| class ReXMoESparseMoeBlock(nn.Module): | |
| """ | |
| Modified PhiMoE Sparse MoE block with cross-layer expert reuse. | |
| Keeps original experts intact but routes to adjacent layer experts. | |
| """ | |
| def __init__(self, original_moe_block, layer_idx, total_layers, all_experts_dict, reuse_scale=3, logger=None, aux_loss_weight=0.02): | |
| super().__init__() | |
| self.hidden_dim = original_moe_block.hidden_dim | |
| self.num_experts = original_moe_block.num_experts | |
| self.top_k = original_moe_block.top_k | |
| self.layer_idx = layer_idx | |
| self.reuse_scale = reuse_scale | |
| # Keep reference to experts from all layers (DO NOT copy parameters) | |
| self.all_experts_dict = all_experts_dict # Dict: {layer_idx: ModuleList of experts} | |
| self.aux_loss_weight = aux_loss_weight | |
| # Replace router with ReXMoE router (parameterless) | |
| self.router = ReXMoERouter( | |
| layer_idx=layer_idx, | |
| total_layers=total_layers, | |
| num_experts_per_layer=self.num_experts, | |
| reuse_scale=reuse_scale, | |
| num_experts_per_tok=self.top_k, | |
| all_experts_dict=all_experts_dict, | |
| aux_loss_weight=self.aux_loss_weight | |
| ) | |
| # Install a gate on the block itself to match original naming | |
| self.gate = nn.Linear(self.hidden_dim, self.router.max_pool_size, bias=False) | |
| # [FIX] Initialize new gate with original router weights for BOTH local and neighbor sections | |
| with torch.no_grad(): | |
| orig_gate_shape = original_moe_block.gate.weight.data.shape[0] | |
| if orig_gate_shape == self.router.max_pool_size: | |
| # If loading from a checkpoint where gate is already expanded, copy it directly | |
| self.gate.weight.data.copy_(original_moe_block.gate.weight.data) | |
| print(f" Base block gate size already {orig_gate_shape}, copied fully for layer {layer_idx}") | |
| elif orig_gate_shape == self.num_experts: | |
| # Calculate where the local experts sit in the new expanded router | |
| half = (reuse_scale - 1) // 2 | |
| local_start_idx = half * self.num_experts | |
| local_end_idx = local_start_idx + self.num_experts | |
| # Standard init from base model: Copy local weights exactly | |
| self.gate.weight[local_start_idx:local_end_idx, :] = original_moe_block.gate.weight.data.clone() | |
| print(f" num_experts: {self.num_experts}, refilled router weights for layer {layer_idx} local section at indices {local_start_idx}:{local_end_idx}") | |
| # Crucial Fix for R > 2: Initialize neighbor sections with copied weights + noise instead of zero | |
| # If neighbor logits are exactly zero initially, it causes router collapse and massive loss spikes (>10) | |
| noise_scale = 0.1 * original_moe_block.gate.weight.data.std().item() | |
| # Fill all sections before the local section (previous layers) | |
| for section in range(half): | |
| start_idx = section * self.num_experts | |
| end_idx = start_idx + self.num_experts | |
| noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale | |
| self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise | |
| print(f" Initialized neighbor section {start_idx}:{end_idx} with noise") | |
| # Fill all sections after the local section (next layers) | |
| for section in range(half + 1, reuse_scale): | |
| start_idx = section * self.num_experts | |
| end_idx = start_idx + self.num_experts | |
| noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale | |
| self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise | |
| print(f" Initialized neighbor section {start_idx}:{end_idx} with noise") | |
| else: | |
| # Fallback for pruned or irregular sized expert checkpoints | |
| # Safe initialization: zero out, then carefully mapped copying | |
| self.gate.weight.zero_() | |
| print(f" Warning: Custom size mismatch for gate refill ({orig_gate_shape} vs {self.num_experts}). Filling with zeros.") | |
| # Try to map whatever we can based on minimum size | |
| min_experts = min(orig_gate_shape, self.router.max_pool_size) | |
| self.gate.weight[:min_experts, :] = original_moe_block.gate.weight.data[:min_experts, :].clone() | |
| # Canonical expert container: match base PhiMoE which uses `.experts` | |
| # so that checkpoints save/load expert weights under the same keys. | |
| self.experts = original_moe_block.experts | |
| # Store current training step for PSR | |
| # Default to None so inference uses full R (not PSR schedule) | |
| self.current_step = None | |
| self.total_steps = None | |
| self.met_warmup = None # Will be set during training if MET is enabled | |
| # Store aux_loss for backward pass | |
| self.last_aux_loss = None | |
| # Store actual routing selections for analysis | |
| self.last_selected_experts = None # Will store (target_layer, target_expert) tuples | |
| self.last_selection_counts = None # Count of tokens routed to each expert | |
| self.logger = logger # Store logger for potential use in forward pass | |
| def local_experts(self): | |
| # expose alias for any code that tries to access it | |
| return self.experts | |
| def map_pruned_expert(self, orig_layer: int, orig_expert: int, old_to_new: dict) -> Optional[int]: | |
| """ | |
| Map original (layer, expert) index to current kept expert index. | |
| Returns None if the expert was pruned. | |
| """ | |
| new_idx = None | |
| if orig_layer in old_to_new: | |
| layer_map = old_to_new[orig_layer] | |
| # Handle both int and str keys (robust to JSON serialization) | |
| new_idx = layer_map.get(orig_expert, layer_map.get(str(orig_expert), None)) | |
| else: | |
| # No pruning map: use original index if it still exists | |
| if orig_layer in self.all_experts_dict and orig_expert < len(self.all_experts_dict[orig_layer]): | |
| new_idx = orig_expert | |
| return new_idx | |
| def forward(self, hidden_states, logger=None): | |
| """ | |
| Args: | |
| hidden_states: [batch_size, seq_len, hidden_dim] | |
| Returns: | |
| output: [batch_size, seq_len, hidden_dim] | |
| """ | |
| # If a logger wasn't passed down through the model forward call, | |
| # fall back to any logger attribute attached to this block instance. | |
| if logger is None: | |
| logger = getattr(self, 'logger', None) | |
| batch_size, seq_len, hidden_dim = hidden_states.shape | |
| hidden_states_flat = hidden_states.view(-1, hidden_dim) # [B*S, H] | |
| # Ensure gate is on the same device as hidden_states (fixes device_map="auto" mismatch) | |
| device = hidden_states_flat.device | |
| self.gate = self.gate.to(device) | |
| # Get routing decisions | |
| # Compute gate logits at block level (so weights are saved as block_sparse_moe.gate) | |
| hidden_states_flat = hidden_states.view(-1, hidden_dim) | |
| all_logits = self.gate(hidden_states_flat) | |
| router_logits, aux_loss, active_mask, layer_expert_mapping = self.router.forward_with_logits( | |
| all_logits, hidden_states, self.current_step, self.total_steps, | |
| met_enabled=getattr(self, 'met_enabled', False), | |
| met_warmup=self.met_warmup, | |
| logger=logger | |
| ) | |
| num_pool = router_logits.shape[-1] | |
| BxS = hidden_states_flat.shape[0] | |
| # Store aux_loss for collection in training loop | |
| self.last_aux_loss = aux_loss | |
| # Get top-k indices and values in one operation | |
| topk_logits, topk_indices = torch.topk(router_logits, self.top_k, dim=-1) # [B*S, k] | |
| topk_weights = torch.softmax(topk_logits, dim=-1) # [B*S, k] | |
| # === VECTORIZED EXPERT EXECUTION === | |
| # Pre-allocate output | |
| final_hidden_states = torch.zeros_like(hidden_states_flat) | |
| selection_counts = {} | |
| old_to_new = getattr(self.router, 'old_to_new', {}) | |
| processed_mask = torch.zeros(BxS, dtype=torch.bool, device=hidden_states.device) | |
| # Process each expert position in the top-k (k=2 is small) | |
| for k_idx in range(self.top_k): | |
| # Get which expert each token selected at position k | |
| selected_positions = topk_indices[:, k_idx] # [B*S] | |
| # Gather logits for weighting | |
| k_weights = topk_weights[:, k_idx:k_idx+1] # [B*S, 1] | |
| # === BATCH EXPERT EXECUTION === | |
| # Group tokens by which expert they selected | |
| for pool_pos in range(self.router.max_pool_size): | |
| if pool_pos >= len(layer_expert_mapping): | |
| continue | |
| # HARD BLOCK FOR PRUNED / INACTIVE EXPERTS: | |
| # Completely bypass execution and prevent token leakage | |
| if not active_mask[pool_pos]: | |
| continue | |
| # Find all tokens that selected this expert | |
| token_mask = (selected_positions == pool_pos) # [B*S] | |
| if not token_mask.any(): | |
| continue | |
| # Get tokens for this expert | |
| selected_tokens = hidden_states_flat[token_mask] # [N, H] | |
| orig_layer, orig_expert = layer_expert_mapping[pool_pos] | |
| new_idx = self.map_pruned_expert(orig_layer, orig_expert, old_to_new) | |
| if new_idx is None: | |
| continue # Pruned expert | |
| expert_module = self.all_experts_dict[orig_layer][new_idx] | |
| # Move selected_tokens to expert's device instead of moving expert | |
| # This is more efficient when model is sharded across GPUs | |
| expert_device = next(expert_module.parameters()).device | |
| selected_tokens = selected_tokens.to(expert_device) | |
| expert_out = expert_module(selected_tokens) # [N, H] - BATCHED! | |
| # Move output back to original device | |
| expert_out = expert_out.to(device) | |
| weighted_out = expert_out * k_weights[token_mask] | |
| final_hidden_states[token_mask] += weighted_out | |
| # CRITICAL: Record the expert index for analysis | |
| # For pruned models: use orig_expert (original index before hard deletion) | |
| # For unpruned models: use new_idx (which equals orig_expert since no mapping) | |
| # The distinction is made at model load time via config.is_pruned or config.pruned | |
| is_pruned_model = getattr(self, 'is_pruned_model', False) or hasattr(self, 'old_to_new') and self.old_to_new | |
| reported_expert = orig_expert if is_pruned_model else new_idx | |
| key = (orig_layer, reported_expert) | |
| selection_counts[key] = selection_counts.get(key, 0) + token_mask.sum().item() | |
| processed_mask[token_mask] = True | |
| if not processed_mask.all(): | |
| final_hidden_states[~processed_mask] = hidden_states_flat[~processed_mask] | |
| # Store selections for analysis | |
| self.last_selection_counts = selection_counts | |
| # Reshape back | |
| final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) | |
| # Return tuple like original PhiMoE: (hidden_states, router_logits) | |
| return final_hidden_states, router_logits | |
| # ==================== 3. ROUTING ANALYSIS ==================== | |
| def analyze_routing_patterns(model, dataloader, current_r, total_layers, device, num_batches=10, logger=None): | |
| """ | |
| Analyze ACTUAL routing patterns by tracking which experts were selected. | |
| For each layer, tracks: | |
| - Which experts are ACTUALLY selected most frequently | |
| - Whether experts from adjacent layers are being used | |
| - Distribution of routing across layers | |
| """ | |
| model.eval() | |
| # Track ACTUAL routing decisions: routing_counts[layer_idx][(target_layer, target_expert)] = count | |
| routing_counts = {} | |
| for layer_idx in range(total_layers): | |
| routing_counts[layer_idx] = {} | |
| total_tokens = 0 | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(dataloader): | |
| if batch_idx >= num_batches: # Sample only first N batches for efficiency | |
| break | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| # Get batch size and sequence length | |
| batch_size, seq_len = input_ids.shape | |
| num_tokens = (attention_mask.sum()).item() # Count non-padding tokens | |
| total_tokens += num_tokens | |
| # Forward pass to get routing decisions | |
| _ = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) | |
| # Always work on the underlying transformer stack that actually owns `.layers`, | |
| # whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model. | |
| backend_model = get_backend_model(model) | |
| # Collect ACTUAL routing decisions from each layer | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| moe_block = layer.block_sparse_moe | |
| # Get actual selections from the forward pass | |
| if moe_block.last_selection_counts is not None: | |
| for (target_layer, target_expert), count in moe_block.last_selection_counts.items(): | |
| key = (target_layer, target_expert) | |
| routing_counts[layer_idx][key] = routing_counts[layer_idx].get(key, 0) + count | |
| model.train() | |
| # Print analysis | |
| msg = f"\nAnalyzing ACTUAL routing patterns from {num_batches} batches ({total_tokens:,} tokens)" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| msg = f"Current reuse scale: R={current_r}" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| # === IG-MET PRUNING ANALYTICS (GLOBAL SUM AGGREGATION) === | |
| # 1. Aggregate EMA for each unique expert across all routers (SUM) | |
| backend_model = get_backend_model(model) | |
| unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score | |
| unique_experts_total = set() | |
| threshold = None | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| router = layer.block_sparse_moe.router | |
| thr = router.mask_threshold.item() | |
| if thr >= 0: | |
| # Reconstruct mapping logic | |
| current_r = router.get_candidate_layers(step=None, total_steps=None) | |
| half = (current_r - 1) // 2 | |
| start_layer = max(0, layer_idx - half) | |
| end_layer = min(len(backend_model.layers), start_layer + current_r) | |
| start_layer = max(0, end_layer - current_r) | |
| current_mapping = [] | |
| for layer_offset in range(current_r): | |
| l_id = start_layer + layer_offset | |
| if l_id >= len(backend_model.layers): break | |
| for e_id in range(router.num_experts_per_layer): | |
| current_mapping.append((l_id, e_id)) | |
| num_active = len(current_mapping) | |
| for pool_pos, key in enumerate(current_mapping): | |
| if pool_pos >= len(router.ema_utilization): break | |
| unique_experts_total.add(key) | |
| ema_val = router.ema_utilization[pool_pos].item() | |
| if key not in unique_experts_sum: | |
| unique_experts_sum[key] = ema_val | |
| else: | |
| unique_experts_sum[key] += ema_val | |
| if threshold is None: | |
| threshold = thr | |
| # 2. Prune based on SUM aggregation and global threshold | |
| unique_experts_pruned = {k for k, v in unique_experts_sum.items() if threshold is not None and v < threshold} | |
| total_unique_pruned = len(unique_experts_pruned) | |
| total_unique = len(unique_experts_total) | |
| msg = "\n[IG-MET Pruning Report]:" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0 | |
| msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold if threshold is not None else -1:.6f}" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| if unique_experts_sum: | |
| global_ema_tensor = torch.tensor(list(unique_experts_sum.values()), device=device) | |
| msg = f"Aggregated EMA (sum across R layers): mean={global_ema_tensor.mean():.6f}, min={global_ema_tensor.min():.6f}, max={global_ema_tensor.max():.6f}" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| print() | |
| # Analyze cross-layer reuse statistics | |
| cross_layer_usage = { | |
| "same_layer": 0, | |
| "adjacent_prev": 0, | |
| "adjacent_next": 0, | |
| "distant": 0 | |
| } | |
| for layer_idx in routing_counts: | |
| for (target_layer, target_expert), count in routing_counts[layer_idx].items(): | |
| if target_layer == layer_idx: | |
| cross_layer_usage["same_layer"] += count | |
| elif target_layer == layer_idx - 1: | |
| cross_layer_usage["adjacent_prev"] += count | |
| elif target_layer == layer_idx + 1: | |
| cross_layer_usage["adjacent_next"] += count | |
| else: | |
| cross_layer_usage["distant"] += count | |
| total_routing = sum(cross_layer_usage.values()) | |
| if total_routing > 0: | |
| msg = "Cross-Layer Routing Distribution (ACTUAL selections):" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| for key, label in [ | |
| ("same_layer", "Same layer (i):"), | |
| ("adjacent_prev", "Previous layer (i-1):"), | |
| ("adjacent_next", "Next layer (i+1):"), | |
| ("distant", "Distant layers:") | |
| ]: | |
| if cross_layer_usage[key] > 0 or key != "distant": | |
| pct = 100 * cross_layer_usage[key] / total_routing | |
| msg = f" {label:25} {cross_layer_usage[key]:>10,} ({pct:>5.1f}%)" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| print() | |
| # Sample detailed analysis for a few layers | |
| sample_layers = [8, 16, 24] if total_layers >= 32 else [total_layers // 4, total_layers // 2, 3 * total_layers // 4] | |
| msg = "Sample Layer-Specific Routing Patterns:" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| for layer_idx in sample_layers: | |
| if layer_idx in routing_counts and routing_counts[layer_idx]: | |
| msg = f"\n Layer {layer_idx}:" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| # Get top 5 most used experts | |
| sorted_experts = sorted(routing_counts[layer_idx].items(), key=lambda x: x[1], reverse=True)[:5] | |
| for (target_layer, target_expert), count in sorted_experts: | |
| pct = 100 * count / total_tokens if total_tokens > 0 else 0 | |
| layer_relation = "same" if target_layer == layer_idx else f"L{target_layer}" | |
| msg = f" Expert {target_expert:>2} from layer {target_layer:>2} ({layer_relation:>4}): {count:>8,} times ({pct:>5.1f}%)" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| print() | |
| # Check if cross-layer reuse is happening | |
| cross_layer_pct = 100 * (cross_layer_usage['adjacent_prev'] + cross_layer_usage['adjacent_next'] + cross_layer_usage['distant']) / total_routing if total_routing > 0 else 0 | |
| if cross_layer_pct > 5: | |
| msg = f"✅ Cross-layer expert reuse detected: {cross_layer_pct:.1f}% of routing uses adjacent layers" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| elif current_r > 1: | |
| msg = f"⚠️ Limited cross-layer reuse: {cross_layer_pct:.1f}% (expected >5% with R={current_r})" | |
| print(msg) | |
| if logger: | |
| logger.warning(msg) | |
| msg = " This may improve as training progresses and routers adapt." | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| else: | |
| msg = f"ℹ️ R=1 mode: Only same-layer experts available (PSR warmup phase)" | |
| print(msg) | |
| if logger: | |
| logger.info(msg) | |
| # ==================== 3.5. CONVERGENCE MONITORING ==================== | |
| def compute_routing_entropy(router_logits): | |
| """ | |
| Compute entropy of routing distribution. | |
| High entropy = uniform routing (may indicate lack of specialization) | |
| Low entropy = concentrated routing (strong preferences) | |
| """ | |
| probs = torch.softmax(router_logits, dim=-1) | |
| entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean() | |
| return entropy.item() | |
| def check_router_convergence(model, total_layers, convergence_history, threshold=0.01, logger=None): | |
| """ | |
| Check if routers have converged by analyzing: | |
| 1. Router weight gradient norms (should be small) | |
| 2. Routing entropy stability (should be stable) | |
| 3. Expert preference consistency (should not fluctuate) | |
| Returns: | |
| converged: bool | |
| metrics: dict of convergence metrics | |
| warnings: list of warning messages | |
| """ | |
| router_grad_norms = [] | |
| # Always work on the underlying transformer stack that actually owns `.layers`, | |
| # whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model. | |
| backend_model = get_backend_model(model) | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| gate = getattr(layer.block_sparse_moe, 'gate', None) | |
| if gate is not None and gate.weight is not None and gate.weight.grad is not None: | |
| grad_norm = gate.weight.grad.norm().item() | |
| router_grad_norms.append(grad_norm) | |
| avg_grad_norm = sum(router_grad_norms) / len(router_grad_norms) if router_grad_norms else 0 | |
| metrics = { | |
| 'avg_router_grad_norm': avg_grad_norm, | |
| 'max_router_grad_norm': max(router_grad_norms) if router_grad_norms else 0, | |
| 'min_router_grad_norm': min(router_grad_norms) if router_grad_norms else 0, | |
| } | |
| warnings = [] | |
| # Check convergence: gradients should be small and stable | |
| convergence_history.append(avg_grad_norm) | |
| # Need at least 5 epochs of history | |
| if len(convergence_history) < 5: | |
| return False, metrics, warnings | |
| # Check if gradient norm is stable (variance < threshold) | |
| recent_grads = convergence_history[-5:] | |
| grad_variance = torch.tensor(recent_grads).var().item() | |
| grad_mean = torch.tensor(recent_grads).mean().item() | |
| metrics['grad_variance'] = grad_variance | |
| metrics['grad_stability'] = grad_variance / (grad_mean + 1e-10) | |
| # Detect oscillations (gradient norm going up and down) | |
| if len(convergence_history) >= 3: | |
| last_3 = convergence_history[-3:] | |
| # Check if middle value is either a peak or valley | |
| if (last_3[1] > last_3[0] and last_3[1] > last_3[2]) or \ | |
| (last_3[1] < last_3[0] and last_3[1] < last_3[2]): | |
| warnings.append("⚠️ Oscillation detected in gradient norms - consider reducing learning rate") | |
| if logger: | |
| logger.warning("Oscillation detected in gradient norms") | |
| # Check for increasing gradients (divergence) | |
| if len(convergence_history) >= 2: | |
| if convergence_history[-1] > convergence_history[-2] * 1.2: | |
| warnings.append("⚠️ Gradients increasing - possible divergence or high learning rate") | |
| if logger: | |
| logger.warning("Gradients increasing - possible divergence") | |
| # Check if gradients are too high | |
| if avg_grad_norm > 0.3: | |
| warnings.append(f"⚠️ High gradient norm ({avg_grad_norm:.4f}) - learning rate may be too high") | |
| if logger: | |
| logger.warning(f"High gradient norm: {avg_grad_norm:.4f}") | |
| # Converged if: gradients are small AND stable | |
| converged = (avg_grad_norm < 0.1) and (metrics['grad_stability'] < threshold) | |
| return converged, metrics, warnings | |
| # Helper: always get the underlying transformer stack (with `.layers`) | |
| def get_backend_model(m): | |
| # If PEFT-wrapped, unwrap to base_model; else keep as is | |
| core = getattr(m, "base_model", m) | |
| # For PhiMoEForCausalLM, the transformer is in `.model` | |
| causalLM = getattr(core, "model", core) | |
| return getattr(causalLM, "model", causalLM) | |
| # ==================== 4. TRAINING LOOP ==================== | |
| def train_rexmoe( | |
| model_name="microsoft/Phi-mini-MoE-instruct", | |
| model_path="../models/models/microsoft/Phi-mini-MoE-instruct", | |
| dataset_path="../dataset/alpaca_data_cleaned.json", | |
| dataset_mode: str = "IF", | |
| reuse_scale=3, | |
| num_samples=10000, | |
| num_epochs=5, | |
| batch_size=16, | |
| max_seq_length=512, | |
| lr=5e-6, | |
| warmup_steps=10, | |
| psr_enabled=True, | |
| save_path="./rexmoe_phi_mini_moe_r3", | |
| gradient_checkpointing=True, | |
| met_enabled=False, | |
| met_mask_ratio=0.1, | |
| met_warmup=0.5, | |
| eval_steps=1000, | |
| log_loss_steps_percent=10, | |
| full_lora=False, | |
| lora_r=16, | |
| use_scheduler=True, | |
| aux_loss_weight=0.02 | |
| ): | |
| # Setup logger | |
| logger, log_file = setup_logger(save_path=os.path.join(save_path, "logs")) | |
| print("="*80) | |
| print("ReXMoE Cross-Layer Expert Reuse Training") | |
| print("="*80) | |
| logger.info("="*80) | |
| logger.info("ReXMoE Cross-Layer Expert Reuse Training") | |
| logger.info("="*80) | |
| logger.info("MET enabled: {}".format(met_enabled)) | |
| config_msg = f""" | |
| Configuration: | |
| Model: {model_name} | |
| Dataset: {dataset_path} | |
| Dataset mode: {dataset_mode} | |
| Reuse Scale (R): {reuse_scale} | |
| Prune Ratio (MET): {met_mask_ratio if met_enabled else 'N/A'} | |
| Epochs: {num_epochs} | |
| Num of samples: {num_samples} | |
| Batch Size: {batch_size} | |
| Sequence Length: {max_seq_length} | |
| Learning Rate: {lr} | |
| PSR Enabled: {psr_enabled} | |
| LR Scheduler: {use_scheduler} | |
| Save Path: {save_path} | |
| Gradient Checkpointing: {gradient_checkpointing} | |
| LoRA Rank: {lora_r} (Full LoRA: {full_lora}) | |
| LoRA Alpha: {lora_r * 2} | |
| MET Enabled: {met_enabled} (Mask Ratio: {met_mask_ratio}, Warmup: {met_warmup}) | |
| Log File: {log_file} | |
| Aux loss weight: {aux_loss_weight} | |
| """ | |
| print(config_msg) | |
| logger.info(config_msg) | |
| print("="*80) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| device_msg = f"💻 Using device: {device})" | |
| print(f"\n{device_msg}") | |
| logger.info(device_msg) | |
| if torch.cuda.is_available(): | |
| gpu_msg = f"GPU: {torch.cuda.get_device_name(0)}, Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" | |
| print(f" {gpu_msg}") | |
| logger.info(gpu_msg) | |
| # Load tokenizer + model | |
| print(f"\n[1/7] Loading model: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(f"Loading model to device {device} (no device_map sharding)...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.bfloat16, | |
| device_map=None, # Do NOT auto-shard - we'll place it manually | |
| trust_remote_code=False | |
| ) | |
| print(f"Moving model to {device}...") | |
| model = model.to(device) | |
| print(f"✓ Model moved to {device}") | |
| # Verify model is on correct device | |
| model_device = next(model.parameters()).device | |
| print(f"✓ Model device verified: {model_device}") | |
| if gradient_checkpointing: | |
| model.gradient_checkpointing_enable() | |
| print(f"Model loaded: {model.config.num_hidden_layers} layers") | |
| print(f"Hidden size: {model.config.hidden_size}") | |
| print(f"Experts per layer: {model.config.num_local_experts}") | |
| # Collect all experts from all layers | |
| print(f"\n[2/7] Collecting expert references from all layers...") | |
| total_layers = model.config.num_hidden_layers | |
| all_experts_dict = {} | |
| for layer_idx, layer in enumerate(model.model.layers): | |
| if hasattr(layer, "block_sparse_moe"): | |
| all_experts_dict[layer_idx] = layer.block_sparse_moe.experts | |
| print(f"Collected {len(all_experts_dict)} MoE layers") | |
| # Replace MoE blocks with ReXMoE blocks | |
| print(f"\n[3/7] Replacing MoE blocks with ReXMoE routers (R={reuse_scale})...") | |
| moe_count = 0 | |
| for layer_idx, layer in enumerate(model.model.layers): | |
| if hasattr(layer, "block_sparse_moe"): | |
| original_moe = layer.block_sparse_moe | |
| # Create ReXMoE block (keeps expert references, replaces router) | |
| rexmoe_block = ReXMoESparseMoeBlock( | |
| original_moe_block=original_moe, | |
| layer_idx=layer_idx, | |
| total_layers=total_layers, | |
| all_experts_dict=all_experts_dict, | |
| reuse_scale=reuse_scale, | |
| logger=logger, | |
| aux_loss_weight=aux_loss_weight | |
| ) | |
| rexmoe_block.met_enabled = met_enabled | |
| # Attach logger to block so its forward can access logging even when | |
| # the higher-level `model()` call doesn't pass a logger argument. | |
| rexmoe_block.logger = logger | |
| # Move ReXMoE block to correct device | |
| rexmoe_block = rexmoe_block.to(dtype=torch.bfloat16, device=device) | |
| # Replace the block | |
| layer.block_sparse_moe = rexmoe_block | |
| moe_count += 1 | |
| print(f"✓ ReXMoE blocks installed: {moe_count} layers modified") | |
| print(f" Each router can now access up to {reuse_scale * 16} experts (R={reuse_scale})") | |
| # Warmup phase: only routers (gates) trainable | |
| print(f"\n[4/7] Initial freeze: only routers trainable for warmup phase...") | |
| total_params = 0 | |
| trainable_params = 0 | |
| for name, param in model.named_parameters(): | |
| total_params += param.numel() | |
| if ".block_sparse_moe.gate" in name: | |
| # Router gates trainable | |
| param.requires_grad = True | |
| trainable_params += param.numel() | |
| else: | |
| param.requires_grad = False | |
| print(f"Total parameters: {total_params:,}") | |
| print(f"Trainable parameters (warmup): {trainable_params:,} ({100*trainable_params/total_params:.2f}%)") | |
| # Verify only routers are trainable at start | |
| trainable_layers = [name for name, param in model.named_parameters() if param.requires_grad] | |
| print(f"✓ Warmup trainable components: {len(trainable_layers)} parameters (router gates only)") | |
| if len(trainable_layers) > 0: | |
| print(f" First: {trainable_layers[0]}") | |
| print(f" Last: {trainable_layers[-1]}") | |
| # Optimizer (router params only) - Use 8-bit AdamW for memory efficiency | |
| print(f"\n[5/7] Setting up optimizer and dataset...") | |
| logger.info("[5/7] Setting up optimizer and dataset...") | |
| print("Using 8-bit AdamW optimizer for memory efficiency") | |
| logger.info("Using 8-bit AdamW optimizer") | |
| optimizer = bnb.optim.AdamW8bit( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=lr, | |
| weight_decay=0.1 | |
| ) | |
| # Learning rate scheduler (cosine annealing) | |
| scheduler = None | |
| if use_scheduler: | |
| scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr * 0.1) | |
| print(f"Using CosineAnnealingLR scheduler: {lr} → {lr * 0.1}") | |
| logger.info(f"LR Scheduler: CosineAnnealingLR ({lr} → {lr * 0.1})") | |
| # Prepare dataset (Instruction Fine-tuning or Pretraining) | |
| print(f"[5/7] Preparing dataset: mode={dataset_mode}, path={dataset_path}") | |
| try: | |
| train_loader = get_dataloader( | |
| mode=dataset_mode, | |
| tokenizer=tokenizer, | |
| dataset_path=dataset_path, | |
| max_seq_length=max_seq_length, | |
| batch_size=batch_size, | |
| num_samples=num_samples, | |
| shuffle=True, | |
| ) | |
| except Exception as e: | |
| print(f"Could not prepare dataset: {e}") | |
| raise | |
| train_len = num_samples | |
| print(f"Training samples: {train_len}") | |
| print(f"Batch size: {batch_size}") | |
| print(f"Sequence length: {max_seq_length}") | |
| # Training loop | |
| print(f"\n[6/7] Starting training for {num_epochs} epochs...") | |
| print(f"PSR enabled: {psr_enabled}") | |
| total_steps = len(train_loader) * num_epochs | |
| # PSR schedule changes based on whether MET is enabled | |
| if psr_enabled: | |
| if met_enabled: | |
| warmup_steps_psr = int(met_warmup * total_steps) | |
| print(f"PSR schedule: R=2 → R={reuse_scale} during MET warmup phase (steps 0-{warmup_steps_psr})") | |
| print(f" then stays at R={reuse_scale} during pruning/finetuning phases") | |
| else: | |
| psr_completion_steps = int(0.5 * total_steps) | |
| print(f"PSR schedule: R=2 → R={reuse_scale} over first 50% of training (steps 0-{psr_completion_steps})") | |
| step = 0 | |
| # Track statistics | |
| first_batch_logged = False | |
| # Track routing patterns for analysis | |
| # Structure: routing_stats[layer_idx][(target_layer, target_expert)] = count | |
| routing_stats = {} | |
| for layer_idx in range(total_layers): | |
| routing_stats[layer_idx] = {} | |
| # Convergence tracking | |
| convergence_history = [] | |
| epoch_entropies = [] | |
| epoch_aux_losses = [] | |
| model.train() | |
| best_val = float("inf") | |
| best_epoch = -1 | |
| qlora_enabled = False # track switch from warmup (routers-only) to routers+LoRA | |
| print_met_active = False | |
| print_met_freeze = False | |
| for epoch in range(num_epochs): | |
| print(f"\n{'='*60}") | |
| print(f"Epoch {epoch+1}/{num_epochs}") | |
| print(f"{'='*60}") | |
| epoch_loss = 0 | |
| epoch_aux_loss = 0 | |
| epoch_entropy = 0 # Track routing entropy | |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}") | |
| for batch_idx, batch in enumerate(pbar): | |
| # Switch from warmup (routers only) to routers + LoRA adapters on experts | |
| if (not qlora_enabled) and step >= warmup_steps: | |
| if full_lora: | |
| logger.info(f"Warmup completed at step {step}. Enabling FULL QLoRA with r = {lora_r} and alpha = {lora_r * 2} on experts and updating optimizer...") | |
| else: | |
| logger.info(f"Warmup completed at step {step}. Enabling QLoRA on experts.") | |
| # Attach LoRA adapters globally (wrap linear layers) | |
| # Note: PhiMoE experts use w1/w2/w3 linear layers, not gate_proj/up_proj/down_proj. | |
| # We target those names so LoRA attaches only to expert MLP weights. | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_r * 2, | |
| lora_dropout=0.00, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=[ | |
| "w1", "w2", "w3", | |
| # if full_lora, also target attention layers of transformer blocks (but NOT router gates) | |
| "q_proj" if full_lora else None, | |
| "k_proj" if full_lora else None, | |
| "v_proj" if full_lora else None, | |
| "o_proj" if full_lora else None | |
| ], | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| # Freeze everything, then re-enable router gates and LoRA params | |
| total_params = 0 | |
| trainable_params = 0 | |
| for name, param in model.named_parameters(): | |
| total_params += param.numel() | |
| param.requires_grad = False | |
| for name, param in model.named_parameters(): | |
| if ".block_sparse_moe.gate" in name: | |
| param.requires_grad = True | |
| trainable_params += param.numel() | |
| elif "lora_" in name: | |
| param.requires_grad = True | |
| trainable_params += param.numel() | |
| optimizer = bnb.optim.AdamW8bit( | |
| [p for p in model.parameters() if p.requires_grad], | |
| lr=lr, | |
| weight_decay=0.1 | |
| ) | |
| print(f"Total parameters after QLoRA: {total_params:,}") | |
| print(f"Trainable parameters (routers + LoRA): {trainable_params:,} ({100*trainable_params/total_params:.4f}%)") | |
| logger.info(f"Trainable params (routers + LoRA): {trainable_params} ({100*trainable_params/total_params:.4f}%)") | |
| trainable_names = [n for n, p in model.named_parameters() if p.requires_grad] | |
| print("Sample trainable params after QLoRA:", trainable_names[:10]) | |
| logger.info(f"Sample trainable params after QLoRA: {trainable_names[:10]}") | |
| qlora_enabled = True | |
| # Update step counter in all MoE blocks (for PSR) | |
| # Note: after enabling QLoRA, `model` becomes a PeftModel whose | |
| # underlying transformer is in `model.base_model.model`. | |
| backend_model = get_backend_model(model) | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| layer.block_sparse_moe.current_step = step | |
| layer.block_sparse_moe.total_steps = total_steps | |
| # Pass met_warmup to PSR scheduler so it completes within warmup phase | |
| layer.block_sparse_moe.met_warmup = met_warmup if met_enabled else None | |
| # === IG-MET Global Threshold Update === | |
| # Key insight: Count UNIQUE experts (512 base), not router-level copies (up to 1024 with reuse). | |
| # If an expert is pruned, it's pruned everywhere it can be accessed. | |
| if met_enabled: | |
| # Calculate target mask ratio for this step with improved schedule | |
| final_mask_ratio = met_mask_ratio | |
| progress = min(step / total_steps, 1.0) | |
| # Improved Three-Phase Schedule (Aggressive): | |
| # Phase 1: 0-met_warmup - NO pruning (extended warmup for stability) | |
| # Phase 2: met_warmup-0.8 - Gradual pruning ramp with curve (avoid abrupt changes) | |
| # Phase 3: 0.8-100% - Freeze pruning, only fine-tune remaining experts | |
| phase2_end = 0.8 | |
| if progress < met_warmup: | |
| # Phase 1: Extended warmup, no pruning | |
| current_target_ratio = 0.0 | |
| elif progress < phase2_end: | |
| # Phase 2: Controlled pruning ramp with exponential curve | |
| pruning_window = phase2_end - met_warmup | |
| pruning_progress = (progress - met_warmup) / pruning_window # 0 to 1 | |
| # Use power curve to avoid aggressive early pruning | |
| # Power = 1.2 for smoother ramp since pruning window is longer | |
| current_target_ratio = final_mask_ratio * (pruning_progress ** 1.2) | |
| if not print_met_active: | |
| logger.info(f"[IG-MET] Entered pruning phase with gradual ramp (step={step}, target_ratio={current_target_ratio:.3f})") | |
| print_met_active = True | |
| else: | |
| # Phase 3: Freeze pruning decisions, only fine-tune remaining experts | |
| current_target_ratio = final_mask_ratio | |
| # Optionally freeze pruning masks here in the future | |
| if not hasattr(model, '_pruning_frozen'): | |
| model._pruning_frozen = True | |
| if not print_met_freeze: | |
| logger.info("[IG-MET] Entered fine-tuning phase (pruning decisions frozen)") | |
| print_met_freeze = True | |
| if current_target_ratio > 0: | |
| if step % 100 == 0: | |
| logger.info(f"[IG-MET] Masked Expert Training is now ACTIVE (step={step}, target_ratio={current_target_ratio:.3f})") | |
| # Collect EMA for UNIQUE (layer_idx, expert_idx) pairs only (not duplicates) | |
| # Use "SUM" aggregation: we care about the total utility of an expert across all contexts. | |
| unique_experts = {} # (orig_layer, orig_expert) -> summed_ema_score | |
| # First Pass: Aggregation | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| router = layer.block_sparse_moe.router | |
| # Reuse same logic as Router to map pool_pos -> (orig_layer, orig_expert) | |
| current_r = router.get_candidate_layers(step, total_steps) | |
| half = (current_r - 1) // 2 | |
| start_layer = max(0, layer_idx - half) | |
| end_layer = min(total_layers, start_layer + current_r) | |
| start_layer = max(0, end_layer - current_r) | |
| # Reconstruct mapping for this router step | |
| current_mapping = [] | |
| for layer_offset in range(current_r): | |
| l_id = start_layer + layer_offset | |
| if l_id >= total_layers: break | |
| for e_id in range(router.num_experts_per_layer): | |
| current_mapping.append((l_id, e_id)) | |
| num_active = len(current_mapping) | |
| # Aggregate EMA | |
| for pool_pos in range(num_active): | |
| if pool_pos >= len(router.ema_utilization): break | |
| key = current_mapping[pool_pos] # (orig_layer, orig_expert) | |
| ema_val = router.ema_utilization[pool_pos].item() | |
| if key not in unique_experts: | |
| unique_experts[key] = ema_val | |
| else: | |
| # Sum EMA across all reused contexts | |
| unique_experts[key] += ema_val | |
| # Compute threshold based on UNIQUE SUMMED experts | |
| if unique_experts: | |
| all_ema_values = list(unique_experts.values()) | |
| all_ema_tensor = torch.tensor(all_ema_values, device=device) | |
| k = int(len(all_ema_values) * current_target_ratio) | |
| # Determine set of GLOBALLY pruned experts | |
| pruned_keys = set() | |
| threshold = 0.0 | |
| if k > 0 and all_ema_tensor.sum() > 0: | |
| sorted_vals, _ = torch.sort(all_ema_tensor) | |
| threshold = sorted_vals[k].item() | |
| # Identify which UNIQUE experts are below the global sum threshold | |
| pruned_keys = {key for key, val in unique_experts.items() if val < threshold} | |
| # Second Pass: Distribute Pruning Mask to Routers | |
| # Instead of a scalar threshold (which fails for summed aggregation), | |
| # we push a binary mask of "keep vs prune" to each router. | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| router = layer.block_sparse_moe.router | |
| # Re-calculate mapping to generate mask | |
| current_r = router.get_candidate_layers(step, total_steps) | |
| half = (current_r - 1) // 2 | |
| start_layer = max(0, layer_idx - half) | |
| end_layer = min(total_layers, start_layer + current_r) | |
| start_layer = max(0, end_layer - current_r) | |
| current_mapping = [] | |
| for layer_offset in range(current_r): | |
| l_id = start_layer + layer_offset | |
| if l_id >= total_layers: break | |
| for e_id in range(router.num_experts_per_layer): | |
| current_mapping.append((l_id, e_id)) | |
| # Create binary mask: True = KEEP, False = PRUNE | |
| # Size = max_pool_size (pad with True to be safe) | |
| keep_mask = torch.ones(router.max_pool_size, dtype=torch.bool, device=device) | |
| for pool_pos, key in enumerate(current_mapping): | |
| if key in pruned_keys: | |
| keep_mask[pool_pos] = False | |
| # Push mask to router | |
| router.global_keep_mask = keep_mask | |
| # We also update mask_threshold for logging purposes | |
| router.mask_threshold.fill_(threshold) | |
| # Log statistics | |
| if step % 10 == 0: | |
| total_unique_pruned = len(pruned_keys) | |
| total_unique_active = len(unique_experts) | |
| logger.info(f"[IG-MET Global] Step {step}: Threshold={threshold:.6f}. Pruned {total_unique_pruned}/{total_unique_active} UNIQUE experts ({100*total_unique_pruned/total_unique_active:.1f}%). Target ratio: {current_target_ratio:.3f}") | |
| else: | |
| # Current ratio too small to mask any expert yet | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) | |
| if hasattr(layer.block_sparse_moe.router, 'global_keep_mask'): | |
| layer.block_sparse_moe.router.global_keep_mask = None | |
| else: | |
| # No experts found (shouldn't happen) | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) | |
| else: | |
| # Reset threshold (no masking) during 0-50% warmup phase | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) | |
| input_ids = batch["input_ids"].to(model.device) | |
| attention_mask = batch["attention_mask"].to(model.device) | |
| labels = batch["labels"].to(model.device) | |
| # Forward pass with PSR-aware routing | |
| # todo: Why not use | |
| ''' | |
| outputs = model( | |
| instructions=instructions, | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels | |
| ) | |
| ''' | |
| outputs = model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| use_cache=False | |
| ) | |
| loss = outputs.loss | |
| # Collect auxiliary losses from all ReXMoE routers | |
| # This is CRITICAL for load balancing and preventing routing collapse | |
| aux_loss_total = 0.0 | |
| for layer in backend_model.layers: | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| if layer.block_sparse_moe.last_aux_loss is not None: | |
| aux_loss_total += layer.block_sparse_moe.last_aux_loss | |
| # Collect routing statistics (which experts were selected) | |
| for layer_idx, layer in enumerate(backend_model.layers): | |
| if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): | |
| moe_block = layer.block_sparse_moe | |
| # Get the layer-expert mapping from the last forward pass | |
| router = moe_block.router | |
| # Don't call router with hidden_states=None (router.forward expects a tensor). | |
| # Prefer the last computed mapping from a real forward pass; if not | |
| # available, query the router with a small dummy tensor to get the mapping. | |
| layer_expert_mapping = getattr(router, 'last_layer_expert_mapping', None) | |
| if layer_expert_mapping is None: | |
| try: | |
| hidden_dim = router.gate.in_features if hasattr(router, 'gate') else moe_block.hidden_dim | |
| dummy = torch.zeros(1, 1, hidden_dim, device=model.device) | |
| _, _, _, layer_expert_mapping = router( | |
| hidden_states=dummy, | |
| step=step, | |
| total_steps=total_steps | |
| ) | |
| except Exception: | |
| layer_expert_mapping = [] | |
| # Note: For efficiency, we'll track routing patterns every N batches | |
| # to avoid slowdown. Full tracking can be enabled for analysis. | |
| pass # Detailed tracking will be done in a separate analysis pass | |
| # Compute routing entropy for convergence monitoring (sample from output) | |
| # We'll approximate entropy from the auxiliary loss and routing distribution | |
| # Note: Full entropy computation would require storing all router outputs | |
| # Total loss = Language modeling loss + Auxiliary load balancing loss | |
| total_loss = loss + aux_loss_total | |
| # Backward pass | |
| optimizer.zero_grad() | |
| total_loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| # Logging | |
| epoch_loss += loss.item() | |
| epoch_aux_loss += aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total | |
| # Calculate current reuse scale for logging (match the PSR logic in router) | |
| if psr_enabled: | |
| if met_enabled: | |
| # New behavior: PSR completes within first phase (0 to met_warmup) | |
| progress = min(step / (met_warmup * total_steps), 1.0) | |
| current_r = 2 + int(progress * (reuse_scale - 2)) | |
| else: | |
| # Legacy behavior: Linear schedule over first 50% of training | |
| progress = min(step / (0.5 * total_steps), 1.0) | |
| current_r = 2 + int(progress * (reuse_scale - 2)) | |
| # print(f"current_r: {current_r}, progress: {progress}") | |
| else: | |
| current_r = reuse_scale | |
| # Log first batch details | |
| if not first_batch_logged: | |
| logger.info(f"\n First batch statistics:") | |
| logger.info(f" LM Loss: {loss.item():.4f}") | |
| logger.info(f" Aux Loss: {aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}") | |
| logger.info(f" Total Loss: {total_loss.item():.4f}") | |
| logger.info(f" Current R: {current_r}") | |
| logger.info(f" Active experts per layer: {current_r * 16}") | |
| logger.info(f" Gradient norm: {torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')):.4f}") | |
| first_batch_logged = True | |
| logger.info(f" \n") | |
| # Print periodic updates (every log_loss_steps_percent% of epoch) | |
| if batch_idx > 0 and batch_idx % max(1, len(train_loader) // (100 // log_loss_steps_percent)) == 0: | |
| logger.info(f" [{batch_idx}/{len(train_loader)}] loss={loss.item():.4f} aux={aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f} R={current_r}") | |
| pbar.set_postfix({ | |
| "loss": f"{loss.item():.4f}", | |
| "aux": f"{aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}", | |
| "total": f"{total_loss.item():.4f}", | |
| "R": current_r, | |
| "step": f"{step}/{total_steps}" | |
| }) | |
| step += 1 | |
| # Evaluate every eval_steps (after batching/optimization) | |
| if step % eval_steps == 0: | |
| model.eval() | |
| logger.info(f"\n[Step {step}/{total_steps}] Running evaluation at eval_steps...") | |
| evaluate_prompt(model, tokenizer, logger=logger) | |
| model.train() | |
| # Routing analysis at eval_steps | |
| logger.info(f"\n[Step {step}] Analyzing routing patterns at eval_steps...") | |
| analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger) | |
| # Save a checkpoint at this step | |
| logger.info(f"\n[Step {step}] Saving checkpoint at eval_steps to {save_path}...") | |
| os.makedirs(save_path, exist_ok=True) | |
| # Option 1: Save only router weights (recommended - more portable) | |
| print("\nSaving trained router weights only...") | |
| router_state_dict = {} | |
| for name, param in model.named_parameters(): | |
| if ".block_sparse_moe.gate" in name and param.requires_grad: | |
| router_state_dict[name] = param.data.cpu() | |
| # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation | |
| for name, buf in model.named_buffers(): | |
| if "ema_utilization" in name or "mask_threshold" in name: | |
| logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation") | |
| router_state_dict[name] = buf.data.cpu() | |
| else: | |
| logger.info(f"Skipping buffer {name} (not related to routing)") | |
| torch.save({ | |
| 'router_state_dict': router_state_dict, | |
| 'config': { | |
| 'reuse_scale': reuse_scale, | |
| 'num_epochs': num_epochs, | |
| 'lr': lr, | |
| 'model_name': model_name | |
| } | |
| }, os.path.join(save_path, 'rexmoe_routers.pt')) | |
| tokenizer.save_pretrained(save_path) | |
| logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") | |
| logger.info(f" File: {save_path}/rexmoe_routers.pt") | |
| logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") | |
| # Option 2: Also save full model (includes architecture, but less portable) | |
| logger.info("\nAlso saving full model with ReXMoE architecture...") | |
| model.save_pretrained(save_path) | |
| # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled) | |
| # This produces a standard HF model directory that can be loaded without PEFT. | |
| if qlora_enabled: | |
| try: | |
| merged_path = os.path.join(save_path, "merged") | |
| os.makedirs(merged_path, exist_ok=True) | |
| logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") | |
| # merge_and_unload() returns the base model with LoRA weights folded in. | |
| merged_model = model.merge_and_unload() | |
| merged_model.eval() | |
| tokenizer.save_pretrained(merged_path) | |
| merged_model.save_pretrained(merged_path) | |
| logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") | |
| except Exception as e: | |
| logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") | |
| avg_epoch_loss = epoch_loss / len(train_loader) | |
| avg_epoch_aux_loss = epoch_aux_loss / len(train_loader) | |
| # Store metrics for convergence tracking | |
| epoch_aux_losses.append(avg_epoch_aux_loss) | |
| logger.info(f"\n{'='*60}") | |
| logger.info(f"Epoch {epoch+1} Summary:") | |
| logger.info(f" Average LM Loss: {avg_epoch_loss:.4f}") | |
| logger.info(f" Average Aux Loss: {avg_epoch_aux_loss:.6f}") | |
| logger.info(f" Average Total Loss: {avg_epoch_loss + avg_epoch_aux_loss:.4f}") | |
| logger.info(f" Final R: {current_r}") | |
| # Evaluate at epoch end | |
| model.eval() | |
| evaluate_prompt(model, tokenizer, logger=logger) | |
| model.train() | |
| # Track and save best checkpoint based on average LM loss | |
| if avg_epoch_loss < best_val: | |
| best_val = avg_epoch_loss | |
| best_epoch = epoch + 1 | |
| logger.info(f"New best epoch {best_epoch} with avg LM loss {best_val:.4f} — saving checkpoint to {save_path}") | |
| os.makedirs(save_path, exist_ok=True) | |
| # Option 1: Save only router weights (recommended - more portable) | |
| print("\nSaving trained router weights only...") | |
| router_state_dict = {} | |
| for name, param in model.named_parameters(): | |
| if ".block_sparse_moe.gate" in name and param.requires_grad: | |
| router_state_dict[name] = param.data.cpu() | |
| # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation | |
| for name, buf in model.named_buffers(): | |
| if "ema_utilization" in name or "mask_threshold" in name: | |
| logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation") | |
| router_state_dict[name] = buf.data.cpu() | |
| else: | |
| logger.info(f"Skipping buffer {name} (not related to routing)") | |
| torch.save({ | |
| 'router_state_dict': router_state_dict, | |
| 'config': { | |
| 'reuse_scale': reuse_scale, | |
| 'num_epochs': num_epochs, | |
| 'lr': lr, | |
| 'model_name': model_name | |
| } | |
| }, os.path.join(save_path, 'rexmoe_routers.pt')) | |
| tokenizer.save_pretrained(save_path) | |
| logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") | |
| logger.info(f" File: {save_path}/rexmoe_routers.pt") | |
| logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") | |
| # Option 2: Also save full model (includes architecture, but less portable) | |
| logger.info("\nAlso saving full model with ReXMoE architecture...") | |
| model.save_pretrained(save_path) | |
| # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled) | |
| # This produces a standard HF model directory that can be loaded without PEFT. | |
| if qlora_enabled: | |
| try: | |
| merged_path = os.path.join(save_path, "merged") | |
| os.makedirs(merged_path, exist_ok=True) | |
| logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") | |
| # merge_and_unload() returns the base model with LoRA weights folded in. | |
| merged_model = model.merge_and_unload() | |
| merged_model.eval() | |
| tokenizer.save_pretrained(merged_path) | |
| merged_model.save_pretrained(merged_path) | |
| logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") | |
| except Exception as e: | |
| logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") | |
| # Check convergence | |
| converged, conv_metrics, conv_warnings = check_router_convergence( | |
| model, total_layers, convergence_history, logger=logger | |
| ) | |
| # Get current learning rate | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| logger.info(f"\n 📊 Convergence Metrics:") | |
| logger.info("Convergence Metrics:") | |
| logger.info(f" Avg Router Grad Norm: {conv_metrics['avg_router_grad_norm']:.6f}") | |
| if 'grad_stability' in conv_metrics: | |
| print(f" Grad Stability: {conv_metrics['grad_stability']:.6f}") | |
| logger.info(f" Grad Stability: {conv_metrics['grad_stability']:.6f}") | |
| print(f" Current Learning Rate: {current_lr:.2e}") | |
| logger.info(f" Current Learning Rate: {current_lr:.2e}") | |
| if len(epoch_aux_losses) >= 2: | |
| aux_change = abs(epoch_aux_losses[-1] - epoch_aux_losses[-2]) | |
| print(f" Aux Loss Change: {aux_change:.6f}") | |
| logger.info(f" Aux Loss Change: {aux_change:.6f}") | |
| # Print warnings | |
| if conv_warnings: | |
| for warning in conv_warnings: | |
| print(f"\n {warning}") | |
| # Already logged in check_router_convergence | |
| if converged : | |
| msg = "✅ CONVERGED: Router weights have stabilized! Gradient norm < 0.1 and stable for 5 epochs" | |
| print(f"\n {msg}") | |
| logger.info(msg) | |
| if not getattr(mode, '_routers_soft_fronze', False): | |
| logger.info("Soft freezing routers for remaining epochs to focus on fine-tuning experts") | |
| for param_group in optimizer.param_groups: | |
| for name, param in model.named_parameters() and param in param_group['params']: | |
| if ".block_sparse_moe.gate" in name: | |
| param_group['weight_decay'] = 0.5 | |
| param_group['lr'] = current_lr * 0.1 | |
| setattr(mode, '_routers_soft_fronze', True) | |
| model._routers_soft_fronze = True | |
| elif len(convergence_history) >= 5: | |
| msg = "⏳ Not yet converged - continuing training..." | |
| print(f"\n {msg}") | |
| logger.info(msg) | |
| else: | |
| msg = "ℹ️ Collecting convergence data (need 5 epochs minimum)..." | |
| print(f"\n {msg}") | |
| logger.info(msg) | |
| print(f"{'='*60}") | |
| # Analyze routing patterns at epoch end | |
| print(f"\n📊 Routing Pattern Analysis (Epoch {epoch+1}):") | |
| logger.info(f"Routing Pattern Analysis (Epoch {epoch+1}):") | |
| print("-" * 60) | |
| analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger) | |
| print("-" * 60) | |
| # Step the learning rate scheduler | |
| if scheduler is not None: | |
| scheduler.step() | |
| logger.info(f"LR stepped to: {optimizer.param_groups[0]['lr']:.2e}") | |
| # Final convergence report | |
| print(f"\n{'='*80}") | |
| print("📈 Training Convergence Summary") | |
| print(f"{'='*80}") | |
| logger.info("="*80) | |
| logger.info("Training Convergence Summary") | |
| logger.info("="*80) | |
| if len(convergence_history) > 0: | |
| print(f"\nRouter Gradient Norms Over Epochs:") | |
| logger.info("Router Gradient Norms Over Epochs:") | |
| for i, grad_norm in enumerate(convergence_history): | |
| trend = "" | |
| if i > 0: | |
| change = grad_norm - convergence_history[i-1] | |
| trend = f" (Δ {change:+.6f})" | |
| msg = f" Epoch {i+1}: {grad_norm:.6f}{trend}" | |
| print(msg) | |
| logger.info(msg) | |
| if len(epoch_aux_losses) > 0: | |
| print(f"\nAuxiliary Loss Over Epochs:") | |
| logger.info("Auxiliary Loss Over Epochs:") | |
| for i, aux_loss in enumerate(epoch_aux_losses): | |
| trend = "" | |
| if i > 0: | |
| change = aux_loss - epoch_aux_losses[i-1] | |
| trend = f" (Δ {change:+.6f})" | |
| msg = f" Epoch {i+1}: {aux_loss:.6f}{trend}" | |
| print(msg) | |
| logger.info(msg) | |
| # Final convergence assessment | |
| if len(convergence_history) >= 5: | |
| final_converged, final_metrics, final_warnings = check_router_convergence( | |
| model, total_layers, convergence_history, logger=logger | |
| ) | |
| print(f"\n{'='*80}") | |
| print(f"Final Convergence Status:") | |
| logger.info("="*80) | |
| logger.info("Final Convergence Status:") | |
| if final_converged: | |
| msg = "✅ CONVERGED - Routers have reached stable configuration" | |
| print(f" {msg}") | |
| logger.info(msg) | |
| print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (< 0.1)") | |
| print(f" - Stability: {final_metrics['grad_stability']:.6f} (< 0.01)") | |
| logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}") | |
| logger.info(f" Stability: {final_metrics['grad_stability']:.6f}") | |
| msg = "Safe to deploy or proceed to parameter merging" | |
| print(f" {msg}") | |
| logger.info(msg) | |
| else: | |
| msg = "⚠️ NOT FULLY CONVERGED" | |
| print(f" {msg}") | |
| logger.warning(msg) | |
| print(f" Current metrics:") | |
| print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (target: < 0.1)") | |
| logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}") | |
| if 'grad_stability' in final_metrics: | |
| print(f" - Stability: {final_metrics['grad_stability']:.6f} (target: < 0.01)") | |
| logger.info(f" Stability: {final_metrics['grad_stability']:.6f}") | |
| print(f" Consider training for more epochs if:") | |
| print(f" - Aux loss still decreasing significantly") | |
| print(f" - Routing patterns still changing") | |
| print(f" - Gradient norms not stabilized") | |
| print(f"{'='*80}\n") | |
| logger.info("="*80) | |
| else: | |
| print(f"\n{'='*80}") | |
| print(f"Convergence Status: Insufficient data (< 5 epochs)") | |
| print(f" Run for at least 5 epochs for convergence analysis") | |
| print(f"{'='*80}\n") | |
| logger.info("Convergence Status: Insufficient data (< 5 epochs)") | |
| # Save model | |
| print(f"\n[7/7] Saving router-adapted checkpoint to: {save_path}") | |
| os.makedirs(save_path, exist_ok=True) | |
| # Option 1: Save only router weights (recommended - more portable) | |
| logger.info("\nSaving trained router weights only...") | |
| router_state_dict = {} | |
| for name, param in model.named_parameters(): | |
| if ".block_sparse_moe.gate" in name and param.requires_grad: | |
| router_state_dict[name] = param.data.cpu() | |
| # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation | |
| for name, buf in model.named_buffers(): | |
| if "ema_utilization" in name or "mask_threshold" in name: | |
| router_state_dict[name] = buf.data.cpu() | |
| torch.save({ | |
| 'router_state_dict': router_state_dict, | |
| 'config': { | |
| 'reuse_scale': reuse_scale, | |
| 'num_epochs': num_epochs, | |
| 'lr': lr, | |
| 'model_name': model_name | |
| } | |
| }, os.path.join(save_path, 'rexmoe_routers.pt')) | |
| tokenizer.save_pretrained(save_path) | |
| logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") | |
| logger.info(f" File: {save_path}/rexmoe_routers.pt") | |
| logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") | |
| # Option 2: Also save full model (includes architecture, but less portable) | |
| logger.info("\nAlso saving full model with ReXMoE architecture...") | |
| model.save_pretrained(save_path) | |
| # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA was enabled) | |
| if qlora_enabled: | |
| try: | |
| merged_path = os.path.join(save_path, "merged") | |
| os.makedirs(merged_path, exist_ok=True) | |
| logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") | |
| merged_model = model.merge_and_unload() | |
| merged_model.eval() | |
| tokenizer.save_pretrained(merged_path) | |
| merged_model.save_pretrained(merged_path) | |
| logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") | |
| except Exception as e: | |
| logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") | |
| # Save the custom classes for reloading | |
| import shutil | |
| shutil.copy(__file__, os.path.join(save_path, 'rexmoe_architecture.py')) | |
| # Print final statistics | |
| full_model_size = sum(os.path.getsize(os.path.join(save_path, f)) | |
| for f in os.listdir(save_path) | |
| if f.endswith('.bin') or f.endswith('.safetensors')) / 1024 / 1024 / 1024 | |
| logger.info("="*80) | |
| logger.info("✓ Training complete. Two checkpoint formats saved:") | |
| logger.info(" 1. Router weights only: rexmoe_routers.pt (portable)") | |
| logger.info(" 2. Full model: pytorch_model.bin (requires rexmoe_architecture.py)") | |
| logger.info(f"\nCheckpoint directory: {save_path}") | |
| logger.info(f"Full model size: {full_model_size:.2f} GB") | |
| logger.info("="*80) | |
| return model | |
| # ==================== 5. USAGE ==================== | |
| if __name__ == "__main__": | |
| # arg parser | |
| parser = argparse.ArgumentParser(description="ReXMoE Training") | |
| parser.add_argument("--model_name", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Pretrained model name") | |
| parser.add_argument("--model_path", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Path to pretrained model") # ../models/models/microsoft/Phi-mini-MoE-instruct | |
| parser.add_argument("--dataset_path", type=str, default="../dataset/alpaca_data_cleaned.json", help="Path to dataset JSON") | |
| parser.add_argument("--mode", type=str, choices=["IF","P", "IF_2"], default="IF", help="Dataset mode: IF = instruction-finetune (Alpaca), P = pretraining (C4)") | |
| parser.add_argument("--reuse_scale", type=int, default=3, help="Reuse scale R for cross-layer routing") | |
| parser.add_argument("--epoch", type=int, default=5, help="Number of training epochs") | |
| parser.add_argument("--num_samples", type=int, default=10000, help="Number of training samples to use from the dataset") | |
| parser.add_argument("--batch_size", type=int, default=32, help="Training batch size") | |
| parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length") | |
| parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate") | |
| parser.add_argument("--warmup_steps", type=int, default=200, help="Number of steps to warm up with R=1 (routers only)") | |
| parser.add_argument("--psr_enabled", action='store_true', help="Enable Progressive Scaling Routing (PSR)") | |
| parser.add_argument("--use_scheduler", action='store_true', default=True, help="Use learning rate scheduler") | |
| parser.add_argument("--gradient_checkpointing", action='store_true', help="Enable gradient checkpointing for memory efficiency") | |
| parser.add_argument("--met_enabled", action='store_true', help="Enable Masked Expert Training (MET)") | |
| parser.add_argument("--met_mask_ratio", type=float, default=0.1, help="MET mask ratio (0.1 = mask 10% of experts)") | |
| parser.add_argument("--met_warmup", type=float, default=0.5, help="Proportion of steps to warm up MET (no masking)") | |
| parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate every N steps during training") | |
| parser.add_argument("--log_loss_steps_percent", type=int, default=10, help="Log loss every N%% of total steps") | |
| parser.add_argument("--full_lora", action='store_true', help="Enable full LoRA training") | |
| parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") | |
| parser.add_argument("--save_path", type=str, default="./rexmoe_natural_phi_mini_moe", help="Base path to save trained model (timestamp will be prefixed)") | |
| parser.add_argument("--aux_loss_weight", type=float, default=0.02, help="Auxiliary loss weight") | |
| args = parser.parse_args() | |
| # Prefix save path with timestamp (DDMM_HHMMSS) to distinguish runs | |
| from datetime import datetime as _dt | |
| timestamp = _dt.now().strftime("%d%m_%H%M%S") | |
| timed_save_path = os.path.join(os.path.dirname(args.save_path), f"{timestamp}_" + f"{int(args.met_mask_ratio*100)}_" + os.path.basename(args.save_path)) + f"_R{args.reuse_scale}" | |
| model = train_rexmoe( | |
| model_name=args.model_name, | |
| model_path=args.model_path, | |
| dataset_path=args.dataset_path, | |
| dataset_mode=args.mode, | |
| reuse_scale=args.reuse_scale, | |
| num_samples=args.num_samples, | |
| num_epochs=args.epoch, # 5 epochs sufficient for router adaptation | |
| batch_size=args.batch_size, # As specified | |
| max_seq_length=args.max_seq_length, # As specified | |
| lr=args.lr, | |
| warmup_steps=args.warmup_steps, | |
| psr_enabled=args.psr_enabled, # Critical: prevents early collapse | |
| use_scheduler=args.use_scheduler, | |
| gradient_checkpointing=args.gradient_checkpointing, | |
| met_enabled=args.met_enabled, | |
| met_mask_ratio=args.met_mask_ratio, | |
| met_warmup=args.met_warmup, | |
| eval_steps=args.eval_steps, | |
| log_loss_steps_percent=args.log_loss_steps_percent, | |
| full_lora=args.full_lora, | |
| lora_r=args.lora_r, | |
| save_path=timed_save_path, | |
| aux_loss_weight=args.aux_loss_weight | |
| ) | |
| print(f"\n✓ All done! Model saved to {timed_save_path}") | |