| | """ |
| | Custom Inference Handler for SigLIP2-base-patch16-512 |
| | Supports: zero_shot, image_embedding, text_embedding, similarity |
| | Returns 768D embeddings. |
| | """ |
| | from typing import Any, Dict, List, Union |
| | import torch |
| | from PIL import Image |
| | import requests |
| | from io import BytesIO |
| | import base64 |
| | from transformers import AutoProcessor, AutoModel |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path: str = ""): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device) |
| | self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| | self.model.eval() |
| |
|
| | def _load_image(self, image_data: Any) -> Image.Image: |
| | if isinstance(image_data, str): |
| | if image_data.startswith(("http://", "https://")): |
| | response = requests.get(image_data, timeout=10) |
| | response.raise_for_status() |
| | return Image.open(BytesIO(response.content)).convert("RGB") |
| | else: |
| | if "," in image_data: |
| | image_data = image_data.split(",")[1] |
| | image_bytes = base64.b64decode(image_data) |
| | return Image.open(BytesIO(image_bytes)).convert("RGB") |
| | elif isinstance(image_data, bytes): |
| | return Image.open(BytesIO(image_data)).convert("RGB") |
| | raise ValueError(f"Unsupported image format: {type(image_data)}") |
| |
|
| | def _get_image_embeddings(self, images: List[Image.Image]) -> torch.Tensor: |
| | inputs = self.processor(images=images, return_tensors="pt").to(self.device) |
| | with torch.no_grad(): |
| | features = self.model.get_image_features(**inputs) |
| | return features / features.norm(dim=-1, keepdim=True) |
| |
|
| | def _get_text_embeddings(self, texts: List[str]) -> torch.Tensor: |
| | inputs = self.processor(text=texts, padding="max_length", truncation=True, return_tensors="pt").to(self.device) |
| | with torch.no_grad(): |
| | features = self.model.get_text_features(**inputs) |
| | return features / features.norm(dim=-1, keepdim=True) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Any: |
| | inputs = data.get("inputs", data) |
| | parameters = data.get("parameters", {}) |
| | mode = parameters.get("mode", "auto") |
| |
|
| | |
| | if mode == "auto": |
| | if isinstance(inputs, dict) and ("image" in inputs or "images" in inputs): |
| | mode = "similarity" |
| | elif "candidate_labels" in parameters: |
| | mode = "zero_shot" |
| | elif isinstance(inputs, str) and not inputs.startswith(("http", "data:")) and len(inputs) < 500: |
| | mode = "text_embedding" |
| | elif isinstance(inputs, list) and all( |
| | isinstance(i, str) and not i.startswith(("http", "data:")) and len(i) < 500 for i in inputs |
| | ): |
| | mode = "text_embedding" |
| | else: |
| | mode = "image_embedding" |
| |
|
| | if mode == "zero_shot": |
| | return self._zero_shot(inputs, parameters) |
| | elif mode == "image_embedding": |
| | return self._image_embedding(inputs) |
| | elif mode == "text_embedding": |
| | return self._text_embedding(inputs) |
| | elif mode == "similarity": |
| | return self._similarity(inputs) |
| | else: |
| | raise ValueError(f"Unknown mode: {mode}") |
| |
|
| | def _zero_shot(self, inputs, parameters): |
| | candidate_labels = parameters.get("candidate_labels", ["photo", "illustration", "diagram"]) |
| | if isinstance(candidate_labels, str): |
| | candidate_labels = [l.strip() for l in candidate_labels.split(",")] |
| |
|
| | images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
| | image_embeds = self._get_image_embeddings(images) |
| | text_embeds = self._get_text_embeddings(candidate_labels) |
| |
|
| | logits = image_embeds @ text_embeds.T |
| | probs = torch.softmax(logits, dim=-1) |
| |
|
| | results = [] |
| | for i, prob in enumerate(probs): |
| | scores = prob.cpu().tolist() |
| | result = [{"label": l, "score": s} for l, s in sorted(zip(candidate_labels, scores), key=lambda x: -x[1])] |
| | results.append(result) |
| |
|
| | return results[0] if len(results) == 1 else results |
| |
|
| | def _image_embedding(self, inputs): |
| | images = [self._load_image(inputs)] if not isinstance(inputs, list) else [self._load_image(i) for i in inputs] |
| | embeddings = self._get_image_embeddings(images) |
| | return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
| |
|
| | def _text_embedding(self, inputs): |
| | texts = [inputs] if isinstance(inputs, str) else inputs |
| | embeddings = self._get_text_embeddings(texts) |
| | return [{"embedding": emb.cpu().tolist()} for emb in embeddings] |
| |
|
| | def _similarity(self, inputs): |
| | image_input = inputs.get("image") or inputs.get("images") |
| | text_input = inputs.get("text") or inputs.get("texts") |
| |
|
| | images = [self._load_image(image_input)] if not isinstance(image_input, list) else [self._load_image(i) for i in image_input] |
| | texts = [text_input] if isinstance(text_input, str) else text_input |
| |
|
| | image_embeds = self._get_image_embeddings(images) |
| | text_embeds = self._get_text_embeddings(texts) |
| |
|
| | similarity = (image_embeds @ text_embeds.T).cpu().tolist() |
| | return {"similarity_scores": similarity, "image_count": len(images), "text_count": len(texts)} |