# FILE: app.py (Updated for Interactive SAM GUI) import gradio as gr from PIL import Image import numpy as np import torch import uuid import os import cairosvg from segment_anything import sam_model_registry, SamPredictor from skimage.measure import find_contours import cv2 # Load SAM model sam_checkpoint = "sam_vit_b_01ec64.pth" model_type = "vit_b" device = "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device) predictor = SamPredictor(sam) # Convert mask to SVG path def mask_to_svg_paths(mask): contours = find_contours(mask, 0.5) paths = [] for contour in contours: points = " ".join([f"{x:.1f},{y:.1f}" for y, x in contour]) path = f'' paths.append(path) return paths # Rasterize uploaded SVG files def rasterize_svg(path): png_path = f"/tmp/{uuid.uuid4()}.png" cairosvg.svg2png(url=path, write_to=png_path, dpi=300) return png_path # Segment image using user-drawn mask def segment_with_sketch(image, sketch): if sketch is None: return image, None image_np = np.array(image) sketch_mask = np.array(sketch.convert("L")) > 128 yx = np.argwhere(sketch_mask) if len(yx) == 0: return image, None input_points = np.array([[x, y] for y, x in yx[::max(1, len(yx)//10)]]) # Sample max 10 points input_labels = np.ones(len(input_points)) predictor.set_image(image_np) masks, scores, _ = predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=False ) mask = masks[0] vis = image_np.copy() vis[mask == False] = (vis[mask == False] * 0.2).astype(np.uint8) return Image.fromarray(vis), mask.tolist() # Export SVG from mask def export_svg(mask, width, height): if mask is None: return None paths = mask_to_svg_paths(np.array(mask)) svg = [f''] svg.extend(paths) svg.append("") svg_code = "\n".join(svg) path = f"/tmp/{uuid.uuid4()}.svg" with open(path, "w") as f: f.write(svg_code) return path # Load and optionally rasterize input image def process_upload(file): ext = os.path.splitext(file.name)[1].lower() if ext == ".svg": image_path = rasterize_svg(file.name) else: image_path = file.name image = Image.open(image_path).convert("RGB") return image with gr.Blocks() as demo: gr.Markdown(""" # 🧠 Interactive SAM Vector Tool Upload an image, sketch over regions you want segmented, preview result, and export as clean SVG. """) with gr.Row(): uploaded_file = gr.File(label="Upload JPG, PNG or SVG") image_display = gr.Image(type="pil", label="Original Image") sketch_input = gr.Image(label="Draw on Regions to Segment", tool="sketch", type="pil") segmentation_output = gr.Image(label="Mask Preview") mask_state = gr.State() export_btn = gr.Button("Export SVG") download_link = gr.File(label="Download SVG") uploaded_file.change(fn=process_upload, inputs=uploaded_file, outputs=[image_display, sketch_input]) sketch_input.change(fn=segment_with_sketch, inputs=[image_display, sketch_input], outputs=[segmentation_output, mask_state]) export_btn.click(fn=lambda m, img: export_svg(m, img.width, img.height), inputs=[mask_state, image_display], outputs=download_link) gr.Markdown("Built with Segment Anything + Gradio") demo.launch()