| import os |
| from io import BytesIO |
| from pathlib import Path |
|
|
| import lmdb |
| from PIL import Image |
| from torch.utils.data import Dataset |
| from torchvision import transforms |
| from torchvision.datasets import CIFAR10, LSUNClass |
| import torch |
| import pandas as pd |
|
|
| import torchvision.transforms.functional as Ftrans |
|
|
|
|
| class ImageDataset(Dataset): |
| def __init__( |
| self, |
| folder, |
| image_size, |
| exts=['jpg'], |
| do_augment: bool = True, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| sort_names=False, |
| has_subdir: bool = True, |
| ): |
| super().__init__() |
| self.folder = folder |
| self.image_size = image_size |
|
|
| |
| if has_subdir: |
| self.paths = [ |
| p.relative_to(folder) for ext in exts |
| for p in Path(f'{folder}').glob(f'**/*.{ext}') |
| ] |
| else: |
| self.paths = [ |
| p.relative_to(folder) for ext in exts |
| for p in Path(f'{folder}').glob(f'*.{ext}') |
| ] |
| if sort_names: |
| self.paths = sorted(self.paths) |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def __len__(self): |
| return len(self.paths) |
|
|
| def __getitem__(self, index): |
| path = os.path.join(self.folder, self.paths[index]) |
| img = Image.open(path) |
| |
| img = img.convert('RGB') |
| if self.transform is not None: |
| img = self.transform(img) |
| return {'img': img, 'index': index} |
|
|
|
|
| class SubsetDataset(Dataset): |
| def __init__(self, dataset, size): |
| assert len(dataset) >= size |
| self.dataset = dataset |
| self.size = size |
|
|
| def __len__(self): |
| return self.size |
|
|
| def __getitem__(self, index): |
| assert index < self.size |
| return self.dataset[index] |
|
|
|
|
| class BaseLMDB(Dataset): |
| def __init__(self, path, original_resolution, zfill: int = 5): |
| self.original_resolution = original_resolution |
| self.zfill = zfill |
| self.env = lmdb.open( |
| path, |
| max_readers=32, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
|
|
| if not self.env: |
| raise IOError('Cannot open lmdb dataset', path) |
|
|
| with self.env.begin(write=False) as txn: |
| self.length = int( |
| txn.get('length'.encode('utf-8')).decode('utf-8')) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| with self.env.begin(write=False) as txn: |
| key = f'{self.original_resolution}-{str(index).zfill(self.zfill)}'.encode( |
| 'utf-8') |
| img_bytes = txn.get(key) |
|
|
| buffer = BytesIO(img_bytes) |
| img = Image.open(buffer) |
| return img |
|
|
|
|
| def make_transform( |
| image_size, |
| flip_prob=0.5, |
| crop_d2c=False, |
| ): |
| if crop_d2c: |
| transform = [ |
| d2c_crop(), |
| transforms.Resize(image_size), |
| ] |
| else: |
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| transform.append(transforms.RandomHorizontalFlip(p=flip_prob)) |
| transform.append(transforms.ToTensor()) |
| transform.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| transform = transforms.Compose(transform) |
| return transform |
|
|
|
|
| class FFHQlmdb(Dataset): |
| def __init__(self, |
| path=os.path.expanduser('datasets/ffhq256.lmdb'), |
| image_size=256, |
| original_resolution=256, |
| split=None, |
| as_tensor: bool = True, |
| do_augment: bool = True, |
| do_normalize: bool = True, |
| **kwargs): |
| self.original_resolution = original_resolution |
| self.data = BaseLMDB(path, original_resolution, zfill=5) |
| self.length = len(self.data) |
|
|
| if split is None: |
| self.offset = 0 |
| elif split == 'train': |
| |
| self.length = self.length - 10000 |
| self.offset = 10000 |
| elif split == 'test': |
| |
| self.length = 10000 |
| self.offset = 0 |
| else: |
| raise NotImplementedError() |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if as_tensor: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| assert index < self.length |
| index = index + self.offset |
| img = self.data[index] |
| if self.transform is not None: |
| img = self.transform(img) |
| return {'img': img, 'index': index} |
|
|
|
|
| class Crop: |
| def __init__(self, x1, x2, y1, y2): |
| self.x1 = x1 |
| self.x2 = x2 |
| self.y1 = y1 |
| self.y2 = y2 |
|
|
| def __call__(self, img): |
| return Ftrans.crop(img, self.x1, self.y1, self.x2 - self.x1, |
| self.y2 - self.y1) |
|
|
| def __repr__(self): |
| return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( |
| self.x1, self.x2, self.y1, self.y2) |
|
|
|
|
| def d2c_crop(): |
| |
| cx = 89 |
| cy = 121 |
| x1 = cy - 64 |
| x2 = cy + 64 |
| y1 = cx - 64 |
| y2 = cx + 64 |
| return Crop(x1, x2, y1, y2) |
|
|
|
|
| class CelebAlmdb(Dataset): |
| """ |
| also supports for d2c crop. |
| """ |
| def __init__(self, |
| path, |
| image_size, |
| original_resolution=128, |
| split=None, |
| as_tensor: bool = True, |
| do_augment: bool = True, |
| do_normalize: bool = True, |
| crop_d2c: bool = False, |
| **kwargs): |
| self.original_resolution = original_resolution |
| self.data = BaseLMDB(path, original_resolution, zfill=7) |
| self.length = len(self.data) |
| self.crop_d2c = crop_d2c |
|
|
| if split is None: |
| self.offset = 0 |
| else: |
| raise NotImplementedError() |
|
|
| if crop_d2c: |
| transform = [ |
| d2c_crop(), |
| transforms.Resize(image_size), |
| ] |
| else: |
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
|
|
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if as_tensor: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| assert index < self.length |
| index = index + self.offset |
| img = self.data[index] |
| if self.transform is not None: |
| img = self.transform(img) |
| return {'img': img, 'index': index} |
|
|
|
|
| class Horse_lmdb(Dataset): |
| def __init__(self, |
| path=os.path.expanduser('datasets/horse256.lmdb'), |
| image_size=128, |
| original_resolution=256, |
| do_augment: bool = True, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| **kwargs): |
| self.original_resolution = original_resolution |
| print(path) |
| self.data = BaseLMDB(path, original_resolution, zfill=7) |
| self.length = len(self.data) |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| img = self.data[index] |
| if self.transform is not None: |
| img = self.transform(img) |
| return {'img': img, 'index': index} |
|
|
|
|
| class Bedroom_lmdb(Dataset): |
| def __init__(self, |
| path=os.path.expanduser('datasets/bedroom256.lmdb'), |
| image_size=128, |
| original_resolution=256, |
| do_augment: bool = True, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| **kwargs): |
| self.original_resolution = original_resolution |
| print(path) |
| self.data = BaseLMDB(path, original_resolution, zfill=7) |
| self.length = len(self.data) |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| img = self.data[index] |
| img = self.transform(img) |
| return {'img': img, 'index': index} |
|
|
|
|
| class CelebAttrDataset(Dataset): |
|
|
| id_to_cls = [ |
| '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', |
| 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', |
| 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', |
| 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', |
| 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', |
| 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', |
| 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', |
| 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', |
| 'Wearing_Necklace', 'Wearing_Necktie', 'Young' |
| ] |
| cls_to_id = {v: k for k, v in enumerate(id_to_cls)} |
|
|
| def __init__(self, |
| folder, |
| image_size=64, |
| attr_path=os.path.expanduser( |
| 'datasets/celeba_anno/list_attr_celeba.txt'), |
| ext='png', |
| only_cls_name: str = None, |
| only_cls_value: int = None, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| d2c: bool = False): |
| super().__init__() |
| self.folder = folder |
| self.image_size = image_size |
| self.ext = ext |
|
|
| |
| paths = [ |
| str(p.relative_to(folder)) |
| for p in Path(f'{folder}').glob(f'**/*.{ext}') |
| ] |
| paths = [str(each).split('.')[0] + '.jpg' for each in paths] |
|
|
| if d2c: |
| transform = [ |
| d2c_crop(), |
| transforms.Resize(image_size), |
| ] |
| else: |
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| with open(attr_path) as f: |
| |
| f.readline() |
| self.df = pd.read_csv(f, delim_whitespace=True) |
| self.df = self.df[self.df.index.isin(paths)] |
|
|
| if only_cls_name is not None: |
| self.df = self.df[self.df[only_cls_name] == only_cls_value] |
|
|
| def pos_count(self, cls_name): |
| return (self.df[cls_name] == 1).sum() |
|
|
| def neg_count(self, cls_name): |
| return (self.df[cls_name] == -1).sum() |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, index): |
| row = self.df.iloc[index] |
| name = row.name.split('.')[0] |
| name = f'{name}.{self.ext}' |
|
|
| path = os.path.join(self.folder, name) |
| img = Image.open(path) |
|
|
| labels = [0] * len(self.id_to_cls) |
| for k, v in row.items(): |
| labels[self.cls_to_id[k]] = int(v) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| return {'img': img, 'index': index, 'labels': torch.tensor(labels)} |
|
|
|
|
| class CelebD2CAttrDataset(CelebAttrDataset): |
| """ |
| the dataset is used in the D2C paper. |
| it has a specific crop from the original CelebA. |
| """ |
| def __init__(self, |
| folder, |
| image_size=64, |
| attr_path=os.path.expanduser( |
| 'datasets/celeba_anno/list_attr_celeba.txt'), |
| ext='jpg', |
| only_cls_name: str = None, |
| only_cls_value: int = None, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| d2c: bool = True): |
| super().__init__(folder, |
| image_size, |
| attr_path, |
| ext=ext, |
| only_cls_name=only_cls_name, |
| only_cls_value=only_cls_value, |
| do_augment=do_augment, |
| do_transform=do_transform, |
| do_normalize=do_normalize, |
| d2c=d2c) |
|
|
|
|
| class CelebAttrFewshotDataset(Dataset): |
| def __init__( |
| self, |
| cls_name, |
| K, |
| img_folder, |
| img_size=64, |
| ext='png', |
| seed=0, |
| only_cls_name: str = None, |
| only_cls_value: int = None, |
| all_neg: bool = False, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| d2c: bool = False, |
| ) -> None: |
| self.cls_name = cls_name |
| self.K = K |
| self.img_folder = img_folder |
| self.ext = ext |
|
|
| if all_neg: |
| path = f'data/celeba_fewshots/K{K}_allneg_{cls_name}_{seed}.csv' |
| else: |
| path = f'data/celeba_fewshots/K{K}_{cls_name}_{seed}.csv' |
| self.df = pd.read_csv(path, index_col=0) |
| if only_cls_name is not None: |
| self.df = self.df[self.df[only_cls_name] == only_cls_value] |
|
|
| if d2c: |
| transform = [ |
| d2c_crop(), |
| transforms.Resize(img_size), |
| ] |
| else: |
| transform = [ |
| transforms.Resize(img_size), |
| transforms.CenterCrop(img_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| def pos_count(self, cls_name): |
| return (self.df[cls_name] == 1).sum() |
|
|
| def neg_count(self, cls_name): |
| return (self.df[cls_name] == -1).sum() |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, index): |
| row = self.df.iloc[index] |
| name = row.name.split('.')[0] |
| name = f'{name}.{self.ext}' |
|
|
| path = os.path.join(self.img_folder, name) |
| img = Image.open(path) |
|
|
| |
| label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| return {'img': img, 'index': index, 'labels': label} |
|
|
|
|
| class CelebD2CAttrFewshotDataset(CelebAttrFewshotDataset): |
| def __init__(self, |
| cls_name, |
| K, |
| img_folder, |
| img_size=64, |
| ext='jpg', |
| seed=0, |
| only_cls_name: str = None, |
| only_cls_value: int = None, |
| all_neg: bool = False, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True, |
| is_negative=False, |
| d2c: bool = True) -> None: |
| super().__init__(cls_name, |
| K, |
| img_folder, |
| img_size, |
| ext=ext, |
| seed=seed, |
| only_cls_name=only_cls_name, |
| only_cls_value=only_cls_value, |
| all_neg=all_neg, |
| do_augment=do_augment, |
| do_transform=do_transform, |
| do_normalize=do_normalize, |
| d2c=d2c) |
| self.is_negative = is_negative |
|
|
|
|
| class CelebHQAttrDataset(Dataset): |
| id_to_cls = [ |
| '5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes', |
| 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair', |
| 'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin', |
| 'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones', |
| 'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard', |
| 'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline', |
| 'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair', |
| 'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick', |
| 'Wearing_Necklace', 'Wearing_Necktie', 'Young' |
| ] |
| cls_to_id = {v: k for k, v in enumerate(id_to_cls)} |
|
|
| def __init__(self, |
| path=os.path.expanduser('datasets/celebahq256.lmdb'), |
| image_size=None, |
| attr_path=os.path.expanduser( |
| 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), |
| original_resolution=256, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True): |
| super().__init__() |
| self.image_size = image_size |
| self.data = BaseLMDB(path, original_resolution, zfill=5) |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| with open(attr_path) as f: |
| |
| f.readline() |
| self.df = pd.read_csv(f, delim_whitespace=True) |
|
|
| def pos_count(self, cls_name): |
| return (self.df[cls_name] == 1).sum() |
|
|
| def neg_count(self, cls_name): |
| return (self.df[cls_name] == -1).sum() |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, index): |
| row = self.df.iloc[index] |
| img_name = row.name |
| img_idx, ext = img_name.split('.') |
| img = self.data[img_idx] |
|
|
| labels = [0] * len(self.id_to_cls) |
| for k, v in row.items(): |
| labels[self.cls_to_id[k]] = int(v) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
| return {'img': img, 'index': index, 'labels': torch.tensor(labels)} |
|
|
|
|
| class CelebHQAttrFewshotDataset(Dataset): |
| def __init__(self, |
| cls_name, |
| K, |
| path, |
| image_size, |
| original_resolution=256, |
| do_augment: bool = False, |
| do_transform: bool = True, |
| do_normalize: bool = True): |
| super().__init__() |
| self.image_size = image_size |
| self.cls_name = cls_name |
| self.K = K |
| self.data = BaseLMDB(path, original_resolution, zfill=5) |
|
|
| transform = [ |
| transforms.Resize(image_size), |
| transforms.CenterCrop(image_size), |
| ] |
| if do_augment: |
| transform.append(transforms.RandomHorizontalFlip()) |
| if do_transform: |
| transform.append(transforms.ToTensor()) |
| if do_normalize: |
| transform.append( |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) |
| self.transform = transforms.Compose(transform) |
|
|
| self.df = pd.read_csv(f'data/celebahq_fewshots/K{K}_{cls_name}.csv', |
| index_col=0) |
|
|
| def pos_count(self, cls_name): |
| return (self.df[cls_name] == 1).sum() |
|
|
| def neg_count(self, cls_name): |
| return (self.df[cls_name] == -1).sum() |
|
|
| def __len__(self): |
| return len(self.df) |
|
|
| def __getitem__(self, index): |
| row = self.df.iloc[index] |
| img_name = row.name |
| img_idx, ext = img_name.split('.') |
| img = self.data[img_idx] |
|
|
| |
| label = torch.tensor(int(row[self.cls_name])).unsqueeze(-1) |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| return {'img': img, 'index': index, 'labels': label} |
|
|
|
|
| class Repeat(Dataset): |
| def __init__(self, dataset, new_len) -> None: |
| super().__init__() |
| self.dataset = dataset |
| self.original_len = len(dataset) |
| self.new_len = new_len |
|
|
| def __len__(self): |
| return self.new_len |
|
|
| def __getitem__(self, index): |
| index = index % self.original_len |
| return self.dataset[index] |
|
|