| from transformers.models.clip.image_processing_clip import CLIPImageProcessor |
| from transformers.image_utils import ( |
| OPENAI_CLIP_MEAN, |
| OPENAI_CLIP_STD, |
| ChannelDimension, |
| ImageInput, |
| PILImageResampling, |
| infer_channel_dimension_format, |
| is_scaled_image, |
| make_list_of_images, |
| to_numpy_array, |
| valid_images, |
| validate_kwargs, |
| validate_preprocess_arguments, |
| ) |
| from transformers.utils import TensorType, is_vision_available, logging |
|
|
| from typing import Dict, List, Optional, Union |
| from math import ceil |
| from torchvision.transforms import Resize |
|
|
| def get_resize_output_image_size_long( |
| image_size, PATCH_SIZE=32, MAX_RESOLUTION = 1024, MIN_RESOLUTION = 448, |
| ) -> tuple: |
| l1, l2 = image_size |
| short, long = (l2, l1) if l2 <= l1 else (l1, l2) |
|
|
| |
| requested_new_long = min( |
| [ |
| ceil(long / PATCH_SIZE) * PATCH_SIZE, |
| MAX_RESOLUTION, |
| ] |
| ) |
|
|
| requested_new_long = max(requested_new_long, MIN_RESOLUTION) |
|
|
| new_long, new_short = requested_new_long, int(requested_new_long * short / long) |
| |
| new_short = ceil(new_short / PATCH_SIZE) * PATCH_SIZE |
| return (new_long, new_short) if l2 <= l1 else (new_short, new_long) |
|
|
|
|
| class SoloCLIPImageProcessor(CLIPImageProcessor): |
|
|
| def __init__( |
| self, |
| do_resize: bool = True, |
| size: Dict[str, int] = None, |
| resample: PILImageResampling = PILImageResampling.BICUBIC, |
| do_center_crop: bool = True, |
| crop_size: Dict[str, int] = None, |
| do_rescale: bool = True, |
| rescale_factor: Union[int, float] = 1 / 255, |
| do_normalize: bool = True, |
| image_mean: Optional[Union[float, List[float]]] = None, |
| image_std: Optional[Union[float, List[float]]] = None, |
| do_convert_rgb: bool = True, |
| PATCH_SIZE=32, |
| MAX_RESOLUTION=1024, |
| MIN_RESOLUTION=448, |
| **kwargs, |
| ) -> None: |
| super(SoloCLIPImageProcessor, self).__init__( |
| do_resize=do_resize, |
| size=size, |
| resample=resample, |
| do_center_crop=do_center_crop, |
| crop_size=crop_size, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| do_convert_rgb=do_convert_rgb, |
| **kwargs, |
| ) |
| self.PATCH_SIZE = PATCH_SIZE |
| self.MAX_RESOLUTION = MAX_RESOLUTION |
| self.MIN_RESOLUTION = MIN_RESOLUTION |
|
|
| def preprocess( |
| self, |
| images: ImageInput, |
| do_resize: bool = None, |
| size: Dict[str, int] = None, |
| resample: PILImageResampling = None, |
| do_center_crop: bool = None, |
| crop_size: int = None, |
| do_rescale: bool = None, |
| rescale_factor: float = None, |
| do_normalize: bool = None, |
| image_mean: Optional[Union[float, List[float]]] = None, |
| image_std: Optional[Union[float, List[float]]] = None, |
| do_convert_rgb: bool = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, |
| input_data_format: Optional[Union[str, ChannelDimension]] = None, |
| **kwargs, |
| ): |
| return_dict = super(SoloCLIPImageProcessor, self).preprocess( |
| images=images, |
| do_resize=do_resize, |
| size=size, |
| resample=resample, |
| do_center_crop=do_center_crop, |
| crop_size=crop_size, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| do_convert_rgb=do_convert_rgb, |
| return_tensors=return_tensors, |
| data_format=data_format, |
| input_data_format=input_data_format, |
| **kwargs, |
| ) |
| pixel_values = return_dict['pixel_values'][0] |
| _, height, width = pixel_values.size() |
| height, width = get_resize_output_image_size_long( |
| (height, width), |
| PATCH_SIZE=self.PATCH_SIZE, |
| MAX_RESOLUTION=self.MAX_RESOLUTION, |
| MIN_RESOLUTION=self.MIN_RESOLUTION, |
| ) |
| pixel_values = Resize(size=(height, width))(pixel_values) |
| return_dict['pixel_values'] = pixel_values.unsqueeze(0) |
| return return_dict |