| from typing import Dict, Any |
| import requests |
| import io |
| import base64 |
| from transformers import CLIPProcessor, CLIPModel |
| from PIL import Image |
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.processor = CLIPProcessor.from_pretrained(path) |
| self.model = CLIPModel.from_pretrained(path) |
|
|
| def __call__(self, data: Dict[str, Any]) -> Dict: |
| print("this shows the custom endpoint handler is being called") |
| inputs = data.pop("inputs", data) |
| text = inputs.pop("text") |
| if "image_url" in inputs: |
| image_url = inputs.pop("image_url") |
| image = Image.open(requests.get(image_url, stream=True).raw) |
| else: |
| image = inputs.pop("image") |
| image = Image.open(io.BytesIO(base64.b64decode(image))) |
| processed_inputs = self.processor(text=text, images=image, |
| return_tensors="pt", padding=True, truncation=True) |
| outputs = self.model(**processed_inputs) |
| embedding_similarity = cosine_similarity(outputs.text_embeds.detach().numpy(), |
| outputs.image_embeds.detach().numpy())[0][0].item() |
| return {"text_embedding": outputs.text_embeds[0].tolist(), |
| "image_embedding": outputs.image_embeds[0].tolist(), |
| "embedding_similarity": embedding_similarity} |
|
|