| | import torch |
| | from torch.utils.data import Dataset, DataLoader |
| | import torchvision |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import to_pil_image, to_tensor |
| | import glob |
| | from PIL import Image |
| | import tqdm |
| | import gc |
| |
|
| | class TestModel(torch.nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False) |
| | self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| | self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| | self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) |
| | self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False) |
| | self.bn1 = torch.nn.BatchNorm2d(16) |
| | self.bn2 = torch.nn.BatchNorm2d(16) |
| |
|
| | def forward(self, x): |
| | x = self.start(x) |
| | x = self.bn1(x) |
| | x = self.conv1(x) + x |
| | x = self.conv2(x) + x |
| | x = self.conv3(x) + x |
| | x = self.bn2(x) |
| | x = self.final(x) |
| | x = torch.clamp(x, -1, 1) |
| | return x |
| | |
| | class DS(Dataset): |
| | def __init__(self): |
| | super().__init__() |
| | self.g = glob.glob("./15k/*") |
| | self.trans = transforms.Compose([ |
| | transforms.RandomCrop((256, 256)), |
| | transforms.ToTensor() |
| | ]) |
| |
|
| | def __len__(self): |
| | return len(self.g) |
| | |
| | def __getitem__(self, idx): |
| | x = self.g[idx] |
| | x = Image.open(x) |
| | x = x.convert("RGB") |
| | x = self.trans(x) |
| | x = x / 127.5 - 1 |
| | return x |
| | |
| | def gettest(self): |
| | x = self.g[0] |
| | x = Image.open(x) |
| | x = x.convert("RGB") |
| | x = to_tensor(x) |
| | x = x / 127.5 - 1 |
| | return x |
| | |
| | def main(): |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | bacth_size = 64 |
| | epoch = 10 |
| |
|
| | model = TestModel() |
| | dataset = DS() |
| | datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True) |
| | criterion = torch.nn.MSELoss() |
| | kl = torch.nn.KLDivLoss(size_average=False) |
| | optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4) |
| | criterion = criterion.to(device) |
| | model = model.to(device) |
| | model.train() |
| |
|
| | def log(l): |
| | model.eval() |
| | x = dataset.gettest().to(device) |
| | x = x.unsqueeze(0) |
| | out = model(x) |
| | to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png") |
| | model.train() |
| |
|
| | log("test") |
| |
|
| | for i in range(epoch): |
| | for j, k in enumerate(tqdm.tqdm(datalaoder)): |
| | k = k.to(device) |
| | model.zero_grad() |
| | out = model(k) |
| | loss = criterion(out, k) |
| | loss.backward() |
| | optim.step() |
| | if j % 100 == 0: |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| | print("EPOCH", i) |
| | print("LAST LOSS", loss) |
| | log(i) |
| | |
| | |
| | if __name__ == "__main__": |
| | main() |