| # ------------------------------------------------------------------------ | |
| # Copyright (c) 2024-present, BAAI. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ------------------------------------------------------------------------ | |
| """Exponential Moving Average (EMA) of model updates.""" | |
| import copy | |
| import torch | |
| class ModelEMA(torch.nn.Module): | |
| """Model Exponential Moving Average.""" | |
| def __init__(self, model, decay=0.99, update_every=100, device="gpu"): | |
| super().__init__() | |
| self.decay = decay | |
| self.update_every = update_every | |
| self.model = copy.deepcopy(model).eval() | |
| self.model._apply(lambda t: t.float() if t.requires_grad else t) if decay < 1 else None | |
| [setattr(p, "requires_grad", False) for p in self.model.parameters()] | |
| self.model.cpu() if device == "cpu" else None | |
| def forward(self, *args, **kwargs): | |
| return self.model(*args, **kwargs) | |
| def update(self, model): | |
| for ema_v, model_v in zip(self.model.parameters(), model.parameters()): | |
| if not model_v.requires_grad: | |
| continue | |
| new_value = model_v.data.float() | |
| value = ema_v.to(device=new_value.device) | |
| ema_v.copy_(value.mul_(self.decay).add_(new_value, alpha=1 - self.decay)) | |