| |
|
|
| import argparse |
| import os |
| import random |
| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import cv2 |
| import numpy as np |
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoModelForMaskGeneration, AutoProcessor, pipeline |
|
|
|
|
| def create_palette(): |
| |
| palette = [ |
| 0, |
| 0, |
| 0, |
| 255, |
| 0, |
| 0, |
| 0, |
| 255, |
| 0, |
| 0, |
| 0, |
| 255, |
| 255, |
| 255, |
| 0, |
| 255, |
| 0, |
| 255, |
| 0, |
| 255, |
| 255, |
| 128, |
| 0, |
| 0, |
| 0, |
| 128, |
| 0, |
| 0, |
| 0, |
| 128, |
| 128, |
| 128, |
| 0, |
| 128, |
| 0, |
| 128, |
| 0, |
| 128, |
| 128, |
| 64, |
| 0, |
| 0, |
| 0, |
| 64, |
| 0, |
| 0, |
| 0, |
| 64, |
| 64, |
| 64, |
| 0, |
| 64, |
| 0, |
| 64, |
| 0, |
| 64, |
| 64, |
| 192, |
| 192, |
| 192, |
| 128, |
| 128, |
| 128, |
| 255, |
| 165, |
| 0, |
| 75, |
| 0, |
| 130, |
| 238, |
| 130, |
| 238, |
| ] |
| |
| palette.extend([0] * (768 - len(palette))) |
| return palette |
|
|
|
|
| PALETTE = create_palette() |
|
|
|
|
| |
| @dataclass |
| class BoundingBox: |
| xmin: int |
| ymin: int |
| xmax: int |
| ymax: int |
|
|
| @property |
| def xyxy(self) -> List[float]: |
| return [self.xmin, self.ymin, self.xmax, self.ymax] |
|
|
|
|
| @dataclass |
| class DetectionResult: |
| score: Optional[float] = None |
| label: Optional[str] = None |
| box: Optional[BoundingBox] = None |
| mask: Optional[np.array] = None |
|
|
| @classmethod |
| def from_dict(cls, detection_dict: Dict) -> "DetectionResult": |
| return cls( |
| score=detection_dict["score"], |
| label=detection_dict["label"], |
| box=BoundingBox( |
| xmin=detection_dict["box"]["xmin"], |
| ymin=detection_dict["box"]["ymin"], |
| xmax=detection_dict["box"]["xmax"], |
| ymax=detection_dict["box"]["ymax"], |
| ), |
| ) |
|
|
|
|
| |
| def mask_to_polygon(mask: np.ndarray) -> List[List[int]]: |
| |
| contours, _ = cv2.findContours( |
| mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE |
| ) |
|
|
| |
| largest_contour = max(contours, key=cv2.contourArea) |
|
|
| |
| polygon = largest_contour.reshape(-1, 2).tolist() |
|
|
| return polygon |
|
|
|
|
| def polygon_to_mask( |
| polygon: List[Tuple[int, int]], image_shape: Tuple[int, int] |
| ) -> np.ndarray: |
| """ |
| Convert a polygon to a segmentation mask. |
| |
| Args: |
| - polygon (list): List of (x, y) coordinates representing the vertices of the polygon. |
| - image_shape (tuple): Shape of the image (height, width) for the mask. |
| |
| Returns: |
| - np.ndarray: Segmentation mask with the polygon filled. |
| """ |
| |
| mask = np.zeros(image_shape, dtype=np.uint8) |
|
|
| |
| pts = np.array(polygon, dtype=np.int32) |
|
|
| |
| cv2.fillPoly(mask, [pts], color=(255,)) |
|
|
| return mask |
|
|
|
|
| def load_image(image_str: str) -> Image.Image: |
| if image_str.startswith("http"): |
| image = Image.open(requests.get(image_str, stream=True).raw).convert("RGB") |
| else: |
| image = Image.open(image_str).convert("RGB") |
|
|
| return image |
|
|
|
|
| def get_boxes(results: DetectionResult) -> List[List[List[float]]]: |
| boxes = [] |
| for result in results: |
| xyxy = result.box.xyxy |
| boxes.append(xyxy) |
|
|
| return [boxes] |
|
|
|
|
| def refine_masks( |
| masks: torch.BoolTensor, polygon_refinement: bool = False |
| ) -> List[np.ndarray]: |
| masks = masks.cpu().float() |
| masks = masks.permute(0, 2, 3, 1) |
| masks = masks.mean(axis=-1) |
| masks = (masks > 0).int() |
| masks = masks.numpy().astype(np.uint8) |
| masks = list(masks) |
|
|
| if polygon_refinement: |
| for idx, mask in enumerate(masks): |
| shape = mask.shape |
| polygon = mask_to_polygon(mask) |
| mask = polygon_to_mask(polygon, shape) |
| masks[idx] = mask |
|
|
| return masks |
|
|
|
|
| |
| def generate_colored_segmentation(label_image): |
| |
| label_image_pil = Image.fromarray(label_image.astype(np.uint8), mode="P") |
|
|
| |
| palette = create_palette() |
| label_image_pil.putpalette(palette) |
|
|
| return label_image_pil |
|
|
|
|
| def plot_segmentation(image, detections): |
| seg_map = np.zeros(image.size[::-1], dtype=np.uint8) |
| for i, detection in enumerate(detections): |
| mask = detection.mask |
| seg_map[mask > 0] = i + 1 |
| seg_map_pil = generate_colored_segmentation(seg_map) |
| return seg_map_pil |
|
|
|
|
| |
| def prepare_model( |
| device: str = "cuda", |
| detector_id: Optional[str] = None, |
| segmenter_id: Optional[str] = None, |
| ): |
| detector_id = ( |
| detector_id if detector_id is not None else "IDEA-Research/grounding-dino-tiny" |
| ) |
| object_detector = pipeline( |
| model=detector_id, task="zero-shot-object-detection", device=device |
| ) |
|
|
| segmenter_id = segmenter_id if segmenter_id is not None else "facebook/sam-vit-base" |
| processor = AutoProcessor.from_pretrained(segmenter_id) |
| segmentator = AutoModelForMaskGeneration.from_pretrained(segmenter_id).to(device) |
|
|
| return object_detector, processor, segmentator |
|
|
|
|
| def detect( |
| object_detector: Any, |
| image: Image.Image, |
| labels: List[str], |
| threshold: float = 0.3, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Use Grounding DINO to detect a set of labels in an image in a zero-shot fashion. |
| """ |
| labels = [label if label.endswith(".") else label + "." for label in labels] |
|
|
| results = object_detector(image, candidate_labels=labels, threshold=threshold) |
| results = [DetectionResult.from_dict(result) for result in results] |
|
|
| return results |
|
|
|
|
| def segment( |
| processor: Any, |
| segmentator: Any, |
| image: Image.Image, |
| boxes: Optional[List[List[List[float]]]] = None, |
| detection_results: Optional[List[Dict[str, Any]]] = None, |
| polygon_refinement: bool = False, |
| ) -> List[DetectionResult]: |
| """ |
| Use Segment Anything (SAM) to generate masks given an image + a set of bounding boxes. |
| """ |
| if detection_results is None and boxes is None: |
| raise ValueError( |
| "Either detection_results or detection_boxes must be provided." |
| ) |
|
|
| if boxes is None: |
| boxes = get_boxes(detection_results) |
|
|
| inputs = processor(images=image, input_boxes=boxes, return_tensors="pt").to( |
| segmentator.device, segmentator.dtype |
| ) |
|
|
| outputs = segmentator(**inputs) |
| masks = processor.post_process_masks( |
| masks=outputs.pred_masks, |
| original_sizes=inputs.original_sizes, |
| reshaped_input_sizes=inputs.reshaped_input_sizes, |
| )[0] |
|
|
| masks = refine_masks(masks, polygon_refinement) |
|
|
| if detection_results is None: |
| detection_results = [DetectionResult() for _ in masks] |
|
|
| for detection_result, mask in zip(detection_results, masks): |
| detection_result.mask = mask |
|
|
| return detection_results |
|
|
|
|
| def grounded_segmentation( |
| object_detector, |
| processor, |
| segmentator, |
| image: Union[Image.Image, str], |
| labels: Union[str, List[str]], |
| threshold: float = 0.3, |
| polygon_refinement: bool = False, |
| ) -> Tuple[np.ndarray, List[DetectionResult], Image.Image]: |
| if isinstance(image, str): |
| image = load_image(image) |
| if isinstance(labels, str): |
| labels = labels.split(",") |
|
|
| detections = detect(object_detector, image, labels, threshold) |
| detections = segment(processor, segmentator, image, detections, polygon_refinement) |
|
|
| seg_map_pil = plot_segmentation(image, detections) |
|
|
| return np.array(image), detections, seg_map_pil |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--image", type=str, required=True) |
| parser.add_argument("--labels", type=str, nargs="+", required=True) |
| parser.add_argument("--output", type=str, default="./", help="Output directory") |
| parser.add_argument("--threshold", type=float, default=0.3) |
| parser.add_argument( |
| "--detector_id", type=str, default="IDEA-Research/grounding-dino-base" |
| ) |
| parser.add_argument("--segmenter_id", type=str, default="facebook/sam-vit-base") |
| args = parser.parse_args() |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| object_detector, processor, segmentator = prepare_model( |
| device=device, detector_id=args.detector_id, segmenter_id=args.segmenter_id |
| ) |
|
|
| image_array, detections, seg_map_pil = grounded_segmentation( |
| object_detector, |
| processor, |
| segmentator, |
| image=args.image, |
| labels=args.labels, |
| threshold=args.threshold, |
| polygon_refinement=True, |
| ) |
|
|
| os.makedirs(args.output, exist_ok=True) |
| seg_map_pil.save(os.path.join(args.output, "segmentation.png")) |
|
|