| | from typing import Dict, List, Any |
| | import onnxruntime as ort |
| | import numpy as np |
| | from PIL import Image |
| | import io |
| | import base64 |
| | import os |
| |
|
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | model_path = path if path else "." |
| | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| |
|
| | self.encoder = ort.InferenceSession( |
| | os.path.join(model_path, "edge_sam_3x_encoder.onnx"), |
| | providers=providers |
| | ) |
| | self.decoder = ort.InferenceSession( |
| | os.path.join(model_path, "edge_sam_3x_decoder.onnx"), |
| | providers=providers |
| | ) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | try: |
| | |
| | inputs = data.get("inputs", data) |
| | params = data.get("parameters", {}) |
| |
|
| | |
| | if isinstance(inputs, str): |
| | image = Image.open(io.BytesIO(base64.b64decode(inputs))) |
| | else: |
| | image = inputs |
| |
|
| | |
| | if image.mode != 'RGB': |
| | image = image.convert('RGB') |
| | image = image.resize((1024, 1024), Image.BILINEAR) |
| | img_array = np.array(image).astype(np.float32) / 255.0 |
| | img_array = img_array.transpose(2, 0, 1)[np.newaxis, :] |
| |
|
| | |
| | embeddings = self.encoder.run(None, {'image': img_array})[0] |
| |
|
| | |
| | coords = np.array(params.get("point_coords", [[512, 512]]), dtype=np.float32) |
| | labels = np.array(params.get("point_labels", [1]), dtype=np.float32) |
| |
|
| | |
| | decoder_outputs = self.decoder.run(None, { |
| | 'image_embeddings': embeddings, |
| | 'point_coords': coords.reshape(1, -1, 2), |
| | 'point_labels': labels.reshape(1, -1) |
| | }) |
| |
|
| | |
| | |
| | masks = decoder_outputs[1] |
| |
|
| | |
| | mask = masks[0, 0] |
| | mask = Image.fromarray(mask).resize((1024, 1024), Image.BILINEAR) |
| | mask = np.array(mask) |
| | mask = (mask > 0.0).astype(np.uint8) * 255 |
| |
|
| | |
| | result = {"mask_shape": list(mask.shape), "has_object": bool(mask.max() > 0)} |
| |
|
| | if params.get("return_mask_image", True): |
| | buffer = io.BytesIO() |
| | Image.fromarray(mask, mode='L').save(buffer, format='PNG') |
| | result["mask"] = base64.b64encode(buffer.getvalue()).decode() |
| |
|
| | return [result] |
| |
|
| | except Exception as e: |
| | import traceback |
| | return [{ |
| | "error": str(e), |
| | "type": type(e).__name__, |
| | "traceback": traceback.format_exc() |
| | }] |
| |
|