| | import torch |
| | import torch.nn.functional as F |
| | import math |
| |
|
| |
|
| | """ |
| | This scheduler has 3 main responsibilities: |
| | |
| | 1. Setup (init) - Pre-compute noise schedule |
| | 2. Training (q_sample) - Add noise to images |
| | 3. Generation (p_sample_text + sample_text) - Remove noise |
| | step-by-step |
| | |
| | """ |
| |
|
| |
|
| | class SimpleDDPMScheduler: |
| | def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02): |
| | self.num_timesteps = num_timesteps |
| |
|
| | |
| | self.betas = torch.linspace(beta_start, beta_end, num_timesteps) |
| | self.alphas = 1.0 - self.betas |
| | self.alphas_cumprod = torch.cumprod( |
| | self.alphas, dim=0 |
| | ) |
| | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) |
| |
|
| | |
| | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) |
| | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) |
| |
|
| | |
| | |
| | |
| | self.posterior_variance = ( |
| | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) |
| | ) |
| |
|
| | def q_sample(self, x_start, t, noise=None): |
| | """Add noise to the clean images according to the noise schedule. |
| | |
| | So we can have examples at any timestep in the forward process.""" |
| | |
| | if noise is None: |
| | noise = torch.randn_like(x_start) |
| |
|
| | sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape) |
| | sqrt_one_minus_alphas_cumprod_t = extract( |
| | self.sqrt_one_minus_alphas_cumprod, t, x_start.shape |
| | ) |
| |
|
| | return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise |
| |
|
| | def p_sample_text(self, model, x, t, text_embeddings, guidance_scale=1.0): |
| | """Sample x_{t-1} from x_t using the model with text conditioning and CFG. |
| | |
| | Args: |
| | model: The diffusion model |
| | x: Current noisy image |
| | t: Current timestep |
| | text_embeddings: Text embeddings for conditioning |
| | guidance_scale: Classifier-free guidance scale (1.0 = no guidance, higher = stronger) |
| | """ |
| | |
| | predicted_noise = model(x, t, text_embeddings) |
| |
|
| | |
| | if guidance_scale > 1.0: |
| | |
| | uncond_embeddings = torch.zeros_like(text_embeddings) |
| | uncond_noise = model(x, t, uncond_embeddings) |
| |
|
| | |
| | predicted_noise = uncond_noise + guidance_scale * (predicted_noise - uncond_noise) |
| |
|
| | |
| | betas_t = extract(self.betas, t, x.shape) |
| | sqrt_one_minus_alphas_cumprod_t = extract( |
| | self.sqrt_one_minus_alphas_cumprod, t, x.shape |
| | ) |
| | sqrt_recip_alphas_t = extract(1.0 / torch.sqrt(self.alphas), t, x.shape) |
| |
|
| | |
| | model_mean = sqrt_recip_alphas_t * ( |
| | x - betas_t * predicted_noise / sqrt_one_minus_alphas_cumprod_t |
| | ) |
| |
|
| | if t[0] == 0: |
| | return model_mean |
| | else: |
| | posterior_variance_t = extract(self.posterior_variance, t, x.shape) |
| | noise = torch.randn_like(x) |
| | return model_mean + torch.sqrt(posterior_variance_t) * noise |
| |
|
| | def sample_text(self, model, shape, text_embeddings, device="cuda", guidance_scale=1.0): |
| | """Generate samples using DDPM sampling with text conditioning and CFG. |
| | |
| | Args: |
| | model: The diffusion model |
| | shape: Output shape (B, C, H, W) |
| | text_embeddings: Text embeddings for conditioning |
| | device: Device to use |
| | guidance_scale: Classifier-free guidance scale (1.0 = no guidance, 3.0-7.0 typical) |
| | """ |
| | b = shape[0] |
| | img = torch.randn(shape, device=device) |
| |
|
| | for i in reversed(range(0, self.num_timesteps)): |
| | t = torch.full((b,), i, device=device, dtype=torch.long) |
| | img = self.p_sample_text(model, img, t, text_embeddings, guidance_scale) |
| |
|
| | |
| | img = torch.clamp(img, -2.0, 2.0) |
| |
|
| | return img |
| |
|
| |
|
| | def extract(a, t, x_shape): |
| | """Extract coefficients from a based on t and reshape to match x_shape.""" |
| | batch_size = t.shape[0] |
| | out = a.gather(-1, t.cpu()) |
| | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device) |
| |
|