| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Simple implementation of continuous flow matching schedulers.""" |
|
|
| import dataclasses |
| import math |
|
|
| import numpy as np |
| import torch |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.models.modeling_outputs import BaseOutput |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin |
|
|
|
|
| @dataclasses.dataclass |
| class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): |
| """Output for scheduler's `step` function output.""" |
|
|
| prev_sample: torch.FloatTensor |
|
|
|
|
| class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): |
|
|
| order = 1 |
|
|
| @register_to_config |
| def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False): |
| timesteps = np.arange(1, num_train_timesteps + 1, dtype="float32")[::-1] |
| sigmas, self._shift = timesteps / num_train_timesteps, shift |
| if not use_dynamic_shifting: |
| sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) |
| self.timesteps = torch.as_tensor(sigmas * num_train_timesteps) |
| self.sigmas = torch.as_tensor(sigmas) |
| self.sigma_min, self.sigma_max = float(sigmas[-1]), float(sigmas[0]) |
| self.timestep = self.sigma = None |
| self._begin_index = self._step_index = None |
|
|
| @property |
| def shift(self): |
| """The value used for shifting.""" |
| return self._shift |
|
|
| @property |
| def step_index(self): |
| """The index counter for current timestep.""" |
| return self._step_index |
|
|
| @property |
| def begin_index(self): |
| """The index for the first timestep.""" |
| return self._begin_index |
|
|
| def _sigma_to_t(self, sigma): |
| return sigma * self.config.num_train_timesteps |
|
|
| def _init_step_index(self, timestep): |
| if self.begin_index is None: |
| self._step_index = self.index_for_timestep(timestep) |
| else: |
| self._step_index = self._begin_index |
|
|
| def time_shift(self, mu: float, sigma: float, t: torch.Tensor): |
| return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
|
|
| def set_shift(self, shift: float): |
| self._shift = shift |
|
|
| def index_for_timestep(self, timestep, schedule_timesteps=None): |
| if schedule_timesteps is None: |
| schedule_timesteps = self.timesteps |
| indices = (schedule_timesteps == timestep).nonzero() |
| return indices[1 if len(indices) > 1 else 0].item() |
|
|
| def sample_timesteps(self, size, device=None): |
| """Sample the discrete timesteps used for training.""" |
| dist = torch.normal(0, 1, size, device=device).sigmoid_() |
| return dist.mul_(self.config.num_train_timesteps).to(dtype=torch.int64) |
|
|
| def set_timesteps(self, num_inference_steps, mu=None): |
| """Sets the discrete timesteps used for the diffusion chain.""" |
| self.num_inference_steps = num_inference_steps |
| t_max, t_min = self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min) |
| timesteps = np.linspace(t_max, t_min, num_inference_steps, dtype="float32") |
| sigmas = timesteps / self.config.num_train_timesteps |
| if self.config.use_dynamic_shifting: |
| sigmas = self.time_shift(mu, 1.0, sigmas) |
| else: |
| sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) |
| self.sigmas = sigmas.tolist() + [0] |
| self.timesteps = sigmas * self.config.num_train_timesteps |
| self._begin_index = self._step_index = None |
|
|
| def add_noise( |
| self, |
| original_samples: torch.Tensor, |
| noise: torch.Tensor, |
| timesteps: torch.Tensor, |
| ): |
| """Add forward noise to samples for training.""" |
| dtype, device = original_samples.dtype, original_samples.device |
| self.timestep = self.timesteps.to(device=device)[timesteps] |
| self.sigma = self.sigmas.to(device=device, dtype=dtype)[timesteps] |
| self.sigma = self.sigma.view(timesteps.shape + (1,) * (noise.dim() - timesteps.dim())) |
| return self.sigma * noise + (1.0 - self.sigma) * original_samples |
|
|
| def scale_noise(self, sample: torch.Tensor, timestep: float, noise: torch.Tensor): |
| """Add forward noise to samples for inference.""" |
| self._init_step_index(timestep) if self.step_index is None else None |
| sigma = self.sigmas[self.step_index] |
| return sigma * noise + (1.0 - sigma) * sample |
|
|
| def step( |
| self, |
| model_output: torch.Tensor, |
| timestep: float, |
| sample: torch.FloatTensor, |
| generator: torch.Generator = None, |
| return_dict=True, |
| ): |
| """Predict the sample from the previous timestep.""" |
| self._init_step_index(timestep) if self.step_index is None else None |
| dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index] |
| prev_sample = model_output.mul(dt).add_(sample) |
| self._step_index += 1 |
| if not return_dict: |
| return (prev_sample,) |
| return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) |
|
|