| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Exponential Moving Average (EMA) of model updates.""" |
|
|
| import copy |
| import torch |
|
|
|
|
| class ModelEMA(torch.nn.Module): |
| """Model Exponential Moving Average.""" |
|
|
| def __init__(self, model, decay=0.99, update_every=100, device="gpu"): |
| super().__init__() |
| self.decay = decay |
| self.update_every = update_every |
| self.model = copy.deepcopy(model).eval() |
| self.model._apply(lambda t: t.float() if t.requires_grad else t) if decay < 1 else None |
| [setattr(p, "requires_grad", False) for p in self.model.parameters()] |
| self.model.cpu() if device == "cpu" else None |
|
|
| def forward(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| @torch.no_grad() |
| def update(self, model): |
| for ema_v, model_v in zip(self.model.parameters(), model.parameters()): |
| if not model_v.requires_grad: |
| continue |
| new_value = model_v.data.float() |
| value = ema_v.to(device=new_value.device) |
| ema_v.copy_(value.mul_(self.decay).add_(new_value, alpha=1 - self.decay)) |
|
|