| | import torch |
| |
|
| | class EMA(torch.nn.Module): |
| | def __init__(self, model: torch.nn.Module, decay: float = 0.999): |
| | super().__init__() |
| | self.model = model |
| | self.decay = decay |
| | if hasattr(self.model, "time_geopath"): |
| | self.time_geopath = self.model.time_geopath |
| |
|
| | |
| | self.register_buffer("num_updates", torch.tensor(0)) |
| |
|
| | self.shadow_params = torch.nn.ParameterList( |
| | [ |
| | torch.nn.Parameter(p.clone().detach(), requires_grad=False) |
| | for p in model.parameters() |
| | if p.requires_grad |
| | ] |
| | ) |
| | self.backup_params = [] |
| |
|
| | def train(self, mode: bool): |
| | if self.training and mode == False: |
| | |
| | |
| | self.backup() |
| | self.copy_to_model() |
| | elif not self.training and mode == True: |
| | |
| | self.restore_to_model() |
| |
|
| | super().train(mode) |
| |
|
| | def update_ema(self): |
| | self.num_updates += 1 |
| | num_updates = self.num_updates.item() |
| | decay = min(self.decay, (1 + num_updates) / (10 + num_updates)) |
| | with torch.no_grad(): |
| | params = [p for p in self.model.parameters() if p.requires_grad] |
| | for shadow, param in zip(self.shadow_params, params): |
| | shadow.sub_((1 - decay) * (shadow - param)) |
| |
|
| | def forward(self, *args, **kwargs): |
| | return self.model(*args, **kwargs) |
| |
|
| | def copy_to_model(self): |
| | |
| | params = [p for p in self.model.parameters() if p.requires_grad] |
| | for shaddow, param in zip(self.shadow_params, params): |
| | param.data.copy_(shaddow.data) |
| |
|
| | def backup(self): |
| | |
| | if len(self.backup_params) > 0: |
| | for p, b in zip(self.model.parameters(), self.backup_params): |
| | b.data.copy_(p.data) |
| | else: |
| | self.backup_params = [param.clone() for param in self.model.parameters()] |
| |
|
| | def restore_to_model(self): |
| | |
| | for param, backup in zip(self.model.parameters(), self.backup_params): |
| | param.data.copy_(backup.data) |