| | """Visualize the diffusion generation process - capture images at each timestep.""" |
| | import torch |
| | import argparse |
| | import os |
| | import matplotlib.pyplot as plt |
| |
|
| | import config |
| | from model import TextConditionedUNet |
| | from scheduler import SimpleDDPMScheduler |
| | from text_encoder import CLIPTextEncoder |
| | from generate import tensor_to_image |
| |
|
| |
|
| | def sample_with_snapshots(scheduler, model, shape, text_embeddings, device='cuda', |
| | guidance_scale=1.0, snapshot_steps=None): |
| | """Modified sampling that captures snapshots at specific timesteps.""" |
| | b = shape[0] |
| | img = torch.randn(shape, device=device) |
| |
|
| | |
| | if snapshot_steps is None: |
| | interval = scheduler.num_timesteps // 10 |
| | snapshot_steps = list(range(scheduler.num_timesteps - 1, -1, -interval)) |
| | if 0 not in snapshot_steps: |
| | snapshot_steps.append(0) |
| |
|
| | snapshots = {} |
| |
|
| | for i in reversed(range(0, scheduler.num_timesteps)): |
| | t = torch.full((b,), i, device=device, dtype=torch.long) |
| | img = scheduler.p_sample_text(model, img, t, text_embeddings, guidance_scale) |
| | img = torch.clamp(img, -2.0, 2.0) |
| |
|
| | if i in snapshot_steps: |
| | snapshots[i] = img.clone().detach() |
| |
|
| | return img, snapshots |
| |
|
| |
|
| | def plot_denoising_process(snapshots, prompt, output_path, sample_idx=0): |
| | """Plot snapshots side by side showing noise -> final image.""" |
| | timesteps = sorted(snapshots.keys(), reverse=True) |
| | num_steps = len(timesteps) |
| |
|
| | fig, axes = plt.subplots(1, num_steps, figsize=(2.5 * num_steps, 3)) |
| | if num_steps == 1: |
| | axes = [axes] |
| |
|
| | fig.suptitle(f'Denoising Process: "{prompt}"', fontsize=12, fontweight='bold') |
| |
|
| | for idx, t in enumerate(timesteps): |
| | img_tensor = snapshots[t][sample_idx] |
| | img = tensor_to_image(img_tensor) |
| |
|
| | axes[idx].imshow(img, cmap='gray') |
| | axes[idx].axis('off') |
| | axes[idx].set_title(f't={t}' if t > 0 else 'Final', fontsize=10) |
| |
|
| | plt.tight_layout() |
| | plt.savefig(output_path, dpi=150, bbox_inches='tight') |
| | plt.close() |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='Visualize denoising process') |
| | parser.add_argument('--checkpoint', type=str, required=True) |
| | parser.add_argument('--prompt', type=str, default="a drawing of a cat") |
| | parser.add_argument('--guidance-scale', type=float, default=config.CFG_GUIDANCE_SCALE) |
| | parser.add_argument('--num-steps', type=int, default=10, |
| | help='Number of snapshots to capture') |
| | parser.add_argument('--device', type=str, default='cuda') |
| | args = parser.parse_args() |
| |
|
| | if args.device == 'cuda' and not torch.cuda.is_available(): |
| | args.device = 'cpu' |
| |
|
| | |
| | checkpoint = torch.load(args.checkpoint, map_location=args.device) |
| | ckpt_config = checkpoint.get('config', {}) |
| |
|
| | model = TextConditionedUNet(text_dim=ckpt_config.get('text_dim', config.TEXT_DIM)).to(args.device) |
| | model.load_state_dict(checkpoint['model_state_dict']) |
| | model.eval() |
| |
|
| | text_encoder = CLIPTextEncoder( |
| | model_name=ckpt_config.get('clip_model', config.CLIP_MODEL), freeze=True |
| | ).to(args.device) |
| | text_encoder.eval() |
| |
|
| | scheduler = SimpleDDPMScheduler(config.TIMESTEPS) |
| |
|
| | |
| | with torch.no_grad(): |
| | text_embedding = text_encoder(args.prompt) |
| | shape = (1, 1, config.IMAGE_SIZE, config.IMAGE_SIZE) |
| |
|
| | _, snapshots = sample_with_snapshots( |
| | scheduler, model, shape, text_embedding, args.device, args.guidance_scale |
| | ) |
| |
|
| | |
| | os.makedirs("outputs", exist_ok=True) |
| | safe_prompt = "".join(c if c.isalnum() or c in " _" else "" for c in args.prompt)[:50] |
| | output_path = f"outputs/denoising_{safe_prompt}.png" |
| |
|
| | plot_denoising_process(snapshots, args.prompt, output_path) |
| | print(f"✅ Saved visualization: {output_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|