| from PIL import Image
|
| from torch import Tensor, stack
|
| from typing import Union, List
|
|
|
| from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
| from timm import create_model
|
| from timm.data import resolve_data_config
|
| from timm.data.transforms_factory import create_transform
|
|
|
| class EfficientNetImageProcessor(BaseImageProcessor):
|
| model_input_names = ["pixel_values"]
|
|
|
| def __init__(
|
| self,
|
| model_name: str,
|
| **kwargs,
|
| ):
|
| self.model_name = model_name
|
| self.config = resolve_data_config({}, model=create_model(model_name, pretrained=False))
|
| super().__init__(**kwargs)
|
|
|
| def preprocess(
|
| self,
|
| images: Union[List[Union[Image.Image, Tensor]], Image.Image, Tensor],
|
| ) -> BatchFeature:
|
| """
|
| Preprocesses input images by applying transformations and returning them as a BatchFeature.
|
|
|
| Parameters
|
| ----------
|
| images : Union[List[PIL.Image.Image, torch.Tensor], PIL.Image.Image, torch.Tensor]
|
| A single image or a list of images in one of the accepted formats.
|
|
|
| Returns
|
| -------
|
| BatchFeature
|
| A batch of transformed images
|
| """
|
| images = [images] if not isinstance(images, list) else images
|
|
|
|
|
| if len(images) == 0:
|
| raise ValueError("Received an empty list of images")
|
|
|
|
|
| test_image = images[0]
|
| if not isinstance(images[0], (Image.Image, Tensor)):
|
| raise TypeError(
|
| f"Expected image to be of type PIL.Image.Image, torch.Tensor, or numpy.ndarray, "
|
| f"but got {type(test_image).__name__} instead."
|
| )
|
|
|
|
|
| transforms = create_transform(**self.config)
|
| transformed_images = [transforms(image) for image in images]
|
|
|
|
|
| transformed_image_tensors = stack(transformed_images)
|
|
|
| data = {'pixel_values': transformed_image_tensors}
|
| return BatchFeature(data=data)
|
|
|
| __all__ = [
|
| "EfficientNetImageProcessor"
|
| ] |