# coding=utf-8 # Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ HyperCLOVAX-SEED Image Processor (Fast) Implements dynamic resolution image processing: - Smart resize: adjusts image to fit within min_pixels and max_pixels - Vision token calculation: token reduction using merge_size - Discrete image processing: separate processing for discrete vision tokens Based on BaseImageProcessorFast with torchvision resize. """ import math import os import PIL from typing import List, Optional, Tuple, TypeAlias, Union import torch from torchvision.transforms.v2 import functional as F try: from transformers.image_processing_utils import BatchFeature except ImportError: from transformers import BatchFeature try: from transformers.image_processing_backends import ( BaseImageProcessorFast, group_images_by_shape, reorder_images, ) except ImportError: # transformers < v5.3.0 from transformers.image_processing_utils_fast import ( BaseImageProcessorFast, group_images_by_shape, reorder_images, ) try: from transformers.image_processing_utils_fast import DefaultFastImageProcessorKwargs except ImportError: from transformers.processing_utils import ImagesKwargs as DefaultFastImageProcessorKwargs # transformers < v5.3.0 try: from PIL.Image import Resampling as PILResampling except (ImportError, AttributeError): # Pillow < 9.1.0 class PILResampling: NEAREST = 0 LANCZOS = 1 BILINEAR = 2 BICUBIC = 3 BOX = 4 HAMMING = 5 try: from transformers.image_utils import SizeDict except ImportError: SizeDict = dict # transformers < 4.46 # OpenAI CLIP normalization constants # Source: transformers.image_utils.OPENAI_CLIP_MEAN / OPENAI_CLIP_STD _OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] _OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280, ) -> Tuple[int, int]: """Smart resize for dynamic resolution. Adjusts image dimensions to satisfy: 1. Both dimensions are divisible by factor. 2. Total pixel count is between min_pixels and max_pixels. Adapted from the Qwen2.5-VL image processing implementation. Reference: https://github.com/QwenLM/Qwen2.5-VL (Apache 2.0 License) Args: height: Original image height. width: Original image width. factor: Rounding unit (default: 28 = patch_size * merge_size). min_pixels: Minimum pixel count (default: 3136). max_pixels: Maximum pixel count (default: 1003520). Returns: Tuple of (new_height, new_width). """ if max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = max(factor, math.floor(height / beta / factor) * factor) w_bar = max(factor, math.floor(width / beta / factor) * factor) elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar class HyperCLOVAXSeedFastImageProcessorKwargs(DefaultFastImageProcessorKwargs, total=False): min_pixels: Optional[int] max_pixels: Optional[int] patch_size: Optional[int] temporal_patch_size: Optional[int] merge_size: Optional[int] # Token parameters image_token: Optional[str] image_start_token: Optional[str] image_end_token: Optional[str] # Discrete image parameters discrete_image_size: Optional[int] discrete_token_size: Optional[int] discrete_image_ratios: Optional[List] discrete_image_token: Optional[str] discrete_image_start_token: Optional[str] discrete_image_end_token: Optional[str] use_discrete_token: Optional[bool] vision_eol_token: Optional[str] vision_eof_token: Optional[str] class HyperCLOVAXSeedImageProcessor(BaseImageProcessorFast): """Fast image processor for HyperCLOVAX-SEED. Uses torchvision-based resize for dynamic resolution processing: 1. Smart resize: adjusts image size to be within min_pixels and max_pixels. 2. Vision token calculation: uses merge_size for token reduction. 3. Discrete image processing: separate processing for discrete vision tokens. """ # Class-level defaults resample = PILResampling.BICUBIC image_mean = _OPENAI_CLIP_MEAN image_std = _OPENAI_CLIP_STD do_resize = True do_rescale = True do_normalize = True do_convert_rgb = True size = {"shortest_edge": 3136, "longest_edge": 2073600} default_to_square = False min_pixels = 3136 max_pixels = 2073600 patch_size = 14 temporal_patch_size = 2 merge_size = 2 image_token = "<|IMAGE_PAD|>" image_start_token = "<|image_start|>" image_end_token = "<|image_end|>" discrete_image_size = 384 discrete_token_size = 27 discrete_image_ratios = [] discrete_image_token = "<|DISCRETE_IMAGE_PAD|>" discrete_image_start_token = "<|discrete_image_start|>" discrete_image_end_token = "<|discrete_image_end|>" use_discrete_token = False vision_eol_token = "<|vision_eol|>" vision_eof_token = "<|vision_eof|>" model_input_names = ["pixel_values"] valid_kwargs = HyperCLOVAXSeedFastImageProcessorKwargs def __init__(self, **kwargs): # Handle size <-> min_pixels/max_pixels size = kwargs.pop("size", None) min_pixels = kwargs.pop("min_pixels", None) max_pixels = kwargs.pop("max_pixels", None) size = {**self.size} if size is None else size if min_pixels is not None: size["shortest_edge"] = min_pixels size.pop("min_pixels", None) if max_pixels is not None: size["longest_edge"] = max_pixels size.pop("max_pixels", None) if "shortest_edge" not in size or "longest_edge" not in size: raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") # Normalize discrete_image_ratios: None → [] if kwargs.get("discrete_image_ratios") is None: kwargs["discrete_image_ratios"] = [] super().__init__(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) # Ensure min_pixels/max_pixels are always set from size if self.min_pixels is None: self.min_pixels = self.size["shortest_edge"] if self.max_pixels is None: self.max_pixels = self.size["longest_edge"] # Build ratio -> token mapping from discrete_image_ratios ratios = self.discrete_image_ratios if self.discrete_image_ratios is not None else [] self.discrete_image_ratio_tokens = { f"{r[0]}:{r[1]}": f"<|vision_ratio_{r[0]}:{r[1]}|>" for r in ratios } def _further_process_kwargs( self, size: Optional[SizeDict] = None, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, **kwargs, ) -> dict: """Synchronize size <-> min_pixels/max_pixels.""" if min_pixels is not None and max_pixels is not None: size = {"shortest_edge": min_pixels, "longest_edge": max_pixels} elif size is not None: if "shortest_edge" not in size or "longest_edge" not in size: raise ValueError("size must contain 'shortest_edge' and 'longest_edge' keys.") min_pixels = size["shortest_edge"] max_pixels = size["longest_edge"] else: size = {**self.size} return super()._further_process_kwargs(size=size, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs) def _find_best_ratio_token( self, original_size: List[int], discrete_image_ratios: Optional[List[List[int]]] = None, ) -> List[int]: """Find the best ratio token based on the original image aspect ratio. Args: original_size: Original [height, width] of the image. discrete_image_ratios: List of [h, w] ratio pairs. Defaults to self.discrete_image_ratios. Returns: Best matching [h_ratio, w_ratio] list element from discrete_image_ratios. """ discrete_image_ratios = discrete_image_ratios if discrete_image_ratios is not None else self.discrete_image_ratios if not discrete_image_ratios: return (1, 1) h, w = original_size if h == 0 or w == 0: return (1, 1) ratios = [i / j for i, j in discrete_image_ratios] diffs = [abs(w / h - r) for r in ratios] best_size_idx = diffs.index(min(diffs)) return discrete_image_ratios[best_size_idx] def _preprocess_continuous_image( self, images: List[PIL.Image.Image], do_resize: bool, size: SizeDict, interpolation: F.InterpolationMode, do_rescale: bool, rescale_factor: float, do_normalize: bool, image_mean: Optional[Union[float, list]], image_std: Optional[Union[float, list]], patch_size: int, temporal_patch_size: int, merge_size: int, disable_grouping: Optional[bool], ) -> dict: """Preprocess images for continuous vision features. Performs smart resize -> rescale+normalize -> patchify. Uses torchvision/torch directly to avoid transformers version dependencies. Args: images: List of image tensors to preprocess. do_resize: Whether to perform resizing. size: SizeDict containing min_pixels/max_pixels. interpolation: torchvision InterpolationMode for resize. do_rescale: Whether to perform rescaling. rescale_factor: Rescale factor (e.g. 1/255). do_normalize: Whether to perform normalization. image_mean: Normalization mean (float or per-channel list). image_std: Normalization std (float or per-channel list). patch_size: ViT patch size. temporal_patch_size: Temporal patch size. merge_size: Token reduction merge size. disable_grouping: Whether to disable image grouping. Returns: Dictionary with: - "pixel_values": Tensor of shape (N, num_patches, patch_dim). - "image_grid_thw": Tensor of shape (N, 3). - "num_image_tokens": Tensor of shape (N,) with per-image token counts. """ # 1. Group & smart resize grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) resized_images_grouped = {} for shape, stacked_images in grouped_images.items(): height, width = stacked_images.shape[-2:] if do_resize: resized_height, resized_width = smart_resize( height, width, factor=patch_size * merge_size, min_pixels=size["shortest_edge"], max_pixels=size["longest_edge"], ) # Use torchvision directly — avoids transformers version-specific # self.resize() signature differences (resample vs interpolation). stacked_images = F.resize( stacked_images, [resized_height, resized_width], interpolation=interpolation, antialias=True, ) resized_images_grouped[shape] = stacked_images resized_images = reorder_images(resized_images_grouped, grouped_images_index) # 2. Group again -> rescale+normalize -> patchify grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) processed_images_grouped = {} processed_grids = {} for shape, stacked_images in grouped_images.items(): resized_height, resized_width = stacked_images.shape[-2:] patches = stacked_images.to(torch.float32) if do_rescale: patches = patches * rescale_factor if do_normalize: mean = torch.tensor( image_mean if isinstance(image_mean, (list, tuple)) else [image_mean], dtype=torch.float32, device=patches.device, ).view(1, -1, 1, 1) std = torch.tensor( image_std if isinstance(image_std, (list, tuple)) else [image_std], dtype=torch.float32, device=patches.device, ).view(1, -1, 1, 1) patches = (patches - mean) / std # Add temporal dimension for images (ndim == 4 means no temporal dim) if patches.ndim == 4: patches = patches.unsqueeze(1) # Pad temporal dimension to be divisible by temporal_patch_size if patches.shape[1] % temporal_patch_size != 0: repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1) patches = torch.cat([patches, repeats], dim=1) batch_size, grid_t, channel = patches.shape[:3] grid_t = grid_t // temporal_patch_size grid_h, grid_w = resized_height // patch_size, resized_width // patch_size # Patchify: reshape -> permute -> flatten patches = patches.view( batch_size, grid_t, temporal_patch_size, channel, grid_h // merge_size, merge_size, patch_size, grid_w // merge_size, merge_size, patch_size, ) patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) flatten_patches = patches.reshape( batch_size, grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size, ) processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_grids = reorder_images(processed_grids, grouped_images_index) pixel_values = torch.cat(processed_images, dim=0) image_grid_thw = torch.tensor(processed_grids) num_image_tokens = image_grid_thw.prod(dim=1) // (merge_size ** 2) return { "pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "num_image_tokens": num_image_tokens, } def _preprocess_discrete_image( self, images: List[PIL.Image.Image], original_sizes: List[Tuple[int, int]], interpolation: Optional[F.InterpolationMode], ) -> dict: """Preprocess images for discrete vision tokens. Resizes each image to a fixed size (discrete_image_size) and finds the closest aspect ratio token. Args: images: List of image tensors to preprocess. original_sizes: List of (height, width) tuples for each image. interpolation: Interpolation method. Returns: Dictionary with: - "discrete_pixel_values": Tensor of shape (N, C, discrete_image_size, discrete_image_size). - "discrete_image_ratios": Tensor of shape (N, 2). - "num_discrete_image_tokens": Tensor of shape (N,) with per-image discrete token counts. """ discrete_pixel_values_list = [] discrete_image_ratios_list = [] for i, img in enumerate(images): orig_h, orig_w = original_sizes[i] best_ratio = self._find_best_ratio_token([orig_h, orig_w]) # Resize to fixed discrete_image_size x discrete_image_size (torchvision) discrete_img = F.resize( img.unsqueeze(0), [self.discrete_image_size, self.discrete_image_size], interpolation=interpolation, antialias=True, ) discrete_img = discrete_img.squeeze(0) # Match torchvision to_tensor: float32 / 255.0 (no normalize) discrete_img = discrete_img.to(torch.float32) / 255.0 discrete_pixel_values_list.append(discrete_img) discrete_image_ratios_list.append(best_ratio) n = len(images) discrete_token_size = self.discrete_token_size # ratio_token(1) + discrete_token_size rows * (discrete_token_size tokens + vision_eol(1)) + vision_eof(1) num_discrete_per_image = 1 + discrete_token_size * (discrete_token_size + 1) + 1 return { "discrete_pixel_values": torch.stack(discrete_pixel_values_list), "discrete_image_ratios": torch.tensor(discrete_image_ratios_list), "num_discrete_image_tokens": torch.full((n,), num_discrete_per_image, dtype=torch.long), } def _preprocess( self, images: List[PIL.Image.Image], **kwargs, ) -> BatchFeature: """Main preprocessing entry point called by BaseImageProcessorFast. Accepts all parameters via **kwargs to handle API differences across transformers versions (e.g. 'interpolation' in v4.57.x vs 'resample' in newer). Returns: BatchFeature containing pixel_values, image_grid_thw, and optionally discrete processing results. """ do_resize = kwargs.get("do_resize", self.do_resize) size = kwargs.get("size", self.size) do_rescale = kwargs.get("do_rescale", self.do_rescale) rescale_factor = kwargs.get("rescale_factor", self.rescale_factor) do_normalize = kwargs.get("do_normalize", self.do_normalize) image_mean = kwargs.get("image_mean", self.image_mean) image_std = kwargs.get("image_std", self.image_std) patch_size = kwargs.get("patch_size", self.patch_size) temporal_patch_size = kwargs.get("temporal_patch_size", self.temporal_patch_size) merge_size = kwargs.get("merge_size", self.merge_size) disable_grouping = kwargs.get("disable_grouping", None) return_tensors = kwargs.get("return_tensors", None) # 1. Resolve interpolation: BaseImageProcessorFast passes "interpolation" (v4.57.x) # or "resample" (v5.3.x); normalize to a single InterpolationMode. resample = kwargs.get("resample", self.resample) interpolation = kwargs.get("interpolation") if interpolation is None: if resample is not None and isinstance(resample, int): _pil_to_torch = { 0: F.InterpolationMode.NEAREST, 1: F.InterpolationMode.LANCZOS, 2: F.InterpolationMode.BILINEAR, 3: F.InterpolationMode.BICUBIC, 4: F.InterpolationMode.BOX, 5: F.InterpolationMode.HAMMING, } interpolation = _pil_to_torch.get(int(resample), F.InterpolationMode.BICUBIC) elif resample is not None: interpolation = resample # already an InterpolationMode else: interpolation = F.InterpolationMode.BICUBIC # 2. Record original sizes before any transforms (needed for discrete processing) if self.use_discrete_token: original_sizes = [(img.shape[-2], img.shape[-1]) for img in images] # 3. Continuous processing continuous_result = self._preprocess_continuous_image( images, do_resize=do_resize, size=size, interpolation=interpolation, do_rescale=do_rescale, rescale_factor=rescale_factor, do_normalize=do_normalize, image_mean=image_mean, image_std=image_std, patch_size=patch_size, temporal_patch_size=temporal_patch_size, merge_size=merge_size, disable_grouping=disable_grouping, ) data = continuous_result # 4. Discrete processing if self.use_discrete_token: discrete_result = self._preprocess_discrete_image( images, original_sizes=original_sizes, interpolation=interpolation, ) data.update(discrete_result) return BatchFeature(data=data, tensor_type=return_tensors) def get_num_image_tokens( self, image_width: Optional[int] = None, image_height: Optional[int] = None, pixel_values: Optional[torch.Tensor] = None, include_boundary_tokens: bool = False, min_pixels: Optional[int] = None, max_pixels: Optional[int] = None, patch_size: Optional[int] = None, merge_size: Optional[int] = None, return_tuple: Optional[bool] = None, ) -> Union[int, Tuple[int, int]]: """Compute the number of image tokens for the given input. Args: image_width: Image width (used when pixel_values is None). image_height: Image height (used when pixel_values is None). pixel_values: Pre-computed pixel values tensor. include_boundary_tokens: Whether to include start/end boundary tokens. min_pixels: Minimum pixel count. Defaults to self.min_pixels. max_pixels: Maximum pixel count. Defaults to self.max_pixels. patch_size: ViT patch size. Defaults to self.patch_size. merge_size: Token reduction merge size. Defaults to self.merge_size. return_tuple: If True, return (continuous, discrete) tuple. Otherwise return the sum. Returns: Token count as int, or (continuous, discrete) tuple if return_tuple is True. """ patch_size = patch_size if patch_size is not None else self.patch_size merge_size = merge_size if merge_size is not None else self.merge_size min_pixels = min_pixels if min_pixels is not None else self.min_pixels max_pixels = max_pixels if max_pixels is not None else self.max_pixels num_continuous_tokens, num_discrete_tokens = 0, 0 if pixel_values is None: factor = patch_size * merge_size resized_height, resized_width = smart_resize( image_height, image_width, factor=factor, min_pixels=min_pixels, max_pixels=max_pixels ) grid_h = resized_height // patch_size grid_w = resized_width // patch_size num_continuous_tokens = (grid_h // merge_size) * (grid_w // merge_size) elif len(pixel_values.shape) == 2: num_continuous_tokens = pixel_values.shape[0] // (merge_size ** 2) else: num_continuous_tokens = sum([ _pixel_values.shape[0] // (merge_size ** 2) for _pixel_values in pixel_values ]) if include_boundary_tokens: num_continuous_tokens += 2 if self.use_discrete_token: discrete_token_size = self.discrete_token_size num_discrete_tokens = discrete_token_size ** 2 if include_boundary_tokens: num_discrete_tokens += 2 if return_tuple: return (num_continuous_tokens, num_discrete_tokens) else: return num_continuous_tokens + num_discrete_tokens def save_pretrained( self, save_directory: Union[str, os.PathLike], *args, **kwargs, ) -> None: """Save the processor to a directory. Registers for auto class before saving. Args: save_directory: Directory path to save the processor. """ self.register_for_auto_class() super().save_pretrained(save_directory, *args, **kwargs)