| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Image processor class for Magma.""" |
|
|
| from typing import List, Optional, Union |
| import ast |
| import numpy as np |
| import torchvision |
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
| from transformers.image_transforms import ( |
| convert_to_rgb, |
| ) |
| from transformers.image_utils import ( |
| OPENAI_CLIP_MEAN, |
| OPENAI_CLIP_STD, |
| ImageInput, |
| make_list_of_images, |
| valid_images, |
| ) |
| from transformers.utils import TensorType, is_vision_available, logging |
|
|
| from transformers import AutoImageProcessor |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| if is_vision_available(): |
| from PIL import Image |
|
|
| import torch |
| import torchvision |
|
|
| def select_best_resolution(original_size, possible_resolutions): |
| """ |
| Selects the best resolution from a list of possible resolutions based on the original size. |
| |
| Args: |
| original_size (tuple): The original size of the image in the format (width, height). |
| possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. |
| |
| Returns: |
| tuple: The best fit resolution in the format (width, height). |
| """ |
| original_width, original_height = original_size |
| best_fit = None |
| max_effective_resolution = 0 |
| min_wasted_resolution = float('inf') |
|
|
| for width, height in possible_resolutions: |
| scale = min(width / original_width, height / original_height) |
| downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) |
| effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) |
| wasted_resolution = (width * height) - effective_resolution |
|
|
| if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): |
| max_effective_resolution = effective_resolution |
| min_wasted_resolution = wasted_resolution |
| best_fit = (width, height) |
|
|
| return best_fit |
|
|
| def process_anyres_image(image, max_num_crops=None, base_width=768, base_height=768): |
| """ |
| Process an image with variable resolutions. |
| |
| Args: |
| image (torch.Tensor): The input image to be processed. |
| max_num_crops (int): Maximum number of crops |
| |
| Returns: |
| torch.Tensor: A tensor containing the processed image patches. |
| """ |
| assert max_num_crops is not None |
| grid_pinpoints = [] |
| for i in range(1, max_num_crops+1): |
| for j in range(1, max_num_crops // i + 1): |
| grid_pinpoints.append((i, j)) |
| grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints] |
|
|
| if type(grid_pinpoints) is list: |
| possible_resolutions = grid_pinpoints |
| else: |
| possible_resolutions = ast.literal_eval(grid_pinpoints) |
| |
| best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions) |
| |
| best_resolution = (best_resolution[1], best_resolution[0]) |
| best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width) |
|
|
| |
| image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear') |
| |
| patches = image.unfold(2, base_height, base_height).unfold(3, base_width, base_width) |
| patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(best_resolution_grid[0]*best_resolution_grid[1], -1, base_height, base_width) |
| return (patches, best_resolution_grid) |
|
|
| def process_anyres_image_global(image, max_num_crops=None, base_width=768, base_height=768): |
| """ |
| Process an image with variable resolutions. |
| |
| Args: |
| image (torch.Tensor): The input image to be processed. |
| max_num_crops (int): Maximum number of crops |
| |
| Returns: |
| torch.Tensor: A tensor containing the processed image patches. |
| """ |
| assert max_num_crops is not None |
| grid_pinpoints = [] |
| for i in range(1, max_num_crops+1): |
| for j in range(1, max_num_crops // i + 1): |
| grid_pinpoints.append((i, j)) |
| grid_pinpoints = [(int(res[0] * base_width), int(res[1] * base_height)) for res in grid_pinpoints] |
|
|
| if type(grid_pinpoints) is list: |
| possible_resolutions = grid_pinpoints |
| else: |
| possible_resolutions = ast.literal_eval(grid_pinpoints) |
| |
| best_resolution = select_best_resolution((image.shape[2], image.shape[1]), possible_resolutions) |
| |
| best_resolution = (best_resolution[1], best_resolution[0]) |
| best_resolution_grid = (best_resolution[0] // base_height, best_resolution[1] // base_width) |
|
|
| |
| image = torch.nn.functional.interpolate(image[None,:,:,:], size=best_resolution, mode='bilinear') |
| return image |
|
|
| class preprocessor(): |
| def __init__(self, image_preprocessor, base_resolution=(256, 256)): |
| self.image_preprocessor = image_preprocessor |
| self.crop_size = { |
| 'height': base_resolution[0], |
| 'width': base_resolution[1] |
| } |
| self.image_mean = image_preprocessor.transforms[-1].mean |
|
|
| def preprocess(self, image, return_tensors='pt'): |
| image = self.image_preprocessor(image).unsqueeze(0) |
| return { |
| 'pixel_values': image, |
| } |
|
|
| class MagmaImageProcessor(BaseImageProcessor): |
| r""" |
| Constructs a Magma image processor. Based on [`CLIPImageProcessor`] with incorporation of additional techniques |
| for processing high resolution images as explained in the [InternLM-XComposer2-4KHD](https://arxiv.org/pdf/2404.06512) |
| |
| Args: |
| anyres_strategy (`str`): |
| strategy to cope with high-resolution images. one conventional way is multi-crop and many other works to accomadate clip-vit models. |
| however, since we are using convnext, which is essentially convnet, so we can use arbitary resolution images. as such, we use global strategy by defualt, |
| i.e., directly resize image holistically to a certain resolution. |
| base_img_size (int, *optional*, defaults to 768): |
| as convnext has 1/32 downsample rate, we use 768 as the base resolution so that the resulted feature map is 24x24. |
| num_crops (int, *optional*, defaults to 1): |
| number of effective crops when coping with images with higher resolution than 768x768. note that num_crops > 1 does not mean we are cropping the image. |
| """ |
|
|
| model_input_names = ["pixel_values"] |
|
|
| def __init__( |
| self, |
| anyres_strategy: str = 'global', |
| base_img_size: int = 768, |
| num_crops: int = 1, |
| do_convert_rgb: bool = True, |
| image_mean: List[float] = OPENAI_CLIP_MEAN, |
| image_std: List[float] = OPENAI_CLIP_STD, |
| **kwargs, |
| ) -> None: |
| super().__init__(**kwargs) |
| self.base_img_size = base_img_size |
| self.anyres_strategy = anyres_strategy |
| self.num_crops = num_crops |
| self.do_convert_rgb = do_convert_rgb |
| self.image_mean = image_mean |
| self.image_std = image_std |
|
|
| def preprocess( |
| self, |
| images: Union[ImageInput, List[ImageInput]], |
| do_pad: bool = False, |
| do_convert_rgb: bool = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| num_crops: int = None, |
| ): |
| """ |
| Args: |
| images (`ImageInput` or `List[ImageInput]`): |
| Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If |
| passing in images with pixel values between 0 and 1, set `do_rescale=False`. |
| image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): |
| Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`. |
| image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): |
| Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to |
| `True`. |
| do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`): |
| Whether to convert the image to RGB. |
| return_tensors (`str` or `TensorType`, *optional*): |
| The type of tensors to return. Can be one of: |
| - Unset: Return a list of `np.ndarray`. |
| - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. |
| - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. |
| - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. |
| - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. |
| """ |
| images = make_list_of_images(images) |
|
|
| if not valid_images(images): |
| raise ValueError( |
| "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " |
| "torch.Tensor, tf.Tensor or jax.ndarray." |
| ) |
| |
| if do_convert_rgb: |
| images = [convert_to_rgb(image) for image in images] |
| |
| |
| img_processor = torchvision.transforms.Compose([ |
| torchvision.transforms.ToTensor(), |
| torchvision.transforms.Normalize(self.image_mean, self.image_std) |
| ]) |
|
|
| images = [img_processor(image) for image in images] |
| image_data_type = 'half' if images[0].type() == 'torch.HalfTensor' else 'float' |
| images = [image.float() for image in images] |
|
|
| |
| image_patches = [process_anyres_image(image, self.num_crops if num_crops is None else num_crops, base_width=self.base_img_size, base_height=self.base_img_size) for image in images] |
| pixel_values = torch.cat([image[0] for image in image_patches], dim=0) |
| |
| image_sizes = [image_patch[1] for image_patch in image_patches] |
|
|
| if image_data_type == 'half': |
| pixel_values = pixel_values.half() |
|
|
| data = { |
| "pixel_values": pixel_values, |
| "image_sizes": image_sizes, |
| } |
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
| AutoImageProcessor.register("MagmaImageProcessor", MagmaImageProcessor) |