Reuse-Trained-R3 / rexmoe_architecture.py
PakNin's picture
Upload folder using huggingface_hub
16059ba verified
Raw
History Blame Contribute Delete
102 kB
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model
import argparse
import logging
from datetime import datetime
from torch.optim.lr_scheduler import CosineAnnealingLR
from typing import Optional
from get_dataset import get_dataloader
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# ==================== 0. LOGGING SETUP ====================
def setup_logger(save_path="./logs"):
"""Setup logger with timestamp-based filename"""
os.makedirs(save_path, exist_ok=True)
# Create unique log filename: DDMM_HHMMSS
timestamp = datetime.now().strftime("%d%m_%H%M%S")
log_file = os.path.join(save_path, f"rexmoe_training_{timestamp}.log")
# Create logger
logger = logging.getLogger('ReXMoE')
logger.setLevel(logging.INFO)
# Remove existing handlers
logger.handlers = []
# File handle
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.INFO)
# Console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# Formatter
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.info(f"=" * 80)
logger.info(f"ReXMoE Training Log - {timestamp}")
logger.info(f"Log file: {log_file}")
logger.info(f"=" * 80)
return logger, log_file
# Format prompt
def format_alpaca_prompt(instruction: str, input_text: str = "") -> str:
"""Match the training prompt template used in main.py."""
if input_text:
return f"### Instruction:\n{instruction}\n\n### Input:\n{input_text}\n\n### Response:\n"
return f"### Instruction:\n{instruction}\n\n### Response:\n"
def build_model_input(tokenizer, instruction: str, input_text: str = "") -> str:
"""Prefer the model chat template if available; fall back to Alpaca prompt."""
user_msg = instruction if not input_text else f"{instruction}\n\n{input_text}"
# Newer HF tokenizers expose an explicit chat template for instruct models.
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
messages = [{"role": "user", "content": user_msg}]
print(f"Applying tokenizer's chat template: {tokenizer.chat_template}")
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return format_alpaca_prompt(instruction=instruction, input_text=input_text)
# Evaluate
# ...existing code...
def evaluate_prompt(model, tokenizer, max_new_tokens=100, do_sample=True, temperature=0.7, logger=None):
"""Generate completions for 3 sample prompts and print/log results."""
try:
# Safely get device for tensors
if hasattr(model, "device"):
device = model.device
else:
# fallback: first parameter device
device = next(model.parameters()).device
msg = "\nEvaluating model with 3 sample prompts..."
if logger:
logger.info(msg)
print(msg)
# Display pruning status if IG-MET is enabled (count UNIQUE experts, not router-level copies)
backend_model = get_backend_model(model)
pruning_info = []
unique_experts_pruned = set()
unique_experts_total = set()
unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score
# 1. Aggregate EMA values
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
router = layer.block_sparse_moe.router
threshold = router.mask_threshold.item()
if threshold >= 0: # IG-MET enabled
# Reconstruct mapping logic carefully
current_r = router.get_candidate_layers(step=None, total_steps=None)
half = (current_r - 1) // 2
start_layer = max(0, layer_idx - half)
end_layer = min(len(backend_model.layers), start_layer + current_r)
start_layer = max(0, end_layer - current_r)
# Build mapping for this router
current_mapping = []
for layer_offset in range(current_r):
l_id = start_layer + layer_offset
if l_id >= len(backend_model.layers): break
for e_id in range(router.num_experts_per_layer):
current_mapping.append((l_id, e_id))
num_active = len(current_mapping)
# Accumulate EMA
for pool_pos, key in enumerate(current_mapping):
if pool_pos >= len(router.ema_utilization): break
unique_experts_total.add(key)
ema_val = router.ema_utilization[pool_pos].item()
if key not in unique_experts_sum:
unique_experts_sum[key] = ema_val
else:
unique_experts_sum[key] += ema_val
# 2. Determine pruning status based on SUMMED usage vs Threshold
# Note: All routers share the same threshold value derived from summed distribution
if unique_experts_sum and hasattr(backend_model.layers[0].block_sparse_moe.router, "mask_threshold"):
# Get current global threshold from first router
threshold = backend_model.layers[0].block_sparse_moe.router.mask_threshold.item()
unique_experts_pruned = {k for k, v in unique_experts_sum.items() if v < threshold}
msg = f"\n[IG-MET Pruning Status during Evaluation]:"
print(msg)
if logger:
logger.info(msg)
total_unique_pruned = len(unique_experts_pruned)
total_unique = len(unique_experts_total)
pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0
msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold:.6f}"
print(msg)
if logger:
logger.info(msg)
# Count per-layer pruning stats based on GLOBAL decision
# Rerun loop just for stats
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
router = layer.block_sparse_moe.router
# Mapping logic again
current_r = router.get_candidate_layers(step=None, total_steps=None)
half = (current_r - 1) // 2
start_layer = max(0, layer_idx - half)
end_layer = min(len(backend_model.layers), start_layer + current_r)
start_layer = max(0, end_layer - current_r)
masked_in_layer = 0
# Only count experts from the CURRENT layer, not all reused layers
current_layer_experts = [(layer_idx, e_id) for e_id in range(router.num_experts_per_layer)]
for key in current_layer_experts:
if key in unique_experts_pruned:
masked_in_layer += 1
# num_active = total experts in current layer (always num_experts_per_layer for the layer itself)
num_active = router.num_experts_per_layer
pruning_info.append((layer_idx, threshold, masked_in_layer, num_active))
# Show top/bottom layers by pruning ratio
pruning_by_ratio = sorted(pruning_info, key=lambda x: x[2]/x[3] if x[3] > 0 else 0, reverse=True)
msg = "Top 5 most pruned layers:"
print(msg)
if logger:
logger.info(msg)
for layer_idx, thr, masked, total in pruning_by_ratio[:5]:
pct = 100 * masked / total if total > 0 else 0
msg = f" Layer {layer_idx:>2}: {masked:>2}/{total} pruned ({pct:>5.1f}%)"
print(msg)
if logger:
logger.info(msg)
# Define 3 evaluation prompts
eval_prompts = [
{
"instruction": "What is the capital of France?",
"input_text": None
},
{
"instruction": "High-pressure systems stop air from rising into the colder regions of the atmosphere where water can condense. What will most likely result if a high-pressure system remains in an area for a long period of time?\nA. fog\nB. rain\nC. drought\nD. tornado\nAnswer:",
"input_text": None
},
{
"instruction": "Given the fact: predators eat prey\nQuestion: Predators eat\nA. lions\nB. humans\nC. bunnies\nD. grass\nAnswer:",
"input_text": None
}
]
for prompt_idx, prompt_config in enumerate(eval_prompts, 1):
print("\n" + "=" * 80)
print(f"Prompt {prompt_idx}/3:")
print(f" Instruction: {prompt_config['instruction']}")
print(f" Input: {prompt_config['input_text']}")
print("=" * 80)
# Build prompt
prompt = build_model_input(
tokenizer,
instruction=prompt_config['instruction'],
input_text=prompt_config['input_text'],
)
# Tokenize and move tensors to device
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
pad_token_id=getattr(tokenizer, "pad_token_id", None),
eos_token_id=getattr(tokenizer, "eos_token_id", None),
)
prompt_len = inputs["input_ids"].shape[-1]
# outputs: tensor [batch, seq_len]
generated = outputs[0]
completion_ids = generated[prompt_len:]
completion_text = tokenizer.decode(completion_ids, skip_special_tokens=True)
full_text = tokenizer.decode(generated, skip_special_tokens=True)
print("GENERATED RESPONSE:")
print("-" * 80)
print(completion_text)
print("-" * 80)
# Debugging info
print(f"\n[debug] prompt_tokens={prompt_len}, new_tokens={int(completion_ids.numel())}")
if completion_ids.numel() == 0:
print("[debug] Model generated 0 new tokens (likely hit EOS immediately).")
print("[debug] Full decoded text:\n" + full_text)
elif completion_text.strip() == "":
print("[debug] Model generated tokens, but they decode to empty/whitespace or special tokens.")
print("[debug] completion token ids:", completion_ids.tolist())
if logger:
logger.info(f"\n--- Prompt {prompt_idx}/3 ---")
logger.info(f"Instruction: {prompt_config['instruction']}")
logger.info(f"Input: {prompt_config['input_text']}")
logger.info(f"Generated completion (len {int(completion_ids.numel())}): {completion_text}")
print("=" * 80)
if logger:
logger.info("Evaluation of all 3 prompts complete.")
except Exception as e:
err = f"evaluate_prompt failed: {e}"
print(err)
if logger:
logger.exception(err)
# ==================== 1. MODEL MODIFICATION ====================
class ReXMoERouter(nn.Module):
"""Router logic (no parameters) for cross-layer expert reuse.
Note: This module is intentionally parameterless so that trainable gate
weights live directly under the MoE block as `block_sparse_moe.gate`,
matching the original Phi-MoE naming. This prevents saving keys like
`router.gate` and keeps checkpoint compatibility.
"""
def __init__(self, layer_idx, total_layers=32, num_experts_per_layer=16,
reuse_scale=3, num_experts_per_tok=2, all_experts_dict=None, aux_loss_weight=0.02):
super().__init__()
self.layer_idx = layer_idx
self.total_layers = total_layers
self.num_experts_per_layer = num_experts_per_layer
self.reuse_scale = reuse_scale # R=3: layers [i-1, i, i+1]
self.num_experts_per_tok = num_experts_per_tok
# Store reference to all experts dict to get actual expert counts
self.all_experts_dict = all_experts_dict
# Max pool size = reuse_scale * num_experts_per_layer
self.max_pool_size = reuse_scale * num_experts_per_layer
# EMA tracking for expert utilization
# Initialize with uniform probability (1/num_experts_per_tok?)
# Or just zeros if we want to learn from scratch.
# User says: "smoothed utilization... Ck = raw selection count"
# Since we start with no history, let's init with zeros.
self.register_buffer('ema_utilization', torch.zeros(self.max_pool_size))
self.register_buffer('mask_threshold', torch.tensor(-1.0)) # Default -1 means no masking
self.aux_loss_weight = aux_loss_weight
def get_candidate_layers(self, step, total_steps, psr_enabled=True, initial_R=2, met_warmup=None):
"""Progressive Scaling Routing: gradually expand reuse scale
Args:
step: current training step
total_steps: total steps in training
psr_enabled: whether PSR is enabled
initial_R: initial reuse scale (default 2)
met_warmup: if provided (float 0-1), PSR only runs during 0-met_warmup phase,
then stays at max reuse_scale. If None, uses old schedule (0-50% of training).
"""
if not psr_enabled or step is None or total_steps is None:
return self.reuse_scale
if met_warmup is not None:
# New behavior: PSR completes within the first phase (0 to met_warmup)
# After met_warmup, stay at max R
progress = min(step / (met_warmup * total_steps), 1.0)
current_r = initial_R + int(progress * (self.reuse_scale - initial_R))
else:
# Legacy behavior: Linear schedule R=2 → target_R over first 50% of training
progress = min(step / (0.5 * total_steps), 1.0)
current_r = initial_R + int(progress * (self.reuse_scale - initial_R))
return current_r
def update_ema(self, selection_counts, alpha=0.9):
"""Update EMA tracking for expert utilization"""
# selection_counts: tensor of shape [max_pool_size]
with torch.no_grad():
self.ema_utilization = alpha * self.ema_utilization + (1 - alpha) * selection_counts
def forward_with_logits(self, all_logits, hidden_states, step=None, total_steps=None, met_enabled=False, met_warmup=None, logger=None):
"""
Args:
all_logits: [batch_size * seq_len, max_pool_size] precomputed logits from block's gate
hidden_states: [batch_size, seq_len, hidden_dim]
met_warmup: if provided (float 0-1), PSR runs only during this phase then stays at max R
Returns:
router_logits: [batch_size * seq_len, max_pool_size]
aux_loss: scalar
active_expert_mask: [max_pool_size] boolean mask
layer_expert_mapping: list of (layer_idx, expert_idx) tuples
"""
batch_size, seq_len, hidden_dim = hidden_states.shape
# all_logits already has shape [B*S, max_pool_size]
# Get current reuse scale via PSR (pass met_warmup if available)
current_r = self.get_candidate_layers(step, total_steps, met_warmup=met_warmup)
# [CRITICAL FIX]: Ensure mapping aligns with STATIC physical gate size (self.max_pool_size)
# The base mapping uses the FULL reuse_scale statically!
base_half = (self.reuse_scale - 1) // 2
base_start = max(0, self.layer_idx - base_half)
base_end = min(self.total_layers, base_start + self.reuse_scale)
base_start = max(0, base_end - self.reuse_scale)
# The PSR subset window defines which subset of the full mapping is CURRENTLY active
psr_half = (current_r - 1) // 2
psr_start = max(0, self.layer_idx - psr_half)
psr_end = min(self.total_layers, psr_start + current_r)
psr_start = max(0, psr_end - current_r)
# Create active_mask natively aligned to the self.gate output nodes
num_active_experts = current_r * self.num_experts_per_layer
active_mask = torch.zeros(self.max_pool_size, dtype=torch.bool, device=all_logits.device)
# Create full layer-expert mapping for the expert selector.
layer_expert_mapping = []
for layer_offset in range(self.reuse_scale):
layer_id = base_start + layer_offset
# Is this physical block currently enabled by PSR?
is_active_psr_layer = (psr_start <= layer_id < psr_end)
for expert_id in range(self.num_experts_per_layer):
pool_idx = len(layer_expert_mapping)
layer_expert_mapping.append((layer_id, expert_id))
# Activate in the mask if it falls within the PSR window AND is a valid layer
if is_active_psr_layer and layer_id < self.total_layers:
active_mask[pool_idx] = True
# Mask out inactive experts by setting their logits to -inf
masked_logits = all_logits.clone()
masked_logits[:, ~active_mask] = float('-inf')
# === DYNAMIC PRUNING MASK (OLD_TO_NEW) ===
# If the checkpoint is hard-pruned, the mapping omits pruned experts.
# Natively mask logits here preventing explicit drops during top-k
if hasattr(self, 'old_to_new') and self.old_to_new:
for pool_idx, (orig_layer, orig_expert) in enumerate(layer_expert_mapping):
if not active_mask[pool_idx]:
continue
orig_layer_int = int(orig_layer) # old_to_new is keyed by int layer
if orig_layer_int in self.old_to_new:
layer_map = self.old_to_new[orig_layer_int]
if orig_expert not in layer_map and str(orig_expert) not in layer_map:
masked_logits[:, pool_idx] = float('-inf')
active_mask[pool_idx] = False
# === IMPORTANCE-GUIDED MASKED EXPERT TRAINING (IG-MET) ===
# Apply mask based on global threshold if enabled
# If router has a specifically pre-calculated pruning mask (from global analysis), use it.
# Otherwise fall back to local thresholding (which may be inaccurate if aggregation is SUM).
# Check for externally provided mask (from train_rexmoe global pass)
global_keep_mask = getattr(self, 'global_keep_mask', None)
if met_enabled:
# Mode A: Precise Global Pruning (via mask pushed from training loop)
if global_keep_mask is not None:
# Ensure mask is on correct device
if global_keep_mask.device != all_logits.device:
global_keep_mask = global_keep_mask.to(all_logits.device)
# Invert to get what we should prune (keep=False -> prune=True)
# The global_keep_mask perfectly aligns with self.max_pool_size
cur_len = min(len(global_keep_mask), self.max_pool_size)
# We mask where global_keep_mask is FALSE, BUT only if it is currently active
target_mask = torch.zeros_like(active_mask)
target_mask[:cur_len] = ~global_keep_mask[:cur_len]
target_mask = target_mask & active_mask
# Safety: Don't prune everything
if target_mask.all() or target_mask.sum() == active_mask.sum():
pass # Don't prune if it kills all remaining active experts
else:
masked_logits[:, target_mask] = float('-inf')
active_mask[target_mask] = False
# Mode B: Local Thresholding (Fallback / Original)
elif self.mask_threshold.item() >= 0:
# Mask experts with EMA utilization below threshold
# Note: Only mask experts that are theoretically active (based on current_r)
threshold = self.mask_threshold.to(all_logits.device)
ema = self.ema_utilization.to(all_logits.device)
under_utilized_mask = (ema < threshold)
target_mask = under_utilized_mask & active_mask
if target_mask.any():
num_active_before = active_mask.sum().item()
num_to_mask = target_mask.sum().item()
if num_to_mask < num_active_before:
masked_logits[:, target_mask] = float('-inf')
active_mask[target_mask] = False
# Compute routing probabilities
routing_weights = torch.softmax(masked_logits, dim=-1) # [B*S, max_pool_size]
# Update EMA tracking (detached from graph)
# Calculate C_k: raw selection count at step k
# We use sum of routing weights as per user request: "= expert utilization counts (sum of routing weights per expert)"
if self.training:
current_counts = routing_weights.sum(dim=0).detach() # [max_pool_size]
self.update_ema(current_counts)
# Auxiliary load balancing loss (coefficient of variation)
# Only compute over active experts
# Ensure active_mask is on the same device as routing_weights
active_mask = active_mask.to(routing_weights.device)
# Safeguard: check that we have active experts
num_true_active = active_mask.sum().item()
if num_true_active == 0:
# Fallback: mark at least the first expert as active
active_mask[0] = True
active_routing_weights = routing_weights[:, active_mask]
expert_counts = active_routing_weights.sum(0) # [num_active_experts]
# Safe CV calculation to prevent NaNs
if expert_counts.numel() > 1:
mean_count = expert_counts.mean()
std_count = expert_counts.std()
# Use larger epsilon and handle potential detached tensor
cv_squared = (std_count / (mean_count + 1e-6)) ** 2
else:
cv_squared = torch.tensor(0.0, device=active_routing_weights.device, dtype=active_routing_weights.dtype)
# aux_loss = 0.01 * cv_squared # α=0.01 per ReXMoE
aux_loss = self.aux_loss_weight * cv_squared # Higher weight on load balancing loss to encourage more even routing, especially important with PSR where early layers have fewer experts and are more likely to be overloaded.
# Keep last mapping for introspection/debugging
self.last_layer_expert_mapping = layer_expert_mapping
return masked_logits, aux_loss, active_mask, layer_expert_mapping
class ReXMoESparseMoeBlock(nn.Module):
"""
Modified PhiMoE Sparse MoE block with cross-layer expert reuse.
Keeps original experts intact but routes to adjacent layer experts.
"""
def __init__(self, original_moe_block, layer_idx, total_layers, all_experts_dict, reuse_scale=3, logger=None, aux_loss_weight=0.02):
super().__init__()
self.hidden_dim = original_moe_block.hidden_dim
self.num_experts = original_moe_block.num_experts
self.top_k = original_moe_block.top_k
self.layer_idx = layer_idx
self.reuse_scale = reuse_scale
# Keep reference to experts from all layers (DO NOT copy parameters)
self.all_experts_dict = all_experts_dict # Dict: {layer_idx: ModuleList of experts}
self.aux_loss_weight = aux_loss_weight
# Replace router with ReXMoE router (parameterless)
self.router = ReXMoERouter(
layer_idx=layer_idx,
total_layers=total_layers,
num_experts_per_layer=self.num_experts,
reuse_scale=reuse_scale,
num_experts_per_tok=self.top_k,
all_experts_dict=all_experts_dict,
aux_loss_weight=self.aux_loss_weight
)
# Install a gate on the block itself to match original naming
self.gate = nn.Linear(self.hidden_dim, self.router.max_pool_size, bias=False)
# [FIX] Initialize new gate with original router weights for BOTH local and neighbor sections
with torch.no_grad():
orig_gate_shape = original_moe_block.gate.weight.data.shape[0]
if orig_gate_shape == self.router.max_pool_size:
# If loading from a checkpoint where gate is already expanded, copy it directly
self.gate.weight.data.copy_(original_moe_block.gate.weight.data)
print(f" Base block gate size already {orig_gate_shape}, copied fully for layer {layer_idx}")
elif orig_gate_shape == self.num_experts:
# Calculate where the local experts sit in the new expanded router
half = (reuse_scale - 1) // 2
local_start_idx = half * self.num_experts
local_end_idx = local_start_idx + self.num_experts
# Standard init from base model: Copy local weights exactly
self.gate.weight[local_start_idx:local_end_idx, :] = original_moe_block.gate.weight.data.clone()
print(f" num_experts: {self.num_experts}, refilled router weights for layer {layer_idx} local section at indices {local_start_idx}:{local_end_idx}")
# Crucial Fix for R > 2: Initialize neighbor sections with copied weights + noise instead of zero
# If neighbor logits are exactly zero initially, it causes router collapse and massive loss spikes (>10)
noise_scale = 0.1 * original_moe_block.gate.weight.data.std().item()
# Fill all sections before the local section (previous layers)
for section in range(half):
start_idx = section * self.num_experts
end_idx = start_idx + self.num_experts
noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale
self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise
print(f" Initialized neighbor section {start_idx}:{end_idx} with noise")
# Fill all sections after the local section (next layers)
for section in range(half + 1, reuse_scale):
start_idx = section * self.num_experts
end_idx = start_idx + self.num_experts
noise = torch.randn_like(original_moe_block.gate.weight.data) * noise_scale
self.gate.weight[start_idx:end_idx, :] = original_moe_block.gate.weight.data.clone() + noise
print(f" Initialized neighbor section {start_idx}:{end_idx} with noise")
else:
# Fallback for pruned or irregular sized expert checkpoints
# Safe initialization: zero out, then carefully mapped copying
self.gate.weight.zero_()
print(f" Warning: Custom size mismatch for gate refill ({orig_gate_shape} vs {self.num_experts}). Filling with zeros.")
# Try to map whatever we can based on minimum size
min_experts = min(orig_gate_shape, self.router.max_pool_size)
self.gate.weight[:min_experts, :] = original_moe_block.gate.weight.data[:min_experts, :].clone()
# Canonical expert container: match base PhiMoE which uses `.experts`
# so that checkpoints save/load expert weights under the same keys.
self.experts = original_moe_block.experts
# Store current training step for PSR
# Default to None so inference uses full R (not PSR schedule)
self.current_step = None
self.total_steps = None
self.met_warmup = None # Will be set during training if MET is enabled
# Store aux_loss for backward pass
self.last_aux_loss = None
# Store actual routing selections for analysis
self.last_selected_experts = None # Will store (target_layer, target_expert) tuples
self.last_selection_counts = None # Count of tokens routed to each expert
self.logger = logger # Store logger for potential use in forward pass
@property
def local_experts(self):
# expose alias for any code that tries to access it
return self.experts
def map_pruned_expert(self, orig_layer: int, orig_expert: int, old_to_new: dict) -> Optional[int]:
"""
Map original (layer, expert) index to current kept expert index.
Returns None if the expert was pruned.
"""
new_idx = None
if orig_layer in old_to_new:
layer_map = old_to_new[orig_layer]
# Handle both int and str keys (robust to JSON serialization)
new_idx = layer_map.get(orig_expert, layer_map.get(str(orig_expert), None))
else:
# No pruning map: use original index if it still exists
if orig_layer in self.all_experts_dict and orig_expert < len(self.all_experts_dict[orig_layer]):
new_idx = orig_expert
return new_idx
def forward(self, hidden_states, logger=None):
"""
Args:
hidden_states: [batch_size, seq_len, hidden_dim]
Returns:
output: [batch_size, seq_len, hidden_dim]
"""
# If a logger wasn't passed down through the model forward call,
# fall back to any logger attribute attached to this block instance.
if logger is None:
logger = getattr(self, 'logger', None)
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim) # [B*S, H]
# Ensure gate is on the same device as hidden_states (fixes device_map="auto" mismatch)
device = hidden_states_flat.device
self.gate = self.gate.to(device)
# Get routing decisions
# Compute gate logits at block level (so weights are saved as block_sparse_moe.gate)
hidden_states_flat = hidden_states.view(-1, hidden_dim)
all_logits = self.gate(hidden_states_flat)
router_logits, aux_loss, active_mask, layer_expert_mapping = self.router.forward_with_logits(
all_logits, hidden_states, self.current_step, self.total_steps,
met_enabled=getattr(self, 'met_enabled', False),
met_warmup=self.met_warmup,
logger=logger
)
num_pool = router_logits.shape[-1]
BxS = hidden_states_flat.shape[0]
# Store aux_loss for collection in training loop
self.last_aux_loss = aux_loss
# Get top-k indices and values in one operation
topk_logits, topk_indices = torch.topk(router_logits, self.top_k, dim=-1) # [B*S, k]
topk_weights = torch.softmax(topk_logits, dim=-1) # [B*S, k]
# === VECTORIZED EXPERT EXECUTION ===
# Pre-allocate output
final_hidden_states = torch.zeros_like(hidden_states_flat)
selection_counts = {}
old_to_new = getattr(self.router, 'old_to_new', {})
processed_mask = torch.zeros(BxS, dtype=torch.bool, device=hidden_states.device)
# Process each expert position in the top-k (k=2 is small)
for k_idx in range(self.top_k):
# Get which expert each token selected at position k
selected_positions = topk_indices[:, k_idx] # [B*S]
# Gather logits for weighting
k_weights = topk_weights[:, k_idx:k_idx+1] # [B*S, 1]
# === BATCH EXPERT EXECUTION ===
# Group tokens by which expert they selected
for pool_pos in range(self.router.max_pool_size):
if pool_pos >= len(layer_expert_mapping):
continue
# HARD BLOCK FOR PRUNED / INACTIVE EXPERTS:
# Completely bypass execution and prevent token leakage
if not active_mask[pool_pos]:
continue
# Find all tokens that selected this expert
token_mask = (selected_positions == pool_pos) # [B*S]
if not token_mask.any():
continue
# Get tokens for this expert
selected_tokens = hidden_states_flat[token_mask] # [N, H]
orig_layer, orig_expert = layer_expert_mapping[pool_pos]
new_idx = self.map_pruned_expert(orig_layer, orig_expert, old_to_new)
if new_idx is None:
continue # Pruned expert
expert_module = self.all_experts_dict[orig_layer][new_idx]
# Move selected_tokens to expert's device instead of moving expert
# This is more efficient when model is sharded across GPUs
expert_device = next(expert_module.parameters()).device
selected_tokens = selected_tokens.to(expert_device)
expert_out = expert_module(selected_tokens) # [N, H] - BATCHED!
# Move output back to original device
expert_out = expert_out.to(device)
weighted_out = expert_out * k_weights[token_mask]
final_hidden_states[token_mask] += weighted_out
# CRITICAL: Record the expert index for analysis
# For pruned models: use orig_expert (original index before hard deletion)
# For unpruned models: use new_idx (which equals orig_expert since no mapping)
# The distinction is made at model load time via config.is_pruned or config.pruned
is_pruned_model = getattr(self, 'is_pruned_model', False) or hasattr(self, 'old_to_new') and self.old_to_new
reported_expert = orig_expert if is_pruned_model else new_idx
key = (orig_layer, reported_expert)
selection_counts[key] = selection_counts.get(key, 0) + token_mask.sum().item()
processed_mask[token_mask] = True
if not processed_mask.all():
final_hidden_states[~processed_mask] = hidden_states_flat[~processed_mask]
# Store selections for analysis
self.last_selection_counts = selection_counts
# Reshape back
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
# Return tuple like original PhiMoE: (hidden_states, router_logits)
return final_hidden_states, router_logits
# ==================== 3. ROUTING ANALYSIS ====================
def analyze_routing_patterns(model, dataloader, current_r, total_layers, device, num_batches=10, logger=None):
"""
Analyze ACTUAL routing patterns by tracking which experts were selected.
For each layer, tracks:
- Which experts are ACTUALLY selected most frequently
- Whether experts from adjacent layers are being used
- Distribution of routing across layers
"""
model.eval()
# Track ACTUAL routing decisions: routing_counts[layer_idx][(target_layer, target_expert)] = count
routing_counts = {}
for layer_idx in range(total_layers):
routing_counts[layer_idx] = {}
total_tokens = 0
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= num_batches: # Sample only first N batches for efficiency
break
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Get batch size and sequence length
batch_size, seq_len = input_ids.shape
num_tokens = (attention_mask.sum()).item() # Count non-padding tokens
total_tokens += num_tokens
# Forward pass to get routing decisions
_ = model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False)
# Always work on the underlying transformer stack that actually owns `.layers`,
# whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model.
backend_model = get_backend_model(model)
# Collect ACTUAL routing decisions from each layer
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
moe_block = layer.block_sparse_moe
# Get actual selections from the forward pass
if moe_block.last_selection_counts is not None:
for (target_layer, target_expert), count in moe_block.last_selection_counts.items():
key = (target_layer, target_expert)
routing_counts[layer_idx][key] = routing_counts[layer_idx].get(key, 0) + count
model.train()
# Print analysis
msg = f"\nAnalyzing ACTUAL routing patterns from {num_batches} batches ({total_tokens:,} tokens)"
print(msg)
if logger:
logger.info(msg)
msg = f"Current reuse scale: R={current_r}"
print(msg)
if logger:
logger.info(msg)
# === IG-MET PRUNING ANALYTICS (GLOBAL SUM AGGREGATION) ===
# 1. Aggregate EMA for each unique expert across all routers (SUM)
backend_model = get_backend_model(model)
unique_experts_sum = {} # (orig_layer, orig_expert) -> summed_ema_score
unique_experts_total = set()
threshold = None
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
router = layer.block_sparse_moe.router
thr = router.mask_threshold.item()
if thr >= 0:
# Reconstruct mapping logic
current_r = router.get_candidate_layers(step=None, total_steps=None)
half = (current_r - 1) // 2
start_layer = max(0, layer_idx - half)
end_layer = min(len(backend_model.layers), start_layer + current_r)
start_layer = max(0, end_layer - current_r)
current_mapping = []
for layer_offset in range(current_r):
l_id = start_layer + layer_offset
if l_id >= len(backend_model.layers): break
for e_id in range(router.num_experts_per_layer):
current_mapping.append((l_id, e_id))
num_active = len(current_mapping)
for pool_pos, key in enumerate(current_mapping):
if pool_pos >= len(router.ema_utilization): break
unique_experts_total.add(key)
ema_val = router.ema_utilization[pool_pos].item()
if key not in unique_experts_sum:
unique_experts_sum[key] = ema_val
else:
unique_experts_sum[key] += ema_val
if threshold is None:
threshold = thr
# 2. Prune based on SUM aggregation and global threshold
unique_experts_pruned = {k for k, v in unique_experts_sum.items() if threshold is not None and v < threshold}
total_unique_pruned = len(unique_experts_pruned)
total_unique = len(unique_experts_total)
msg = "\n[IG-MET Pruning Report]:"
print(msg)
if logger:
logger.info(msg)
pct = 100 * total_unique_pruned / total_unique if total_unique > 0 else 0
msg = f"Global: {total_unique_pruned}/{total_unique} UNIQUE experts pruned ({pct:.1f}%) | threshold={threshold if threshold is not None else -1:.6f}"
print(msg)
if logger:
logger.info(msg)
if unique_experts_sum:
global_ema_tensor = torch.tensor(list(unique_experts_sum.values()), device=device)
msg = f"Aggregated EMA (sum across R layers): mean={global_ema_tensor.mean():.6f}, min={global_ema_tensor.min():.6f}, max={global_ema_tensor.max():.6f}"
print(msg)
if logger:
logger.info(msg)
print()
# Analyze cross-layer reuse statistics
cross_layer_usage = {
"same_layer": 0,
"adjacent_prev": 0,
"adjacent_next": 0,
"distant": 0
}
for layer_idx in routing_counts:
for (target_layer, target_expert), count in routing_counts[layer_idx].items():
if target_layer == layer_idx:
cross_layer_usage["same_layer"] += count
elif target_layer == layer_idx - 1:
cross_layer_usage["adjacent_prev"] += count
elif target_layer == layer_idx + 1:
cross_layer_usage["adjacent_next"] += count
else:
cross_layer_usage["distant"] += count
total_routing = sum(cross_layer_usage.values())
if total_routing > 0:
msg = "Cross-Layer Routing Distribution (ACTUAL selections):"
print(msg)
if logger:
logger.info(msg)
for key, label in [
("same_layer", "Same layer (i):"),
("adjacent_prev", "Previous layer (i-1):"),
("adjacent_next", "Next layer (i+1):"),
("distant", "Distant layers:")
]:
if cross_layer_usage[key] > 0 or key != "distant":
pct = 100 * cross_layer_usage[key] / total_routing
msg = f" {label:25} {cross_layer_usage[key]:>10,} ({pct:>5.1f}%)"
print(msg)
if logger:
logger.info(msg)
print()
# Sample detailed analysis for a few layers
sample_layers = [8, 16, 24] if total_layers >= 32 else [total_layers // 4, total_layers // 2, 3 * total_layers // 4]
msg = "Sample Layer-Specific Routing Patterns:"
print(msg)
if logger:
logger.info(msg)
for layer_idx in sample_layers:
if layer_idx in routing_counts and routing_counts[layer_idx]:
msg = f"\n Layer {layer_idx}:"
print(msg)
if logger:
logger.info(msg)
# Get top 5 most used experts
sorted_experts = sorted(routing_counts[layer_idx].items(), key=lambda x: x[1], reverse=True)[:5]
for (target_layer, target_expert), count in sorted_experts:
pct = 100 * count / total_tokens if total_tokens > 0 else 0
layer_relation = "same" if target_layer == layer_idx else f"L{target_layer}"
msg = f" Expert {target_expert:>2} from layer {target_layer:>2} ({layer_relation:>4}): {count:>8,} times ({pct:>5.1f}%)"
print(msg)
if logger:
logger.info(msg)
print()
# Check if cross-layer reuse is happening
cross_layer_pct = 100 * (cross_layer_usage['adjacent_prev'] + cross_layer_usage['adjacent_next'] + cross_layer_usage['distant']) / total_routing if total_routing > 0 else 0
if cross_layer_pct > 5:
msg = f"✅ Cross-layer expert reuse detected: {cross_layer_pct:.1f}% of routing uses adjacent layers"
print(msg)
if logger:
logger.info(msg)
elif current_r > 1:
msg = f"⚠️ Limited cross-layer reuse: {cross_layer_pct:.1f}% (expected >5% with R={current_r})"
print(msg)
if logger:
logger.warning(msg)
msg = " This may improve as training progresses and routers adapt."
print(msg)
if logger:
logger.info(msg)
else:
msg = f"ℹ️ R=1 mode: Only same-layer experts available (PSR warmup phase)"
print(msg)
if logger:
logger.info(msg)
# ==================== 3.5. CONVERGENCE MONITORING ====================
def compute_routing_entropy(router_logits):
"""
Compute entropy of routing distribution.
High entropy = uniform routing (may indicate lack of specialization)
Low entropy = concentrated routing (strong preferences)
"""
probs = torch.softmax(router_logits, dim=-1)
entropy = -(probs * torch.log(probs + 1e-10)).sum(dim=-1).mean()
return entropy.item()
def check_router_convergence(model, total_layers, convergence_history, threshold=0.01, logger=None):
"""
Check if routers have converged by analyzing:
1. Router weight gradient norms (should be small)
2. Routing entropy stability (should be stable)
3. Expert preference consistency (should not fluctuate)
Returns:
converged: bool
metrics: dict of convergence metrics
warnings: list of warning messages
"""
router_grad_norms = []
# Always work on the underlying transformer stack that actually owns `.layers`,
# whether `model` is a bare PhiMoEForCausalLM or a PEFT-wrapped model.
backend_model = get_backend_model(model)
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
gate = getattr(layer.block_sparse_moe, 'gate', None)
if gate is not None and gate.weight is not None and gate.weight.grad is not None:
grad_norm = gate.weight.grad.norm().item()
router_grad_norms.append(grad_norm)
avg_grad_norm = sum(router_grad_norms) / len(router_grad_norms) if router_grad_norms else 0
metrics = {
'avg_router_grad_norm': avg_grad_norm,
'max_router_grad_norm': max(router_grad_norms) if router_grad_norms else 0,
'min_router_grad_norm': min(router_grad_norms) if router_grad_norms else 0,
}
warnings = []
# Check convergence: gradients should be small and stable
convergence_history.append(avg_grad_norm)
# Need at least 5 epochs of history
if len(convergence_history) < 5:
return False, metrics, warnings
# Check if gradient norm is stable (variance < threshold)
recent_grads = convergence_history[-5:]
grad_variance = torch.tensor(recent_grads).var().item()
grad_mean = torch.tensor(recent_grads).mean().item()
metrics['grad_variance'] = grad_variance
metrics['grad_stability'] = grad_variance / (grad_mean + 1e-10)
# Detect oscillations (gradient norm going up and down)
if len(convergence_history) >= 3:
last_3 = convergence_history[-3:]
# Check if middle value is either a peak or valley
if (last_3[1] > last_3[0] and last_3[1] > last_3[2]) or \
(last_3[1] < last_3[0] and last_3[1] < last_3[2]):
warnings.append("⚠️ Oscillation detected in gradient norms - consider reducing learning rate")
if logger:
logger.warning("Oscillation detected in gradient norms")
# Check for increasing gradients (divergence)
if len(convergence_history) >= 2:
if convergence_history[-1] > convergence_history[-2] * 1.2:
warnings.append("⚠️ Gradients increasing - possible divergence or high learning rate")
if logger:
logger.warning("Gradients increasing - possible divergence")
# Check if gradients are too high
if avg_grad_norm > 0.3:
warnings.append(f"⚠️ High gradient norm ({avg_grad_norm:.4f}) - learning rate may be too high")
if logger:
logger.warning(f"High gradient norm: {avg_grad_norm:.4f}")
# Converged if: gradients are small AND stable
converged = (avg_grad_norm < 0.1) and (metrics['grad_stability'] < threshold)
return converged, metrics, warnings
# Helper: always get the underlying transformer stack (with `.layers`)
def get_backend_model(m):
# If PEFT-wrapped, unwrap to base_model; else keep as is
core = getattr(m, "base_model", m)
# For PhiMoEForCausalLM, the transformer is in `.model`
causalLM = getattr(core, "model", core)
return getattr(causalLM, "model", causalLM)
# ==================== 4. TRAINING LOOP ====================
def train_rexmoe(
model_name="microsoft/Phi-mini-MoE-instruct",
model_path="../models/models/microsoft/Phi-mini-MoE-instruct",
dataset_path="../dataset/alpaca_data_cleaned.json",
dataset_mode: str = "IF",
reuse_scale=3,
num_samples=10000,
num_epochs=5,
batch_size=16,
max_seq_length=512,
lr=5e-6,
warmup_steps=10,
psr_enabled=True,
save_path="./rexmoe_phi_mini_moe_r3",
gradient_checkpointing=True,
met_enabled=False,
met_mask_ratio=0.1,
met_warmup=0.5,
eval_steps=1000,
log_loss_steps_percent=10,
full_lora=False,
lora_r=16,
use_scheduler=True,
aux_loss_weight=0.02
):
# Setup logger
logger, log_file = setup_logger(save_path=os.path.join(save_path, "logs"))
print("="*80)
print("ReXMoE Cross-Layer Expert Reuse Training")
print("="*80)
logger.info("="*80)
logger.info("ReXMoE Cross-Layer Expert Reuse Training")
logger.info("="*80)
logger.info("MET enabled: {}".format(met_enabled))
config_msg = f"""
Configuration:
Model: {model_name}
Dataset: {dataset_path}
Dataset mode: {dataset_mode}
Reuse Scale (R): {reuse_scale}
Prune Ratio (MET): {met_mask_ratio if met_enabled else 'N/A'}
Epochs: {num_epochs}
Num of samples: {num_samples}
Batch Size: {batch_size}
Sequence Length: {max_seq_length}
Learning Rate: {lr}
PSR Enabled: {psr_enabled}
LR Scheduler: {use_scheduler}
Save Path: {save_path}
Gradient Checkpointing: {gradient_checkpointing}
LoRA Rank: {lora_r} (Full LoRA: {full_lora})
LoRA Alpha: {lora_r * 2}
MET Enabled: {met_enabled} (Mask Ratio: {met_mask_ratio}, Warmup: {met_warmup})
Log File: {log_file}
Aux loss weight: {aux_loss_weight}
"""
print(config_msg)
logger.info(config_msg)
print("="*80)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
device_msg = f"💻 Using device: {device})"
print(f"\n{device_msg}")
logger.info(device_msg)
if torch.cuda.is_available():
gpu_msg = f"GPU: {torch.cuda.get_device_name(0)}, Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB"
print(f" {gpu_msg}")
logger.info(gpu_msg)
# Load tokenizer + model
print(f"\n[1/7] Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(f"Loading model to device {device} (no device_map sharding)...")
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
device_map=None, # Do NOT auto-shard - we'll place it manually
trust_remote_code=False
)
print(f"Moving model to {device}...")
model = model.to(device)
print(f"✓ Model moved to {device}")
# Verify model is on correct device
model_device = next(model.parameters()).device
print(f"✓ Model device verified: {model_device}")
if gradient_checkpointing:
model.gradient_checkpointing_enable()
print(f"Model loaded: {model.config.num_hidden_layers} layers")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Experts per layer: {model.config.num_local_experts}")
# Collect all experts from all layers
print(f"\n[2/7] Collecting expert references from all layers...")
total_layers = model.config.num_hidden_layers
all_experts_dict = {}
for layer_idx, layer in enumerate(model.model.layers):
if hasattr(layer, "block_sparse_moe"):
all_experts_dict[layer_idx] = layer.block_sparse_moe.experts
print(f"Collected {len(all_experts_dict)} MoE layers")
# Replace MoE blocks with ReXMoE blocks
print(f"\n[3/7] Replacing MoE blocks with ReXMoE routers (R={reuse_scale})...")
moe_count = 0
for layer_idx, layer in enumerate(model.model.layers):
if hasattr(layer, "block_sparse_moe"):
original_moe = layer.block_sparse_moe
# Create ReXMoE block (keeps expert references, replaces router)
rexmoe_block = ReXMoESparseMoeBlock(
original_moe_block=original_moe,
layer_idx=layer_idx,
total_layers=total_layers,
all_experts_dict=all_experts_dict,
reuse_scale=reuse_scale,
logger=logger,
aux_loss_weight=aux_loss_weight
)
rexmoe_block.met_enabled = met_enabled
# Attach logger to block so its forward can access logging even when
# the higher-level `model()` call doesn't pass a logger argument.
rexmoe_block.logger = logger
# Move ReXMoE block to correct device
rexmoe_block = rexmoe_block.to(dtype=torch.bfloat16, device=device)
# Replace the block
layer.block_sparse_moe = rexmoe_block
moe_count += 1
print(f"✓ ReXMoE blocks installed: {moe_count} layers modified")
print(f" Each router can now access up to {reuse_scale * 16} experts (R={reuse_scale})")
# Warmup phase: only routers (gates) trainable
print(f"\n[4/7] Initial freeze: only routers trainable for warmup phase...")
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
if ".block_sparse_moe.gate" in name:
# Router gates trainable
param.requires_grad = True
trainable_params += param.numel()
else:
param.requires_grad = False
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters (warmup): {trainable_params:,} ({100*trainable_params/total_params:.2f}%)")
# Verify only routers are trainable at start
trainable_layers = [name for name, param in model.named_parameters() if param.requires_grad]
print(f"✓ Warmup trainable components: {len(trainable_layers)} parameters (router gates only)")
if len(trainable_layers) > 0:
print(f" First: {trainable_layers[0]}")
print(f" Last: {trainable_layers[-1]}")
# Optimizer (router params only) - Use 8-bit AdamW for memory efficiency
print(f"\n[5/7] Setting up optimizer and dataset...")
logger.info("[5/7] Setting up optimizer and dataset...")
print("Using 8-bit AdamW optimizer for memory efficiency")
logger.info("Using 8-bit AdamW optimizer")
optimizer = bnb.optim.AdamW8bit(
[p for p in model.parameters() if p.requires_grad],
lr=lr,
weight_decay=0.1
)
# Learning rate scheduler (cosine annealing)
scheduler = None
if use_scheduler:
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr * 0.1)
print(f"Using CosineAnnealingLR scheduler: {lr}{lr * 0.1}")
logger.info(f"LR Scheduler: CosineAnnealingLR ({lr}{lr * 0.1})")
# Prepare dataset (Instruction Fine-tuning or Pretraining)
print(f"[5/7] Preparing dataset: mode={dataset_mode}, path={dataset_path}")
try:
train_loader = get_dataloader(
mode=dataset_mode,
tokenizer=tokenizer,
dataset_path=dataset_path,
max_seq_length=max_seq_length,
batch_size=batch_size,
num_samples=num_samples,
shuffle=True,
)
except Exception as e:
print(f"Could not prepare dataset: {e}")
raise
train_len = num_samples
print(f"Training samples: {train_len}")
print(f"Batch size: {batch_size}")
print(f"Sequence length: {max_seq_length}")
# Training loop
print(f"\n[6/7] Starting training for {num_epochs} epochs...")
print(f"PSR enabled: {psr_enabled}")
total_steps = len(train_loader) * num_epochs
# PSR schedule changes based on whether MET is enabled
if psr_enabled:
if met_enabled:
warmup_steps_psr = int(met_warmup * total_steps)
print(f"PSR schedule: R=2 → R={reuse_scale} during MET warmup phase (steps 0-{warmup_steps_psr})")
print(f" then stays at R={reuse_scale} during pruning/finetuning phases")
else:
psr_completion_steps = int(0.5 * total_steps)
print(f"PSR schedule: R=2 → R={reuse_scale} over first 50% of training (steps 0-{psr_completion_steps})")
step = 0
# Track statistics
first_batch_logged = False
# Track routing patterns for analysis
# Structure: routing_stats[layer_idx][(target_layer, target_expert)] = count
routing_stats = {}
for layer_idx in range(total_layers):
routing_stats[layer_idx] = {}
# Convergence tracking
convergence_history = []
epoch_entropies = []
epoch_aux_losses = []
model.train()
best_val = float("inf")
best_epoch = -1
qlora_enabled = False # track switch from warmup (routers-only) to routers+LoRA
print_met_active = False
print_met_freeze = False
for epoch in range(num_epochs):
print(f"\n{'='*60}")
print(f"Epoch {epoch+1}/{num_epochs}")
print(f"{'='*60}")
epoch_loss = 0
epoch_aux_loss = 0
epoch_entropy = 0 # Track routing entropy
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
for batch_idx, batch in enumerate(pbar):
# Switch from warmup (routers only) to routers + LoRA adapters on experts
if (not qlora_enabled) and step >= warmup_steps:
if full_lora:
logger.info(f"Warmup completed at step {step}. Enabling FULL QLoRA with r = {lora_r} and alpha = {lora_r * 2} on experts and updating optimizer...")
else:
logger.info(f"Warmup completed at step {step}. Enabling QLoRA on experts.")
# Attach LoRA adapters globally (wrap linear layers)
# Note: PhiMoE experts use w1/w2/w3 linear layers, not gate_proj/up_proj/down_proj.
# We target those names so LoRA attaches only to expert MLP weights.
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_r * 2,
lora_dropout=0.00,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"w1", "w2", "w3",
# if full_lora, also target attention layers of transformer blocks (but NOT router gates)
"q_proj" if full_lora else None,
"k_proj" if full_lora else None,
"v_proj" if full_lora else None,
"o_proj" if full_lora else None
],
)
model = get_peft_model(model, lora_config)
# Freeze everything, then re-enable router gates and LoRA params
total_params = 0
trainable_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
param.requires_grad = False
for name, param in model.named_parameters():
if ".block_sparse_moe.gate" in name:
param.requires_grad = True
trainable_params += param.numel()
elif "lora_" in name:
param.requires_grad = True
trainable_params += param.numel()
optimizer = bnb.optim.AdamW8bit(
[p for p in model.parameters() if p.requires_grad],
lr=lr,
weight_decay=0.1
)
print(f"Total parameters after QLoRA: {total_params:,}")
print(f"Trainable parameters (routers + LoRA): {trainable_params:,} ({100*trainable_params/total_params:.4f}%)")
logger.info(f"Trainable params (routers + LoRA): {trainable_params} ({100*trainable_params/total_params:.4f}%)")
trainable_names = [n for n, p in model.named_parameters() if p.requires_grad]
print("Sample trainable params after QLoRA:", trainable_names[:10])
logger.info(f"Sample trainable params after QLoRA: {trainable_names[:10]}")
qlora_enabled = True
# Update step counter in all MoE blocks (for PSR)
# Note: after enabling QLoRA, `model` becomes a PeftModel whose
# underlying transformer is in `model.base_model.model`.
backend_model = get_backend_model(model)
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
layer.block_sparse_moe.current_step = step
layer.block_sparse_moe.total_steps = total_steps
# Pass met_warmup to PSR scheduler so it completes within warmup phase
layer.block_sparse_moe.met_warmup = met_warmup if met_enabled else None
# === IG-MET Global Threshold Update ===
# Key insight: Count UNIQUE experts (512 base), not router-level copies (up to 1024 with reuse).
# If an expert is pruned, it's pruned everywhere it can be accessed.
if met_enabled:
# Calculate target mask ratio for this step with improved schedule
final_mask_ratio = met_mask_ratio
progress = min(step / total_steps, 1.0)
# Improved Three-Phase Schedule (Aggressive):
# Phase 1: 0-met_warmup - NO pruning (extended warmup for stability)
# Phase 2: met_warmup-0.8 - Gradual pruning ramp with curve (avoid abrupt changes)
# Phase 3: 0.8-100% - Freeze pruning, only fine-tune remaining experts
phase2_end = 0.8
if progress < met_warmup:
# Phase 1: Extended warmup, no pruning
current_target_ratio = 0.0
elif progress < phase2_end:
# Phase 2: Controlled pruning ramp with exponential curve
pruning_window = phase2_end - met_warmup
pruning_progress = (progress - met_warmup) / pruning_window # 0 to 1
# Use power curve to avoid aggressive early pruning
# Power = 1.2 for smoother ramp since pruning window is longer
current_target_ratio = final_mask_ratio * (pruning_progress ** 1.2)
if not print_met_active:
logger.info(f"[IG-MET] Entered pruning phase with gradual ramp (step={step}, target_ratio={current_target_ratio:.3f})")
print_met_active = True
else:
# Phase 3: Freeze pruning decisions, only fine-tune remaining experts
current_target_ratio = final_mask_ratio
# Optionally freeze pruning masks here in the future
if not hasattr(model, '_pruning_frozen'):
model._pruning_frozen = True
if not print_met_freeze:
logger.info("[IG-MET] Entered fine-tuning phase (pruning decisions frozen)")
print_met_freeze = True
if current_target_ratio > 0:
if step % 100 == 0:
logger.info(f"[IG-MET] Masked Expert Training is now ACTIVE (step={step}, target_ratio={current_target_ratio:.3f})")
# Collect EMA for UNIQUE (layer_idx, expert_idx) pairs only (not duplicates)
# Use "SUM" aggregation: we care about the total utility of an expert across all contexts.
unique_experts = {} # (orig_layer, orig_expert) -> summed_ema_score
# First Pass: Aggregation
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
router = layer.block_sparse_moe.router
# Reuse same logic as Router to map pool_pos -> (orig_layer, orig_expert)
current_r = router.get_candidate_layers(step, total_steps)
half = (current_r - 1) // 2
start_layer = max(0, layer_idx - half)
end_layer = min(total_layers, start_layer + current_r)
start_layer = max(0, end_layer - current_r)
# Reconstruct mapping for this router step
current_mapping = []
for layer_offset in range(current_r):
l_id = start_layer + layer_offset
if l_id >= total_layers: break
for e_id in range(router.num_experts_per_layer):
current_mapping.append((l_id, e_id))
num_active = len(current_mapping)
# Aggregate EMA
for pool_pos in range(num_active):
if pool_pos >= len(router.ema_utilization): break
key = current_mapping[pool_pos] # (orig_layer, orig_expert)
ema_val = router.ema_utilization[pool_pos].item()
if key not in unique_experts:
unique_experts[key] = ema_val
else:
# Sum EMA across all reused contexts
unique_experts[key] += ema_val
# Compute threshold based on UNIQUE SUMMED experts
if unique_experts:
all_ema_values = list(unique_experts.values())
all_ema_tensor = torch.tensor(all_ema_values, device=device)
k = int(len(all_ema_values) * current_target_ratio)
# Determine set of GLOBALLY pruned experts
pruned_keys = set()
threshold = 0.0
if k > 0 and all_ema_tensor.sum() > 0:
sorted_vals, _ = torch.sort(all_ema_tensor)
threshold = sorted_vals[k].item()
# Identify which UNIQUE experts are below the global sum threshold
pruned_keys = {key for key, val in unique_experts.items() if val < threshold}
# Second Pass: Distribute Pruning Mask to Routers
# Instead of a scalar threshold (which fails for summed aggregation),
# we push a binary mask of "keep vs prune" to each router.
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
router = layer.block_sparse_moe.router
# Re-calculate mapping to generate mask
current_r = router.get_candidate_layers(step, total_steps)
half = (current_r - 1) // 2
start_layer = max(0, layer_idx - half)
end_layer = min(total_layers, start_layer + current_r)
start_layer = max(0, end_layer - current_r)
current_mapping = []
for layer_offset in range(current_r):
l_id = start_layer + layer_offset
if l_id >= total_layers: break
for e_id in range(router.num_experts_per_layer):
current_mapping.append((l_id, e_id))
# Create binary mask: True = KEEP, False = PRUNE
# Size = max_pool_size (pad with True to be safe)
keep_mask = torch.ones(router.max_pool_size, dtype=torch.bool, device=device)
for pool_pos, key in enumerate(current_mapping):
if key in pruned_keys:
keep_mask[pool_pos] = False
# Push mask to router
router.global_keep_mask = keep_mask
# We also update mask_threshold for logging purposes
router.mask_threshold.fill_(threshold)
# Log statistics
if step % 10 == 0:
total_unique_pruned = len(pruned_keys)
total_unique_active = len(unique_experts)
logger.info(f"[IG-MET Global] Step {step}: Threshold={threshold:.6f}. Pruned {total_unique_pruned}/{total_unique_active} UNIQUE experts ({100*total_unique_pruned/total_unique_active:.1f}%). Target ratio: {current_target_ratio:.3f}")
else:
# Current ratio too small to mask any expert yet
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
layer.block_sparse_moe.router.mask_threshold.fill_(-1.0)
if hasattr(layer.block_sparse_moe.router, 'global_keep_mask'):
layer.block_sparse_moe.router.global_keep_mask = None
else:
# No experts found (shouldn't happen)
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
layer.block_sparse_moe.router.mask_threshold.fill_(-1.0)
else:
# Reset threshold (no masking) during 0-50% warmup phase
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
layer.block_sparse_moe.router.mask_threshold.fill_(-1.0)
input_ids = batch["input_ids"].to(model.device)
attention_mask = batch["attention_mask"].to(model.device)
labels = batch["labels"].to(model.device)
# Forward pass with PSR-aware routing
# todo: Why not use
'''
outputs = model(
instructions=instructions,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
'''
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
use_cache=False
)
loss = outputs.loss
# Collect auxiliary losses from all ReXMoE routers
# This is CRITICAL for load balancing and preventing routing collapse
aux_loss_total = 0.0
for layer in backend_model.layers:
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
if layer.block_sparse_moe.last_aux_loss is not None:
aux_loss_total += layer.block_sparse_moe.last_aux_loss
# Collect routing statistics (which experts were selected)
for layer_idx, layer in enumerate(backend_model.layers):
if hasattr(layer, "block_sparse_moe") and isinstance(layer.block_sparse_moe, ReXMoESparseMoeBlock):
moe_block = layer.block_sparse_moe
# Get the layer-expert mapping from the last forward pass
router = moe_block.router
# Don't call router with hidden_states=None (router.forward expects a tensor).
# Prefer the last computed mapping from a real forward pass; if not
# available, query the router with a small dummy tensor to get the mapping.
layer_expert_mapping = getattr(router, 'last_layer_expert_mapping', None)
if layer_expert_mapping is None:
try:
hidden_dim = router.gate.in_features if hasattr(router, 'gate') else moe_block.hidden_dim
dummy = torch.zeros(1, 1, hidden_dim, device=model.device)
_, _, _, layer_expert_mapping = router(
hidden_states=dummy,
step=step,
total_steps=total_steps
)
except Exception:
layer_expert_mapping = []
# Note: For efficiency, we'll track routing patterns every N batches
# to avoid slowdown. Full tracking can be enabled for analysis.
pass # Detailed tracking will be done in a separate analysis pass
# Compute routing entropy for convergence monitoring (sample from output)
# We'll approximate entropy from the auxiliary loss and routing distribution
# Note: Full entropy computation would require storing all router outputs
# Total loss = Language modeling loss + Auxiliary load balancing loss
total_loss = loss + aux_loss_total
# Backward pass
optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Logging
epoch_loss += loss.item()
epoch_aux_loss += aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total
# Calculate current reuse scale for logging (match the PSR logic in router)
if psr_enabled:
if met_enabled:
# New behavior: PSR completes within first phase (0 to met_warmup)
progress = min(step / (met_warmup * total_steps), 1.0)
current_r = 2 + int(progress * (reuse_scale - 2))
else:
# Legacy behavior: Linear schedule over first 50% of training
progress = min(step / (0.5 * total_steps), 1.0)
current_r = 2 + int(progress * (reuse_scale - 2))
# print(f"current_r: {current_r}, progress: {progress}")
else:
current_r = reuse_scale
# Log first batch details
if not first_batch_logged:
logger.info(f"\n First batch statistics:")
logger.info(f" LM Loss: {loss.item():.4f}")
logger.info(f" Aux Loss: {aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}")
logger.info(f" Total Loss: {total_loss.item():.4f}")
logger.info(f" Current R: {current_r}")
logger.info(f" Active experts per layer: {current_r * 16}")
logger.info(f" Gradient norm: {torch.nn.utils.clip_grad_norm_(model.parameters(), float('inf')):.4f}")
first_batch_logged = True
logger.info(f" \n")
# Print periodic updates (every log_loss_steps_percent% of epoch)
if batch_idx > 0 and batch_idx % max(1, len(train_loader) // (100 // log_loss_steps_percent)) == 0:
logger.info(f" [{batch_idx}/{len(train_loader)}] loss={loss.item():.4f} aux={aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f} R={current_r}")
pbar.set_postfix({
"loss": f"{loss.item():.4f}",
"aux": f"{aux_loss_total.item() if isinstance(aux_loss_total, torch.Tensor) else aux_loss_total:.6f}",
"total": f"{total_loss.item():.4f}",
"R": current_r,
"step": f"{step}/{total_steps}"
})
step += 1
# Evaluate every eval_steps (after batching/optimization)
if step % eval_steps == 0:
model.eval()
logger.info(f"\n[Step {step}/{total_steps}] Running evaluation at eval_steps...")
evaluate_prompt(model, tokenizer, logger=logger)
model.train()
# Routing analysis at eval_steps
logger.info(f"\n[Step {step}] Analyzing routing patterns at eval_steps...")
analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger)
# Save a checkpoint at this step
logger.info(f"\n[Step {step}] Saving checkpoint at eval_steps to {save_path}...")
os.makedirs(save_path, exist_ok=True)
# Option 1: Save only router weights (recommended - more portable)
print("\nSaving trained router weights only...")
router_state_dict = {}
for name, param in model.named_parameters():
if ".block_sparse_moe.gate" in name and param.requires_grad:
router_state_dict[name] = param.data.cpu()
# IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation
for name, buf in model.named_buffers():
if "ema_utilization" in name or "mask_threshold" in name:
logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation")
router_state_dict[name] = buf.data.cpu()
else:
logger.info(f"Skipping buffer {name} (not related to routing)")
torch.save({
'router_state_dict': router_state_dict,
'config': {
'reuse_scale': reuse_scale,
'num_epochs': num_epochs,
'lr': lr,
'model_name': model_name
}
}, os.path.join(save_path, 'rexmoe_routers.pt'))
tokenizer.save_pretrained(save_path)
logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters")
logger.info(f" File: {save_path}/rexmoe_routers.pt")
logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB")
# Option 2: Also save full model (includes architecture, but less portable)
logger.info("\nAlso saving full model with ReXMoE architecture...")
model.save_pretrained(save_path)
# Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled)
# This produces a standard HF model directory that can be loaded without PEFT.
if qlora_enabled:
try:
merged_path = os.path.join(save_path, "merged")
os.makedirs(merged_path, exist_ok=True)
logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}")
# merge_and_unload() returns the base model with LoRA weights folded in.
merged_model = model.merge_and_unload()
merged_model.eval()
tokenizer.save_pretrained(merged_path)
merged_model.save_pretrained(merged_path)
logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading")
except Exception as e:
logger.warning(f"Could not merge and save LoRA weights (continuing): {e}")
avg_epoch_loss = epoch_loss / len(train_loader)
avg_epoch_aux_loss = epoch_aux_loss / len(train_loader)
# Store metrics for convergence tracking
epoch_aux_losses.append(avg_epoch_aux_loss)
logger.info(f"\n{'='*60}")
logger.info(f"Epoch {epoch+1} Summary:")
logger.info(f" Average LM Loss: {avg_epoch_loss:.4f}")
logger.info(f" Average Aux Loss: {avg_epoch_aux_loss:.6f}")
logger.info(f" Average Total Loss: {avg_epoch_loss + avg_epoch_aux_loss:.4f}")
logger.info(f" Final R: {current_r}")
# Evaluate at epoch end
model.eval()
evaluate_prompt(model, tokenizer, logger=logger)
model.train()
# Track and save best checkpoint based on average LM loss
if avg_epoch_loss < best_val:
best_val = avg_epoch_loss
best_epoch = epoch + 1
logger.info(f"New best epoch {best_epoch} with avg LM loss {best_val:.4f} — saving checkpoint to {save_path}")
os.makedirs(save_path, exist_ok=True)
# Option 1: Save only router weights (recommended - more portable)
print("\nSaving trained router weights only...")
router_state_dict = {}
for name, param in model.named_parameters():
if ".block_sparse_moe.gate" in name and param.requires_grad:
router_state_dict[name] = param.data.cpu()
# IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation
for name, buf in model.named_buffers():
if "ema_utilization" in name or "mask_threshold" in name:
logger.info(f"Saving buffer {name} with shape {buf.shape} for pruning evaluation")
router_state_dict[name] = buf.data.cpu()
else:
logger.info(f"Skipping buffer {name} (not related to routing)")
torch.save({
'router_state_dict': router_state_dict,
'config': {
'reuse_scale': reuse_scale,
'num_epochs': num_epochs,
'lr': lr,
'model_name': model_name
}
}, os.path.join(save_path, 'rexmoe_routers.pt'))
tokenizer.save_pretrained(save_path)
logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters")
logger.info(f" File: {save_path}/rexmoe_routers.pt")
logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB")
# Option 2: Also save full model (includes architecture, but less portable)
logger.info("\nAlso saving full model with ReXMoE architecture...")
model.save_pretrained(save_path)
# Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA enabled)
# This produces a standard HF model directory that can be loaded without PEFT.
if qlora_enabled:
try:
merged_path = os.path.join(save_path, "merged")
os.makedirs(merged_path, exist_ok=True)
logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}")
# merge_and_unload() returns the base model with LoRA weights folded in.
merged_model = model.merge_and_unload()
merged_model.eval()
tokenizer.save_pretrained(merged_path)
merged_model.save_pretrained(merged_path)
logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading")
except Exception as e:
logger.warning(f"Could not merge and save LoRA weights (continuing): {e}")
# Check convergence
converged, conv_metrics, conv_warnings = check_router_convergence(
model, total_layers, convergence_history, logger=logger
)
# Get current learning rate
current_lr = optimizer.param_groups[0]['lr']
logger.info(f"\n 📊 Convergence Metrics:")
logger.info("Convergence Metrics:")
logger.info(f" Avg Router Grad Norm: {conv_metrics['avg_router_grad_norm']:.6f}")
if 'grad_stability' in conv_metrics:
print(f" Grad Stability: {conv_metrics['grad_stability']:.6f}")
logger.info(f" Grad Stability: {conv_metrics['grad_stability']:.6f}")
print(f" Current Learning Rate: {current_lr:.2e}")
logger.info(f" Current Learning Rate: {current_lr:.2e}")
if len(epoch_aux_losses) >= 2:
aux_change = abs(epoch_aux_losses[-1] - epoch_aux_losses[-2])
print(f" Aux Loss Change: {aux_change:.6f}")
logger.info(f" Aux Loss Change: {aux_change:.6f}")
# Print warnings
if conv_warnings:
for warning in conv_warnings:
print(f"\n {warning}")
# Already logged in check_router_convergence
if converged :
msg = "✅ CONVERGED: Router weights have stabilized! Gradient norm < 0.1 and stable for 5 epochs"
print(f"\n {msg}")
logger.info(msg)
if not getattr(mode, '_routers_soft_fronze', False):
logger.info("Soft freezing routers for remaining epochs to focus on fine-tuning experts")
for param_group in optimizer.param_groups:
for name, param in model.named_parameters() and param in param_group['params']:
if ".block_sparse_moe.gate" in name:
param_group['weight_decay'] = 0.5
param_group['lr'] = current_lr * 0.1
setattr(mode, '_routers_soft_fronze', True)
model._routers_soft_fronze = True
elif len(convergence_history) >= 5:
msg = "⏳ Not yet converged - continuing training..."
print(f"\n {msg}")
logger.info(msg)
else:
msg = "ℹ️ Collecting convergence data (need 5 epochs minimum)..."
print(f"\n {msg}")
logger.info(msg)
print(f"{'='*60}")
# Analyze routing patterns at epoch end
print(f"\n📊 Routing Pattern Analysis (Epoch {epoch+1}):")
logger.info(f"Routing Pattern Analysis (Epoch {epoch+1}):")
print("-" * 60)
analyze_routing_patterns(model, train_loader, current_r, total_layers, device, logger=logger)
print("-" * 60)
# Step the learning rate scheduler
if scheduler is not None:
scheduler.step()
logger.info(f"LR stepped to: {optimizer.param_groups[0]['lr']:.2e}")
# Final convergence report
print(f"\n{'='*80}")
print("📈 Training Convergence Summary")
print(f"{'='*80}")
logger.info("="*80)
logger.info("Training Convergence Summary")
logger.info("="*80)
if len(convergence_history) > 0:
print(f"\nRouter Gradient Norms Over Epochs:")
logger.info("Router Gradient Norms Over Epochs:")
for i, grad_norm in enumerate(convergence_history):
trend = ""
if i > 0:
change = grad_norm - convergence_history[i-1]
trend = f" (Δ {change:+.6f})"
msg = f" Epoch {i+1}: {grad_norm:.6f}{trend}"
print(msg)
logger.info(msg)
if len(epoch_aux_losses) > 0:
print(f"\nAuxiliary Loss Over Epochs:")
logger.info("Auxiliary Loss Over Epochs:")
for i, aux_loss in enumerate(epoch_aux_losses):
trend = ""
if i > 0:
change = aux_loss - epoch_aux_losses[i-1]
trend = f" (Δ {change:+.6f})"
msg = f" Epoch {i+1}: {aux_loss:.6f}{trend}"
print(msg)
logger.info(msg)
# Final convergence assessment
if len(convergence_history) >= 5:
final_converged, final_metrics, final_warnings = check_router_convergence(
model, total_layers, convergence_history, logger=logger
)
print(f"\n{'='*80}")
print(f"Final Convergence Status:")
logger.info("="*80)
logger.info("Final Convergence Status:")
if final_converged:
msg = "✅ CONVERGED - Routers have reached stable configuration"
print(f" {msg}")
logger.info(msg)
print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (< 0.1)")
print(f" - Stability: {final_metrics['grad_stability']:.6f} (< 0.01)")
logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}")
logger.info(f" Stability: {final_metrics['grad_stability']:.6f}")
msg = "Safe to deploy or proceed to parameter merging"
print(f" {msg}")
logger.info(msg)
else:
msg = "⚠️ NOT FULLY CONVERGED"
print(f" {msg}")
logger.warning(msg)
print(f" Current metrics:")
print(f" - Gradient norm: {final_metrics['avg_router_grad_norm']:.6f} (target: < 0.1)")
logger.info(f" Gradient norm: {final_metrics['avg_router_grad_norm']:.6f}")
if 'grad_stability' in final_metrics:
print(f" - Stability: {final_metrics['grad_stability']:.6f} (target: < 0.01)")
logger.info(f" Stability: {final_metrics['grad_stability']:.6f}")
print(f" Consider training for more epochs if:")
print(f" - Aux loss still decreasing significantly")
print(f" - Routing patterns still changing")
print(f" - Gradient norms not stabilized")
print(f"{'='*80}\n")
logger.info("="*80)
else:
print(f"\n{'='*80}")
print(f"Convergence Status: Insufficient data (< 5 epochs)")
print(f" Run for at least 5 epochs for convergence analysis")
print(f"{'='*80}\n")
logger.info("Convergence Status: Insufficient data (< 5 epochs)")
# Save model
print(f"\n[7/7] Saving router-adapted checkpoint to: {save_path}")
os.makedirs(save_path, exist_ok=True)
# Option 1: Save only router weights (recommended - more portable)
logger.info("\nSaving trained router weights only...")
router_state_dict = {}
for name, param in model.named_parameters():
if ".block_sparse_moe.gate" in name and param.requires_grad:
router_state_dict[name] = param.data.cpu()
# IMPORTANT: Save EMA buffers and thresholds for permanent pruning evaluation
for name, buf in model.named_buffers():
if "ema_utilization" in name or "mask_threshold" in name:
router_state_dict[name] = buf.data.cpu()
torch.save({
'router_state_dict': router_state_dict,
'config': {
'reuse_scale': reuse_scale,
'num_epochs': num_epochs,
'lr': lr,
'model_name': model_name
}
}, os.path.join(save_path, 'rexmoe_routers.pt'))
tokenizer.save_pretrained(save_path)
logger.info(f"✓ Saved trained router weights: {len(router_state_dict)} parameters")
logger.info(f" File: {save_path}/rexmoe_routers.pt")
logger.info(f" Size: {os.path.getsize(os.path.join(save_path, 'rexmoe_routers.pt')) / 1024 / 1024:.2f} MB")
# Option 2: Also save full model (includes architecture, but less portable)
logger.info("\nAlso saving full model with ReXMoE architecture...")
model.save_pretrained(save_path)
# Option 3: Save a SINGLE-DIR full model with LoRA merged (if QLoRA was enabled)
if qlora_enabled:
try:
merged_path = os.path.join(save_path, "merged")
os.makedirs(merged_path, exist_ok=True)
logger.info(f"\nMerging LoRA adapters into base weights and saving to: {merged_path}")
merged_model = model.merge_and_unload()
merged_model.eval()
tokenizer.save_pretrained(merged_path)
merged_model.save_pretrained(merged_path)
logger.info("✓ Saved merged full model (base+routers+LoRA) for one-step loading")
except Exception as e:
logger.warning(f"Could not merge and save LoRA weights (continuing): {e}")
# Save the custom classes for reloading
import shutil
shutil.copy(__file__, os.path.join(save_path, 'rexmoe_architecture.py'))
# Print final statistics
full_model_size = sum(os.path.getsize(os.path.join(save_path, f))
for f in os.listdir(save_path)
if f.endswith('.bin') or f.endswith('.safetensors')) / 1024 / 1024 / 1024
logger.info("="*80)
logger.info("✓ Training complete. Two checkpoint formats saved:")
logger.info(" 1. Router weights only: rexmoe_routers.pt (portable)")
logger.info(" 2. Full model: pytorch_model.bin (requires rexmoe_architecture.py)")
logger.info(f"\nCheckpoint directory: {save_path}")
logger.info(f"Full model size: {full_model_size:.2f} GB")
logger.info("="*80)
return model
# ==================== 5. USAGE ====================
if __name__ == "__main__":
# arg parser
parser = argparse.ArgumentParser(description="ReXMoE Training")
parser.add_argument("--model_name", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Pretrained model name")
parser.add_argument("--model_path", type=str, default="microsoft/Phi-mini-MoE-instruct", help="Path to pretrained model") # ../models/models/microsoft/Phi-mini-MoE-instruct
parser.add_argument("--dataset_path", type=str, default="../dataset/alpaca_data_cleaned.json", help="Path to dataset JSON")
parser.add_argument("--mode", type=str, choices=["IF","P", "IF_2"], default="IF", help="Dataset mode: IF = instruction-finetune (Alpaca), P = pretraining (C4)")
parser.add_argument("--reuse_scale", type=int, default=3, help="Reuse scale R for cross-layer routing")
parser.add_argument("--epoch", type=int, default=5, help="Number of training epochs")
parser.add_argument("--num_samples", type=int, default=10000, help="Number of training samples to use from the dataset")
parser.add_argument("--batch_size", type=int, default=32, help="Training batch size")
parser.add_argument("--max_seq_length", type=int, default=512, help="Maximum sequence length")
parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate")
parser.add_argument("--warmup_steps", type=int, default=200, help="Number of steps to warm up with R=1 (routers only)")
parser.add_argument("--psr_enabled", action='store_true', help="Enable Progressive Scaling Routing (PSR)")
parser.add_argument("--use_scheduler", action='store_true', default=True, help="Use learning rate scheduler")
parser.add_argument("--gradient_checkpointing", action='store_true', help="Enable gradient checkpointing for memory efficiency")
parser.add_argument("--met_enabled", action='store_true', help="Enable Masked Expert Training (MET)")
parser.add_argument("--met_mask_ratio", type=float, default=0.1, help="MET mask ratio (0.1 = mask 10% of experts)")
parser.add_argument("--met_warmup", type=float, default=0.5, help="Proportion of steps to warm up MET (no masking)")
parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate every N steps during training")
parser.add_argument("--log_loss_steps_percent", type=int, default=10, help="Log loss every N%% of total steps")
parser.add_argument("--full_lora", action='store_true', help="Enable full LoRA training")
parser.add_argument("--lora_r", type=int, default=16, help="LoRA rank")
parser.add_argument("--save_path", type=str, default="./rexmoe_natural_phi_mini_moe", help="Base path to save trained model (timestamp will be prefixed)")
parser.add_argument("--aux_loss_weight", type=float, default=0.02, help="Auxiliary loss weight")
args = parser.parse_args()
# Prefix save path with timestamp (DDMM_HHMMSS) to distinguish runs
from datetime import datetime as _dt
timestamp = _dt.now().strftime("%d%m_%H%M%S")
timed_save_path = os.path.join(os.path.dirname(args.save_path), f"{timestamp}_" + f"{int(args.met_mask_ratio*100)}_" + os.path.basename(args.save_path)) + f"_R{args.reuse_scale}"
model = train_rexmoe(
model_name=args.model_name,
model_path=args.model_path,
dataset_path=args.dataset_path,
dataset_mode=args.mode,
reuse_scale=args.reuse_scale,
num_samples=args.num_samples,
num_epochs=args.epoch, # 5 epochs sufficient for router adaptation
batch_size=args.batch_size, # As specified
max_seq_length=args.max_seq_length, # As specified
lr=args.lr,
warmup_steps=args.warmup_steps,
psr_enabled=args.psr_enabled, # Critical: prevents early collapse
use_scheduler=args.use_scheduler,
gradient_checkpointing=args.gradient_checkpointing,
met_enabled=args.met_enabled,
met_mask_ratio=args.met_mask_ratio,
met_warmup=args.met_warmup,
eval_steps=args.eval_steps,
log_loss_steps_percent=args.log_loss_steps_percent,
full_lora=args.full_lora,
lora_r=args.lora_r,
save_path=timed_save_path,
aux_loss_weight=args.aux_loss_weight
)
print(f"\n✓ All done! Model saved to {timed_save_path}")