# FILE: app.py (Updated for Interactive SAM GUI)
import gradio as gr
from PIL import Image
import numpy as np
import torch
import uuid
import os
import cairosvg
from segment_anything import sam_model_registry, SamPredictor
from skimage.measure import find_contours
import cv2
# Load SAM model
sam_checkpoint = "sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cpu"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
predictor = SamPredictor(sam)
# Convert mask to SVG path
def mask_to_svg_paths(mask):
contours = find_contours(mask, 0.5)
paths = []
for contour in contours:
points = " ".join([f"{x:.1f},{y:.1f}" for y, x in contour])
path = f''
paths.append(path)
return paths
# Rasterize uploaded SVG files
def rasterize_svg(path):
png_path = f"/tmp/{uuid.uuid4()}.png"
cairosvg.svg2png(url=path, write_to=png_path, dpi=300)
return png_path
# Segment image using user-drawn mask
def segment_with_sketch(image, sketch):
if sketch is None:
return image, None
image_np = np.array(image)
sketch_mask = np.array(sketch.convert("L")) > 128
yx = np.argwhere(sketch_mask)
if len(yx) == 0:
return image, None
input_points = np.array([[x, y] for y, x in yx[::max(1, len(yx)//10)]]) # Sample max 10 points
input_labels = np.ones(len(input_points))
predictor.set_image(image_np)
masks, scores, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels,
multimask_output=False
)
mask = masks[0]
vis = image_np.copy()
vis[mask == False] = (vis[mask == False] * 0.2).astype(np.uint8)
return Image.fromarray(vis), mask.tolist()
# Export SVG from mask
def export_svg(mask, width, height):
if mask is None:
return None
paths = mask_to_svg_paths(np.array(mask))
svg = [f'")
svg_code = "\n".join(svg)
path = f"/tmp/{uuid.uuid4()}.svg"
with open(path, "w") as f:
f.write(svg_code)
return path
# Load and optionally rasterize input image
def process_upload(file):
ext = os.path.splitext(file.name)[1].lower()
if ext == ".svg":
image_path = rasterize_svg(file.name)
else:
image_path = file.name
image = Image.open(image_path).convert("RGB")
return image
with gr.Blocks() as demo:
gr.Markdown("""
# 🧠Interactive SAM Vector Tool
Upload an image, sketch over regions you want segmented, preview result, and export as clean SVG.
""")
with gr.Row():
uploaded_file = gr.File(label="Upload JPG, PNG or SVG")
image_display = gr.Image(type="pil", label="Original Image")
sketch_input = gr.Image(label="Draw on Regions to Segment", tool="sketch", type="pil")
segmentation_output = gr.Image(label="Mask Preview")
mask_state = gr.State()
export_btn = gr.Button("Export SVG")
download_link = gr.File(label="Download SVG")
uploaded_file.change(fn=process_upload, inputs=uploaded_file, outputs=[image_display, sketch_input])
sketch_input.change(fn=segment_with_sketch, inputs=[image_display, sketch_input], outputs=[segmentation_output, mask_state])
export_btn.click(fn=lambda m, img: export_svg(m, img.width, img.height), inputs=[mask_state, image_display], outputs=download_link)
gr.Markdown("Built with Segment Anything + Gradio")
demo.launch()