|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Facilities for reporting and collecting training statistics across
|
| multiple processes and devices. The interface is designed to minimize
|
| synchronization overhead as well as the amount of boilerplate in user
|
| code."""
|
|
|
| import re
|
| import numpy as np
|
| import torch
|
| import dnnlib
|
|
|
| from . import misc
|
|
|
|
|
|
|
| _num_moments = 3
|
| _reduce_dtype = torch.float32
|
| _counter_dtype = torch.float64
|
| _rank = 0
|
| _sync_device = None
|
| _sync_called = False
|
| _counters = dict()
|
| _cumulative = dict()
|
|
|
|
|
|
|
| def init_multiprocessing(rank, sync_device):
|
| r"""Initializes `torch_utils.training_stats` for collecting statistics
|
| across multiple processes.
|
|
|
| This function must be called after
|
| `torch.distributed.init_process_group()` and before `Collector.update()`.
|
| The call is not necessary if multi-process collection is not needed.
|
|
|
| Args:
|
| rank: Rank of the current process.
|
| sync_device: PyTorch device to use for inter-process
|
| communication, or None to disable multi-process
|
| collection. Typically `torch.device('cuda', rank)`.
|
| """
|
| global _rank, _sync_device
|
| assert not _sync_called
|
| _rank = rank
|
| _sync_device = sync_device
|
|
|
|
|
|
|
| @misc.profiled_function
|
| def report(name, value):
|
| r"""Broadcasts the given set of scalars to all interested instances of
|
| `Collector`, across device and process boundaries.
|
|
|
| This function is expected to be extremely cheap and can be safely
|
| called from anywhere in the training loop, loss function, or inside a
|
| `torch.nn.Module`.
|
|
|
| Warning: The current implementation expects the set of unique names to
|
| be consistent across processes. Please make sure that `report()` is
|
| called at least once for each unique name by each process, and in the
|
| same order. If a given process has no scalars to broadcast, it can do
|
| `report(name, [])` (empty list).
|
|
|
| Args:
|
| name: Arbitrary string specifying the name of the statistic.
|
| Averages are accumulated separately for each unique name.
|
| value: Arbitrary set of scalars. Can be a list, tuple,
|
| NumPy array, PyTorch tensor, or Python scalar.
|
|
|
| Returns:
|
| The same `value` that was passed in.
|
| """
|
| if name not in _counters:
|
| _counters[name] = dict()
|
|
|
| elems = torch.as_tensor(value)
|
| if elems.numel() == 0:
|
| return value
|
|
|
| elems = elems.detach().flatten().to(_reduce_dtype)
|
| moments = torch.stack([
|
| torch.ones_like(elems).sum(),
|
| elems.sum(),
|
| elems.square().sum(),
|
| ])
|
| assert moments.ndim == 1 and moments.shape[0] == _num_moments
|
| moments = moments.to(_counter_dtype)
|
|
|
| device = moments.device
|
| if device not in _counters[name]:
|
| _counters[name][device] = torch.zeros_like(moments)
|
| _counters[name][device].add_(moments)
|
| return value
|
|
|
|
|
|
|
| def report0(name, value):
|
| r"""Broadcasts the given set of scalars by the first process (`rank = 0`),
|
| but ignores any scalars provided by the other processes.
|
| See `report()` for further details.
|
| """
|
| report(name, value if _rank == 0 else [])
|
| return value
|
|
|
|
|
|
|
| class Collector:
|
| r"""Collects the scalars broadcasted by `report()` and `report0()` and
|
| computes their long-term averages (mean and standard deviation) over
|
| user-defined periods of time.
|
|
|
| The averages are first collected into internal counters that are not
|
| directly visible to the user. They are then copied to the user-visible
|
| state as a result of calling `update()` and can then be queried using
|
| `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the
|
| internal counters for the next round, so that the user-visible state
|
| effectively reflects averages collected between the last two calls to
|
| `update()`.
|
|
|
| Args:
|
| regex: Regular expression defining which statistics to
|
| collect. The default is to collect everything.
|
| keep_previous: Whether to retain the previous averages if no
|
| scalars were collected on a given round
|
| (default: True).
|
| """
|
| def __init__(self, regex='.*', keep_previous=True):
|
| self._regex = re.compile(regex)
|
| self._keep_previous = keep_previous
|
| self._cumulative = dict()
|
| self._moments = dict()
|
| self.update()
|
| self._moments.clear()
|
|
|
| def names(self):
|
| r"""Returns the names of all statistics broadcasted so far that
|
| match the regular expression specified at construction time.
|
| """
|
| return [name for name in _counters if self._regex.fullmatch(name)]
|
|
|
| def update(self):
|
| r"""Copies current values of the internal counters to the
|
| user-visible state and resets them for the next round.
|
|
|
| If `keep_previous=True` was specified at construction time, the
|
| operation is skipped for statistics that have received no scalars
|
| since the last update, retaining their previous averages.
|
|
|
| This method performs a number of GPU-to-CPU transfers and one
|
| `torch.distributed.all_reduce()`. It is intended to be called
|
| periodically in the main training loop, typically once every
|
| N training steps.
|
| """
|
| if not self._keep_previous:
|
| self._moments.clear()
|
| for name, cumulative in _sync(self.names()):
|
| if name not in self._cumulative:
|
| self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| delta = cumulative - self._cumulative[name]
|
| self._cumulative[name].copy_(cumulative)
|
| if float(delta[0]) != 0:
|
| self._moments[name] = delta
|
|
|
| def _get_delta(self, name):
|
| r"""Returns the raw moments that were accumulated for the given
|
| statistic between the last two calls to `update()`, or zero if
|
| no scalars were collected.
|
| """
|
| assert self._regex.fullmatch(name)
|
| if name not in self._moments:
|
| self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| return self._moments[name]
|
|
|
| def num(self, name):
|
| r"""Returns the number of scalars that were accumulated for the given
|
| statistic between the last two calls to `update()`, or zero if
|
| no scalars were collected.
|
| """
|
| delta = self._get_delta(name)
|
| return int(delta[0])
|
|
|
| def mean(self, name):
|
| r"""Returns the mean of the scalars that were accumulated for the
|
| given statistic between the last two calls to `update()`, or NaN if
|
| no scalars were collected.
|
| """
|
| delta = self._get_delta(name)
|
| if int(delta[0]) == 0:
|
| return float('nan')
|
| return float(delta[1] / delta[0])
|
|
|
| def std(self, name):
|
| r"""Returns the standard deviation of the scalars that were
|
| accumulated for the given statistic between the last two calls to
|
| `update()`, or NaN if no scalars were collected.
|
| """
|
| delta = self._get_delta(name)
|
| if int(delta[0]) == 0 or not np.isfinite(float(delta[1])):
|
| return float('nan')
|
| if int(delta[0]) == 1:
|
| return float(0)
|
| mean = float(delta[1] / delta[0])
|
| raw_var = float(delta[2] / delta[0])
|
| return np.sqrt(max(raw_var - np.square(mean), 0))
|
|
|
| def as_dict(self):
|
| r"""Returns the averages accumulated between the last two calls to
|
| `update()` as an `dnnlib.EasyDict`. The contents are as follows:
|
|
|
| dnnlib.EasyDict(
|
| NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT),
|
| ...
|
| )
|
| """
|
| stats = dnnlib.EasyDict()
|
| for name in self.names():
|
| stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name))
|
| return stats
|
|
|
| def __getitem__(self, name):
|
| r"""Convenience getter.
|
| `collector[name]` is a synonym for `collector.mean(name)`.
|
| """
|
| return self.mean(name)
|
|
|
|
|
|
|
| def _sync(names):
|
| r"""Synchronize the global cumulative counters across devices and
|
| processes. Called internally by `Collector.update()`.
|
| """
|
| if len(names) == 0:
|
| return []
|
| global _sync_called
|
| _sync_called = True
|
|
|
|
|
| deltas = []
|
| device = _sync_device if _sync_device is not None else torch.device('cpu')
|
| for name in names:
|
| delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device)
|
| for counter in _counters[name].values():
|
| delta.add_(counter.to(device))
|
| counter.copy_(torch.zeros_like(counter))
|
| deltas.append(delta)
|
| deltas = torch.stack(deltas)
|
|
|
|
|
| if _sync_device is not None:
|
| torch.distributed.all_reduce(deltas)
|
|
|
|
|
| deltas = deltas.cpu()
|
| for idx, name in enumerate(names):
|
| if name not in _cumulative:
|
| _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype)
|
| _cumulative[name].add_(deltas[idx])
|
|
|
|
|
| return [(name, _cumulative[name]) for name in names]
|
|
|
|
|
|
|