| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| HyperCLOVAX-SEED Image Processor (Fast) |
| |
| Implements dynamic resolution image processing: |
| - Smart resize: adjusts image to fit within min_pixels and max_pixels |
| - Vision token calculation: token reduction using merge_size |
| - Discrete image processing: separate processing for discrete vision tokens |
| |
| Based on BaseImageProcessorFast with torchvision resize. |
| """ |
|
|
| import math |
| import os |
| import PIL |
| from typing import List, Optional, Tuple, TypeAlias, Union |
|
|
| import torch |
| from torchvision.transforms.v2 import functional as F |
| try: |
| from transformers.image_processing_utils import BatchFeature |
| except ImportError: |
| from transformers import BatchFeature |
| try: |
| from transformers.image_processing_backends import ( |
| BaseImageProcessorFast, |
| group_images_by_shape, |
| reorder_images, |
| ) |
| except ImportError: |
| |
| from transformers.image_processing_utils_fast import ( |
| BaseImageProcessorFast, |
| group_images_by_shape, |
| reorder_images, |
| ) |
| try: |
| from transformers.image_processing_utils_fast import DefaultFastImageProcessorKwargs |
| except ImportError: |
| from transformers.processing_utils import ImagesKwargs as DefaultFastImageProcessorKwargs |
| try: |
| from PIL.Image import Resampling as PILResampling |
| except (ImportError, AttributeError): |
| |
| class PILResampling: |
| NEAREST = 0 |
| LANCZOS = 1 |
| BILINEAR = 2 |
| BICUBIC = 3 |
| BOX = 4 |
| HAMMING = 5 |
| try: |
| from transformers.image_utils import SizeDict |
| except ImportError: |
| SizeDict = dict |
| |
| |
| _OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] |
| _OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] |
|
|
|
|
| def smart_resize( |
| height: int, |
| width: int, |
| factor: int = 28, |
| min_pixels: int = 56 * 56, |
| max_pixels: int = 14 * 14 * 4 * 1280, |
| ) -> Tuple[int, int]: |
| """Smart resize for dynamic resolution. |
| |
| Adjusts image dimensions to satisfy: |
| 1. Both dimensions are divisible by factor. |
| 2. Total pixel count is between min_pixels and max_pixels. |
| |
| Adapted from the Qwen2.5-VL image processing implementation. |
| Reference: https://github.com/QwenLM/Qwen2.5-VL (Apache 2.0 License) |
| |
| Args: |
| height: Original image height. |
| width: Original image width. |
| factor: Rounding unit (default: 28 = patch_size * merge_size). |
| min_pixels: Minimum pixel count (default: 3136). |
| max_pixels: Maximum pixel count (default: 1003520). |
| |
| Returns: |
| Tuple of (new_height, new_width). |
| """ |
| if max(height, width) / min(height, width) > 200: |
| raise ValueError( |
| f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" |
| ) |
| h_bar = round(height / factor) * factor |
| w_bar = round(width / factor) * factor |
| if h_bar * w_bar > max_pixels: |
| beta = math.sqrt((height * width) / max_pixels) |
| h_bar = max(factor, math.floor(height / beta / factor) * factor) |
| w_bar = max(factor, math.floor(width / beta / factor) * factor) |
| elif h_bar * w_bar < min_pixels: |
| beta = math.sqrt(min_pixels / (height * width)) |
| h_bar = math.ceil(height * beta / factor) * factor |
| w_bar = math.ceil(width * beta / factor) * factor |
| return h_bar, w_bar |
|
|
|
|
| class HyperCLOVAXSeedFastImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): |
| min_pixels: Optional[int] |
| max_pixels: Optional[int] |
| patch_size: Optional[int] |
| temporal_patch_size: Optional[int] |
| merge_size: Optional[int] |
| |
| image_token: Optional[str] |
| image_start_token: Optional[str] |
| image_end_token: Optional[str] |
| |
| discrete_image_size: Optional[int] |
| discrete_token_size: Optional[int] |
| discrete_image_ratios: Optional[List] |
| discrete_image_token: Optional[str] |
| discrete_image_start_token: Optional[str] |
| discrete_image_end_token: Optional[str] |
| use_discrete_token: Optional[bool] |
| vision_eol_token: Optional[str] |
| vision_eof_token: Optional[str] |
|
|
|
|
| class HyperCLOVAXSeedImageProcessor(BaseImageProcessorFast): |
| """Fast image processor for HyperCLOVAX-SEED. |
| |
| Uses torchvision-based resize for dynamic resolution processing: |
| 1. Smart resize: adjusts image size to be within min_pixels and max_pixels. |
| 2. Vision token calculation: uses merge_size for token reduction. |
| 3. Discrete image processing: separate processing for discrete vision tokens. |
| """ |
|
|
| |
| resample = PILResampling.BICUBIC |
| image_mean = _OPENAI_CLIP_MEAN |
| image_std = _OPENAI_CLIP_STD |
| do_resize = True |
| do_rescale = True |
| do_normalize = True |
| do_convert_rgb = True |
| size = {"shortest_edge": 3136, "longest_edge": 2073600} |
| default_to_square = False |
| min_pixels = 3136 |
| max_pixels = 2073600 |
| patch_size = 14 |
| temporal_patch_size = 2 |
| merge_size = 2 |
| image_token = "<|IMAGE_PAD|>" |
| image_start_token = "<|image_start|>" |
| image_end_token = "<|image_end|>" |
| discrete_image_size = 384 |
| discrete_token_size = 27 |
| discrete_image_ratios = [] |
| discrete_image_token = "<|DISCRETE_IMAGE_PAD|>" |
| discrete_image_start_token = "<|discrete_image_start|>" |
| discrete_image_end_token = "<|discrete_image_end|>" |
| use_discrete_token = False |
| vision_eol_token = "<|vision_eol|>" |
| vision_eof_token = "<|vision_eof|>" |
| model_input_names = ["pixel_values"] |
| valid_kwargs = HyperCLOVAXSeedFastImageProcessorKwargs |
|
|
| def __init__(self, **kwargs): |
| |
| size = kwargs.pop("size", None) |
| min_pixels = kwargs.pop("min_pixels", None) |
| max_pixels = kwargs.pop("max_pixels", None) |
|
|
| size = {**self.size} if size is None else size |
| if min_pixels is not None: |
| size["shortest_edge"] = min_pixels |
| size.pop("min_pixels", None) |
| if max_pixels is not None: |
| size["longest_edge"] = max_pixels |
| size.pop("max_pixels", None) |
| if "shortest_edge" not in size or "longest_edge" not in size: |
| raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") |
|
|
| |
| if kwargs.get("discrete_image_ratios") is None: |
| kwargs["discrete_image_ratios"] = [] |
|
|
| super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) |
|
|
| |
| if self.min_pixels is None: |
| self.min_pixels = self.size["shortest_edge"] |
| if self.max_pixels is None: |
| self.max_pixels = self.size["longest_edge"] |
|
|
| |
| ratios = self.discrete_image_ratios if self.discrete_image_ratios is not None else [] |
| self.discrete_image_ratio_tokens = { |
| f"{r[0]}:{r[1]}": f"<|vision_ratio_{r[0]}:{r[1]}|>" |
| for r in ratios |
| } |
|
|
| def _further_process_kwargs( |
| self, |
| size: Optional[SizeDict] = None, |
| min_pixels: Optional[int] = None, |
| max_pixels: Optional[int] = None, |
| **kwargs, |
| ) -> dict: |
| """Synchronize size <-> min_pixels/max_pixels.""" |
| if min_pixels is not None and max_pixels is not None: |
| size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} |
| elif size is not None: |
| if "shortest_edge" not in size or "longest_edge" not in size: |
| raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") |
| min_pixels = size["shortest_edge"] |
| max_pixels = size["longest_edge"] |
| else: |
| size = {**self.size} |
|
|
| return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) |
|
|
| def _find_best_ratio_token( |
| self, |
| original_size: List[int], |
| discrete_image_ratios: Optional[List[List[int]]] = None, |
| ) -> List[int]: |
| """Find the best ratio token based on the original image aspect ratio. |
| |
| Args: |
| original_size: Original [height, width] of the image. |
| discrete_image_ratios: List of [h, w] ratio pairs. Defaults to self.discrete_image_ratios. |
| |
| Returns: |
| Best matching [h_ratio, w_ratio] list element from discrete_image_ratios. |
| """ |
| discrete_image_ratios = discrete_image_ratios if discrete_image_ratios is not None else self.discrete_image_ratios |
|
|
| if not discrete_image_ratios: |
| return (1, 1) |
|
|
| h, w = original_size |
| if h == 0 or w == 0: |
| return (1, 1) |
|
|
| ratios = [i / j for i, j in discrete_image_ratios] |
| diffs = [abs(w / h - r) for r in ratios] |
| best_size_idx = diffs.index(min(diffs)) |
|
|
| return discrete_image_ratios[best_size_idx] |
|
|
| def _preprocess_continuous_image( |
| self, |
| images: List[PIL.Image.Image], |
| do_resize: bool, |
| size: SizeDict, |
| interpolation: F.InterpolationMode, |
| do_rescale: bool, |
| rescale_factor: float, |
| do_normalize: bool, |
| image_mean: Optional[Union[float, list]], |
| image_std: Optional[Union[float, list]], |
| patch_size: int, |
| temporal_patch_size: int, |
| merge_size: int, |
| disable_grouping: Optional[bool], |
| ) -> dict: |
| """Preprocess images for continuous vision features. |
| |
| Performs smart resize -> rescale+normalize -> patchify. |
| Uses torchvision/torch directly to avoid transformers version dependencies. |
| |
| Args: |
| images: List of image tensors to preprocess. |
| do_resize: Whether to perform resizing. |
| size: SizeDict containing min_pixels/max_pixels. |
| interpolation: torchvision InterpolationMode for resize. |
| do_rescale: Whether to perform rescaling. |
| rescale_factor: Rescale factor (e.g. 1/255). |
| do_normalize: Whether to perform normalization. |
| image_mean: Normalization mean (float or per-channel list). |
| image_std: Normalization std (float or per-channel list). |
| patch_size: ViT patch size. |
| temporal_patch_size: Temporal patch size. |
| merge_size: Token reduction merge size. |
| disable_grouping: Whether to disable image grouping. |
| |
| Returns: |
| Dictionary with: |
| - "pixel_values": Tensor of shape (N, num_patches, patch_dim). |
| - "image_grid_thw": Tensor of shape (N, 3). |
| - "num_image_tokens": Tensor of shape (N,) with per-image token counts. |
| """ |
| |
| grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) |
| resized_images_grouped = {} |
| for shape, stacked_images in grouped_images.items(): |
| height, width = stacked_images.shape[-2:] |
| if do_resize: |
| resized_height, resized_width = smart_resize( |
| height, width, |
| factor=patch_size * merge_size, |
| min_pixels=size["shortest_edge"], |
| max_pixels=size["longest_edge"], |
| ) |
| |
| |
| stacked_images = F.resize( |
| stacked_images, |
| [resized_height, resized_width], |
| interpolation=interpolation, |
| antialias=True, |
| ) |
| resized_images_grouped[shape] = stacked_images |
| resized_images = reorder_images(resized_images_grouped, grouped_images_index) |
|
|
| |
| grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) |
| processed_images_grouped = {} |
| processed_grids = {} |
| for shape, stacked_images in grouped_images.items(): |
| resized_height, resized_width = stacked_images.shape[-2:] |
|
|
| patches = stacked_images.to(torch.float32) |
| if do_rescale: |
| patches = patches * rescale_factor |
| if do_normalize: |
| mean = torch.tensor( |
| image_mean if isinstance(image_mean, (list, tuple)) else [image_mean], |
| dtype=torch.float32, device=patches.device, |
| ).view(1, -1, 1, 1) |
| std = torch.tensor( |
| image_std if isinstance(image_std, (list, tuple)) else [image_std], |
| dtype=torch.float32, device=patches.device, |
| ).view(1, -1, 1, 1) |
| patches = (patches - mean) / std |
|
|
| |
| if patches.ndim == 4: |
| patches = patches.unsqueeze(1) |
|
|
| |
| if patches.shape[1] % temporal_patch_size != 0: |
| repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) |
| patches = torch.cat([patches, repeats], dim=1) |
|
|
| batch_size, grid_t, channel = patches.shape[:3] |
| grid_t = grid_t // temporal_patch_size |
| grid_h, grid_w = resized_height // patch_size, resized_width // patch_size |
|
|
| |
| patches = patches.view( |
| batch_size, |
| grid_t, temporal_patch_size, |
| channel, |
| grid_h // merge_size, merge_size, patch_size, |
| grid_w // merge_size, merge_size, patch_size, |
| ) |
| patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) |
| flatten_patches = patches.reshape( |
| batch_size, |
| grid_t * grid_h * grid_w, |
| channel * temporal_patch_size * patch_size * patch_size, |
| ) |
|
|
| processed_images_grouped[shape] = flatten_patches |
| processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size |
|
|
| processed_images = reorder_images(processed_images_grouped, grouped_images_index) |
| processed_grids = reorder_images(processed_grids, grouped_images_index) |
| pixel_values = torch.cat(processed_images, dim=0) |
| image_grid_thw = torch.tensor(processed_grids) |
|
|
| num_image_tokens = image_grid_thw.prod(dim=1) // (merge_size ** 2) |
|
|
| return { |
| "pixel_values": pixel_values, |
| "image_grid_thw": image_grid_thw, |
| "num_image_tokens": num_image_tokens, |
| } |
|
|
| def _preprocess_discrete_image( |
| self, |
| images: List[PIL.Image.Image], |
| original_sizes: List[Tuple[int, int]], |
| interpolation: Optional[F.InterpolationMode], |
| ) -> dict: |
| """Preprocess images for discrete vision tokens. |
| |
| Resizes each image to a fixed size (discrete_image_size) and finds |
| the closest aspect ratio token. |
| |
| Args: |
| images: List of image tensors to preprocess. |
| original_sizes: List of (height, width) tuples for each image. |
| interpolation: Interpolation method. |
| |
| Returns: |
| Dictionary with: |
| - "discrete_pixel_values": Tensor of shape (N, C, discrete_image_size, discrete_image_size). |
| - "discrete_image_ratios": Tensor of shape (N, 2). |
| - "num_discrete_image_tokens": Tensor of shape (N,) with per-image discrete token counts. |
| """ |
| discrete_pixel_values_list = [] |
| discrete_image_ratios_list = [] |
|
|
| for i, img in enumerate(images): |
| orig_h, orig_w = original_sizes[i] |
| best_ratio = self._find_best_ratio_token([orig_h, orig_w]) |
|
|
| |
| discrete_img = F.resize( |
| img.unsqueeze(0), |
| [self.discrete_image_size, self.discrete_image_size], |
| interpolation=interpolation, |
| antialias=True, |
| ) |
| discrete_img = discrete_img.squeeze(0) |
|
|
| |
| discrete_img = discrete_img.to(torch.float32) / 255.0 |
|
|
| discrete_pixel_values_list.append(discrete_img) |
| discrete_image_ratios_list.append(best_ratio) |
|
|
| n = len(images) |
| discrete_token_size = self.discrete_token_size |
| |
| num_discrete_per_image = 1 + discrete_token_size * (discrete_token_size + 1) + 1 |
|
|
| return { |
| "discrete_pixel_values": torch.stack(discrete_pixel_values_list), |
| "discrete_image_ratios": torch.tensor(discrete_image_ratios_list), |
| "num_discrete_image_tokens": torch.full((n,), num_discrete_per_image, dtype=torch.long), |
| } |
|
|
| def _preprocess( |
| self, |
| images: List[PIL.Image.Image], |
| **kwargs, |
| ) -> BatchFeature: |
| """Main preprocessing entry point called by BaseImageProcessorFast. |
| |
| Accepts all parameters via **kwargs to handle API differences across |
| transformers versions (e.g. 'interpolation' in v4.57.x vs 'resample' in newer). |
| |
| Returns: |
| BatchFeature containing pixel_values, image_grid_thw, and optionally |
| discrete processing results. |
| """ |
| do_resize = kwargs.get("do_resize", self.do_resize) |
| size = kwargs.get("size", self.size) |
| do_rescale = kwargs.get("do_rescale", self.do_rescale) |
| rescale_factor = kwargs.get("rescale_factor", self.rescale_factor) |
| do_normalize = kwargs.get("do_normalize", self.do_normalize) |
| image_mean = kwargs.get("image_mean", self.image_mean) |
| image_std = kwargs.get("image_std", self.image_std) |
| patch_size = kwargs.get("patch_size", self.patch_size) |
| temporal_patch_size = kwargs.get("temporal_patch_size", self.temporal_patch_size) |
| merge_size = kwargs.get("merge_size", self.merge_size) |
| disable_grouping = kwargs.get("disable_grouping", None) |
| return_tensors = kwargs.get("return_tensors", None) |
|
|
| |
| |
| resample = kwargs.get("resample", self.resample) |
| interpolation = kwargs.get("interpolation") |
| if interpolation is None: |
| if resample is not None and isinstance(resample, int): |
| _pil_to_torch = { |
| 0: F.InterpolationMode.NEAREST, |
| 1: F.InterpolationMode.LANCZOS, |
| 2: F.InterpolationMode.BILINEAR, |
| 3: F.InterpolationMode.BICUBIC, |
| 4: F.InterpolationMode.BOX, |
| 5: F.InterpolationMode.HAMMING, |
| } |
| interpolation = _pil_to_torch.get(int(resample), F.InterpolationMode.BICUBIC) |
| elif resample is not None: |
| interpolation = resample |
| else: |
| interpolation = F.InterpolationMode.BICUBIC |
|
|
| |
| if self.use_discrete_token: |
| original_sizes = [(img.shape[-2], img.shape[-1]) for img in images] |
|
|
| |
| continuous_result = self._preprocess_continuous_image( |
| images, |
| do_resize=do_resize, |
| size=size, |
| interpolation=interpolation, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| patch_size=patch_size, |
| temporal_patch_size=temporal_patch_size, |
| merge_size=merge_size, |
| disable_grouping=disable_grouping, |
| ) |
| data = continuous_result |
|
|
| |
| if self.use_discrete_token: |
| discrete_result = self._preprocess_discrete_image( |
| images, |
| original_sizes=original_sizes, |
| interpolation=interpolation, |
| ) |
| data.update(discrete_result) |
|
|
| return BatchFeature(data=data, tensor_type=return_tensors) |
|
|
| def get_num_image_tokens( |
| self, |
| image_width: Optional[int] = None, |
| image_height: Optional[int] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| include_boundary_tokens: bool = False, |
| min_pixels: Optional[int] = None, |
| max_pixels: Optional[int] = None, |
| patch_size: Optional[int] = None, |
| merge_size: Optional[int] = None, |
| return_tuple: Optional[bool] = None, |
| ) -> Union[int, Tuple[int, int]]: |
| """Compute the number of image tokens for the given input. |
| |
| Args: |
| image_width: Image width (used when pixel_values is None). |
| image_height: Image height (used when pixel_values is None). |
| pixel_values: Pre-computed pixel values tensor. |
| include_boundary_tokens: Whether to include start/end boundary tokens. |
| min_pixels: Minimum pixel count. Defaults to self.min_pixels. |
| max_pixels: Maximum pixel count. Defaults to self.max_pixels. |
| patch_size: ViT patch size. Defaults to self.patch_size. |
| merge_size: Token reduction merge size. Defaults to self.merge_size. |
| return_tuple: If True, return (continuous, discrete) tuple. |
| Otherwise return the sum. |
| |
| Returns: |
| Token count as int, or (continuous, discrete) tuple if return_tuple is True. |
| """ |
| patch_size = patch_size if patch_size is not None else self.patch_size |
| merge_size = merge_size if merge_size is not None else self.merge_size |
| min_pixels = min_pixels if min_pixels is not None else self.min_pixels |
| max_pixels = max_pixels if max_pixels is not None else self.max_pixels |
|
|
| num_continuous_tokens, num_discrete_tokens = 0, 0 |
| if pixel_values is None: |
| factor = patch_size * merge_size |
| resized_height, resized_width = smart_resize( |
| image_height, image_width, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels |
| ) |
| grid_h = resized_height // patch_size |
| grid_w = resized_width // patch_size |
| num_continuous_tokens = (grid_h // merge_size) * (grid_w // merge_size) |
| elif len(pixel_values.shape) == 2: |
| num_continuous_tokens = pixel_values.shape[0] // (merge_size ** 2) |
| else: |
| num_continuous_tokens = sum([ |
| _pixel_values.shape[0] // (merge_size ** 2) |
| for _pixel_values in pixel_values |
| ]) |
| if include_boundary_tokens: |
| num_continuous_tokens += 2 |
|
|
| if self.use_discrete_token: |
| discrete_token_size = self.discrete_token_size |
| num_discrete_tokens = discrete_token_size ** 2 |
| if include_boundary_tokens: |
| num_discrete_tokens += 2 |
|
|
| if return_tuple: |
| return (num_continuous_tokens, num_discrete_tokens) |
| else: |
| return num_continuous_tokens + num_discrete_tokens |
|
|
| def save_pretrained( |
| self, |
| save_directory: Union[str, os.PathLike], |
| *args, |
| **kwargs, |
| ) -> None: |
| """Save the processor to a directory. |
| |
| Registers for auto class before saving. |
| |
| Args: |
| save_directory: Directory path to save the processor. |
| """ |
| self.register_for_auto_class() |
| super().save_pretrained(save_directory, *args, **kwargs) |
|
|