| |
|
|
| import torch |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
|
| class EMAModel: |
| """ |
| Exponential Moving Average of models weights |
| """ |
| def __init__( |
| self, |
| model, |
| update_after_step=0, |
| inv_gamma=1.0, |
| power=2 / 3, |
| min_value=0.0, |
| max_value=0.9999 |
| ): |
| """ |
| @crowsonkb's notes on EMA Warmup: |
| If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan |
| to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), |
| gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 |
| at 215.4k steps). |
| Args: |
| inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. |
| power (float): Exponential factor of EMA warmup. Default: 2/3. |
| min_value (float): The minimum EMA decay rate. Default: 0. |
| """ |
|
|
| self.averaged_model = model |
| self.averaged_model.eval() |
| self.averaged_model.requires_grad_(False) |
|
|
| self.update_after_step = update_after_step |
| self.inv_gamma = inv_gamma |
| self.power = power |
| self.min_value = min_value |
| self.max_value = max_value |
|
|
| self.decay = 0.0 |
| self.optimization_step = 0 |
|
|
| def get_decay(self, optimization_step): |
| """ |
| Compute the decay factor for the exponential moving average. |
| """ |
| step = max(0, optimization_step - self.update_after_step - 1) |
| value = 1 - (1 + step / self.inv_gamma) ** -self.power |
|
|
| if step <= 0: |
| return 0.0 |
|
|
| return max(self.min_value, min(value, self.max_value)) |
|
|
| @torch.no_grad() |
| def step(self, new_model): |
| self.decay = self.get_decay(self.optimization_step) |
|
|
| |
| |
| |
| |
| |
|
|
| all_dataptrs = set() |
| for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): |
| for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): |
| |
| if isinstance(param, dict): |
| raise RuntimeError('Dict parameter not supported') |
| |
| |
| |
| |
|
|
| if isinstance(module, _BatchNorm): |
| |
| ema_param.copy_(param.to(dtype=ema_param.dtype).data) |
| elif not param.requires_grad: |
| ema_param.copy_(param.to(dtype=ema_param.dtype).data) |
| else: |
| ema_param.mul_(self.decay) |
| ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) |
|
|
| |
| |
| self.optimization_step += 1 |
|
|