| from model import ConditionalGenerator, Discriminator |
| from data_loader import MultiStyleDataset |
| import torch.optim as optim |
|
|
|
|
| num_styles = len(os.listdir(".styles/")) |
| G = ConditionalGenerator(num_styles) |
| D = Discriminator(num_styles) |
|
|
| opt_G = optim.Adam(G.parameters(), lr=2e-4) |
| opt_D = optim.Adam(D.parameters(), lr=2e-4) |
|
|
| dataset = MultiStyleDataset(".styles/") |
|
|
| for epoch in range(100): |
| for img, style_id in dataset: |
| |
| fake_img = G(img.unsqueeze(0), torch.tensor([style_id])) |
| |
|
|
| real_loss = torch.mean((D(img.unsqueeze(0), torch.tensor([style_id])) - 1)**2) |
| fake_loss = torch.mean(D(fake_img.detach(), torch.tensor([style_id]))**2) |
| loss_D = (real_loss + fake_loss) / 2 |
| |
| opt_D.zero_grad() |
| loss_D.backward() |
| opt_D.step() |
| |
| |
| loss_G = torch.mean((D(fake_img, torch.tensor([style_id])) - 1)**2 |
| opt_G.zero_grad() |
| loss_G.backward() |
| opt_G.step() |