| """ |
| SAM3 Static Image Segmentation - Correct Implementation |
| |
| Uses Sam3Model (not Sam3VideoModel) for text-prompted static image segmentation. |
| """ |
| import base64 |
| import io |
| import asyncio |
| import torch |
| import numpy as np |
| from PIL import Image |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from transformers import AutoProcessor, AutoModel |
| import logging |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| processor = AutoProcessor.from_pretrained("./model", trust_remote_code=True) |
| model = AutoModel.from_pretrained( |
| "./model", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| trust_remote_code=True |
| ) |
|
|
| model.eval() |
| if torch.cuda.is_available(): |
| model.cuda() |
| logger.info(f"GPU: {torch.cuda.get_device_name()}") |
| logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
|
|
| logger.info(f"✓ Loaded {model.__class__.__name__} for static image segmentation") |
|
|
| |
| class VRAMManager: |
| def __init__(self): |
| self.semaphore = asyncio.Semaphore(2) |
| self.processing_count = 0 |
|
|
| def get_vram_status(self): |
| if not torch.cuda.is_available(): |
| return {} |
| return { |
| "total_gb": torch.cuda.get_device_properties(0).total_memory / 1e9, |
| "allocated_gb": torch.cuda.memory_allocated() / 1e9, |
| "free_gb": (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9, |
| "processing_now": self.processing_count |
| } |
|
|
| async def acquire(self, rid): |
| await self.semaphore.acquire() |
| self.processing_count += 1 |
|
|
| def release(self, rid): |
| self.processing_count -= 1 |
| self.semaphore.release() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| vram_manager = VRAMManager() |
| app = FastAPI(title="SAM3 Static Image API") |
|
|
| class Request(BaseModel): |
| inputs: str |
| parameters: dict |
|
|
|
|
| def run_inference(image_b64: str, classes: list, request_id: str): |
| """ |
| Sam3Model inference for static images with text prompts. |
| |
| Uses official SAM3 processor post-processing for correct mask generation. |
| """ |
| try: |
| |
| image_bytes = base64.b64decode(image_b64) |
| pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| logger.info(f"[{request_id}] Image: {pil_image.size}, Classes: {classes}") |
|
|
| |
| |
| |
| images_batch = [pil_image] * len(classes) |
| inputs = processor( |
| images=images_batch, |
| text=classes, |
| return_tensors="pt" |
| ) |
|
|
| |
| |
| |
| original_size = [pil_image.size[1], pil_image.size[0]] |
| original_sizes = torch.tensor([original_size] * len(classes)) |
| inputs["original_sizes"] = original_sizes |
|
|
| logger.info(f"[{request_id}] Processing {len(classes)} classes with batched images") |
| logger.info(f"[{request_id}] Original size: {pil_image.size} (W x H)") |
|
|
| |
| if torch.cuda.is_available(): |
| model_dtype = next(model.parameters()).dtype |
| inputs = { |
| k: v.cuda().to(model_dtype) if isinstance(v, torch.Tensor) and v.dtype.is_floating_point else v.cuda() if isinstance(v, torch.Tensor) else v |
| for k, v in inputs.items() |
| } |
| logger.info(f"[{request_id}] Moved inputs to GPU (float tensors to {model_dtype})") |
|
|
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logger.info(f"[{request_id}] Forward pass successful!") |
|
|
| logger.info(f"[{request_id}] Output type: {type(outputs)}") |
|
|
| |
| |
| |
| |
| |
| |
| logger.info(f"[{request_id}] Using processor.post_process_instance_segmentation()") |
|
|
| try: |
| processed = processor.post_process_instance_segmentation( |
| outputs, |
| threshold=0.3, |
| mask_threshold=0.5, |
| target_sizes=original_sizes.tolist() |
| ) |
| |
|
|
| logger.info(f"[{request_id}] Post-processing successful!") |
| logger.info(f"[{request_id}] Number of batched results: {len(processed)}") |
|
|
| except Exception as proc_error: |
| logger.error(f"[{request_id}] Post-processing failed: {proc_error}") |
| logger.info(f"[{request_id}] Falling back to manual processing") |
|
|
| |
| results = [] |
|
|
| |
| if hasattr(outputs, 'pred_masks'): |
| pred_masks = outputs.pred_masks |
| elif hasattr(outputs, 'masks'): |
| pred_masks = outputs.masks |
| elif isinstance(outputs, dict) and 'pred_masks' in outputs: |
| pred_masks = outputs['pred_masks'] |
| else: |
| raise ValueError("Cannot find masks in model output") |
|
|
| logger.info(f"[{request_id}] pred_masks shape: {pred_masks.shape}") |
|
|
| for i, cls in enumerate(classes): |
| if i < pred_masks.shape[1]: |
| mask_tensor = pred_masks[0, i] |
|
|
| |
| if mask_tensor.shape[-2:] != pil_image.size[::-1]: |
| mask_tensor = torch.nn.functional.interpolate( |
| mask_tensor.unsqueeze(0).unsqueeze(0), |
| size=pil_image.size[::-1], |
| mode='bilinear', |
| align_corners=False |
| ).squeeze() |
|
|
| |
| probs = torch.sigmoid(mask_tensor) |
| binary_mask = (probs > 0.5).float().cpu().numpy().astype("uint8") * 255 |
| else: |
| binary_mask = np.zeros(pil_image.size[::-1], dtype="uint8") |
|
|
| |
| pil_mask = Image.fromarray(binary_mask, mode="L") |
| buf = io.BytesIO() |
| pil_mask.save(buf, format="PNG") |
| mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
| |
| score = 1.0 |
| if hasattr(outputs, 'pred_logits') and i < outputs.pred_logits.shape[1]: |
| |
| score = float(torch.sigmoid(outputs.pred_logits[0, i]).cpu()) |
|
|
| results.append({ |
| "label": cls, |
| "mask": mask_b64, |
| "score": score |
| }) |
|
|
| logger.info(f"[{request_id}] Completed (fallback): {len(results)} masks generated") |
| return results |
|
|
| |
| |
| |
| results = [] |
|
|
| total_instances = 0 |
| for i, cls in enumerate(classes): |
| class_result = processed[i] |
|
|
| num_instances = len(class_result['masks']) if 'masks' in class_result else 0 |
| total_instances += num_instances |
|
|
| if num_instances > 0: |
| logger.info(f"[{request_id}] {cls}: {num_instances} instance(s) detected") |
|
|
| |
| for j in range(num_instances): |
| |
| mask_np = class_result['masks'][j].cpu().numpy().astype("uint8") * 255 |
|
|
| |
| pil_mask = Image.fromarray(mask_np, mode="L") |
| buf = io.BytesIO() |
| pil_mask.save(buf, format="PNG") |
| mask_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
| |
| score = float(class_result['scores'][j]) if 'scores' in class_result else 1.0 |
|
|
| |
| coverage = (mask_np > 0).sum() / mask_np.size * 100 |
|
|
| results.append({ |
| "label": cls, |
| "mask": mask_b64, |
| "score": score, |
| "instance_id": j |
| }) |
|
|
| logger.info(f"[{request_id}] └─ Instance {j}: score={score:.3f}, coverage={coverage:.2f}%") |
| else: |
| logger.info(f"[{request_id}] {cls}: No instances detected") |
|
|
| logger.info(f"[{request_id}] Completed: {total_instances} instance(s) across {len(classes)} class(es)") |
| return results |
|
|
| except Exception as e: |
| logger.error(f"[{request_id}] Failed: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise |
|
|
|
|
| @app.post("/") |
| async def predict(req: Request): |
| request_id = str(id(req))[:8] |
| try: |
| await vram_manager.acquire(request_id) |
| try: |
| results = await asyncio.to_thread( |
| run_inference, |
| req.inputs, |
| req.parameters.get("classes", []), |
| request_id |
| ) |
| return results |
| finally: |
| vram_manager.release(request_id) |
| except Exception as e: |
| logger.error(f"[{request_id}] Error: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "healthy", |
| "model": model.__class__.__name__, |
| "gpu_available": torch.cuda.is_available(), |
| "vram": vram_manager.get_vram_status() |
| } |
|
|
|
|
| @app.get("/metrics") |
| async def metrics(): |
| return vram_manager.get_vram_status() |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860, workers=1) |
|
|