File size: 5,421 Bytes
9aff0cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import json
import torch
import numpy as np
import cv2
import os
from segment_anything import sam_model_registry, SamPredictor
from lvis import LVIS
import copy
from pathlib import Path
class Objects365SAM():
def __init__(self, index_low, index_high):
# Load SAM model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam_checkpoint = "checkpoints/sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=self.device)
self.predictor = SamPredictor(sam)
self.index_low = index_low
self.index_high = index_high
# Load annotations
def load_annotations(self, annotation_file):
with open(annotation_file, 'r') as f:
self.json_data = json.load(f)
def process_annotations_with_sam(self, images_dir, output_dir):
image_info_list = self.json_data['images']
counter = 0
for image_info in image_info_list[self.index_low:self.index_high]:
# start_time = time.time()
image_id = image_info['id']
image_name = image_info['file_name'].split('/')[-1]
image_subset = image_info['file_name'].split('/')[-2]
output_json_dir = Path(os.path.join(output_dir, image_subset))
output_json_dir.mkdir(exist_ok=True)
image_path = os.path.join(images_dir, image_subset, image_name)
# Load the image
image = cv2.imread(image_path)
if image is None:
print(f"Image not found: {image_path}")
continue
h, w, _ = image.shape
self.predictor.set_image(image)
# Get annotations for this image
image_annotations = [anno for anno in self.json_data['annotations'] if anno['image_id'] == image_id]
# Create bounding boxes from COCO format
bounding_boxes = []
for anno in image_annotations:
xmin, ymin, width, height = anno['bbox']
xmax, ymax = xmin + width, ymin + height
bounding_boxes.append([xmin, ymin, xmax, ymax])
# Convert bounding boxes to tensor for SAM
input_boxes = torch.tensor(bounding_boxes, device=self.device).float()
transformed_boxes = self.predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
# Get masks from SAM
with torch.no_grad():
masks, scores, logits = self.predictor.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes,
multimask_output=False,
)
# Convert masks to COCO-style annotations
mask_annotations = []
for mask in masks:
binary_mask = mask.squeeze().cpu().numpy().astype(np.uint8)
contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
continue
largest_contour = max(contours, key=cv2.contourArea)
segmentation = largest_contour.flatten().tolist()
x, y, w, h = cv2.boundingRect(largest_contour)
area = float(cv2.contourArea(largest_contour))
# mask_annotations.append(segmentation)
mask_annotations.append({
"segmentation": [segmentation],
"bbox": [x, y, w, h],
"area": area,
"category_id": 1
})
save_annotations_to_json(image_id,
mask_annotations,
os.path.join(output_json_dir, image_name[:-4]+'.json')
)
torch.cuda.empty_cache()
counter += 1
print('Done image idex: ', counter)
def save_annotations_to_json(image_id, mask_annotations, output_file):
coco_format_output = {
"image_id": image_id,
"annotations": mask_annotations
}
with open(output_file, 'w') as f:
json.dump(coco_format_output, f)
if __name__ == "__main__":
'''
Image number: train/test: 1742292/80000
'''
import argparse
parser = argparse.ArgumentParser(description="Annotate labels with Segment Anything")
parser.add_argument('--is_train', action='store_true', help="Train/Test")
parser.add_argument("--index_low", type=int, default=0, help="Lower bound of indexes for processing Objects365 dataset.")
parser.add_argument("--index_high", type=int, default=1742292, help="Upper bound of indexes for processing Objects365 dataset.")
args = parser.parse_args()
if args.is_train:
input_json_dir = '../data/object365/zhiyuan_objv2_train.json'
input_image_dir = '../data/object365/images/train/'
output_dir = Path('../data/object365/labels/train/')
else:
input_json_dir = '../data/object365/zhiyuan_objv2_val.json'
input_image_dir = '../data/object365/images/val/'
output_dir = Path('../data/object365/labels/val/')
output_dir.mkdir(exist_ok=True)
sam_annotator = Objects365SAM(args.index_low, args.index_high)
sam_annotator.load_annotations(input_json_dir)
sam_annotator.process_annotations_with_sam(input_image_dir, output_dir)
|