| | import torch |
| | from typing import Dict, Any, List |
| | from PIL import Image |
| | import base64 |
| | from io import BytesIO |
| | import logging |
| | from transformers import AutoImageProcessor, AutoModel |
| | import os |
| | from dataclasses import dataclass |
| |
|
| |
|
| | |
| | @dataclass |
| | class ImageEncodingResult: |
| | image_encoded: List[List[float]] |
| | image_encoded_average: List[float] |
| |
|
| |
|
| | class EndpointHandler: |
| | """ |
| | A handler class for processing images and generating embeddings using a pre-trained model. |
| | Attributes: |
| | processor: The pre-trained image processor. |
| | model: The pre-trained model for generating embeddings. |
| | device: The device (CPU or CUDA) used to run model inference. |
| | """ |
| |
|
| | def __init__(self, path: str = ""): |
| | """ |
| | Initializes the EndpointHandler with the model and processor from the current directory. |
| | """ |
| | |
| | logging.basicConfig(level=logging.INFO) |
| | self.logger = logging.getLogger(__name__) |
| |
|
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.logger.info(f"Using device: {self.device}") |
| |
|
| | |
| | self.logger.info("Loading model and processor from the current directory.") |
| | try: |
| | self.processor = AutoImageProcessor.from_pretrained(path) |
| | self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to( |
| | self.device |
| | ) |
| | self.logger.info("Model and processor loaded successfully.") |
| | except Exception as e: |
| | self.logger.error(f"Failed to load model or processor: {e}") |
| | raise |
| |
|
| | def _resize_image_if_large( |
| | self, image: Image.Image, max_size: int = 1080 |
| | ) -> Image.Image: |
| | """ |
| | Resizes an image if its dimensions exceed the specified maximum size. |
| | Args: |
| | image (Image.Image): Input image. |
| | max_size (int): Maximum size for the image dimensions. |
| | Returns: |
| | Image.Image: Resized image. |
| | """ |
| | width, height = image.size |
| | if width > max_size or height > max_size: |
| | scale = max_size / max(width, height) |
| | new_width = int(width * scale) |
| | new_height = int(height * scale) |
| | image = image.resize((new_width, new_height), resample=Image.BILINEAR) |
| | return image |
| |
|
| | def _encode_image(self, image: Image.Image) -> ImageEncodingResult: |
| | """ |
| | Encodes an image into embeddings using the model. |
| | Args: |
| | image (Image.Image): Input image. |
| | Returns: |
| | ImageEncodingResult: Dataclass containing the encoded embeddings and their average. |
| | """ |
| | try: |
| | |
| | image = self._resize_image_if_large(image) |
| |
|
| | |
| | inputs = self.processor(image, return_tensors="pt").to(self.device) |
| | with torch.inference_mode(): |
| | outputs = self.model(**inputs) |
| | last_hidden_state = outputs.last_hidden_state |
| | image_encoded = last_hidden_state.squeeze().tolist() |
| | image_encoded_average = last_hidden_state.mean(dim=1).squeeze().tolist() |
| |
|
| | return ImageEncodingResult( |
| | image_encoded=image_encoded, |
| | image_encoded_average=image_encoded_average, |
| | ) |
| | except Exception as e: |
| | self.logger.error(f"Error encoding image: {e}") |
| | raise |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Processes input data containing base64-encoded images and generates embeddings. |
| | Args: |
| | data (Dict[str, Any]): Dictionary containing input images. |
| | Returns: |
| | Dict[str, Any]: Dictionary containing encoded embeddings or error messages. |
| | """ |
| | images_data = data.get("inputs", []) |
| |
|
| | if not images_data: |
| | return {"error": "No image data provided."} |
| |
|
| | results = [] |
| | for img_data in images_data: |
| | if isinstance(img_data, str): |
| | try: |
| | |
| | image_bytes = base64.b64decode(img_data) |
| | image = Image.open(BytesIO(image_bytes)).convert("RGB") |
| |
|
| | |
| | encoded_image = self._encode_image(image) |
| | results.append(encoded_image) |
| | except Exception as e: |
| | self.logger.error(f"Invalid image data: {e}") |
| | return {"error": f"Invalid image data: {e}"} |
| | else: |
| | self.logger.error("Images should be base64-encoded strings.") |
| | return {"error": "Images should be base64-encoded strings."} |
| |
|
| | |
| | return {"results": [result.__dict__ for result in results]} |
| |
|