| import functools |
| import math |
| from typing import Any, Callable, Dict, List, Optional, Type, Union |
|
|
| import torch |
| from torch.distributed.checkpoint.state_dict import ( |
| StateDictOptions, |
| get_optimizer_state_dict, |
| set_optimizer_state_dict, |
| ) |
| from torch.distributed.checkpoint.stateful import Stateful |
|
|
| from .parallel import ParallelBackendEnum |
| from .utils.import_utils import is_bitsandbytes_available |
|
|
|
|
| class OptimizerWrapper(Stateful): |
| r""" |
| Optimizer wrapper that: |
| - allows step/zero_grad on multiple optimizers needed for virtual pipeline stages |
| - saves/loading optimizer state_dict at checkpoint |
| """ |
|
|
| def __init__( |
| self, |
| model_parts: List[torch.nn.Module], |
| optimizer_cls: Type[torch.optim.Optimizer], |
| optimizer_kwargs: Dict[str, Any], |
| ) -> None: |
| self.optimizer_cls = optimizer_cls |
| self.optimizer_kwargs = optimizer_kwargs |
|
|
| self.optimizers = [] |
| self.model_parts = model_parts |
|
|
| for model in self.model_parts: |
| optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs) |
| self.optimizers.append(optimizer) |
|
|
| def step(self) -> None: |
| for optimizer in self.optimizers: |
| optimizer.step() |
|
|
| def zero_grad(self) -> None: |
| for optimizer in self.optimizers: |
| optimizer.zero_grad() |
|
|
| def state_dict(self) -> Dict[str, Any]: |
| func = functools.partial( |
| get_optimizer_state_dict, |
| options=StateDictOptions(flatten_optimizer_state_dict=True), |
| ) |
| return {k: v for sd in map(func, self.model_parts, self.optimizers) for k, v in sd.items()} |
|
|
| def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
| func = functools.partial( |
| set_optimizer_state_dict, |
| optim_state_dict=state_dict, |
| options=StateDictOptions(flatten_optimizer_state_dict=True), |
| ) |
| list(map(func, self.model_parts, self.optimizers)) |
|
|
|
|
| class SchedulerWrapper: |
| def __init__( |
| self, optimizers, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int |
| ) -> None: |
| self.schedulers = [] |
| for optimizer in optimizers: |
| self.schedulers.append(torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch)) |
|
|
| def step(self) -> None: |
| for scheduler in self.schedulers: |
| scheduler.step() |
|
|
| def get_last_lr(self) -> List[float]: |
| |
| return {f"lr_{idx}": scheduler.get_last_lr() for idx, scheduler in enumerate(self.schedulers)} |
|
|
| def get_lr_scheduler_state(self) -> Dict[str, Any]: |
| state_dict = {} |
| if len(self.schedulers) == 1: |
| state_dict["lr_scheduler"] = self.schedulers[0] |
| else: |
| |
| |
| for idx, lr_scheduler in enumerate(self.schedulers): |
| state_dict[f"lr_scheduler_{idx}"] = lr_scheduler |
| return state_dict |
|
|
|
|
| def get_optimizer( |
| parallel_backend: ParallelBackendEnum, |
| name: str, |
| model_parts: List[torch.nn.Module], |
| learning_rate: float = 1e-3, |
| beta1: float = 0.9, |
| beta2: float = 0.95, |
| beta3: float = 0.999, |
| epsilon: float = 1e-8, |
| weight_decay: float = 1e-4, |
| fused: bool = False, |
| ) -> Union[torch.optim.Optimizer, OptimizerWrapper]: |
| name = name.lower() |
|
|
| _raise_errors_if_packages_not_available(name) |
|
|
| if name == "adam": |
| optimizer_cls = torch.optim.Adam |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| "fused": fused, |
| } |
| elif name == "adamw": |
| optimizer_cls = torch.optim.AdamW |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| "fused": fused, |
| } |
| elif name == "adam-bnb": |
| from bitsandbytes.optim import Adam |
|
|
| optimizer_cls = Adam |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| } |
| elif name == "adamw-bnb": |
| from bitsandbytes.optim import AdamW |
|
|
| optimizer_cls = AdamW |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| } |
| elif name == "adam-bnb-8bit": |
| from bitsandbytes.optim import Adam8bit |
|
|
| optimizer_cls = Adam8bit |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| } |
| elif name == "adamw-bnb-8bit": |
| from bitsandbytes.optim import AdamW8bit |
|
|
| optimizer_cls = AdamW8bit |
| optimizer_kwargs = { |
| "lr": learning_rate, |
| "betas": (beta1, beta2), |
| "eps": epsilon, |
| "weight_decay": weight_decay, |
| } |
|
|
| |
| else: |
| raise ValueError(f"Unsupported optimizer: {name}") |
|
|
| if parallel_backend == ParallelBackendEnum.ACCELERATE: |
| return get_optimizer_accelerate(model_parts, optimizer_cls, optimizer_kwargs) |
| elif parallel_backend == ParallelBackendEnum.PTD: |
| return get_optimizer_ptd(model_parts, optimizer_cls, optimizer_kwargs) |
|
|
|
|
| def get_optimizer_accelerate( |
| model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] |
| ) -> torch.optim.Optimizer: |
| params = [param for model in model_parts for param in model.parameters() if param.requires_grad] |
| optimizer = optimizer_cls(params, **optimizer_kwargs) |
| return optimizer |
|
|
|
|
| def get_optimizer_ptd( |
| model_parts: List[torch.nn.Module], optimizer_cls: Type[torch.optim.Optimizer], optimizer_kwargs: Dict[str, Any] |
| ) -> OptimizerWrapper: |
| return OptimizerWrapper(model_parts, optimizer_cls, optimizer_kwargs) |
|
|
|
|
| def get_lr_scheduler( |
| parallel_backend: ParallelBackendEnum, |
| name: str, |
| optimizer: Union[torch.optim.Optimizer, OptimizerWrapper], |
| step_rules: Optional[str] = None, |
| num_warmup_steps: Optional[int] = None, |
| num_training_steps: Optional[int] = None, |
| num_cycles: int = 1, |
| power: float = 1.0, |
| lr_init: float = 1e-3, |
| lr_end: float = 1e-7, |
| last_epoch: int = -1, |
| ) -> Union[torch.optim.lr_scheduler.LambdaLR, SchedulerWrapper]: |
| name = name.lower() |
| if name == "constant": |
| scheduler_lambda_fn = get_constant_schedule() |
| elif name == "constant_with_warmup": |
| scheduler_lambda_fn = get_constant_schedule_with_warmup(num_warmup_steps) |
| elif name == "piecewise_constant": |
| scheduler_lambda_fn = get_piecewise_constant_schedule(step_rules) |
| elif name == "linear": |
| scheduler_lambda_fn = get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps) |
| elif name == "cosine": |
| scheduler_lambda_fn = get_cosine_schedule_with_warmup(num_warmup_steps, num_training_steps, num_cycles) |
| elif name == "cosine_with_restarts": |
| scheduler_lambda_fn = get_cosine_with_hard_restarts_schedule_with_warmup( |
| num_warmup_steps, num_training_steps, num_cycles |
| ) |
| elif name == "polynomial": |
| scheduler_lambda_fn = get_polynomial_decay_schedule_with_warmup( |
| num_warmup_steps, num_training_steps, lr_init, lr_end, power |
| ) |
| else: |
| raise ValueError(f"Unsupported scheduler: {name}") |
|
|
| if parallel_backend == ParallelBackendEnum.ACCELERATE: |
| return get_lr_scheduler_accelerate(optimizer, scheduler_lambda_fn, last_epoch) |
| elif parallel_backend == ParallelBackendEnum.PTD: |
| return get_lr_scheduler_ptd(optimizer, scheduler_lambda_fn, last_epoch) |
|
|
|
|
| def get_lr_scheduler_accelerate( |
| optimizer: torch.optim.Optimizer, |
| scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], |
| last_epoch: int = -1, |
| ) -> torch.optim.lr_scheduler.LambdaLR: |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_fn, last_epoch) |
| return scheduler |
|
|
|
|
| def get_lr_scheduler_ptd( |
| optimizer: OptimizerWrapper, scheduler_lambda_fn: Type[torch.optim.lr_scheduler.LRScheduler], last_epoch: int = -1 |
| ) -> SchedulerWrapper: |
| return SchedulerWrapper(optimizer.optimizers, scheduler_lambda_fn, last_epoch) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def get_constant_schedule() -> Callable[[int], float]: |
| r""" |
| Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| """ |
|
|
| def lr_lambda(current_step: int): |
| return 1.0 |
|
|
| return lr_lambda |
|
|
|
|
| def get_constant_schedule_with_warmup(num_warmup_steps: int) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate |
| increases linearly between 0 and the initial lr set in the optimizer. |
| |
| Args: |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| """ |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1.0, num_warmup_steps)) |
| return 1.0 |
|
|
| return lr_lambda |
|
|
|
|
| def get_piecewise_constant_schedule(step_rules: str) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| |
| Args: |
| step_rules (`string`): |
| The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate |
| if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 |
| steps and multiple 0.005 for the other steps. |
| """ |
|
|
| rules_dict = {} |
| rule_list = step_rules.split(",") |
| for rule_str in rule_list[:-1]: |
| value_str, steps_str = rule_str.split(":") |
| steps = int(steps_str) |
| value = float(value_str) |
| rules_dict[steps] = value |
| last_lr_multiple = float(rule_list[-1]) |
|
|
| def create_rules_function(rules_dict, last_lr_multiple): |
| def rule_func(steps: int) -> float: |
| sorted_steps = sorted(rules_dict.keys()) |
| for i, sorted_step in enumerate(sorted_steps): |
| if steps < sorted_step: |
| return rules_dict[sorted_steps[i]] |
| return last_lr_multiple |
|
|
| return rule_func |
|
|
| rules_func = create_rules_function(rules_dict, last_lr_multiple) |
| return rules_func |
|
|
|
|
| def get_linear_schedule_with_warmup(num_warmup_steps: int, num_training_steps: int) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
| a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
| |
| Args: |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| """ |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| return max( |
| 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) |
| ) |
|
|
| return lr_lambda |
|
|
|
|
| def get_cosine_schedule_with_warmup( |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: float = 0.5, |
| ) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| initial lr set in the optimizer. |
| |
| Args: |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| num_periods (`float`, *optional*, defaults to 0.5): |
| The number of periods of the cosine function in a schedule (the default is to just decrease from the max |
| value to 0 following a half-cosine). |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
|
|
| return lr_lambda |
|
|
|
|
| def get_cosine_with_hard_restarts_schedule_with_warmup( |
| num_warmup_steps: int, |
| num_training_steps: int, |
| num_cycles: int = 1, |
| ) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases |
| linearly between 0 and the initial lr set in the optimizer. |
| |
| Args: |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| num_cycles (`int`, *optional*, defaults to 1): |
| The number of hard restarts to use. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| if progress >= 1.0: |
| return 0.0 |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) |
|
|
| return lr_lambda |
|
|
|
|
| def get_polynomial_decay_schedule_with_warmup( |
| num_warmup_steps: int, |
| num_training_steps: int, |
| lr_init: float, |
| lr_end: float = 1e-7, |
| power: float = 1.0, |
| ) -> Callable[[int], float]: |
| r""" |
| Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the |
| optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the |
| initial lr set in the optimizer. |
| |
| Args: |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| lr_end (`float`, *optional*, defaults to 1e-7): |
| The end LR. |
| power (`float`, *optional*, defaults to 1.0): |
| Power factor. |
| |
| Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT implementation at |
| https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 |
| """ |
|
|
| if not (lr_init > lr_end): |
| raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})") |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| elif current_step > num_training_steps: |
| return lr_end / lr_init |
| else: |
| lr_range = lr_init - lr_end |
| decay_steps = num_training_steps - num_warmup_steps |
| pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps |
| decay = lr_range * pct_remaining**power + lr_end |
| return decay / lr_init |
|
|
| return lr_lambda |
|
|
|
|
| def _raise_errors_if_packages_not_available(name: str) -> None: |
| name_split = name.split("-") |
| if len(name_split) < 2: |
| return |
| package_name = name_split[1] |
| if package_name == "bnb": |
| if not is_bitsandbytes_available(): |
| raise ImportError( |
| f"Please install bitsandbytes by running `pip install bitsandbytes` to use the {name} optimizer." |
| ) |
|
|