| | from typing import Dict, Tuple, Union |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| | from PIL import Image |
| | from PIL.Image import Image as PilImage |
| | from torchvision import transforms |
| | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| | from transformers.image_utils import ImageInput |
| |
|
| |
|
| | class RescaleT(object): |
| | def __init__(self, output_size: Union[int, Tuple[int, int]]) -> None: |
| | super().__init__() |
| | assert isinstance(output_size, (int, tuple)) |
| | self.output_size = output_size |
| |
|
| | def __call__(self, sample) -> Dict[str, np.ndarray]: |
| | image, label = sample["image"], sample["label"] |
| |
|
| | h, w = image.shape[:2] |
| |
|
| | if isinstance(self.output_size, int): |
| | if h > w: |
| | new_h, new_w = self.output_size * h / w, self.output_size |
| | else: |
| | new_h, new_w = self.output_size, self.output_size * w / h |
| | else: |
| | new_h, new_w = self.output_size |
| |
|
| | new_h, new_w = int(new_h), int(new_w) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | img = ( |
| | cv2.resize( |
| | image, |
| | (self.output_size, self.output_size), |
| | interpolation=cv2.INTER_AREA, |
| | ) |
| | / 255.0 |
| | ) |
| | |
| | |
| | |
| | |
| | lbl = cv2.resize( |
| | label, (self.output_size, self.output_size), interpolation=cv2.INTER_NEAREST |
| | ) |
| | lbl = np.expand_dims(lbl, axis=-1) |
| | lbl = np.clip(lbl, np.min(label), np.max(label)) |
| |
|
| | return {"image": img, "label": lbl} |
| |
|
| |
|
| | class ToTensorLab(object): |
| | """Convert ndarrays in sample to Tensors.""" |
| |
|
| | def __init__(self, flag: int = 0) -> None: |
| | self.flag = flag |
| |
|
| | def __call__(self, sample): |
| | image, label = sample["image"], sample["label"] |
| |
|
| | tmpLbl = np.zeros(label.shape) |
| |
|
| | if np.max(label) < 1e-6: |
| | label = label |
| | else: |
| | label = label / np.max(label) |
| |
|
| | |
| | if self.flag == 2: |
| | tmpImg = np.zeros((image.shape[0], image.shape[1], 6)) |
| | tmpImgt = np.zeros((image.shape[0], image.shape[1], 3)) |
| | if image.shape[2] == 1: |
| | tmpImgt[:, :, 0] = image[:, :, 0] |
| | tmpImgt[:, :, 1] = image[:, :, 0] |
| | tmpImgt[:, :, 2] = image[:, :, 0] |
| | else: |
| | tmpImgt = image |
| | |
| | tmpImgtl = cv2.cvtColor(tmpImgt, cv2.COLOR_RGB2LAB) |
| |
|
| | |
| | tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / ( |
| | np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]) |
| | ) |
| | tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / ( |
| | np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]) |
| | ) |
| | tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / ( |
| | np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]) |
| | ) |
| | tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / ( |
| | np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]) |
| | ) |
| | tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / ( |
| | np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]) |
| | ) |
| | tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / ( |
| | np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]) |
| | ) |
| |
|
| | |
| |
|
| | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( |
| | tmpImg[:, :, 0] |
| | ) |
| | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( |
| | tmpImg[:, :, 1] |
| | ) |
| | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( |
| | tmpImg[:, :, 2] |
| | ) |
| | tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std( |
| | tmpImg[:, :, 3] |
| | ) |
| | tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std( |
| | tmpImg[:, :, 4] |
| | ) |
| | tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std( |
| | tmpImg[:, :, 5] |
| | ) |
| |
|
| | elif self.flag == 1: |
| | tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) |
| |
|
| | if image.shape[2] == 1: |
| | tmpImg[:, :, 0] = image[:, :, 0] |
| | tmpImg[:, :, 1] = image[:, :, 0] |
| | tmpImg[:, :, 2] = image[:, :, 0] |
| | else: |
| | tmpImg = image |
| |
|
| | |
| | print("tmpImg:", tmpImg.min(), tmpImg.max()) |
| | exit() |
| | tmpImg = cv2.cvtColor(tmpImg, cv2.COLOR_RGB2LAB) |
| |
|
| | |
| |
|
| | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / ( |
| | np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]) |
| | ) |
| | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / ( |
| | np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]) |
| | ) |
| | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / ( |
| | np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]) |
| | ) |
| |
|
| | tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std( |
| | tmpImg[:, :, 0] |
| | ) |
| | tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std( |
| | tmpImg[:, :, 1] |
| | ) |
| | tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std( |
| | tmpImg[:, :, 2] |
| | ) |
| |
|
| | else: |
| | tmpImg = np.zeros((image.shape[0], image.shape[1], 3)) |
| | image = image / np.max(image) |
| | if image.shape[2] == 1: |
| | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 |
| | tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229 |
| | tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229 |
| | else: |
| | tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229 |
| | tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224 |
| | tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225 |
| |
|
| | tmpLbl[:, :, 0] = label[:, :, 0] |
| |
|
| | |
| | |
| | tmpImg = tmpImg.transpose((2, 0, 1)) |
| | tmpLbl = label.transpose((2, 0, 1)) |
| |
|
| | return {"image": torch.from_numpy(tmpImg), "label": torch.from_numpy(tmpLbl)} |
| |
|
| |
|
| | def apply_transform( |
| | data: Dict[str, np.ndarray], rescale_size: int, to_tensor_lab_flag: int |
| | ) -> Dict[str, torch.Tensor]: |
| | transform = transforms.Compose( |
| | [RescaleT(output_size=rescale_size), ToTensorLab(flag=to_tensor_lab_flag)] |
| | ) |
| | return transform(data) |
| |
|
| |
|
| | class BASNetImageProcessor(BaseImageProcessor): |
| | model_input_names = ["pixel_values"] |
| |
|
| | def __init__( |
| | self, rescale_size: int = 256, to_tensor_lab_flag: int = 0, **kwargs |
| | ) -> None: |
| | super().__init__(**kwargs) |
| | self.rescale_size = rescale_size |
| | self.to_tensor_lab_flag = to_tensor_lab_flag |
| |
|
| | def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature: |
| | if not isinstance(images, PilImage): |
| | raise ValueError(f"Expected PIL.Image, got {type(images)}") |
| |
|
| | image_pil = images |
| | image_npy = np.array(image_pil, dtype=np.uint8) |
| | width, height = image_pil.size |
| | label_npy = np.zeros((height, width), dtype=np.uint8) |
| |
|
| | assert image_npy.shape[-1] == 3 |
| | output = apply_transform( |
| | {"image": image_npy, "label": label_npy}, |
| | rescale_size=self.rescale_size, |
| | to_tensor_lab_flag=self.to_tensor_lab_flag, |
| | ) |
| | image = output["image"] |
| |
|
| | assert isinstance(image, torch.Tensor) |
| |
|
| | return BatchFeature( |
| | data={"pixel_values": image.float().unsqueeze(dim=0)}, tensor_type="pt" |
| | ) |
| |
|
| | def postprocess( |
| | self, prediction: torch.Tensor, width: int, height: int |
| | ) -> PilImage: |
| | def _norm_prediction(d: torch.Tensor) -> torch.Tensor: |
| | ma, mi = torch.max(d), torch.min(d) |
| |
|
| | |
| | dn = (d - mi) / ((ma - mi) + torch.finfo(torch.float32).eps) |
| | return dn |
| |
|
| | prediction = _norm_prediction(prediction) |
| | prediction = prediction.squeeze() |
| | prediction = prediction * 255 + 0.5 |
| | prediction = prediction.clamp(0, 255) |
| |
|
| | prediction_np = prediction.cpu().numpy() |
| | image = Image.fromarray(prediction_np).convert("RGB") |
| | image = image.resize((width, height), resample=Image.Resampling.BILINEAR) |
| | return image |
| |
|