Zero-Shot Image Classification
Transformers
Safetensors
siglip
vision
basiliskan commited on
Commit
36b465e
·
verified ·
1 Parent(s): 5b57b6c

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +118 -0
handler.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+ import torch
3
+ from PIL import Image
4
+ import requests
5
+ from io import BytesIO
6
+ import base64
7
+ from transformers import AutoProcessor, AutoModel
8
+
9
+
10
+ class EndpointHandler:
11
+ def __init__(self, path: str = ""):
12
+ """
13
+ Initialize the handler by loading the SigLIP2 model and processor.
14
+
15
+ Args:
16
+ path: Path to the model directory (provided by HF Inference Endpoints)
17
+ """
18
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True).to(self.device)
20
+ self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
21
+ self.model.eval()
22
+
23
+ def _load_image(self, image_data: Any) -> Image.Image:
24
+ """
25
+ Load an image from various input formats.
26
+
27
+ Args:
28
+ image_data: Can be a URL string, base64 string, or raw bytes
29
+
30
+ Returns:
31
+ PIL Image object
32
+ """
33
+ if isinstance(image_data, str):
34
+ # Check if it's a URL
35
+ if image_data.startswith(("http://", "https://")):
36
+ response = requests.get(image_data, timeout=10)
37
+ response.raise_for_status()
38
+ return Image.open(BytesIO(response.content)).convert("RGB")
39
+ # Otherwise assume base64
40
+ else:
41
+ # Handle data URI format
42
+ if "," in image_data:
43
+ image_data = image_data.split(",")[1]
44
+ image_bytes = base64.b64decode(image_data)
45
+ return Image.open(BytesIO(image_bytes)).convert("RGB")
46
+ elif isinstance(image_data, bytes):
47
+ return Image.open(BytesIO(image_data)).convert("RGB")
48
+ else:
49
+ raise ValueError(f"Unsupported image format: {type(image_data)}")
50
+
51
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
52
+ """
53
+ Process inference requests for zero-shot image classification.
54
+
55
+ Args:
56
+ data: Dictionary containing:
57
+ - "inputs": Image data (URL, base64, or bytes)
58
+ - "parameters": Optional dict with:
59
+ - "candidate_labels": List of text labels to classify against
60
+
61
+ Returns:
62
+ List of dictionaries with "label" and "score" for each candidate
63
+ """
64
+ # Extract inputs
65
+ inputs = data.get("inputs")
66
+ parameters = data.get("parameters", {})
67
+
68
+ # Get candidate labels (required for zero-shot classification)
69
+ candidate_labels = parameters.get("candidate_labels", [])
70
+
71
+ if not candidate_labels:
72
+ # Default labels if none provided
73
+ candidate_labels = ["a photo", "an illustration", "a diagram"]
74
+
75
+ # Ensure candidate_labels is a list
76
+ if isinstance(candidate_labels, str):
77
+ candidate_labels = [label.strip() for label in candidate_labels.split(",")]
78
+
79
+ # Load the image
80
+ image = self._load_image(inputs)
81
+
82
+ # Process inputs
83
+ processed_inputs = self.processor(
84
+ text=candidate_labels,
85
+ images=image,
86
+ padding="max_length",
87
+ return_tensors="pt"
88
+ ).to(self.device)
89
+
90
+ # Run inference
91
+ with torch.no_grad():
92
+ outputs = self.model(**processed_inputs)
93
+
94
+ # Get image and text embeddings
95
+ image_embeds = outputs.image_embeds
96
+ text_embeds = outputs.text_embeds
97
+
98
+ # Normalize embeddings
99
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
100
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
101
+
102
+ # Compute similarity scores
103
+ logits_per_image = torch.matmul(image_embeds, text_embeds.t())
104
+
105
+ # Apply softmax to get probabilities
106
+ probs = torch.softmax(logits_per_image, dim=-1)
107
+
108
+ # Format results
109
+ scores = probs[0].cpu().tolist()
110
+ results = [
111
+ {"label": label, "score": score}
112
+ for label, score in zip(candidate_labels, scores)
113
+ ]
114
+
115
+ # Sort by score descending
116
+ results.sort(key=lambda x: x["score"], reverse=True)
117
+
118
+ return results