import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import torch import torch.nn as nn from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset from torch.utils.data import DataLoader from tqdm import tqdm import json import bitsandbytes as bnb from peft import LoraConfig, get_peft_model import argparse import logging from datetime import datetime from torch.optim.lr_scheduler import CosineAnnealingLR from typing import Optional from get_dataset import get_dataloader if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") # ==================== 0. LOGGING SETUP ==================== def setup_logger(save_path="./logs"): """Setup logger with timestamp-based filename""" os.makedirs(save_path, exist_ok=True) # Create unique log filename: DDMM_HHMMSS timestamp = datetime.now().strftime("%d%m_%H%M%S") log_file = os.path.join(save_path, f"rexmoe_training_{timestamp}.log") # Create logger logger = logging.getLogger('ReXMoE') logger.setLevel(logging.INFO) # Remove existing handlers logger.handlers = [] # File handle file_handler = logging.FileHandler(log_file) file_handler.setLevel(logging.INFO) # Console handler console_handler = logging.StreamHandler() console_handler.setLevel(logging.INFO) # Formatter formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) file_handler.setFormatter(formatter) console_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.addHandler(console_handler) logger.info(f"=" * 80) logger.info(f"ReXMoE Training Log - {timestamp}") logger.info(f"Log file: {log_file}") logger.info(f"=" * 80) return logger, log_file # Format prompt def format_alpaca_prompt(instruction: str, input_text: str = "") -> str: """Match the training prompt template used in main.py.""" if input_text: return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n" return f"### Instruction:\n{instruction}\n\n### Response:\n" def build_model_input(tokenizer, instruction: str, input_text: str = "") -> str: """Prefer the model chat template if available; fall back to Alpaca prompt.""" user_msg = instruction if not input_text else f"{instruction}\n\n{input_text}" # Newer HF tokenizers expose an explicit chat template for instruct models. if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): messages = [{"role": "user", "content": user_msg}] print(f"Applying tokenizer's chat template: {tokenizer.chat_template}") return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) return format_alpaca_prompt(instruction=instruction, input_text=input_text) # Evaluate # ...existing code... def evaluate_prompt(model, tokenizer, max_new_tokens=100, do_sample=True, temperature=0.7, logger=None): """Generate completions for 3 sample prompts and print/log results.""" try: # Safely get device for tensors if hasattr(model, "device"): device = model.device else: # fallback: first parameter device device = next(model.parameters()).device msg = "\nEvaluating model with 3 sample prompts..." if logger: logger.info(msg) print(msg) # Display pruning status if IG-MET is enabled (count UNIQUE experts, not router-level copies) backend_model = get_backend_model(model) pruning_info = [] unique_experts_pruned = set() unique_experts_total = set() unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score # 1. Aggregate EMA values for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): router = layer.block_sparse_moe.router threshold = router.mask_threshold.item() if threshold >= 0: # IG-MET enabled # Reconstruct mapping logic carefully current_r = router.get_candidate_layers(step=None, total_steps=None) half = (current_r - 1) // 2 start_layer = max(0, layer_idx - half) end_layer = min(len(backend_model.layers), start_layer + current_r) start_layer = max(0, end_layer - current_r) # Build mapping for this router current_mapping = [] for layer_offset in range(current_r): l_id = start_layer + layer_offset if l_id >= len(backend_model.layers): break for e_id in range(router.num_experts_per_layer): current_mapping.append((l_id, e_id)) num_active = len(current_mapping) # Accumulate EMA for pool_pos, key in enumerate(current_mapping): if pool_pos >= len(router.ema_utilization): break unique_experts_total.add(key) ema_val = router.ema_utilization[pool_pos].item() if key not in unique_experts_sum: unique_experts_sum[key] = ema_val else: unique_experts_sum[key] += ema_val # 2. Determine pruning status based on SUMMED usage vs Threshold # Note: All routers share the same threshold value derived from summed distribution if unique_experts_sum and hasattr(backend_model.layers[0].block_sparse_moe.router, "mask_threshold"): # Get current global threshold from first router threshold = backend_model.layers[0].block_sparse_moe.router.mask_threshold.item() unique_experts_pruned = {k for k, v in unique_experts_sum.items() if v < threshold} msg = f"\n[IG-MET Pruning Status during Evaluation]:" print(msg) if logger: logger.info(msg) total_unique_pruned = len(unique_experts_pruned) total_unique = len(unique_experts_total) pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0 msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold:.6f}" print(msg) if logger: logger.info(msg) # Count per-layer pruning stats based on GLOBAL decision # Rerun loop just for stats for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): router = layer.block_sparse_moe.router # Mapping logic again current_r = router.get_candidate_layers(step=None, total_steps=None) half = (current_r - 1) // 2 start_layer = max(0, layer_idx - half) end_layer = min(len(backend_model.layers), start_layer + current_r) start_layer = max(0, end_layer - current_r) masked_in_layer = 0 # Only count experts from the CURRENT layer, not all reused layers current_layer_experts = [(layer_idx, e_id) for e_id in range(router.num_experts_per_layer)] for key in current_layer_experts: if key in unique_experts_pruned: masked_in_layer += 1 # num_active = total experts in current layer (always num_experts_per_layer for the layer itself) num_active = router.num_experts_per_layer pruning_info.append((layer_idx, threshold, masked_in_layer, num_active)) # Show top/bottom layers by pruning ratio pruning_by_ratio = sorted(pruning_info, key=lambda x: x[2]/x[3] if x[3] > 0 else 0, reverse=True) msg = "Top 5 most pruned layers:" print(msg) if logger: logger.info(msg) for layer_idx, thr, masked, total in pruning_by_ratio[:5]: pct = 100 * masked / total if total > 0 else 0 msg = f" Layer {layer_idx:>2}: {masked:>2}/{total} pruned ({pct:>5.1f}%)" print(msg) if logger: logger.info(msg) # Define 3 evaluation prompts eval_prompts = [ { "instruction": "What is the capital of France?", "input_text": None }, { "instruction": "High-pressure systems stop air from rising into the colder regions of the atmosphere where water can condense. What will most likely result if a high-pressure system remains in an area for a long period of time?\nA. fog\nB. rain\nC. drought\nD. tornado\nAnswer:", "input_text": None }, { "instruction": "Given the fact: predators eat prey\nQuestion: Predators eat\nA. lions\nB. humans\nC. bunnies\nD. grass\nAnswer:", "input_text": None } ] for prompt_idx, prompt_config in enumerate(eval_prompts, 1): print("\n" + "=" * 80) print(f"Prompt {prompt_idx}/3:") print(f" Instruction: {prompt_config['instruction']}") print(f" Input: {prompt_config['input_text']}") print("=" * 80) # Build prompt prompt = build_model_input( tokenizer, instruction=prompt_config['instruction'], input_text=prompt_config['input_text'], ) # Tokenize and move tensors to device inputs = tokenizer(prompt, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Generate with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, pad_token_id=getattr(tokenizer, "pad_token_id", None), eos_token_id=getattr(tokenizer, "eos_token_id", None), ) prompt_len = inputs["input_ids"].shape[-1] # outputs: tensor [batch, seq_len] generated = outputs[0] completion_ids = generated[prompt_len:] completion_text = tokenizer.decode(completion_ids, skip_special_tokens=True) full_text = tokenizer.decode(generated, skip_special_tokens=True) print("GENERATED RESPONSE:") print("-" * 80) print(completion_text) print("-" * 80) # Debugging info print(f"\n[debug] prompt_tokens={prompt_len}, new_tokens={int(completion_ids.numel())}") if completion_ids.numel() == 0: print("[debug] Model generated 0 new tokens (likely hit EOS immediately).") print("[debug] Full decoded text:\n" + full_text) elif completion_text.strip() == "": print("[debug] Model generated tokens, but they decode to empty/whitespace or special tokens.") print("[debug] completion token ids:", completion_ids.tolist()) if logger: logger.info(f"\n--- Prompt {prompt_idx}/3 ---") logger.info(f"Instruction: {prompt_config['instruction']}") logger.info(f"Input: {prompt_config['input_text']}") logger.info(f"Generated completion (len {int(completion_ids.numel())}): {completion_text}") print("=" * 80) if logger: logger.info("Evaluation of all 3 prompts complete.") except Exception as e: err = f"evaluate_prompt failed: {e}" print(err) if logger: logger.exception(err) # ==================== 1. MODEL MODIFICATION ==================== class ReXMoERouter(nn.Module): """Router logic (no parameters) for cross-layer expert reuse. Note: This module is intentionally parameterless so that trainable gate weights live directly under the MoE block as `block_sparse_moe.gate`, matching the original Phi-MoE naming. This prevents saving keys like `router.gate` and keeps checkpoint compatibility. """ def __init__(self, layer_idx, total_layers=32, num_experts_per_layer=16, reuse_scale=3, num_experts_per_tok=2, all_experts_dict=None, aux_loss_weight=0.02): super().__init__() self.layer_idx = layer_idx self.total_layers = total_layers self.num_experts_per_layer = num_experts_per_layer self.reuse_scale = reuse_scale # R=3: layers [i-1, i, i+1] self.num_experts_per_tok = num_experts_per_tok # Store reference to all experts dict to get actual expert counts self.all_experts_dict = all_experts_dict # Max pool size = reuse_scale * num_experts_per_layer self.max_pool_size = reuse_scale * num_experts_per_layer # EMA tracking for expert utilization # Initialize with uniform probability (1/num_experts_per_tok?) # Or just zeros if we want to learn from scratch. # User says: "smoothed utilization... Ck = raw selection count" # Since we start with no history, let's init with zeros. self.register_buffer('ema_utilization', torch.zeros(self.max_pool_size)) self.register_buffer('mask_threshold', torch.tensor(-1.0)) # Default -1 means no masking self.aux_loss_weight = aux_loss_weight def get_candidate_layers(self, step, total_steps, psr_enabled=True, initial_R=2, met_warmup=None): """Progressive Scaling Routing: gradually expand reuse scale Args: step: current training step total_steps: total steps in training psr_enabled: whether PSR is enabled initial_R: initial reuse scale (default 2) met_warmup: if provided (float 0-1), PSR only runs during 0-met_warmup phase, then stays at max reuse_scale. If None, uses old schedule (0-50% of training). """ if not psr_enabled or step is None or total_steps is None: return self.reuse_scale if met_warmup is not None: # New behavior: PSR completes within the first phase (0 to met_warmup) # After met_warmup, stay at max R progress = min(step / (met_warmup * total_steps), 1.0) current_r = initial_R + int(progress * (self.reuse_scale - initial_R)) else: # Legacy behavior: Linear schedule R=2 → target_R over first 50% of training progress = min(step / (0.5 * total_steps), 1.0) current_r = initial_R + int(progress * (self.reuse_scale - initial_R)) return current_r def update_ema(self, selection_counts, alpha=0.9): """Update EMA tracking for expert utilization""" # selection_counts: tensor of shape [max_pool_size] with torch.no_grad(): self.ema_utilization = alpha * self.ema_utilization + (1 - alpha) * selection_counts def forward_with_logits(self, all_logits, hidden_states, step=None, total_steps=None, met_enabled=False, met_warmup=None, logger=None): """ Args: all_logits: [batch_size * seq_len, max_pool_size] precomputed logits from block's gate hidden_states: [batch_size, seq_len, hidden_dim] met_warmup: if provided (float 0-1), PSR runs only during this phase then stays at max R Returns: router_logits: [batch_size * seq_len, max_pool_size] aux_loss: scalar active_expert_mask: [max_pool_size] boolean mask layer_expert_mapping: list of (layer_idx, expert_idx) tuples """ batch_size, seq_len, hidden_dim = hidden_states.shape # all_logits already has shape [B*S, max_pool_size] # Get current reuse scale via PSR (pass met_warmup if available) current_r = self.get_candidate_layers(step, total_steps, met_warmup=met_warmup) # [CRITICAL FIX]: Ensure mapping aligns with STATIC physical gate size (self.max_pool_size) # The base mapping uses the FULL reuse_scale statically! base_half = (self.reuse_scale - 1) // 2 base_start = max(0, self.layer_idx - base_half) base_end = min(self.total_layers, base_start + self.reuse_scale) base_start = max(0, base_end - self.reuse_scale) # The PSR subset window defines which subset of the full mapping is CURRENTLY active psr_half = (current_r - 1) // 2 psr_start = max(0, self.layer_idx - psr_half) psr_end = min(self.total_layers, psr_start + current_r) psr_start = max(0, psr_end - current_r) # Create active_mask natively aligned to the self.gate output nodes num_active_experts = current_r * self.num_experts_per_layer active_mask = torch.zeros(self.max_pool_size, dtype=torch.bool, device=all_logits.device) # Create full layer-expert mapping for the expert selector. layer_expert_mapping = [] for layer_offset in range(self.reuse_scale): layer_id = base_start + layer_offset # Is this physical block currently enabled by PSR? is_active_psr_layer = (psr_start <= layer_id < psr_end) for expert_id in range(self.num_experts_per_layer): pool_idx = len(layer_expert_mapping) layer_expert_mapping.append((layer_id, expert_id)) # Activate in the mask if it falls within the PSR window AND is a valid layer if is_active_psr_layer and layer_id < self.total_layers: active_mask[pool_idx] = True # Mask out inactive experts by setting their logits to -inf masked_logits = all_logits.clone() masked_logits[:, ~active_mask] = float('-inf') # === DYNAMIC PRUNING MASK (OLD_TO_NEW) === # If the checkpoint is hard-pruned, the mapping omits pruned experts. # Natively mask logits here preventing explicit drops during top-k if hasattr(self, 'old_to_new') and self.old_to_new: for pool_idx, (orig_layer, orig_expert) in enumerate(layer_expert_mapping): if not active_mask[pool_idx]: continue orig_layer_int = int(orig_layer) # old_to_new is keyed by int layer if orig_layer_int in self.old_to_new: layer_map = self.old_to_new[orig_layer_int] if orig_expert not in layer_map and str(orig_expert) not in layer_map: masked_logits[:, pool_idx] = float('-inf') active_mask[pool_idx] = False # === IMPORTANCE-GUIDED MASKED EXPERT TRAINING (IG-MET) === # Apply mask based on global threshold if enabled # If router has a specifically pre-calculated pruning mask (from global analysis), use it. # Otherwise fall back to local thresholding (which may be inaccurate if aggregation is SUM). # Check for externally provided mask (from train_rexmoe global pass) global_keep_mask = getattr(self, 'global_keep_mask', None) if met_enabled: # Mode A: Precise Global Pruning (via mask pushed from training loop) if global_keep_mask is not None: # Ensure mask is on correct device if global_keep_mask.device != all_logits.device: global_keep_mask = global_keep_mask.to(all_logits.device) # Invert to get what we should prune (keep=False -> prune=True) # The global_keep_mask perfectly aligns with self.max_pool_size cur_len = min(len(global_keep_mask), self.max_pool_size) # We mask where global_keep_mask is FALSE, BUT only if it is currently active target_mask = torch.zeros_like(active_mask) target_mask[:cur_len] = ~global_keep_mask[:cur_len] target_mask = target_mask & active_mask # Safety: Don't prune everything if target_mask.all() or target_mask.sum() == active_mask.sum(): pass # Don't prune if it kills all remaining active experts else: masked_logits[:, target_mask] = float('-inf') active_mask[target_mask] = False # Mode B: Local Thresholding (Fallback / Original) elif self.mask_threshold.item() >= 0: # Mask experts with EMA utilization below threshold # Note: Only mask experts that are theoretically active (based on current_r) threshold = self.mask_threshold.to(all_logits.device) ema = self.ema_utilization.to(all_logits.device) under_utilized_mask = (ema < threshold) target_mask = under_utilized_mask & active_mask if target_mask.any(): num_active_before = active_mask.sum().item() num_to_mask = target_mask.sum().item() if num_to_mask < num_active_before: masked_logits[:, target_mask] = float('-inf') active_mask[target_mask] = False # Compute routing probabilities routing_weights = torch.softmax(masked_logits, dim=-1) # [B*S, max_pool_size] # Update EMA tracking (detached from graph) # Calculate C_k: raw selection count at step k # We use sum of routing weights as per user request: "= expert utilization counts (sum of routing weights per expert)" if self.training: current_counts = routing_weights.sum(dim=0).detach() # [max_pool_size] self.update_ema(current_counts) # Auxiliary load balancing loss (coefficient of variation) # Only compute over active experts # Ensure active_mask is on the same device as routing_weights active_mask = active_mask.to(routing_weights.device) # Safeguard: check that we have active experts num_true_active = active_mask.sum().item() if num_true_active == 0: # Fallback: mark at least the first expert as active active_mask[0] = True active_routing_weights = routing_weights[:, active_mask] expert_counts = active_routing_weights.sum(0) # [num_active_experts] # Safe CV calculation to prevent NaNs if expert_counts.numel() > 1: mean_count = expert_counts.mean() std_count = expert_counts.std() # Use larger epsilon and handle potential detached tensor cv_squared = (std_count / (mean_count + 1e-6)) ** 2 else: cv_squared = torch.tensor(0.0, device=active_routing_weights.device, dtype=active_routing_weights.dtype) # aux_loss = 0.01 * cv_squared # α=0.01 per ReXMoE aux_loss = self.aux_loss_weight * cv_squared # Higher weight on load balancing loss to encourage more even routing, especially important with PSR where early layers have fewer experts and are more likely to be overloaded. # Keep last mapping for introspection/debugging self.last_layer_expert_mapping = layer_expert_mapping return masked_logits, aux_loss, active_mask, layer_expert_mapping class ReXMoESparseMoeBlock(nn.Module): """ Modified PhiMoE Sparse MoE block with cross-layer expert reuse. Keeps original experts intact but routes to adjacent layer experts. """ def __init__(self, original_moe_block, layer_idx, total_layers, all_experts_dict, reuse_scale=3, logger=None, aux_loss_weight=0.02): super().__init__() self.hidden_dim = original_moe_block.hidden_dim self.num_experts = original_moe_block.num_experts self.top_k = original_moe_block.top_k self.layer_idx = layer_idx self.reuse_scale = reuse_scale # Keep reference to experts from all layers (DO NOT copy parameters) self.all_experts_dict = all_experts_dict # Dict: {layer_idx: ModuleList of experts} self.aux_loss_weight = aux_loss_weight # Replace router with ReXMoE router (parameterless) self.router = ReXMoERouter( layer_idx=layer_idx, total_layers=total_layers, num_experts_per_layer=self.num_experts, reuse_scale=reuse_scale, num_experts_per_tok=self.top_k, all_experts_dict=all_experts_dict, aux_loss_weight=self.aux_loss_weight ) # Install a gate on the block itself to match original naming self.gate = nn.Linear(self.hidden_dim, self.router.max_pool_size, bias=False) # [FIX] Initialize new gate with original router weights for BOTH local and neighbor sections with torch.no_grad(): orig_gate_shape = original_moe_block.gate.weight.data.shape[0] if orig_gate_shape == self.router.max_pool_size: # If loading from a checkpoint where gate is already expanded, copy it directly self.gate.weight.data.copy_(original_moe_block.gate.weight.data) print(f" Base block gate size already {orig_gate_shape}, copied fully for layer {layer_idx}") elif orig_gate_shape == self.num_experts: # Calculate where the local experts sit in the new expanded router half = (reuse_scale - 1) // 2 local_start_idx = half * self.num_experts local_end_idx = local_start_idx + self.num_experts # Standard init from base model: Copy local weights exactly self.gate.weight[local_start_idx:local_end_idx, :] = original_moe_block.gate.weight.data.clone() print(f" num_experts: {self.num_experts}, refilled router weights for layer {layer_idx} local section at indices {local_start_idx}:{local_end_idx}") # Crucial Fix for R > 2: Initialize neighbor sections with copied weights + noise instead of zero # If neighbor logits are exactly zero initially, it causes router collapse and massive loss spikes (>10) noise_scale = 0.1 * original_moe_block.gate.weight.data.std().item() # Fill all sections before the local section (previous layers) for section in range(half): start_idx = section * self.num_experts end_idx = start_idx + self.num_experts noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise print(f" Initialized neighbor section {start_idx}:{end_idx} with noise") # Fill all sections after the local section (next layers) for section in range(half + 1, reuse_scale): start_idx = section * self.num_experts end_idx = start_idx + self.num_experts noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise print(f" Initialized neighbor section {start_idx}:{end_idx} with noise") else: # Fallback for pruned or irregular sized expert checkpoints # Safe initialization: zero out, then carefully mapped copying self.gate.weight.zero_() print(f" Warning: Custom size mismatch for gate refill ({orig_gate_shape} vs {self.num_experts}). Filling with zeros.") # Try to map whatever we can based on minimum size min_experts = min(orig_gate_shape, self.router.max_pool_size) self.gate.weight[:min_experts, :] = original_moe_block.gate.weight.data[:min_experts, :].clone() # Canonical expert container: match base PhiMoE which uses `.experts` # so that checkpoints save/load expert weights under the same keys. self.experts = original_moe_block.experts # Store current training step for PSR # Default to None so inference uses full R (not PSR schedule) self.current_step = None self.total_steps = None self.met_warmup = None # Will be set during training if MET is enabled # Store aux_loss for backward pass self.last_aux_loss = None # Store actual routing selections for analysis self.last_selected_experts = None # Will store (target_layer, target_expert) tuples self.last_selection_counts = None # Count of tokens routed to each expert self.logger = logger # Store logger for potential use in forward pass @property def local_experts(self): # expose alias for any code that tries to access it return self.experts def map_pruned_expert(self, orig_layer: int, orig_expert: int, old_to_new: dict) -> Optional[int]: """ Map original (layer, expert) index to current kept expert index. Returns None if the expert was pruned. """ new_idx = None if orig_layer in old_to_new: layer_map = old_to_new[orig_layer] # Handle both int and str keys (robust to JSON serialization) new_idx = layer_map.get(orig_expert, layer_map.get(str(orig_expert), None)) else: # No pruning map: use original index if it still exists if orig_layer in self.all_experts_dict and orig_expert < len(self.all_experts_dict[orig_layer]): new_idx = orig_expert return new_idx def forward(self, hidden_states, logger=None): """ Args: hidden_states: [batch_size, seq_len, hidden_dim] Returns: output: [batch_size, seq_len, hidden_dim] """ # If a logger wasn't passed down through the model forward call, # fall back to any logger attribute attached to this block instance. if logger is None: logger = getattr(self, 'logger', None) batch_size, seq_len, hidden_dim = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_dim) # [B*S, H] # Ensure gate is on the same device as hidden_states (fixes device_map="auto" mismatch) device = hidden_states_flat.device self.gate = self.gate.to(device) # Get routing decisions # Compute gate logits at block level (so weights are saved as block_sparse_moe.gate) hidden_states_flat = hidden_states.view(-1, hidden_dim) all_logits = self.gate(hidden_states_flat) router_logits, aux_loss, active_mask, layer_expert_mapping = self.router.forward_with_logits( all_logits, hidden_states, self.current_step, self.total_steps, met_enabled=getattr(self, 'met_enabled', False), met_warmup=self.met_warmup, logger=logger ) num_pool = router_logits.shape[-1] BxS = hidden_states_flat.shape[0] # Store aux_loss for collection in training loop self.last_aux_loss = aux_loss # Get top-k indices and values in one operation topk_logits, topk_indices = torch.topk(router_logits, self.top_k, dim=-1) # [B*S, k] topk_weights = torch.softmax(topk_logits, dim=-1) # [B*S, k] # === VECTORIZED EXPERT EXECUTION === # Pre-allocate output final_hidden_states = torch.zeros_like(hidden_states_flat) selection_counts = {} old_to_new = getattr(self.router, 'old_to_new', {}) processed_mask = torch.zeros(BxS, dtype=torch.bool, device=hidden_states.device) # Process each expert position in the top-k (k=2 is small) for k_idx in range(self.top_k): # Get which expert each token selected at position k selected_positions = topk_indices[:, k_idx] # [B*S] # Gather logits for weighting k_weights = topk_weights[:, k_idx:k_idx+1] # [B*S, 1] # === BATCH EXPERT EXECUTION === # Group tokens by which expert they selected for pool_pos in range(self.router.max_pool_size): if pool_pos >= len(layer_expert_mapping): continue # HARD BLOCK FOR PRUNED / INACTIVE EXPERTS: # Completely bypass execution and prevent token leakage if not active_mask[pool_pos]: continue # Find all tokens that selected this expert token_mask = (selected_positions == pool_pos) # [B*S] if not token_mask.any(): continue # Get tokens for this expert selected_tokens = hidden_states_flat[token_mask] # [N, H] orig_layer, orig_expert = layer_expert_mapping[pool_pos] new_idx = self.map_pruned_expert(orig_layer, orig_expert, old_to_new) if new_idx is None: continue # Pruned expert expert_module = self.all_experts_dict[orig_layer][new_idx] # Move selected_tokens to expert's device instead of moving expert # This is more efficient when model is sharded across GPUs expert_device = next(expert_module.parameters()).device selected_tokens = selected_tokens.to(expert_device) expert_out = expert_module(selected_tokens) # [N, H] - BATCHED! # Move output back to original device expert_out = expert_out.to(device) weighted_out = expert_out * k_weights[token_mask] final_hidden_states[token_mask] += weighted_out # CRITICAL: Record the expert index for analysis # For pruned models: use orig_expert (original index before hard deletion) # For unpruned models: use new_idx (which equals orig_expert since no mapping) # The distinction is made at model load time via config.is_pruned or config.pruned is_pruned_model = getattr(self, 'is_pruned_model', False) or hasattr(self, 'old_to_new') and self.old_to_new reported_expert = orig_expert if is_pruned_model else new_idx key = (orig_layer, reported_expert) selection_counts[key] = selection_counts.get(key, 0) + token_mask.sum().item() processed_mask[token_mask] = True if not processed_mask.all(): final_hidden_states[~processed_mask] = hidden_states_flat[~processed_mask] # Store selections for analysis self.last_selection_counts = selection_counts # Reshape back final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim) # Return tuple like original PhiMoE: (hidden_states, router_logits) return final_hidden_states, router_logits # ==================== 3. ROUTING ANALYSIS ==================== def analyze_routing_patterns(model, dataloader, current_r, total_layers, device, num_batches=10, logger=None): """ Analyze ACTUAL routing patterns by tracking which experts were selected. For each layer, tracks: - Which experts are ACTUALLY selected most frequently - Whether experts from adjacent layers are being used - Distribution of routing across layers """ model.eval() # Track ACTUAL routing decisions: routing_counts[layer_idx][(target_layer, target_expert)] = count routing_counts = {} for layer_idx in range(total_layers): routing_counts[layer_idx] = {} total_tokens = 0 with torch.no_grad(): for batch_idx, batch in enumerate(dataloader): if batch_idx >= num_batches: # Sample only first N batches for efficiency break input_ids = batch["input_ids"].to(device) attention_mask = batch["attention_mask"].to(device) # Get batch size and sequence length batch_size, seq_len = input_ids.shape num_tokens = (attention_mask.sum()).item() # Count non-padding tokens total_tokens += num_tokens # Forward pass to get routing decisions _ = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False) # Always work on the underlying transformer stack that actually owns `.layers`, # whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model. backend_model = get_backend_model(model) # Collect ACTUAL routing decisions from each layer for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): moe_block = layer.block_sparse_moe # Get actual selections from the forward pass if moe_block.last_selection_counts is not None: for (target_layer, target_expert), count in moe_block.last_selection_counts.items(): key = (target_layer, target_expert) routing_counts[layer_idx][key] = routing_counts[layer_idx].get(key, 0) + count model.train() # Print analysis msg = f"\nAnalyzing ACTUAL routing patterns from {num_batches} batches ({total_tokens:,} tokens)" print(msg) if logger: logger.info(msg) msg = f"Current reuse scale: R={current_r}" print(msg) if logger: logger.info(msg) # === IG-MET PRUNING ANALYTICS (GLOBAL SUM AGGREGATION) === # 1. Aggregate EMA for each unique expert across all routers (SUM) backend_model = get_backend_model(model) unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score unique_experts_total = set() threshold = None for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): router = layer.block_sparse_moe.router thr = router.mask_threshold.item() if thr >= 0: # Reconstruct mapping logic current_r = router.get_candidate_layers(step=None, total_steps=None) half = (current_r - 1) // 2 start_layer = max(0, layer_idx - half) end_layer = min(len(backend_model.layers), start_layer + current_r) start_layer = max(0, end_layer - current_r) current_mapping = [] for layer_offset in range(current_r): l_id = start_layer + layer_offset if l_id >= len(backend_model.layers): break for e_id in range(router.num_experts_per_layer): current_mapping.append((l_id, e_id)) num_active = len(current_mapping) for pool_pos, key in enumerate(current_mapping): if pool_pos >= len(router.ema_utilization): break unique_experts_total.add(key) ema_val = router.ema_utilization[pool_pos].item() if key not in unique_experts_sum: unique_experts_sum[key] = ema_val else: unique_experts_sum[key] += ema_val if threshold is None: threshold = thr # 2. Prune based on SUM aggregation and global threshold unique_experts_pruned = {k for k, v in unique_experts_sum.items() if threshold is not None and v < threshold} total_unique_pruned = len(unique_experts_pruned) total_unique = len(unique_experts_total) msg = "\n[IG-MET Pruning Report]:" print(msg) if logger: logger.info(msg) pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0 msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold if threshold is not None else -1:.6f}" print(msg) if logger: logger.info(msg) if unique_experts_sum: global_ema_tensor = torch.tensor(list(unique_experts_sum.values()), device=device) msg = f"Aggregated EMA (sum across R layers): mean={global_ema_tensor.mean():.6f}, min={global_ema_tensor.min():.6f}, max={global_ema_tensor.max():.6f}" print(msg) if logger: logger.info(msg) print() # Analyze cross-layer reuse statistics cross_layer_usage = { "same_layer": 0, "adjacent_prev": 0, "adjacent_next": 0, "distant": 0 } for layer_idx in routing_counts: for (target_layer, target_expert), count in routing_counts[layer_idx].items(): if target_layer == layer_idx: cross_layer_usage["same_layer"] += count elif target_layer == layer_idx - 1: cross_layer_usage["adjacent_prev"] += count elif target_layer == layer_idx + 1: cross_layer_usage["adjacent_next"] += count else: cross_layer_usage["distant"] += count total_routing = sum(cross_layer_usage.values()) if total_routing > 0: msg = "Cross-Layer Routing Distribution (ACTUAL selections):" print(msg) if logger: logger.info(msg) for key, label in [ ("same_layer", "Same layer (i):"), ("adjacent_prev", "Previous layer (i-1):"), ("adjacent_next", "Next layer (i+1):"), ("distant", "Distant layers:") ]: if cross_layer_usage[key] > 0 or key != "distant": pct = 100 * cross_layer_usage[key] / total_routing msg = f" {label:25} {cross_layer_usage[key]:>10,} ({pct:>5.1f}%)" print(msg) if logger: logger.info(msg) print() # Sample detailed analysis for a few layers sample_layers = [8, 16, 24] if total_layers >= 32 else [total_layers // 4, total_layers // 2, 3 * total_layers // 4] msg = "Sample Layer-Specific Routing Patterns:" print(msg) if logger: logger.info(msg) for layer_idx in sample_layers: if layer_idx in routing_counts and routing_counts[layer_idx]: msg = f"\n Layer {layer_idx}:" print(msg) if logger: logger.info(msg) # Get top 5 most used experts sorted_experts = sorted(routing_counts[layer_idx].items(), key=lambda x: x[1], reverse=True)[:5] for (target_layer, target_expert), count in sorted_experts: pct = 100 * count / total_tokens if total_tokens > 0 else 0 layer_relation = "same" if target_layer == layer_idx else f"L{target_layer}" msg = f" Expert {target_expert:>2} from layer {target_layer:>2} ({layer_relation:>4}): {count:>8,} times ({pct:>5.1f}%)" print(msg) if logger: logger.info(msg) print() # Check if cross-layer reuse is happening cross_layer_pct = 100 * (cross_layer_usage['adjacent_prev'] + cross_layer_usage['adjacent_next'] + cross_layer_usage['distant']) / total_routing if total_routing > 0 else 0 if cross_layer_pct > 5: msg = f"✅ Cross-layer expert reuse detected: {cross_layer_pct:.1f}% of routing uses adjacent layers" print(msg) if logger: logger.info(msg) elif current_r > 1: msg = f"⚠️ Limited cross-layer reuse: {cross_layer_pct:.1f}% (expected >5% with R={current_r})" print(msg) if logger: logger.warning(msg) msg = " This may improve as training progresses and routers adapt." print(msg) if logger: logger.info(msg) else: msg = f"ℹ️ R=1 mode: Only same-layer experts available (PSR warmup phase)" print(msg) if logger: logger.info(msg) # ==================== 3.5. CONVERGENCE MONITORING ==================== def compute_routing_entropy(router_logits): """ Compute entropy of routing distribution. High entropy = uniform routing (may indicate lack of specialization) Low entropy = concentrated routing (strong preferences) """ probs = torch.softmax(router_logits, dim=-1) entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean() return entropy.item() def check_router_convergence(model, total_layers, convergence_history, threshold=0.01, logger=None): """ Check if routers have converged by analyzing: 1. Router weight gradient norms (should be small) 2. Routing entropy stability (should be stable) 3. Expert preference consistency (should not fluctuate) Returns: converged: bool metrics: dict of convergence metrics warnings: list of warning messages """ router_grad_norms = [] # Always work on the underlying transformer stack that actually owns `.layers`, # whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model. backend_model = get_backend_model(model) for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): gate = getattr(layer.block_sparse_moe, 'gate', None) if gate is not None and gate.weight is not None and gate.weight.grad is not None: grad_norm = gate.weight.grad.norm().item() router_grad_norms.append(grad_norm) avg_grad_norm = sum(router_grad_norms) / len(router_grad_norms) if router_grad_norms else 0 metrics = { 'avg_router_grad_norm': avg_grad_norm, 'max_router_grad_norm': max(router_grad_norms) if router_grad_norms else 0, 'min_router_grad_norm': min(router_grad_norms) if router_grad_norms else 0, } warnings = [] # Check convergence: gradients should be small and stable convergence_history.append(avg_grad_norm) # Need at least 5 epochs of history if len(convergence_history) < 5: return False, metrics, warnings # Check if gradient norm is stable (variance < threshold) recent_grads = convergence_history[-5:] grad_variance = torch.tensor(recent_grads).var().item() grad_mean = torch.tensor(recent_grads).mean().item() metrics['grad_variance'] = grad_variance metrics['grad_stability'] = grad_variance / (grad_mean + 1e-10) # Detect oscillations (gradient norm going up and down) if len(convergence_history) >= 3: last_3 = convergence_history[-3:] # Check if middle value is either a peak or valley if (last_3[1] > last_3[0] and last_3[1] > last_3[2]) or \ (last_3[1] < last_3[0] and last_3[1] < last_3[2]): warnings.append("⚠️ Oscillation detected in gradient norms - consider reducing learning rate") if logger: logger.warning("Oscillation detected in gradient norms") # Check for increasing gradients (divergence) if len(convergence_history) >= 2: if convergence_history[-1] > convergence_history[-2] * 1.2: warnings.append("⚠️ Gradients increasing - possible divergence or high learning rate") if logger: logger.warning("Gradients increasing - possible divergence") # Check if gradients are too high if avg_grad_norm > 0.3: warnings.append(f"⚠️ High gradient norm ({avg_grad_norm:.4f}) - learning rate may be too high") if logger: logger.warning(f"High gradient norm: {avg_grad_norm:.4f}") # Converged if: gradients are small AND stable converged = (avg_grad_norm < 0.1) and (metrics['grad_stability'] < threshold) return converged, metrics, warnings # Helper: always get the underlying transformer stack (with `.layers`) def get_backend_model(m): # If PEFT-wrapped, unwrap to base_model; else keep as is core = getattr(m, "base_model", m) # For PhiMoEForCausalLM, the transformer is in `.model` causalLM = getattr(core, "model", core) return getattr(causalLM, "model", causalLM) # ==================== 4. TRAINING LOOP ==================== def train_rexmoe( model_name="microsoft/Phi-mini-MoE-instruct", model_path="../models/models/microsoft/Phi-mini-MoE-instruct", dataset_path="../dataset/alpaca_data_cleaned.json", dataset_mode: str = "IF", reuse_scale=3, num_samples=10000, num_epochs=5, batch_size=16, max_seq_length=512, lr=5e-6, warmup_steps=10, psr_enabled=True, save_path="./rexmoe_phi_mini_moe_r3", gradient_checkpointing=True, met_enabled=False, met_mask_ratio=0.1, met_warmup=0.5, eval_steps=1000, log_loss_steps_percent=10, full_lora=False, lora_r=16, use_scheduler=True, aux_loss_weight=0.02 ): # Setup logger logger, log_file = setup_logger(save_path=os.path.join(save_path, "logs")) print("="*80) print("ReXMoE Cross-Layer Expert Reuse Training") print("="*80) logger.info("="*80) logger.info("ReXMoE Cross-Layer Expert Reuse Training") logger.info("="*80) logger.info("MET enabled: {}".format(met_enabled)) config_msg = f""" Configuration: Model: {model_name} Dataset: {dataset_path} Dataset mode: {dataset_mode} Reuse Scale (R): {reuse_scale} Prune Ratio (MET): {met_mask_ratio if met_enabled else 'N/A'} Epochs: {num_epochs} Num of samples: {num_samples} Batch Size: {batch_size} Sequence Length: {max_seq_length} Learning Rate: {lr} PSR Enabled: {psr_enabled} LR Scheduler: {use_scheduler} Save Path: {save_path} Gradient Checkpointing: {gradient_checkpointing} LoRA Rank: {lora_r} (Full LoRA: {full_lora}) LoRA Alpha: {lora_r * 2} MET Enabled: {met_enabled} (Mask Ratio: {met_mask_ratio}, Warmup: {met_warmup}) Log File: {log_file} Aux loss weight: {aux_loss_weight} """ print(config_msg) logger.info(config_msg) print("="*80) if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") device_msg = f"💻 Using device: {device})" print(f"\n{device_msg}") logger.info(device_msg) if torch.cuda.is_available(): gpu_msg = f"GPU: {torch.cuda.get_device_name(0)}, Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB" print(f" {gpu_msg}") logger.info(gpu_msg) # Load tokenizer + model print(f"\n[1/7] Loading model: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print(f"Loading model to device {device} (no device_map sharding)...") model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, device_map=None, # Do NOT auto-shard - we'll place it manually trust_remote_code=False ) print(f"Moving model to {device}...") model = model.to(device) print(f"✓ Model moved to {device}") # Verify model is on correct device model_device = next(model.parameters()).device print(f"✓ Model device verified: {model_device}") if gradient_checkpointing: model.gradient_checkpointing_enable() print(f"Model loaded: {model.config.num_hidden_layers} layers") print(f"Hidden size: {model.config.hidden_size}") print(f"Experts per layer: {model.config.num_local_experts}") # Collect all experts from all layers print(f"\n[2/7] Collecting expert references from all layers...") total_layers = model.config.num_hidden_layers all_experts_dict = {} for layer_idx, layer in enumerate(model.model.layers): if hasattr(layer, "block_sparse_moe"): all_experts_dict[layer_idx] = layer.block_sparse_moe.experts print(f"Collected {len(all_experts_dict)} MoE layers") # Replace MoE blocks with ReXMoE blocks print(f"\n[3/7] Replacing MoE blocks with ReXMoE routers (R={reuse_scale})...") moe_count = 0 for layer_idx, layer in enumerate(model.model.layers): if hasattr(layer, "block_sparse_moe"): original_moe = layer.block_sparse_moe # Create ReXMoE block (keeps expert references, replaces router) rexmoe_block = ReXMoESparseMoeBlock( original_moe_block=original_moe, layer_idx=layer_idx, total_layers=total_layers, all_experts_dict=all_experts_dict, reuse_scale=reuse_scale, logger=logger, aux_loss_weight=aux_loss_weight ) rexmoe_block.met_enabled = met_enabled # Attach logger to block so its forward can access logging even when # the higher-level `model()` call doesn't pass a logger argument. rexmoe_block.logger = logger # Move ReXMoE block to correct device rexmoe_block = rexmoe_block.to(dtype=torch.bfloat16, device=device) # Replace the block layer.block_sparse_moe = rexmoe_block moe_count += 1 print(f"✓ ReXMoE blocks installed: {moe_count} layers modified") print(f" Each router can now access up to {reuse_scale * 16} experts (R={reuse_scale})") # Warmup phase: only routers (gates) trainable print(f"\n[4/7] Initial freeze: only routers trainable for warmup phase...") total_params = 0 trainable_params = 0 for name, param in model.named_parameters(): total_params += param.numel() if ".block_sparse_moe.gate" in name: # Router gates trainable param.requires_grad = True trainable_params += param.numel() else: param.requires_grad = False print(f"Total parameters: {total_params:,}") print(f"Trainable parameters (warmup): {trainable_params:,} ({100*trainable_params/total_params:.2f}%)") # Verify only routers are trainable at start trainable_layers = [name for name, param in model.named_parameters() if param.requires_grad] print(f"✓ Warmup trainable components: {len(trainable_layers)} parameters (router gates only)") if len(trainable_layers) > 0: print(f" First: {trainable_layers[0]}") print(f" Last: {trainable_layers[-1]}") # Optimizer (router params only) - Use 8-bit AdamW for memory efficiency print(f"\n[5/7] Setting up optimizer and dataset...") logger.info("[5/7] Setting up optimizer and dataset...") print("Using 8-bit AdamW optimizer for memory efficiency") logger.info("Using 8-bit AdamW optimizer") optimizer = bnb.optim.AdamW8bit( [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=0.1 ) # Learning rate scheduler (cosine annealing) scheduler = None if use_scheduler: scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr * 0.1) print(f"Using CosineAnnealingLR scheduler: {lr} → {lr * 0.1}") logger.info(f"LR Scheduler: CosineAnnealingLR ({lr} → {lr * 0.1})") # Prepare dataset (Instruction Fine-tuning or Pretraining) print(f"[5/7] Preparing dataset: mode={dataset_mode}, path={dataset_path}") try: train_loader = get_dataloader( mode=dataset_mode, tokenizer=tokenizer, dataset_path=dataset_path, max_seq_length=max_seq_length, batch_size=batch_size, num_samples=num_samples, shuffle=True, ) except Exception as e: print(f"Could not prepare dataset: {e}") raise train_len = num_samples print(f"Training samples: {train_len}") print(f"Batch size: {batch_size}") print(f"Sequence length: {max_seq_length}") # Training loop print(f"\n[6/7] Starting training for {num_epochs} epochs...") print(f"PSR enabled: {psr_enabled}") total_steps = len(train_loader) * num_epochs # PSR schedule changes based on whether MET is enabled if psr_enabled: if met_enabled: warmup_steps_psr = int(met_warmup * total_steps) print(f"PSR schedule: R=2 → R={reuse_scale} during MET warmup phase (steps 0-{warmup_steps_psr})") print(f" then stays at R={reuse_scale} during pruning/finetuning phases") else: psr_completion_steps = int(0.5 * total_steps) print(f"PSR schedule: R=2 → R={reuse_scale} over first 50% of training (steps 0-{psr_completion_steps})") step = 0 # Track statistics first_batch_logged = False # Track routing patterns for analysis # Structure: routing_stats[layer_idx][(target_layer, target_expert)] = count routing_stats = {} for layer_idx in range(total_layers): routing_stats[layer_idx] = {} # Convergence tracking convergence_history = [] epoch_entropies = [] epoch_aux_losses = [] model.train() best_val = float("inf") best_epoch = -1 qlora_enabled = False # track switch from warmup (routers-only) to routers+LoRA print_met_active = False print_met_freeze = False for epoch in range(num_epochs): print(f"\n{'='*60}") print(f"Epoch {epoch+1}/{num_epochs}") print(f"{'='*60}") epoch_loss = 0 epoch_aux_loss = 0 epoch_entropy = 0 # Track routing entropy pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}") for batch_idx, batch in enumerate(pbar): # Switch from warmup (routers only) to routers + LoRA adapters on experts if (not qlora_enabled) and step >= warmup_steps: if full_lora: logger.info(f"Warmup completed at step {step}. Enabling FULL QLoRA with r = {lora_r} and alpha = {lora_r * 2} on experts and updating optimizer...") else: logger.info(f"Warmup completed at step {step}. Enabling QLoRA on experts.") # Attach LoRA adapters globally (wrap linear layers) # Note: PhiMoE experts use w1/w2/w3 linear layers, not gate_proj/up_proj/down_proj. # We target those names so LoRA attaches only to expert MLP weights. lora_config = LoraConfig( r=lora_r, lora_alpha=lora_r * 2, lora_dropout=0.00, bias="none", task_type="CAUSAL_LM", target_modules=[ "w1", "w2", "w3", # if full_lora, also target attention layers of transformer blocks (but NOT router gates) "q_proj" if full_lora else None, "k_proj" if full_lora else None, "v_proj" if full_lora else None, "o_proj" if full_lora else None ], ) model = get_peft_model(model, lora_config) # Freeze everything, then re-enable router gates and LoRA params total_params = 0 trainable_params = 0 for name, param in model.named_parameters(): total_params += param.numel() param.requires_grad = False for name, param in model.named_parameters(): if ".block_sparse_moe.gate" in name: param.requires_grad = True trainable_params += param.numel() elif "lora_" in name: param.requires_grad = True trainable_params += param.numel() optimizer = bnb.optim.AdamW8bit( [p for p in model.parameters() if p.requires_grad], lr=lr, weight_decay=0.1 ) print(f"Total parameters after QLoRA: {total_params:,}") print(f"Trainable parameters (routers + LoRA): {trainable_params:,} ({100*trainable_params/total_params:.4f}%)") logger.info(f"Trainable params (routers + LoRA): {trainable_params} ({100*trainable_params/total_params:.4f}%)") trainable_names = [n for n, p in model.named_parameters() if p.requires_grad] print("Sample trainable params after QLoRA:", trainable_names[:10]) logger.info(f"Sample trainable params after QLoRA: {trainable_names[:10]}") qlora_enabled = True # Update step counter in all MoE blocks (for PSR) # Note: after enabling QLoRA, `model` becomes a PeftModel whose # underlying transformer is in `model.base_model.model`. backend_model = get_backend_model(model) for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): layer.block_sparse_moe.current_step = step layer.block_sparse_moe.total_steps = total_steps # Pass met_warmup to PSR scheduler so it completes within warmup phase layer.block_sparse_moe.met_warmup = met_warmup if met_enabled else None # === IG-MET Global Threshold Update === # Key insight: Count UNIQUE experts (512 base), not router-level copies (up to 1024 with reuse). # If an expert is pruned, it's pruned everywhere it can be accessed. if met_enabled: # Calculate target mask ratio for this step with improved schedule final_mask_ratio = met_mask_ratio progress = min(step / total_steps, 1.0) # Improved Three-Phase Schedule (Aggressive): # Phase 1: 0-met_warmup - NO pruning (extended warmup for stability) # Phase 2: met_warmup-0.8 - Gradual pruning ramp with curve (avoid abrupt changes) # Phase 3: 0.8-100% - Freeze pruning, only fine-tune remaining experts phase2_end = 0.8 if progress < met_warmup: # Phase 1: Extended warmup, no pruning current_target_ratio = 0.0 elif progress < phase2_end: # Phase 2: Controlled pruning ramp with exponential curve pruning_window = phase2_end - met_warmup pruning_progress = (progress - met_warmup) / pruning_window # 0 to 1 # Use power curve to avoid aggressive early pruning # Power = 1.2 for smoother ramp since pruning window is longer current_target_ratio = final_mask_ratio * (pruning_progress ** 1.2) if not print_met_active: logger.info(f"[IG-MET] Entered pruning phase with gradual ramp (step={step}, target_ratio={current_target_ratio:.3f})") print_met_active = True else: # Phase 3: Freeze pruning decisions, only fine-tune remaining experts current_target_ratio = final_mask_ratio # Optionally freeze pruning masks here in the future if not hasattr(model, '_pruning_frozen'): model._pruning_frozen = True if not print_met_freeze: logger.info("[IG-MET] Entered fine-tuning phase (pruning decisions frozen)") print_met_freeze = True if current_target_ratio > 0: if step % 100 == 0: logger.info(f"[IG-MET] Masked Expert Training is now ACTIVE (step={step}, target_ratio={current_target_ratio:.3f})") # Collect EMA for UNIQUE (layer_idx, expert_idx) pairs only (not duplicates) # Use "SUM" aggregation: we care about the total utility of an expert across all contexts. unique_experts = {} # (orig_layer, orig_expert) -> summed_ema_score # First Pass: Aggregation for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): router = layer.block_sparse_moe.router # Reuse same logic as Router to map pool_pos -> (orig_layer, orig_expert) current_r = router.get_candidate_layers(step, total_steps) half = (current_r - 1) // 2 start_layer = max(0, layer_idx - half) end_layer = min(total_layers, start_layer + current_r) start_layer = max(0, end_layer - current_r) # Reconstruct mapping for this router step current_mapping = [] for layer_offset in range(current_r): l_id = start_layer + layer_offset if l_id >= total_layers: break for e_id in range(router.num_experts_per_layer): current_mapping.append((l_id, e_id)) num_active = len(current_mapping) # Aggregate EMA for pool_pos in range(num_active): if pool_pos >= len(router.ema_utilization): break key = current_mapping[pool_pos] # (orig_layer, orig_expert) ema_val = router.ema_utilization[pool_pos].item() if key not in unique_experts: unique_experts[key] = ema_val else: # Sum EMA across all reused contexts unique_experts[key] += ema_val # Compute threshold based on UNIQUE SUMMED experts if unique_experts: all_ema_values = list(unique_experts.values()) all_ema_tensor = torch.tensor(all_ema_values, device=device) k = int(len(all_ema_values) * current_target_ratio) # Determine set of GLOBALLY pruned experts pruned_keys = set() threshold = 0.0 if k > 0 and all_ema_tensor.sum() > 0: sorted_vals, _ = torch.sort(all_ema_tensor) threshold = sorted_vals[k].item() # Identify which UNIQUE experts are below the global sum threshold pruned_keys = {key for key, val in unique_experts.items() if val < threshold} # Second Pass: Distribute Pruning Mask to Routers # Instead of a scalar threshold (which fails for summed aggregation), # we push a binary mask of "keep vs prune" to each router. for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): router = layer.block_sparse_moe.router # Re-calculate mapping to generate mask current_r = router.get_candidate_layers(step, total_steps) half = (current_r - 1) // 2 start_layer = max(0, layer_idx - half) end_layer = min(total_layers, start_layer + current_r) start_layer = max(0, end_layer - current_r) current_mapping = [] for layer_offset in range(current_r): l_id = start_layer + layer_offset if l_id >= total_layers: break for e_id in range(router.num_experts_per_layer): current_mapping.append((l_id, e_id)) # Create binary mask: True = KEEP, False = PRUNE # Size = max_pool_size (pad with True to be safe) keep_mask = torch.ones(router.max_pool_size, dtype=torch.bool, device=device) for pool_pos, key in enumerate(current_mapping): if key in pruned_keys: keep_mask[pool_pos] = False # Push mask to router router.global_keep_mask = keep_mask # We also update mask_threshold for logging purposes router.mask_threshold.fill_(threshold) # Log statistics if step % 10 == 0: total_unique_pruned = len(pruned_keys) total_unique_active = len(unique_experts) logger.info(f"[IG-MET Global] Step {step}: Threshold={threshold:.6f}. Pruned {total_unique_pruned}/{total_unique_active} UNIQUE experts ({100*total_unique_pruned/total_unique_active:.1f}%). Target ratio: {current_target_ratio:.3f}") else: # Current ratio too small to mask any expert yet for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) if hasattr(layer.block_sparse_moe.router, 'global_keep_mask'): layer.block_sparse_moe.router.global_keep_mask = None else: # No experts found (shouldn't happen) for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) else: # Reset threshold (no masking) during 0-50% warmup phase for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): layer.block_sparse_moe.router.mask_threshold.fill_(-1.0) input_ids = batch["input_ids"].to(model.device) attention_mask = batch["attention_mask"].to(model.device) labels = batch["labels"].to(model.device) # Forward pass with PSR-aware routing # todo: Why not use ''' outputs = model( instructions=instructions, input_ids=input_ids, attention_mask=attention_mask, labels=labels ) ''' outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, use_cache=False ) loss = outputs.loss # Collect auxiliary losses from all ReXMoE routers # This is CRITICAL for load balancing and preventing routing collapse aux_loss_total = 0.0 for layer in backend_model.layers: if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): if layer.block_sparse_moe.last_aux_loss is not None: aux_loss_total += layer.block_sparse_moe.last_aux_loss # Collect routing statistics (which experts were selected) for layer_idx, layer in enumerate(backend_model.layers): if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock): moe_block = layer.block_sparse_moe # Get the layer-expert mapping from the last forward pass router = moe_block.router # Don't call router with hidden_states=None (router.forward expects a tensor). # Prefer the last computed mapping from a real forward pass; if not # available, query the router with a small dummy tensor to get the mapping. layer_expert_mapping = getattr(router, 'last_layer_expert_mapping', None) if layer_expert_mapping is None: try: hidden_dim = router.gate.in_features if hasattr(router, 'gate') else moe_block.hidden_dim dummy = torch.zeros(1, 1, hidden_dim, device=model.device) _, _, _, layer_expert_mapping = router( hidden_states=dummy, step=step, total_steps=total_steps ) except Exception: layer_expert_mapping = [] # Note: For efficiency, we'll track routing patterns every N batches # to avoid slowdown. Full tracking can be enabled for analysis. pass # Detailed tracking will be done in a separate analysis pass # Compute routing entropy for convergence monitoring (sample from output) # We'll approximate entropy from the auxiliary loss and routing distribution # Note: Full entropy computation would require storing all router outputs # Total loss = Language modeling loss + Auxiliary load balancing loss total_loss = loss + aux_loss_total # Backward pass optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # Logging epoch_loss += loss.item() epoch_aux_loss += aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total # Calculate current reuse scale for logging (match the PSR logic in router) if psr_enabled: if met_enabled: # New behavior: PSR completes within first phase (0 to met_warmup) progress = min(step / (met_warmup * total_steps), 1.0) current_r = 2 + int(progress * (reuse_scale - 2)) else: # Legacy behavior: Linear schedule over first 50% of training progress = min(step / (0.5 * total_steps), 1.0) current_r = 2 + int(progress * (reuse_scale - 2)) # print(f"current_r: {current_r}, progress: {progress}") else: current_r = reuse_scale # Log first batch details if not first_batch_logged: logger.info(f"\n First batch statistics:") logger.info(f" LM Loss: {loss.item():.4f}") logger.info(f" Aux Loss: {aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}") logger.info(f" Total Loss: {total_loss.item():.4f}") logger.info(f" Current R: {current_r}") logger.info(f" Active experts per layer: {current_r * 16}") logger.info(f" Gradient norm: {torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')):.4f}") first_batch_logged = True logger.info(f" \n") # Print periodic updates (every log_loss_steps_percent% of epoch) if batch_idx > 0 and batch_idx % max(1, len(train_loader) // (100 // log_loss_steps_percent)) == 0: logger.info(f" [{batch_idx}/{len(train_loader)}] loss={loss.item():.4f} aux={aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f} R={current_r}") pbar.set_postfix({ "loss": f"{loss.item():.4f}", "aux": f"{aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}", "total": f"{total_loss.item():.4f}", "R": current_r, "step": f"{step}/{total_steps}" }) step += 1 # Evaluate every eval_steps (after batching/optimization) if step % eval_steps == 0: model.eval() logger.info(f"\n[Step {step}/{total_steps}] Running evaluation at eval_steps...") evaluate_prompt(model, tokenizer, logger=logger) model.train() # Routing analysis at eval_steps logger.info(f"\n[Step {step}] Analyzing routing patterns at eval_steps...") analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger) # Save a checkpoint at this step logger.info(f"\n[Step {step}] Saving checkpoint at eval_steps to {save_path}...") os.makedirs(save_path, exist_ok=True) # Option 1: Save only router weights (recommended - more portable) print("\nSaving trained router weights only...") router_state_dict = {} for name, param in model.named_parameters(): if ".block_sparse_moe.gate" in name and param.requires_grad: router_state_dict[name] = param.data.cpu() # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation for name, buf in model.named_buffers(): if "ema_utilization" in name or "mask_threshold" in name: logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation") router_state_dict[name] = buf.data.cpu() else: logger.info(f"Skipping buffer {name} (not related to routing)") torch.save({ 'router_state_dict': router_state_dict, 'config': { 'reuse_scale': reuse_scale, 'num_epochs': num_epochs, 'lr': lr, 'model_name': model_name } }, os.path.join(save_path, 'rexmoe_routers.pt')) tokenizer.save_pretrained(save_path) logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") logger.info(f" File: {save_path}/rexmoe_routers.pt") logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") # Option 2: Also save full model (includes architecture, but less portable) logger.info("\nAlso saving full model with ReXMoE architecture...") model.save_pretrained(save_path) # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled) # This produces a standard HF model directory that can be loaded without PEFT. if qlora_enabled: try: merged_path = os.path.join(save_path, "merged") os.makedirs(merged_path, exist_ok=True) logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") # merge_and_unload() returns the base model with LoRA weights folded in. merged_model = model.merge_and_unload() merged_model.eval() tokenizer.save_pretrained(merged_path) merged_model.save_pretrained(merged_path) logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") except Exception as e: logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") avg_epoch_loss = epoch_loss / len(train_loader) avg_epoch_aux_loss = epoch_aux_loss / len(train_loader) # Store metrics for convergence tracking epoch_aux_losses.append(avg_epoch_aux_loss) logger.info(f"\n{'='*60}") logger.info(f"Epoch {epoch+1} Summary:") logger.info(f" Average LM Loss: {avg_epoch_loss:.4f}") logger.info(f" Average Aux Loss: {avg_epoch_aux_loss:.6f}") logger.info(f" Average Total Loss: {avg_epoch_loss + avg_epoch_aux_loss:.4f}") logger.info(f" Final R: {current_r}") # Evaluate at epoch end model.eval() evaluate_prompt(model, tokenizer, logger=logger) model.train() # Track and save best checkpoint based on average LM loss if avg_epoch_loss < best_val: best_val = avg_epoch_loss best_epoch = epoch + 1 logger.info(f"New best epoch {best_epoch} with avg LM loss {best_val:.4f} — saving checkpoint to {save_path}") os.makedirs(save_path, exist_ok=True) # Option 1: Save only router weights (recommended - more portable) print("\nSaving trained router weights only...") router_state_dict = {} for name, param in model.named_parameters(): if ".block_sparse_moe.gate" in name and param.requires_grad: router_state_dict[name] = param.data.cpu() # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation for name, buf in model.named_buffers(): if "ema_utilization" in name or "mask_threshold" in name: logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation") router_state_dict[name] = buf.data.cpu() else: logger.info(f"Skipping buffer {name} (not related to routing)") torch.save({ 'router_state_dict': router_state_dict, 'config': { 'reuse_scale': reuse_scale, 'num_epochs': num_epochs, 'lr': lr, 'model_name': model_name } }, os.path.join(save_path, 'rexmoe_routers.pt')) tokenizer.save_pretrained(save_path) logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") logger.info(f" File: {save_path}/rexmoe_routers.pt") logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") # Option 2: Also save full model (includes architecture, but less portable) logger.info("\nAlso saving full model with ReXMoE architecture...") model.save_pretrained(save_path) # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled) # This produces a standard HF model directory that can be loaded without PEFT. if qlora_enabled: try: merged_path = os.path.join(save_path, "merged") os.makedirs(merged_path, exist_ok=True) logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") # merge_and_unload() returns the base model with LoRA weights folded in. merged_model = model.merge_and_unload() merged_model.eval() tokenizer.save_pretrained(merged_path) merged_model.save_pretrained(merged_path) logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") except Exception as e: logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") # Check convergence converged, conv_metrics, conv_warnings = check_router_convergence( model, total_layers, convergence_history, logger=logger ) # Get current learning rate current_lr = optimizer.param_groups[0]['lr'] logger.info(f"\n 📊 Convergence Metrics:") logger.info("Convergence Metrics:") logger.info(f" Avg Router Grad Norm: {conv_metrics['avg_router_grad_norm']:.6f}") if 'grad_stability' in conv_metrics: print(f" Grad Stability: {conv_metrics['grad_stability']:.6f}") logger.info(f" Grad Stability: {conv_metrics['grad_stability']:.6f}") print(f" Current Learning Rate: {current_lr:.2e}") logger.info(f" Current Learning Rate: {current_lr:.2e}") if len(epoch_aux_losses) >= 2: aux_change = abs(epoch_aux_losses[-1] - epoch_aux_losses[-2]) print(f" Aux Loss Change: {aux_change:.6f}") logger.info(f" Aux Loss Change: {aux_change:.6f}") # Print warnings if conv_warnings: for warning in conv_warnings: print(f"\n {warning}") # Already logged in check_router_convergence if converged : msg = "✅ CONVERGED: Router weights have stabilized! Gradient norm < 0.1 and stable for 5 epochs" print(f"\n {msg}") logger.info(msg) if not getattr(mode, '_routers_soft_fronze', False): logger.info("Soft freezing routers for remaining epochs to focus on fine-tuning experts") for param_group in optimizer.param_groups: for name, param in model.named_parameters() and param in param_group['params']: if ".block_sparse_moe.gate" in name: param_group['weight_decay'] = 0.5 param_group['lr'] = current_lr * 0.1 setattr(mode, '_routers_soft_fronze', True) model._routers_soft_fronze = True elif len(convergence_history) >= 5: msg = "⏳ Not yet converged - continuing training..." print(f"\n {msg}") logger.info(msg) else: msg = "ℹ️ Collecting convergence data (need 5 epochs minimum)..." print(f"\n {msg}") logger.info(msg) print(f"{'='*60}") # Analyze routing patterns at epoch end print(f"\n📊 Routing Pattern Analysis (Epoch {epoch+1}):") logger.info(f"Routing Pattern Analysis (Epoch {epoch+1}):") print("-" * 60) analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger) print("-" * 60) # Step the learning rate scheduler if scheduler is not None: scheduler.step() logger.info(f"LR stepped to: {optimizer.param_groups[0]['lr']:.2e}") # Final convergence report print(f"\n{'='*80}") print("📈 Training Convergence Summary") print(f"{'='*80}") logger.info("="*80) logger.info("Training Convergence Summary") logger.info("="*80) if len(convergence_history) > 0: print(f"\nRouter Gradient Norms Over Epochs:") logger.info("Router Gradient Norms Over Epochs:") for i, grad_norm in enumerate(convergence_history): trend = "" if i > 0: change = grad_norm - convergence_history[i-1] trend = f" (Δ {change:+.6f})" msg = f" Epoch {i+1}: {grad_norm:.6f}{trend}" print(msg) logger.info(msg) if len(epoch_aux_losses) > 0: print(f"\nAuxiliary Loss Over Epochs:") logger.info("Auxiliary Loss Over Epochs:") for i, aux_loss in enumerate(epoch_aux_losses): trend = "" if i > 0: change = aux_loss - epoch_aux_losses[i-1] trend = f" (Δ {change:+.6f})" msg = f" Epoch {i+1}: {aux_loss:.6f}{trend}" print(msg) logger.info(msg) # Final convergence assessment if len(convergence_history) >= 5: final_converged, final_metrics, final_warnings = check_router_convergence( model, total_layers, convergence_history, logger=logger ) print(f"\n{'='*80}") print(f"Final Convergence Status:") logger.info("="*80) logger.info("Final Convergence Status:") if final_converged: msg = "✅ CONVERGED - Routers have reached stable configuration" print(f" {msg}") logger.info(msg) print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (< 0.1)") print(f" - Stability: {final_metrics['grad_stability']:.6f} (< 0.01)") logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}") logger.info(f" Stability: {final_metrics['grad_stability']:.6f}") msg = "Safe to deploy or proceed to parameter merging" print(f" {msg}") logger.info(msg) else: msg = "⚠️ NOT FULLY CONVERGED" print(f" {msg}") logger.warning(msg) print(f" Current metrics:") print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (target: < 0.1)") logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}") if 'grad_stability' in final_metrics: print(f" - Stability: {final_metrics['grad_stability']:.6f} (target: < 0.01)") logger.info(f" Stability: {final_metrics['grad_stability']:.6f}") print(f" Consider training for more epochs if:") print(f" - Aux loss still decreasing significantly") print(f" - Routing patterns still changing") print(f" - Gradient norms not stabilized") print(f"{'='*80}\n") logger.info("="*80) else: print(f"\n{'='*80}") print(f"Convergence Status: Insufficient data (< 5 epochs)") print(f" Run for at least 5 epochs for convergence analysis") print(f"{'='*80}\n") logger.info("Convergence Status: Insufficient data (< 5 epochs)") # Save model print(f"\n[7/7] Saving router-adapted checkpoint to: {save_path}") os.makedirs(save_path, exist_ok=True) # Option 1: Save only router weights (recommended - more portable) logger.info("\nSaving trained router weights only...") router_state_dict = {} for name, param in model.named_parameters(): if ".block_sparse_moe.gate" in name and param.requires_grad: router_state_dict[name] = param.data.cpu() # IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation for name, buf in model.named_buffers(): if "ema_utilization" in name or "mask_threshold" in name: router_state_dict[name] = buf.data.cpu() torch.save({ 'router_state_dict': router_state_dict, 'config': { 'reuse_scale': reuse_scale, 'num_epochs': num_epochs, 'lr': lr, 'model_name': model_name } }, os.path.join(save_path, 'rexmoe_routers.pt')) tokenizer.save_pretrained(save_path) logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters") logger.info(f" File: {save_path}/rexmoe_routers.pt") logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB") # Option 2: Also save full model (includes architecture, but less portable) logger.info("\nAlso saving full model with ReXMoE architecture...") model.save_pretrained(save_path) # Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA was enabled) if qlora_enabled: try: merged_path = os.path.join(save_path, "merged") os.makedirs(merged_path, exist_ok=True) logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}") merged_model = model.merge_and_unload() merged_model.eval() tokenizer.save_pretrained(merged_path) merged_model.save_pretrained(merged_path) logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading") except Exception as e: logger.warning(f"Could not merge and save LoRA weights (continuing): {e}") # Save the custom classes for reloading import shutil shutil.copy(__file__, os.path.join(save_path, 'rexmoe_architecture.py')) # Print final statistics full_model_size = sum(os.path.getsize(os.path.join(save_path, f)) for f in os.listdir(save_path) if f.endswith('.bin') or f.endswith('.safetensors')) / 1024 / 1024 / 1024 logger.info("="*80) logger.info("✓ Training complete. Two checkpoint formats saved:") logger.info(" 1. Router weights only: rexmoe_routers.pt (portable)") logger.info(" 2. Full model: pytorch_model.bin (requires rexmoe_architecture.py)") logger.info(f"\nCheckpoint directory: {save_path}") logger.info(f"Full model size: {full_model_size:.2f} GB") logger.info("="*80) return model # ==================== 5. USAGE ==================== if __name__ == "__main__": # arg parser parser = argparse.ArgumentParser(description="ReXMoE Training") parser.add_argument("--model_name", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Pretrained model name") parser.add_argument("--model_path", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Path to pretrained model") # ../models/models/microsoft/Phi-mini-MoE-instruct parser.add_argument("--dataset_path", type=str, default="../dataset/alpaca_data_cleaned.json", help="Path to dataset JSON") parser.add_argument("--mode", type=str, choices=["IF","P", "IF_2"], default="IF", help="Dataset mode: IF = instruction-finetune (Alpaca), P = pretraining (C4)") parser.add_argument("--reuse_scale", type=int, default=3, help="Reuse scale R for cross-layer routing") parser.add_argument("--epoch", type=int, default=5, help="Number of training epochs") parser.add_argument("--num_samples", type=int, default=10000, help="Number of training samples to use from the dataset") parser.add_argument("--batch_size", type=int, default=32, help="Training batch size") parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length") parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate") parser.add_argument("--warmup_steps", type=int, default=200, help="Number of steps to warm up with R=1 (routers only)") parser.add_argument("--psr_enabled", action='store_true', help="Enable Progressive Scaling Routing (PSR)") parser.add_argument("--use_scheduler", action='store_true', default=True, help="Use learning rate scheduler") parser.add_argument("--gradient_checkpointing", action='store_true', help="Enable gradient checkpointing for memory efficiency") parser.add_argument("--met_enabled", action='store_true', help="Enable Masked Expert Training (MET)") parser.add_argument("--met_mask_ratio", type=float, default=0.1, help="MET mask ratio (0.1 = mask 10% of experts)") parser.add_argument("--met_warmup", type=float, default=0.5, help="Proportion of steps to warm up MET (no masking)") parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate every N steps during training") parser.add_argument("--log_loss_steps_percent", type=int, default=10, help="Log loss every N%% of total steps") parser.add_argument("--full_lora", action='store_true', help="Enable full LoRA training") parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank") parser.add_argument("--save_path", type=str, default="./rexmoe_natural_phi_mini_moe", help="Base path to save trained model (timestamp will be prefixed)") parser.add_argument("--aux_loss_weight", type=float, default=0.02, help="Auxiliary loss weight") args = parser.parse_args() # Prefix save path with timestamp (DDMM_HHMMSS) to distinguish runs from datetime import datetime as _dt timestamp = _dt.now().strftime("%d%m_%H%M%S") timed_save_path = os.path.join(os.path.dirname(args.save_path), f"{timestamp}_" + f"{int(args.met_mask_ratio*100)}_" + os.path.basename(args.save_path)) + f"_R{args.reuse_scale}" model = train_rexmoe( model_name=args.model_name, model_path=args.model_path, dataset_path=args.dataset_path, dataset_mode=args.mode, reuse_scale=args.reuse_scale, num_samples=args.num_samples, num_epochs=args.epoch, # 5 epochs sufficient for router adaptation batch_size=args.batch_size, # As specified max_seq_length=args.max_seq_length, # As specified lr=args.lr, warmup_steps=args.warmup_steps, psr_enabled=args.psr_enabled, # Critical: prevents early collapse use_scheduler=args.use_scheduler, gradient_checkpointing=args.gradient_checkpointing, met_enabled=args.met_enabled, met_mask_ratio=args.met_mask_ratio, met_warmup=args.met_warmup, eval_steps=args.eval_steps, log_loss_steps_percent=args.log_loss_steps_percent, full_lora=args.full_lora, lora_r=args.lora_r, save_path=timed_save_path, aux_loss_weight=args.aux_loss_weight ) print(f"\n✓ All done! Model saved to {timed_save_path}")