| | """ |
| | Adafactor Optimizer for BitTransformerLM Extensions |
| | =================================================== |
| | |
| | Implementation of the Adafactor optimizer with memory-efficient factorization. |
| | Based on "Adafactor: Adaptive Learning Rates with Sublinear Memory Cost" research. |
| | |
| | Key features: |
| | - Factorized second moment estimates for memory efficiency |
| | - Automatic scaling of learning rates |
| | - Relative step size and clip threshold |
| | - Compatible with BitTransformerLM's training infrastructure |
| | """ |
| |
|
| | import math |
| | import torch |
| | from torch.optim.optimizer import Optimizer |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| |
|
| | class Adafactor(Optimizer): |
| | """ |
| | Adafactor optimizer implementation. |
| | |
| | Adafactor reduces memory usage by factorizing the second moment estimates |
| | for parameters with 2 or more dimensions, making it highly memory efficient |
| | for large transformer models. |
| | |
| | Args: |
| | params: Iterable of parameters to optimize |
| | lr: External learning rate (default: None, uses automatic scaling) |
| | eps2: Regularization constant for second moment (default: 1e-30) |
| | cliping_threshold: Threshold for adaptive clipping (default: 1.0) |
| | decay_rate: Coefficient used for computing running averages (default: -0.8) |
| | beta1: Coefficient used for computing running averages of gradient (default: None) |
| | weight_decay: Weight decay coefficient (default: 0.0) |
| | scale_parameter: If True, learning rate is scaled by root mean square of parameter (default: True) |
| | relative_step_size: If True, use relative step size (default: True) |
| | warmup_init: If True, warmup learning rate (default: False) |
| | """ |
| | |
| | def __init__( |
| | self, |
| | params, |
| | lr: Optional[float] = None, |
| | eps2: float = 1e-30, |
| | cliping_threshold: float = 1.0, |
| | decay_rate: float = -0.8, |
| | beta1: Optional[float] = None, |
| | weight_decay: float = 0.0, |
| | scale_parameter: bool = True, |
| | relative_step_size: bool = True, |
| | warmup_init: bool = False, |
| | ): |
| | if lr is not None and lr <= 0.0: |
| | raise ValueError(f"Invalid learning rate: {lr}") |
| | if weight_decay < 0.0: |
| | raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| | |
| | defaults = dict( |
| | lr=lr, |
| | eps2=eps2, |
| | cliping_threshold=cliping_threshold, |
| | decay_rate=decay_rate, |
| | beta1=beta1, |
| | weight_decay=weight_decay, |
| | scale_parameter=scale_parameter, |
| | relative_step_size=relative_step_size, |
| | warmup_init=warmup_init, |
| | ) |
| | super().__init__(params, defaults) |
| | |
| | def _get_lr(self, param_group, param_state): |
| | """Compute learning rate for parameter group.""" |
| | min_step = 1e-6 * param_state["step"] if param_group["warmup_init"] else 1e-2 |
| | rel_step_sz = min(min_step, 1.0 / math.sqrt(param_state["step"])) |
| | param_scale = 1.0 |
| | if param_group["scale_parameter"]: |
| | param_scale = max(param_group["eps2"], param_state["RMS"]) |
| | return param_scale * rel_step_sz |
| | |
| | def _get_options(self, param_group, param_shape): |
| | """Get optimization options for parameter.""" |
| | factored = len(param_shape) >= 2 |
| | use_first_moment = param_group["beta1"] is not None |
| | return factored, use_first_moment |
| | |
| | def _rms(self, tensor): |
| | """Root mean square.""" |
| | return tensor.norm(2) / (tensor.numel() ** 0.5) |
| | |
| | def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): |
| | """Approximation of exponential moving average of square of gradient.""" |
| | r_factor = ((exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)) |
| | .rsqrt_()) |
| | c_factor = ((exp_avg_sq_col).rsqrt()) |
| | return torch.mul(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) |
| | |
| | @torch.no_grad() |
| | def step(self, closure=None): |
| | """Perform a single optimization step.""" |
| | loss = None |
| | if closure is not None: |
| | with torch.enable_grad(): |
| | loss = closure() |
| | |
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | |
| | grad = p.grad |
| | if grad.dtype in {torch.float16, torch.bfloat16}: |
| | grad = grad.float() |
| | |
| | state = self.state[p] |
| | grad_shape = grad.shape |
| | |
| | factored, use_first_moment = self._get_options(group, grad_shape) |
| | |
| | |
| | if len(state) == 0: |
| | state["step"] = 0 |
| | |
| | if use_first_moment: |
| | |
| | state["exp_avg"] = torch.zeros_like(grad).float() |
| | if factored: |
| | state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).float() |
| | state["exp_avg_sq_col"] = torch.zeros( |
| | grad_shape[:-2] + grad_shape[-1:]).float() |
| | else: |
| | state["exp_avg_sq"] = torch.zeros_like(grad).float() |
| | |
| | state["RMS"] = 0 |
| | |
| | p_data_fp32 = p.data |
| | if p.data.dtype in {torch.float16, torch.bfloat16}: |
| | p_data_fp32 = p_data_fp32.float() |
| | |
| | state["step"] += 1 |
| | state["RMS"] = self._rms(p_data_fp32) |
| | |
| | lr = group["lr"] |
| | if group["lr"] is None: |
| | lr = self._get_lr(group, state) |
| | |
| | beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) |
| | update = grad**2 + group["eps2"] |
| | |
| | if factored: |
| | exp_avg_sq_row = state["exp_avg_sq_row"] |
| | exp_avg_sq_col = state["exp_avg_sq_col"] |
| | |
| | exp_avg_sq_row.mul_(beta2t).add_( |
| | update.mean(dim=-1), alpha=1.0 - beta2t) |
| | exp_avg_sq_col.mul_(beta2t).add_( |
| | update.mean(dim=-2), alpha=1.0 - beta2t) |
| | |
| | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) |
| | update.mul_(grad) |
| | else: |
| | exp_avg_sq = state["exp_avg_sq"] |
| | exp_avg_sq.mul_(beta2t).add_(update, alpha=1.0 - beta2t) |
| | update = exp_avg_sq.rsqrt().mul_(grad) |
| | |
| | update.div_(max(1.0, self._rms(update) / group["cliping_threshold"])) |
| | |
| | if use_first_moment: |
| | exp_avg = state["exp_avg"] |
| | exp_avg.mul_(group["beta1"]).add_(update, alpha=1 - group["beta1"]) |
| | update = exp_avg |
| | |
| | if group["weight_decay"] != 0: |
| | p_data_fp32.mul_(1 - group["weight_decay"] * lr) |
| | |
| | p_data_fp32.add_(update, alpha=-lr) |
| | |
| | if p.data.dtype in {torch.float16, torch.bfloat16}: |
| | p.data.copy_(p_data_fp32) |
| | |
| | return loss |
| |
|
| |
|
| | def configure_adafactor_optimizer( |
| | model: torch.nn.Module, |
| | lr: Optional[float] = None, |
| | weight_decay: float = 0.0, |
| | total_steps: Optional[int] = None, |
| | warmup_ratio: float = 0.1, |
| | scale_parameter: bool = True, |
| | relative_step_size: bool = True, |
| | warmup_init: bool = False, |
| | cliping_threshold: float = 1.0, |
| | decay_rate: float = -0.8, |
| | beta1: Optional[float] = None, |
| | eps2: float = 1e-30, |
| | **adafactor_kwargs |
| | ) -> Tuple[Adafactor, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| | """ |
| | Configure Adafactor optimizer with optional learning rate scheduling. |
| | |
| | This function provides a drop-in replacement for BitTransformerLM's |
| | configure_optimizer function, using Adafactor instead of AdamW. |
| | |
| | Args: |
| | model: PyTorch model to optimize |
| | lr: External learning rate (None for automatic scaling) |
| | weight_decay: Weight decay coefficient |
| | total_steps: Total training steps for scheduling |
| | warmup_ratio: Fraction of steps for warmup |
| | scale_parameter: Whether to scale learning rate by parameter RMS |
| | relative_step_size: Whether to use relative step size |
| | warmup_init: Whether to use warmup initialization |
| | cliping_threshold: Threshold for adaptive clipping |
| | decay_rate: Decay rate for second moment estimates |
| | beta1: Coefficient for first moment (None to disable) |
| | eps2: Regularization constant |
| | **adafactor_kwargs: Additional arguments for Adafactor |
| | |
| | Returns: |
| | Tuple of (optimizer, scheduler) |
| | """ |
| | |
| | params = [p for p in model.parameters() if p.requires_grad] |
| | |
| | optimizer = Adafactor( |
| | params, |
| | lr=lr, |
| | weight_decay=weight_decay, |
| | scale_parameter=scale_parameter, |
| | relative_step_size=relative_step_size, |
| | warmup_init=warmup_init, |
| | cliping_threshold=cliping_threshold, |
| | decay_rate=decay_rate, |
| | beta1=beta1, |
| | eps2=eps2, |
| | **adafactor_kwargs |
| | ) |
| | |
| | scheduler = None |
| | |
| | if total_steps is not None and total_steps > 0 and lr is not None: |
| | scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=lr, |
| | total_steps=total_steps, |
| | pct_start=warmup_ratio, |
| | anneal_strategy='cos', |
| | cycle_momentum=False, |
| | div_factor=25.0, |
| | final_div_factor=1e4, |
| | ) |
| | |
| | return optimizer, scheduler |
| |
|
| |
|
| | class AdafactorScheduler(torch.optim.lr_scheduler._LRScheduler): |
| | """ |
| | Custom scheduler for Adafactor with warmup and polynomial decay. |
| | |
| | This scheduler is specifically designed to work with Adafactor's |
| | relative step size feature. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | optimizer: Adafactor, |
| | warmup_steps: int = 1000, |
| | total_steps: Optional[int] = None, |
| | min_lr_ratio: float = 0.1, |
| | polynomial_power: float = 1.0, |
| | last_epoch: int = -1, |
| | ): |
| | self.warmup_steps = warmup_steps |
| | self.total_steps = total_steps |
| | self.min_lr_ratio = min_lr_ratio |
| | self.polynomial_power = polynomial_power |
| | super().__init__(optimizer, last_epoch) |
| | |
| | def get_lr(self): |
| | step = self.last_epoch + 1 |
| | |
| | if step < self.warmup_steps: |
| | |
| | return [base_lr * step / self.warmup_steps for base_lr in self.base_lrs] |
| | |
| | if self.total_steps is None: |
| | |
| | return self.base_lrs |
| | |
| | |
| | progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) |
| | progress = min(progress, 1.0) |
| | decay_factor = (1 - progress) ** self.polynomial_power |
| | decay_factor = max(decay_factor, self.min_lr_ratio) |
| | |
| | return [base_lr * decay_factor for base_lr in self.base_lrs] |
| |
|
| |
|
| | def configure_adafactor_with_scheduler( |
| | model: torch.nn.Module, |
| | lr: float = 1e-3, |
| | warmup_steps: int = 1000, |
| | total_steps: Optional[int] = None, |
| | weight_decay: float = 0.0, |
| | **kwargs |
| | ) -> Tuple[Adafactor, AdafactorScheduler]: |
| | """ |
| | Configure Adafactor optimizer with custom Adafactor scheduler. |
| | |
| | Args: |
| | model: PyTorch model to optimize |
| | lr: Base learning rate |
| | warmup_steps: Number of warmup steps |
| | total_steps: Total training steps |
| | weight_decay: Weight decay coefficient |
| | **kwargs: Additional arguments for Adafactor |
| | |
| | Returns: |
| | Tuple of (optimizer, scheduler) |
| | """ |
| | params = [p for p in model.parameters() if p.requires_grad] |
| | |
| | optimizer = Adafactor( |
| | params, |
| | lr=lr, |
| | weight_decay=weight_decay, |
| | relative_step_size=False, |
| | **kwargs |
| | ) |
| | |
| | scheduler = AdafactorScheduler( |
| | optimizer, |
| | warmup_steps=warmup_steps, |
| | total_steps=total_steps, |
| | ) |
| | |
| | return optimizer, scheduler |
| |
|
| |
|
| | def create_adafactor_training_config( |
| | lr: Optional[float] = None, |
| | weight_decay: float = 0.0, |
| | scale_parameter: bool = True, |
| | relative_step_size: bool = True, |
| | warmup_init: bool = False, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Create a training configuration dictionary for Adafactor optimizer. |
| | |
| | Args: |
| | lr: External learning rate (None for automatic) |
| | weight_decay: Weight decay coefficient |
| | scale_parameter: Whether to scale by parameter RMS |
| | relative_step_size: Whether to use relative step size |
| | warmup_init: Whether to use warmup initialization |
| | **kwargs: Additional configuration options |
| | |
| | Returns: |
| | Dictionary containing training configuration |
| | """ |
| | config = { |
| | "optimizer_type": "adafactor", |
| | "optimizer_config": { |
| | "lr": lr, |
| | "weight_decay": weight_decay, |
| | "scale_parameter": scale_parameter, |
| | "relative_step_size": relative_step_size, |
| | "warmup_init": warmup_init, |
| | **kwargs |
| | }, |
| | "scheduler_type": "adafactor_custom" if lr is None else "onecycle", |
| | } |
| | |
| | return config |
| |
|
| |
|
| | |
| | def integrate_with_bittransformerlm(): |
| | """ |
| | Example of how to integrate Adafactor optimizer with BitTransformerLM training. |
| | |
| | Usage: |
| | from BTLM_Extensions.adafactor_optimizer import configure_adafactor_optimizer |
| | |
| | # Option 1: Use Adafactor with automatic learning rate scaling |
| | optimizer, scheduler = configure_adafactor_optimizer( |
| | model, lr=None, total_steps=1000 # lr=None enables auto-scaling |
| | ) |
| | |
| | # Option 2: Use Adafactor with fixed learning rate |
| | optimizer, scheduler = configure_adafactor_optimizer( |
| | model, lr=1e-3, total_steps=1000 |
| | ) |
| | |
| | # Option 3: Use Adafactor with custom scheduler |
| | from BTLM_Extensions.adafactor_optimizer import configure_adafactor_with_scheduler |
| | |
| | optimizer, scheduler = configure_adafactor_with_scheduler( |
| | model, lr=1e-3, warmup_steps=100, total_steps=1000 |
| | ) |
| | |
| | # Use in training loop |
| | train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
| | """ |
| | pass |
| |
|
| |
|
| | def analyze_memory_usage(model: torch.nn.Module) -> Dict[str, float]: |
| | """ |
| | Analyze memory usage comparison between optimizers. |
| | |
| | Args: |
| | model: PyTorch model to analyze |
| | |
| | Returns: |
| | Dictionary with memory usage estimates in MB |
| | """ |
| | param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| | param_bytes = param_count * 4 |
| | |
| | |
| | adamw_memory = param_bytes * 4 |
| | |
| | |
| | adafactor_memory = param_bytes |
| | adafactor_memory += param_bytes |
| | |
| | |
| | factored_params = 0 |
| | unfactored_params = 0 |
| | |
| | for p in model.parameters(): |
| | if p.requires_grad: |
| | if len(p.shape) >= 2: |
| | factored_params += p.shape[0] + p.shape[1] |
| | else: |
| | unfactored_params += p.numel() |
| | |
| | adafactor_memory += (factored_params + unfactored_params) * 4 |
| | |
| | return { |
| | "adamw_mb": adamw_memory / (1024 * 1024), |
| | "adafactor_mb": adafactor_memory / (1024 * 1024), |
| | "savings_mb": (adamw_memory - adafactor_memory) / (1024 * 1024), |
| | "savings_percent": ((adamw_memory - adafactor_memory) / adamw_memory) * 100, |
| | } |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import torch.nn as nn |
| | |
| | model = nn.Sequential( |
| | nn.Linear(100, 200), |
| | nn.ReLU(), |
| | nn.Linear(200, 50), |
| | nn.ReLU(), |
| | nn.Linear(50, 1) |
| | ) |
| | |
| | print("Testing Adafactor optimizer...") |
| | |
| | |
| | optimizer, scheduler = configure_adafactor_optimizer( |
| | model, lr=None, total_steps=100 |
| | ) |
| | |
| | |
| | x = torch.randn(32, 100) |
| | y = torch.randn(32, 1) |
| | |
| | pred = model(x) |
| | loss = nn.functional.mse_loss(pred, y) |
| | initial_loss = loss.item() |
| | loss.backward() |
| | |
| | optimizer.step() |
| | if scheduler: |
| | scheduler.step() |
| | |
| | |
| | optimizer2, scheduler2 = configure_adafactor_optimizer( |
| | model, lr=1e-3, total_steps=100 |
| | ) |
| | |
| | pred = model(x) |
| | loss = nn.functional.mse_loss(pred, y) |
| | loss.backward() |
| | optimizer2.step() |
| | if scheduler2: |
| | scheduler2.step() |
| | |
| | |
| | memory_analysis = analyze_memory_usage(model) |
| | |
| | print("Adafactor optimizer test completed successfully!") |
| | print(f"Initial loss: {initial_loss:.4f}") |
| | print(f"Final loss: {loss.item():.4f}") |
| | print(f"Memory analysis: {memory_analysis}") |