| | |
| | |
| | |
| | |
| | |
| | from functools import partial |
| |
|
| | import torch |
| | from omegaconf import DictConfig |
| |
|
| |
|
| | def get_optimizer(optimizer_name, parameters, lr=None, config: DictConfig = None): |
| | param_dict = {} |
| | if optimizer_name == "adam": |
| | optimizer = partial(torch.optim.Adam, params=parameters) |
| | if lr is not None: |
| | optimizer = partial(torch.optim.Adam, params=parameters, lr=lr) |
| | if config.get('betas'): |
| | param_dict['betas'] = config.betas |
| | if config.get('weight_decay'): |
| | param_dict['weight_decay'] = config.weight_decay |
| | if config.get('eps'): |
| | param_dict['eps'] = config.eps |
| | elif optimizer_name == "adamw": |
| | optimizer = partial(torch.optim.AdamW, params=parameters) |
| | if lr is not None: |
| | optimizer = partial(torch.optim.AdamW, params=parameters, lr=lr) |
| | if config.get('betas'): |
| | param_dict['betas'] = config.betas |
| | if config.get('weight_decay'): |
| | param_dict['weight_decay'] = config.weight_decay |
| | if config.get('eps'): |
| | param_dict['eps'] = config.eps |
| | elif optimizer_name == "radam": |
| | optimizer = partial(torch.optim.RAdam, params=parameters) |
| | if lr is not None: |
| | optimizer = partial(torch.optim.RAdam, params=parameters, lr=lr) |
| | if config.get('betas'): |
| | param_dict['betas'] = config.betas |
| | if config.get('weight_decay'): |
| | param_dict['weight_decay'] = config.weight_decay |
| | elif optimizer_name == "sgd": |
| | optimizer = partial(torch.optim.SGD, params=parameters) |
| | if lr is not None: |
| | optimizer = partial(torch.optim.SGD, params=parameters, lr=lr) |
| | if config.get('momentum'): |
| | param_dict['momentum'] = config.momentum |
| | if config.get('weight_decay'): |
| | param_dict['weight_decay'] = config.weight_decay |
| | if config.get('nesterov'): |
| | param_dict['nesterov'] = config.nesterov |
| | else: |
| | raise NotImplementedError(f"Optimizer {optimizer_name} not implemented.") |
| |
|
| | if len(param_dict.keys()) > 0: |
| | return optimizer(**param_dict) |
| | else: |
| | return optimizer() |
| |
|