| | from typing import Dict, List, Any |
| | from PIL import Image |
| | from transformers import Pix2StructProcessor, Pix2StructForConditionalGeneration |
| | import torch |
| | import base64 |
| | import io |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | """Called when the endpoint starts. Load model and processor.""" |
| | self.processor = Pix2StructProcessor.from_pretrained(path) |
| | self.model = Pix2StructForConditionalGeneration.from_pretrained(path) |
| | |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model.to(self.device) |
| | self.model.eval() |
| | |
| | |
| | self.default_header = "Generate underlying data table of the figure below:" |
| | |
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Called on every request. |
| | |
| | Args: |
| | data: Dictionary containing: |
| | - inputs: base64 encoded image string |
| | - parameters (optional): dict with: |
| | - header: text prompt for the model (default: DePlot prompt) |
| | - max_new_tokens: max generation length (default: 512) |
| | |
| | Returns: |
| | List containing the generated table text |
| | """ |
| | inputs = data.get("inputs") |
| | parameters = data.get("parameters", {}) |
| | |
| | |
| | header_text = ( |
| | parameters.get("header") or |
| | parameters.get("text") or |
| | parameters.get("prompt") or |
| | data.get("header") or |
| | data.get("text") or |
| | data.get("prompt") or |
| | self.default_header |
| | ) |
| | |
| | |
| | if isinstance(inputs, str): |
| | try: |
| | image_bytes = base64.b64decode(inputs) |
| | image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| | except Exception as e: |
| | raise ValueError(f"Failed to decode base64 image: {e}") |
| | else: |
| | raise ValueError("Expected base64 encoded image string in 'inputs'") |
| | |
| | |
| | model_inputs = self.processor( |
| | images=image, |
| | text=header_text, |
| | return_tensors="pt" |
| | ).to(self.device) |
| | |
| | |
| | max_new_tokens = parameters.get("max_new_tokens", 512) |
| | |
| | |
| | with torch.no_grad(): |
| | predictions = self.model.generate( |
| | **model_inputs, |
| | max_new_tokens=max_new_tokens |
| | ) |
| | |
| | |
| | output_text = self.processor.decode( |
| | predictions[0], |
| | skip_special_tokens=True |
| | ) |
| | |
| | return [{"generated_text": output_text}] |