| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Paddle optimization for diffusion models.""" |
|
|
| import math |
| from enum import Enum |
| from typing import Optional, Union |
|
|
| from paddle.optimizer.lr import LambdaDecay |
|
|
| from .utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class SchedulerType(Enum): |
| LINEAR = "linear" |
| COSINE = "cosine" |
| COSINE_WITH_RESTARTS = "cosine_with_restarts" |
| POLYNOMIAL = "polynomial" |
| CONSTANT = "constant" |
| CONSTANT_WITH_WARMUP = "constant_with_warmup" |
|
|
|
|
| def get_constant_schedule(learning_rate: float, last_epoch: int = -1): |
| """ |
| Create a schedule with a constant learning rate, using the learning rate set in optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| """ |
| return LambdaDecay(learning_rate, lambda _: 1, last_epoch=last_epoch) |
|
|
|
|
| def get_constant_schedule_with_warmup(learning_rate: float, num_warmup_steps: int, last_epoch: int = -1): |
| """ |
| Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate |
| increases linearly between 0 and the initial lr set in the optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1.0, num_warmup_steps)) |
| return 1.0 |
|
|
| return LambdaDecay(learning_rate, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
| def get_linear_schedule_with_warmup( |
| learning_rate: float, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1 |
| ): |
| """ |
| Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
| a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| return max( |
| 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) |
| ) |
|
|
| return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
|
|
|
|
| def get_cosine_schedule_with_warmup( |
| learning_rate: float, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1 |
| ): |
| """ |
| Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the |
| initial lr set in the optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| num_cycles (`float`, *optional*, defaults to 0.5): |
| The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0 |
| following a half-cosine). |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) |
|
|
| return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
|
|
|
|
| def get_cosine_with_hard_restarts_schedule_with_warmup( |
| learning_rate: float, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1 |
| ): |
| """ |
| Create a schedule with a learning rate that decreases following the values of the cosine function between the |
| initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases |
| linearly between 0 and the initial lr set in the optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| num_cycles (`int`, *optional*, defaults to 1): |
| The number of hard restarts to use. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| """ |
|
|
| def lr_lambda(current_step): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
| if progress >= 1.0: |
| return 0.0 |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0)))) |
|
|
| return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
|
|
|
|
| def get_polynomial_decay_schedule_with_warmup( |
| learning_rate: float, |
| num_warmup_steps: int, |
| num_training_steps: int, |
| lr_end: float = 1e-7, |
| power: float = 1.0, |
| last_epoch: int = -1, |
| ): |
| """ |
| Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the |
| optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the |
| initial lr set in the optimizer. |
| |
| Args: |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`): |
| The number of steps for the warmup phase. |
| num_training_steps (`int`): |
| The total number of training steps. |
| lr_end (`float`, *optional*, defaults to 1e-7): |
| The end LR. |
| power (`float`, *optional*, defaults to 1.0): |
| Power factor. |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| |
| Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT |
| implementation at |
| https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37 |
| |
| Return: |
| `paddle.optimizer.lr.LambdaDecay` with the appropriate schedule. |
| |
| """ |
|
|
| lr_init = learning_rate |
| if not (lr_init > lr_end): |
| raise ValueError(f"lr_end ({lr_end}) must be be smaller than initial lr ({lr_init})") |
|
|
| def lr_lambda(current_step: int): |
| if current_step < num_warmup_steps: |
| return float(current_step) / float(max(1, num_warmup_steps)) |
| elif current_step > num_training_steps: |
| return lr_end / lr_init |
| else: |
| lr_range = lr_init - lr_end |
| decay_steps = num_training_steps - num_warmup_steps |
| pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps |
| decay = lr_range * pct_remaining**power + lr_end |
| return decay / lr_init |
|
|
| return LambdaDecay(learning_rate, lr_lambda, last_epoch) |
|
|
|
|
| TYPE_TO_SCHEDULER_FUNCTION = { |
| SchedulerType.LINEAR: get_linear_schedule_with_warmup, |
| SchedulerType.COSINE: get_cosine_schedule_with_warmup, |
| SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup, |
| SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, |
| SchedulerType.CONSTANT: get_constant_schedule, |
| SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, |
| } |
|
|
|
|
| def get_scheduler( |
| name: Union[str, SchedulerType], |
| learning_rate: float = 0.1, |
| num_warmup_steps: Optional[int] = None, |
| num_training_steps: Optional[int] = None, |
| num_cycles: int = 1, |
| power: float = 1.0, |
| last_epoch: int = -1, |
| ): |
| """ |
| Unified API to get any scheduler from its name. |
| |
| Args: |
| name (`str` or `SchedulerType`): |
| The name of the scheduler to use. |
| learning_rate (`float`): |
| The base learning rate. It is a python float number. |
| num_warmup_steps (`int`, *optional*): |
| The number of warmup steps to do. This is not required by all schedulers (hence the argument being |
| optional), the function will raise an error if it's unset and the scheduler type requires it. |
| num_training_steps (`int``, *optional*): |
| The number of training steps to do. This is not required by all schedulers (hence the argument being |
| optional), the function will raise an error if it's unset and the scheduler type requires it. |
| num_cycles (`int`, *optional*): |
| The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler. |
| power (`float`, *optional*, defaults to 1.0): |
| Power factor. See `POLYNOMIAL` scheduler |
| last_epoch (`int`, *optional*, defaults to -1): |
| The index of the last epoch when resuming training. |
| """ |
| name = SchedulerType(name) |
| schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] |
| if name == SchedulerType.CONSTANT: |
| return schedule_func(learning_rate=learning_rate, last_epoch=last_epoch) |
|
|
| |
| if num_warmup_steps is None: |
| raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") |
|
|
| if name == SchedulerType.CONSTANT_WITH_WARMUP: |
| return schedule_func(learning_rate=learning_rate, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch) |
|
|
| |
| if num_training_steps is None: |
| raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") |
|
|
| if name == SchedulerType.COSINE_WITH_RESTARTS: |
| return schedule_func( |
| learning_rate=learning_rate, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| num_cycles=num_cycles, |
| last_epoch=last_epoch, |
| ) |
|
|
| if name == SchedulerType.POLYNOMIAL: |
| return schedule_func( |
| learning_rate=learning_rate, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| power=power, |
| last_epoch=last_epoch, |
| ) |
|
|
| return schedule_func( |
| learning_rate=learning_rate, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps, |
| last_epoch=last_epoch, |
| ) |
|
|