| | """ |
| | Lion Optimizer for BitTransformerLM Extensions |
| | ============================================== |
| | |
| | Implementation of the Lion optimizer (EvoLved Sign Momentum). |
| | Based on "Symbolic Discovery of Optimization Algorithms" research. |
| | |
| | Key features: |
| | - Sign-based momentum updates |
| | - Extremely memory efficient (only stores momentum) |
| | - Often outperforms Adam/AdamW with larger learning rates |
| | - Compatible with BitTransformerLM's training infrastructure |
| | """ |
| |
|
| | import torch |
| | from torch.optim.optimizer import Optimizer |
| | from typing import Any, Dict, List, Optional, Tuple, Union |
| |
|
| |
|
| | class Lion(Optimizer): |
| | """ |
| | Lion optimizer implementation. |
| | |
| | Lion uses the sign of the interpolated momentum for parameter updates, |
| | making it very memory efficient while maintaining competitive performance. |
| | |
| | Args: |
| | params: Iterable of parameters to optimize |
| | lr: Learning rate (default: 1e-4, typically needs to be smaller than Adam) |
| | betas: Coefficients for computing momentum (default: (0.9, 0.99)) |
| | weight_decay: Weight decay coefficient (default: 0.0) |
| | eps: Small constant for numerical stability (default: 1e-8) |
| | maximize: Whether to maximize the objective (default: False) |
| | """ |
| | |
| | def __init__( |
| | self, |
| | params, |
| | lr: float = 1e-4, |
| | betas: Tuple[float, float] = (0.9, 0.99), |
| | weight_decay: float = 0.0, |
| | eps: float = 1e-8, |
| | maximize: bool = False, |
| | ): |
| | if not 0.0 <= lr: |
| | raise ValueError(f"Invalid learning rate: {lr}") |
| | if not 0.0 <= betas[0] < 1.0: |
| | raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") |
| | if not 0.0 <= betas[1] < 1.0: |
| | raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") |
| | if not 0.0 <= weight_decay: |
| | raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| | if not 0.0 <= eps: |
| | raise ValueError(f"Invalid epsilon value: {eps}") |
| | |
| | defaults = dict( |
| | lr=lr, |
| | betas=betas, |
| | weight_decay=weight_decay, |
| | eps=eps, |
| | maximize=maximize, |
| | ) |
| | super().__init__(params, defaults) |
| | |
| | @torch.no_grad() |
| | def step(self, closure=None): |
| | """Perform a single optimization step.""" |
| | loss = None |
| | if closure is not None: |
| | with torch.enable_grad(): |
| | loss = closure() |
| | |
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | |
| | grad = p.grad |
| | if group["maximize"]: |
| | grad = -grad |
| | |
| | if grad.dtype in {torch.float16, torch.bfloat16}: |
| | grad = grad.float() |
| | |
| | state = self.state[p] |
| | |
| | |
| | if len(state) == 0: |
| | state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| | |
| | momentum = state["momentum"] |
| | beta1, beta2 = group["betas"] |
| | |
| | |
| | if group["weight_decay"] != 0: |
| | p.mul_(1 - group["lr"] * group["weight_decay"]) |
| | |
| | |
| | |
| | interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
| | |
| | |
| | |
| | p.add_(torch.sign(interpolated), alpha=-group["lr"]) |
| | |
| | |
| | |
| | momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
| | |
| | return loss |
| |
|
| |
|
| | def configure_lion_optimizer( |
| | model: torch.nn.Module, |
| | lr: float = 1e-4, |
| | betas: Tuple[float, float] = (0.9, 0.99), |
| | weight_decay: float = 0.01, |
| | total_steps: Optional[int] = None, |
| | warmup_ratio: float = 0.1, |
| | **lion_kwargs |
| | ) -> Tuple[Lion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| | """ |
| | Configure Lion optimizer with OneCycle learning rate schedule. |
| | |
| | This function provides a drop-in replacement for BitTransformerLM's |
| | configure_optimizer function, using Lion instead of AdamW. |
| | |
| | Note: Lion typically works well with learning rates about 3-10x smaller |
| | than Adam/AdamW, but higher weight decay (0.01-0.1). |
| | |
| | Args: |
| | model: PyTorch model to optimize |
| | lr: Peak learning rate (typically smaller than Adam) |
| | betas: Beta coefficients for momentum computation |
| | weight_decay: Weight decay coefficient (can be higher than Adam) |
| | total_steps: Total training steps for OneCycle schedule |
| | warmup_ratio: Fraction of steps for warmup |
| | **lion_kwargs: Additional arguments for Lion optimizer |
| | |
| | Returns: |
| | Tuple of (optimizer, scheduler) |
| | """ |
| | |
| | decay_params = [] |
| | no_decay_params = [] |
| | |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | |
| | if param.dim() >= 2: |
| | decay_params.append(param) |
| | else: |
| | no_decay_params.append(param) |
| | |
| | param_groups = [ |
| | {"params": decay_params, "weight_decay": weight_decay}, |
| | {"params": no_decay_params, "weight_decay": 0.0}, |
| | ] |
| | |
| | optimizer = Lion( |
| | param_groups, |
| | lr=lr, |
| | betas=betas, |
| | **lion_kwargs |
| | ) |
| | |
| | scheduler = None |
| | if total_steps is not None and total_steps > 0: |
| | scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=lr, |
| | total_steps=total_steps, |
| | pct_start=warmup_ratio, |
| | anneal_strategy='cos', |
| | cycle_momentum=False, |
| | div_factor=25.0, |
| | final_div_factor=1e4, |
| | ) |
| | |
| | return optimizer, scheduler |
| |
|
| |
|
| | def create_lion_training_config( |
| | lr: float = 1e-4, |
| | betas: Tuple[float, float] = (0.9, 0.99), |
| | weight_decay: float = 0.01, |
| | **kwargs |
| | ) -> Dict[str, Any]: |
| | """ |
| | Create a training configuration dictionary for Lion optimizer. |
| | |
| | This can be used with BitTransformerLM's training scripts by passing |
| | the config to the training loop. |
| | |
| | Args: |
| | lr: Learning rate |
| | betas: Beta coefficients for momentum |
| | weight_decay: Weight decay coefficient |
| | **kwargs: Additional configuration options |
| | |
| | Returns: |
| | Dictionary containing training configuration |
| | """ |
| | config = { |
| | "optimizer_type": "lion", |
| | "optimizer_config": { |
| | "lr": lr, |
| | "betas": betas, |
| | "weight_decay": weight_decay, |
| | **kwargs |
| | }, |
| | "scheduler_type": "onecycle", |
| | } |
| | |
| | return config |
| |
|
| |
|
| | class AdaptiveLion(Lion): |
| | """ |
| | Enhanced Lion optimizer with adaptive learning rate scaling. |
| | |
| | This variant automatically adjusts the learning rate based on the |
| | magnitude of gradients and momentum, potentially improving stability. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | params, |
| | lr: float = 1e-4, |
| | betas: Tuple[float, float] = (0.9, 0.99), |
| | weight_decay: float = 0.0, |
| | eps: float = 1e-8, |
| | maximize: bool = False, |
| | adaptive_scale: float = 0.1, |
| | min_scale: float = 0.01, |
| | max_scale: float = 10.0, |
| | ): |
| | """ |
| | Args: |
| | adaptive_scale: Scaling factor for adaptive adjustment |
| | min_scale: Minimum learning rate scale |
| | max_scale: Maximum learning rate scale |
| | """ |
| | self.adaptive_scale = adaptive_scale |
| | self.min_scale = min_scale |
| | self.max_scale = max_scale |
| | |
| | super().__init__(params, lr, betas, weight_decay, eps, maximize) |
| | |
| | @torch.no_grad() |
| | def step(self, closure=None): |
| | """Perform optimization step with adaptive scaling.""" |
| | loss = None |
| | if closure is not None: |
| | with torch.enable_grad(): |
| | loss = closure() |
| | |
| | for group in self.param_groups: |
| | for p in group["params"]: |
| | if p.grad is None: |
| | continue |
| | |
| | grad = p.grad |
| | if group["maximize"]: |
| | grad = -grad |
| | |
| | if grad.dtype in {torch.float16, torch.bfloat16}: |
| | grad = grad.float() |
| | |
| | state = self.state[p] |
| | |
| | if len(state) == 0: |
| | state["momentum"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| | state["step"] = 0 |
| | |
| | momentum = state["momentum"] |
| | state["step"] += 1 |
| | beta1, beta2 = group["betas"] |
| | |
| | |
| | grad_norm = grad.norm().item() |
| | momentum_norm = momentum.norm().item() |
| | |
| | |
| | if momentum_norm > 1e-8: |
| | scale = 1.0 + self.adaptive_scale * (grad_norm / momentum_norm - 1.0) |
| | scale = torch.clamp(torch.tensor(scale), self.min_scale, self.max_scale).item() |
| | else: |
| | scale = 1.0 |
| | |
| | adaptive_lr = group["lr"] * scale |
| | |
| | |
| | if group["weight_decay"] != 0: |
| | p.mul_(1 - adaptive_lr * group["weight_decay"]) |
| | |
| | |
| | interpolated = momentum.mul(beta1).add_(grad, alpha=1 - beta1) |
| | p.add_(torch.sign(interpolated), alpha=-adaptive_lr) |
| | momentum.mul_(beta2).add_(grad, alpha=1 - beta2) |
| | |
| | return loss |
| |
|
| |
|
| | def configure_adaptive_lion_optimizer( |
| | model: torch.nn.Module, |
| | lr: float = 1e-4, |
| | adaptive_scale: float = 0.1, |
| | **kwargs |
| | ) -> Tuple[AdaptiveLion, Optional[torch.optim.lr_scheduler._LRScheduler]]: |
| | """Configure AdaptiveLion optimizer with learning rate scheduling.""" |
| | |
| | decay_params = [] |
| | no_decay_params = [] |
| | |
| | for name, param in model.named_parameters(): |
| | if not param.requires_grad: |
| | continue |
| | if param.dim() >= 2: |
| | decay_params.append(param) |
| | else: |
| | no_decay_params.append(param) |
| | |
| | param_groups = [ |
| | {"params": decay_params, "weight_decay": kwargs.get("weight_decay", 0.01)}, |
| | {"params": no_decay_params, "weight_decay": 0.0}, |
| | ] |
| | |
| | optimizer = AdaptiveLion( |
| | param_groups, |
| | lr=lr, |
| | adaptive_scale=adaptive_scale, |
| | **{k: v for k, v in kwargs.items() if k != "weight_decay"} |
| | ) |
| | |
| | scheduler = None |
| | total_steps = kwargs.get("total_steps") |
| | if total_steps is not None and total_steps > 0: |
| | scheduler = torch.optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | max_lr=lr, |
| | total_steps=total_steps, |
| | pct_start=kwargs.get("warmup_ratio", 0.1), |
| | anneal_strategy='cos', |
| | cycle_momentum=False, |
| | div_factor=25.0, |
| | final_div_factor=1e4, |
| | ) |
| | |
| | return optimizer, scheduler |
| |
|
| |
|
| | |
| | def integrate_with_bittransformerlm(): |
| | """ |
| | Example of how to integrate Lion optimizer with BitTransformerLM training. |
| | |
| | Usage: |
| | from BTLM_Extensions.lion_optimizer import configure_lion_optimizer |
| | |
| | # Replace the standard optimizer configuration |
| | # Note: Lion typically needs smaller learning rates than Adam |
| | optimizer, scheduler = configure_lion_optimizer( |
| | model, lr=1e-4, weight_decay=0.01, total_steps=1000 |
| | ) |
| | |
| | # Use in training loop |
| | train_loop(model, data, optimizer=optimizer, scheduler=scheduler) |
| | |
| | # For adaptive version: |
| | from BTLM_Extensions.lion_optimizer import configure_adaptive_lion_optimizer |
| | |
| | optimizer, scheduler = configure_adaptive_lion_optimizer( |
| | model, lr=1e-4, adaptive_scale=0.1, total_steps=1000 |
| | ) |
| | """ |
| | pass |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | import torch.nn as nn |
| | |
| | model = nn.Sequential( |
| | nn.Linear(10, 20), |
| | nn.ReLU(), |
| | nn.Linear(20, 1) |
| | ) |
| | |
| | print("Testing standard Lion optimizer...") |
| | optimizer, scheduler = configure_lion_optimizer(model, lr=1e-4, total_steps=100) |
| | |
| | |
| | x = torch.randn(32, 10) |
| | y = torch.randn(32, 1) |
| | |
| | pred = model(x) |
| | loss = nn.functional.mse_loss(pred, y) |
| | initial_loss = loss.item() |
| | loss.backward() |
| | |
| | optimizer.step() |
| | if scheduler: |
| | scheduler.step() |
| | |
| | print(f"Initial loss: {initial_loss:.4f}") |
| | |
| | |
| | print("Testing Adaptive Lion optimizer...") |
| | model2 = nn.Sequential( |
| | nn.Linear(10, 20), |
| | nn.ReLU(), |
| | nn.Linear(20, 1) |
| | ) |
| | |
| | optimizer2, scheduler2 = configure_adaptive_lion_optimizer( |
| | model2, lr=1e-4, adaptive_scale=0.1, total_steps=100 |
| | ) |
| | |
| | pred2 = model2(x) |
| | loss2 = nn.functional.mse_loss(pred2, y) |
| | loss2.backward() |
| | optimizer2.step() |
| | if scheduler2: |
| | scheduler2.step() |
| | |
| | print("Lion optimizers test completed successfully!") |
| | print(f"Standard Lion loss: {initial_loss:.4f}") |
| | print(f"Adaptive Lion loss: {loss2.item():.4f}") |