| | from abc import ABC, abstractmethod |
| |
|
| | import numpy as np |
| | import torch as th |
| | import torch.distributed as dist |
| |
|
| |
|
| | def create_named_schedule_sampler(name, diffusion): |
| | """ |
| | Create a ScheduleSampler from a library of pre-defined samplers. |
| | |
| | :param name: the name of the sampler. |
| | :param diffusion: the diffusion object to sample for. |
| | """ |
| | if name == "uniform": |
| | return UniformSampler(diffusion) |
| | else: |
| | raise NotImplementedError(f"unknown schedule sampler: {name}") |
| |
|
| |
|
| | class ScheduleSampler(ABC): |
| | """ |
| | A distribution over timesteps in the diffusion process, intended to reduce |
| | variance of the objective. |
| | |
| | By default, samplers perform unbiased importance sampling, in which the |
| | objective's mean is unchanged. |
| | However, subclasses may override sample() to change how the resampled |
| | terms are reweighted, allowing for actual changes in the objective. |
| | """ |
| | @abstractmethod |
| | def weights(self): |
| | """ |
| | Get a numpy array of weights, one per diffusion step. |
| | |
| | The weights needn't be normalized, but must be positive. |
| | """ |
| |
|
| | def sample(self, batch_size, device): |
| | """ |
| | Importance-sample timesteps for a batch. |
| | |
| | :param batch_size: the number of timesteps. |
| | :param device: the torch device to save to. |
| | :return: a tuple (timesteps, weights): |
| | - timesteps: a tensor of timestep indices. |
| | - weights: a tensor of weights to scale the resulting losses. |
| | """ |
| | w = self.weights() |
| | p = w / np.sum(w) |
| | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) |
| | indices = th.from_numpy(indices_np).long().to(device) |
| | weights_np = 1 / (len(p) * p[indices_np]) |
| | weights = th.from_numpy(weights_np).float().to(device) |
| | return indices, weights |
| |
|
| |
|
| | class UniformSampler(ScheduleSampler): |
| | def __init__(self, num_timesteps): |
| | self._weights = np.ones([num_timesteps]) |
| |
|
| | def weights(self): |
| | return self._weights |
| |
|