World_Model / URSA /diffnext /schedulers /scheduling_cfm.py
BryanW's picture
Add files using upload-large-folder tool
b6ff324 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Simple implementation of continuous flow matching schedulers."""
import dataclasses
import math
import numpy as np
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import BaseOutput
from diffusers.schedulers.scheduling_utils import SchedulerMixin
@dataclasses.dataclass
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
"""Output for scheduler's `step` function output."""
prev_sample: torch.FloatTensor
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
order = 1
@register_to_config
def __init__(self, num_train_timesteps=1000, shift=1.0, use_dynamic_shifting=False):
timesteps = np.arange(1, num_train_timesteps + 1, dtype="float32")[::-1]
sigmas, self._shift = timesteps / num_train_timesteps, shift
if not use_dynamic_shifting:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.timesteps = torch.as_tensor(sigmas * num_train_timesteps)
self.sigmas = torch.as_tensor(sigmas)
self.sigma_min, self.sigma_max = float(sigmas[-1]), float(sigmas[0])
self.timestep = self.sigma = None # Training states.
self._begin_index = self._step_index = None # Inference counters.
@property
def shift(self):
"""The value used for shifting."""
return self._shift
@property
def step_index(self):
"""The index counter for current timestep."""
return self._step_index
@property
def begin_index(self):
"""The index for the first timestep."""
return self._begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def _init_step_index(self, timestep):
if self.begin_index is None:
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def set_shift(self, shift: float):
self._shift = shift
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
return indices[1 if len(indices) > 1 else 0].item()
def sample_timesteps(self, size, device=None):
"""Sample the discrete timesteps used for training."""
dist = torch.normal(0, 1, size, device=device).sigmoid_()
return dist.mul_(self.config.num_train_timesteps).to(dtype=torch.int64)
def set_timesteps(self, num_inference_steps, mu=None):
"""Sets the discrete timesteps used for the diffusion chain."""
self.num_inference_steps = num_inference_steps
t_max, t_min = self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min)
timesteps = np.linspace(t_max, t_min, num_inference_steps, dtype="float32")
sigmas = timesteps / self.config.num_train_timesteps
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas.tolist() + [0]
self.timesteps = sigmas * self.config.num_train_timesteps
self._begin_index = self._step_index = None
def add_noise(
self,
original_samples: torch.Tensor,
noise: torch.Tensor,
timesteps: torch.Tensor,
):
"""Add forward noise to samples for training."""
dtype, device = original_samples.dtype, original_samples.device
self.timestep = self.timesteps.to(device=device)[timesteps]
self.sigma = self.sigmas.to(device=device, dtype=dtype)[timesteps]
self.sigma = self.sigma.view(timesteps.shape + (1,) * (noise.dim() - timesteps.dim()))
return self.sigma * noise + (1.0 - self.sigma) * original_samples
def scale_noise(self, sample: torch.Tensor, timestep: float, noise: torch.Tensor):
"""Add forward noise to samples for inference."""
self._init_step_index(timestep) if self.step_index is None else None
sigma = self.sigmas[self.step_index]
return sigma * noise + (1.0 - sigma) * sample
def step(
self,
model_output: torch.Tensor,
timestep: float,
sample: torch.FloatTensor,
generator: torch.Generator = None,
return_dict=True,
):
"""Predict the sample from the previous timestep."""
self._init_step_index(timestep) if self.step_index is None else None
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
prev_sample = model_output.mul(dt).add_(sample)
self._step_index += 1
if not return_dict:
return (prev_sample,)
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)