| from typing import Dict, List, Any |
| from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation |
| from PIL import Image |
| import torch |
| import base64 |
| import io |
| import numpy as np |
|
|
| class EndpointHandler(): |
| def __init__(self, path=""): |
| |
| |
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device) |
| self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| data args: |
| inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| kwargs |
| Return: |
| A :obj:`list` | `dict`: will be serialized and returned |
| """ |
| if "inputs" not in data: |
| return [{"error": "Missing 'inputs' key"}] |
|
|
| inputs_data = data["inputs"] |
| if "image" not in inputs_data or "text" not in inputs_data: |
| return [{"error": "Missing 'image' or 'text' key in input data"}] |
|
|
| try: |
| |
| image = self.decode_image(inputs_data["image"]) |
| prompts = inputs_data["text"] |
| |
| |
| inputs = self.processor( |
| text=prompts, |
| images=[image] * len(prompts), |
| padding="max_length", |
| return_tensors="pt" |
| ).to("cuda") |
|
|
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs) |
| |
| segmentation_mask = outputs.logits.cpu().numpy() |
| segmentation_mask = segmentation_mask.squeeze() |
|
|
| segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) |
| segmentation_mask = (segmentation_mask * 255).astype(np.uint8) |
| |
| seg_image = Image.fromarray(segmentation_mask) |
|
|
| seg_image_base64 = self.encode_image(seg_image) |
|
|
| return [{"seg_image": seg_image_base64}] |
| |
| except Exception as e: |
| return [{"error": str(e)}] |
|
|
| |
| def decode_image(self, image_data: str) -> Image.Image: |
| """Decodes a base64-encoded image into a PIL image.""" |
| image_bytes = base64.b64decode(image_data) |
| return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
| def encode_image(self, image: Image.Image) -> str: |
| """Encodes a PIL image to a base64 string.""" |
| buffered = io.BytesIO() |
| image.save(buffered, format="PNG") |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") |
| |
| def process_depth(self, image): |
| print("Processing depth") |
| print(type(image)) |
| if isinstance(image, np.ndarray): |
| image = Image.fromarray(image.astype("uint8")) |
| output = self.depth_pipe(image) |
| depth_map = np.array(output["depth"]) |
|
|
| |
| depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6) |
| depth_map = (depth_map * 255).astype(np.uint8) |
|
|
| return Image.fromarray(depth_map) |
| |
|
|