| import os |
| import torch |
| import torchvision |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from model import DiffusionModel, UNet |
| from torchvision.datasets import CocoCaptions |
| import argparse |
| from tqdm import tqdm |
|
|
| |
| IMAGE_SIZE = 256 |
| BATCH_SIZE = 16 |
| EPOCHS = 50 |
| LR = 2e-5 |
| TIMESTEPS = 1000 |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def load_coco_dataset(): |
| transform = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| ]) |
| |
| dataset = CocoCaptions( |
| root='./train2017', |
| annFile='./annotations/captions_train2017.json', |
| transform=transform |
| ) |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=4, |
| collate_fn=lambda x: (torch.stack([item[0] for item in x]), [item[1] for item in x]) |
| ) |
| return dataloader |
|
|
| def train(): |
| |
| model = UNet().to(DEVICE) |
| betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(DEVICE) |
| diffusion = DiffusionModel(model, betas, DEVICE) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR) |
| dataloader = load_coco_dataset() |
| |
| |
| for epoch in range(EPOCHS): |
| pbar = tqdm(dataloader) |
| for images, captions in pbar: |
| images = images.to(DEVICE) |
| |
| |
| captions = [cap for sublist in captions for cap in sublist] |
| images = images.repeat_interleave(5, dim=0) |
| |
| |
| t = torch.randint(0, TIMESTEPS, (images.shape[0],), device=DEVICE).long() |
| |
| |
| loss = diffusion.p_losses(images, captions, t) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| |
| pbar.set_description(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
| |
| |
| torch.save(model.state_dict(), f"diffusion_model_epoch_{epoch}.pth") |
|
|
| if __name__ == "__main__": |
| train() |