| | from typing import Sequence, Dict, Union, List, Mapping, Any, Optional |
| | import math |
| | import time |
| | import io |
| | import random |
| |
|
| | import numpy as np |
| | import cv2 |
| | from PIL import Image |
| | import torch.utils.data as data |
| |
|
| | from dataset.degradation import ( |
| | random_mixed_kernels, |
| | random_add_gaussian_noise, |
| | random_add_jpg_compression |
| | ) |
| | from dataset.utils import load_file_list, center_crop_arr, random_crop_arr |
| | from utils.common import instantiate_from_config |
| |
|
| |
|
| | class CodeformerDataset(data.Dataset): |
| | |
| | def __init__( |
| | self, |
| | file_list: str, |
| | file_backend_cfg: Mapping[str, Any], |
| | out_size: int, |
| | crop_type: str, |
| | blur_kernel_size: int, |
| | kernel_list: Sequence[str], |
| | kernel_prob: Sequence[float], |
| | blur_sigma: Sequence[float], |
| | downsample_range: Sequence[float], |
| | noise_range: Sequence[float], |
| | jpeg_range: Sequence[int] |
| | ) -> "CodeformerDataset": |
| | super(CodeformerDataset, self).__init__() |
| | self.file_list = file_list |
| | self.image_files = load_file_list(file_list) |
| | self.file_backend = instantiate_from_config(file_backend_cfg) |
| | self.out_size = out_size |
| | self.crop_type = crop_type |
| | assert self.crop_type in ["none", "center", "random"] |
| | |
| | self.blur_kernel_size = blur_kernel_size |
| | self.kernel_list = kernel_list |
| | self.kernel_prob = kernel_prob |
| | self.blur_sigma = blur_sigma |
| | self.downsample_range = downsample_range |
| | self.noise_range = noise_range |
| | self.jpeg_range = jpeg_range |
| |
|
| | def load_gt_image(self, image_path: str, max_retry: int=5) -> Optional[np.ndarray]: |
| | image_bytes = None |
| | while image_bytes is None: |
| | if max_retry == 0: |
| | return None |
| | image_bytes = self.file_backend.get(image_path) |
| | max_retry -= 1 |
| | if image_bytes is None: |
| | time.sleep(0.5) |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | if self.crop_type != "none": |
| | if image.height == self.out_size and image.width == self.out_size: |
| | image = np.array(image) |
| | else: |
| | if self.crop_type == "center": |
| | image = center_crop_arr(image, self.out_size) |
| | elif self.crop_type == "random": |
| | image = random_crop_arr(image, self.out_size, min_crop_frac=0.7) |
| | else: |
| | assert image.height == self.out_size and image.width == self.out_size |
| | image = np.array(image) |
| | |
| | return image |
| |
|
| | def __getitem__(self, index: int) -> Dict[str, Union[np.ndarray, str]]: |
| | |
| | img_gt = None |
| | while img_gt is None: |
| | |
| | image_file = self.image_files[index] |
| | gt_path = image_file["image_path"] |
| | prompt = image_file["prompt"] |
| | img_gt = self.load_gt_image(gt_path) |
| | if img_gt is None: |
| | print(f"filed to load {gt_path}, try another image") |
| | index = random.randint(0, len(self) - 1) |
| | |
| | |
| | img_gt = (img_gt[..., ::-1] / 255.0).astype(np.float32) |
| | h, w, _ = img_gt.shape |
| | if np.random.uniform() < 0.5: |
| | prompt = "" |
| | |
| | |
| | |
| | kernel = random_mixed_kernels( |
| | self.kernel_list, |
| | self.kernel_prob, |
| | self.blur_kernel_size, |
| | self.blur_sigma, |
| | self.blur_sigma, |
| | [-math.pi, math.pi], |
| | noise_range=None |
| | ) |
| | img_lq = cv2.filter2D(img_gt, -1, kernel) |
| | |
| | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) |
| | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) |
| | |
| | if self.noise_range is not None: |
| | img_lq = random_add_gaussian_noise(img_lq, self.noise_range) |
| | |
| | if self.jpeg_range is not None: |
| | img_lq = random_add_jpg_compression(img_lq, self.jpeg_range) |
| | |
| | |
| | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) |
| | |
| | |
| | gt = (img_gt[..., ::-1] * 2 - 1).astype(np.float32) |
| | |
| | lq = img_lq[..., ::-1].astype(np.float32) |
| | |
| | return gt, lq, prompt |
| |
|
| | def __len__(self) -> int: |
| | return len(self.image_files) |
| |
|