| """ |
| Gaussian Diffusion (DDPM) framework for PDE next-frame prediction. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class GaussianDiffusion(nn.Module): |
| """DDPM with linear beta schedule. |
| |
| Training: given (condition, target), add noise to target, predict noise. |
| Sampling: iteratively denoise starting from Gaussian noise. |
| |
| Args: |
| model: U-Net (or any eps-predicting network). |
| timesteps: number of diffusion steps. |
| beta_start: starting noise level. |
| beta_end: ending noise level. |
| """ |
|
|
| def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02): |
| super().__init__() |
| self.model = model |
| self.T = timesteps |
|
|
| |
| betas = torch.linspace(beta_start, beta_end, timesteps) |
| alphas = 1.0 - betas |
| alpha_bar = torch.cumprod(alphas, dim=0) |
|
|
| self.register_buffer("betas", betas) |
| self.register_buffer("alphas", alphas) |
| self.register_buffer("alpha_bar", alpha_bar) |
| self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar)) |
| self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar)) |
| self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas)) |
| self.register_buffer( |
| "posterior_variance", |
| betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar), |
| ) |
|
|
| def q_sample(self, x0, t, noise=None): |
| """Forward process: add noise to x0 at timestep t.""" |
| if noise is None: |
| noise = torch.randn_like(x0) |
| a = self.sqrt_alpha_bar[t][:, None, None, None] |
| b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None] |
| return a * x0 + b * noise, noise |
|
|
| def training_loss(self, x_target, x_cond): |
| """Compute training loss (predict noise). |
| |
| Args: |
| x_target: clean target frames [B, C, H, W]. |
| x_cond: condition frames [B, C, H, W]. |
| |
| Returns: |
| scalar MSE loss. |
| """ |
| B = x_target.shape[0] |
| t = torch.randint(0, self.T, (B,), device=x_target.device) |
| noise = torch.randn_like(x_target) |
| x_noisy, _ = self.q_sample(x_target, t, noise) |
|
|
| eps_pred = self.model(x_noisy, t, cond=x_cond) |
| return F.mse_loss(eps_pred, noise) |
|
|
| @torch.no_grad() |
| def sample(self, x_cond, shape=None): |
| """Generate target frames by iterative denoising (DDPM). |
| |
| Args: |
| x_cond: condition frames [B, C_cond, H, W]. |
| shape: (B, C_out, H, W) of the target. Inferred if None. |
| |
| Returns: |
| denoised sample [B, C_out, H, W]. |
| """ |
| device = x_cond.device |
| if shape is None: |
| shape = x_cond.shape |
|
|
| x = torch.randn(shape, device=device) |
|
|
| for i in reversed(range(self.T)): |
| t = torch.full((shape[0],), i, device=device, dtype=torch.long) |
| eps = self.model(x, t, cond=x_cond) |
|
|
| alpha = self.alphas[i] |
| alpha_bar = self.alpha_bar[i] |
| beta = self.betas[i] |
|
|
| mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps) |
|
|
| if i > 0: |
| sigma = self.posterior_variance[i].sqrt() |
| x = mean + sigma * torch.randn_like(x) |
| else: |
| x = mean |
|
|
| return x |
|
|
| @torch.no_grad() |
| def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0): |
| """DDIM accelerated sampling. |
| |
| Args: |
| x_cond: condition [B, C_cond, H, W]. |
| shape: target shape. |
| steps: number of DDIM steps (<<T for speed). |
| eta: stochasticity (0=deterministic DDIM, 1=DDPM). |
| |
| Returns: |
| denoised sample [B, C_out, H, W]. |
| """ |
| device = x_cond.device |
| if shape is None: |
| shape = x_cond.shape |
|
|
| |
| step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device) |
| step_indices = step_indices.flip(0) |
|
|
| x = torch.randn(shape, device=device) |
|
|
| for idx in range(len(step_indices) - 1): |
| t_cur = step_indices[idx] |
| t_next = step_indices[idx + 1] |
|
|
| t_batch = t_cur.expand(shape[0]) |
| eps = self.model(x, t_batch, cond=x_cond) |
|
|
| ab_cur = self.alpha_bar[t_cur] |
| ab_next = self.alpha_bar[t_next] |
|
|
| |
| x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt() |
| x0_pred = x0_pred.clamp(-5, 5) |
|
|
| |
| sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt() |
| dir_xt = (1 - ab_next - sigma**2).sqrt() * eps |
|
|
| x = ab_next.sqrt() * x0_pred + dir_xt |
| if sigma > 0: |
| x = x + sigma * torch.randn_like(x) |
|
|
| return x |
|
|