| import os |
| import numpy as np |
| import streamlit as st |
| from PIL import Image, ImageDraw, ImageFilter |
| import numpy as np |
| import torch |
| from streamlit_js_eval import streamlit_js_eval |
|
|
|
|
|
|
| |
| from streamlit_image_coordinates import streamlit_image_coordinates |
|
|
| |
| from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
| |
| from ultralytics import FastSAM |
|
|
| |
| st.set_page_config(page_title="Inpainting Demo", layout="centered") |
|
|
|
|
| page_width = streamlit_js_eval(js_expressions='window.innerWidth', key='WIDTH', want_output = True,) |
|
|
|
|
| |
| FASTSAM_CHECKPOINT = "FastSAM-x.pt" |
| SD_MODEL_ID = "runwayml/stable-diffusion-inpainting" |
|
|
| |
| def crop_resize_image(image, target_width=480, target_height=640): |
| desired_ratio = target_width / target_height |
| width, height = image.size |
| current_ratio = width / height |
|
|
| |
| if current_ratio > desired_ratio: |
| new_width = int(height * desired_ratio) |
| left = (width - new_width) // 2 |
| right = left + new_width |
| image = image.crop((left, 0, right, height)) |
| |
| elif current_ratio < desired_ratio: |
| new_height = int(width / desired_ratio) |
| top = (height - new_height) // 2 |
| bottom = top + new_height |
| image = image.crop((0, top, width, bottom)) |
| |
| return image.resize((target_width, target_height)) |
|
|
| |
| if not os.path.exists(FASTSAM_CHECKPOINT): |
| |
| |
| import requests |
| fastsam_url = "https://github.com/ultralytics/assets/releases/download/v8.2.0/FastSAM-x.pt" |
| |
| resp = requests.get(fastsam_url) |
| open(FASTSAM_CHECKPOINT, "wb").write(resp.content) |
|
|
| |
| @st.cache_resource |
| def load_models(): |
| |
| fastsam_model = FastSAM(FASTSAM_CHECKPOINT) |
| |
| |
| |
| |
| sd_pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| SD_MODEL_ID, |
| torch_dtype=None |
| ) |
| |
| sd_pipe = sd_pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
| |
| sd_pipe.enable_attention_slicing() |
| return fastsam_model, sd_pipe |
|
|
| |
| fastsam_model, sd_pipe = load_models() |
|
|
| |
| if "is_removing_dot" not in st.session_state: |
| st.session_state.is_removing_dot = False |
|
|
| |
| st.subheader("InteractiveInpainting") |
|
|
|
|
|
|
| |
| |
|
|
|
|
| |
| |
|
|
|
|
|
|
| if "img" not in st.session_state: |
| enable = st.checkbox("Enable camera") |
| picture = st.camera_input("Take a picture", disabled=not enable) |
| if picture is not None: |
| img = Image.open(picture) |
| img = crop_resize_image(img, target_width=480, target_height=640) |
| st.session_state.img = img |
| |
| st.session_state.coords_list = [] |
| st.rerun() |
|
|
| else: |
| img = st.session_state.img |
|
|
| |
| if "coords_list" not in st.session_state: |
| st.session_state.coords_list = [] |
|
|
| |
| |
| if st.session_state.coords_list: |
| points = [[int(pt["x"]), int(pt["y"])] for pt in st.session_state.coords_list] |
| labels = [1] * len(points) |
| results = fastsam_model(img, points=points, labels=labels) |
| |
| masks_tensor = results[0].masks.data |
| masks = masks_tensor.cpu().numpy() |
| if masks.ndim == 3 and masks.shape[0] > 0: |
| |
| combined_mask = np.max(masks, axis=0) |
| combined_mask_img = Image.fromarray((combined_mask * 255).astype(np.uint8)) |
| |
| combined_mask_img = combined_mask_img.resize(img.size, Image.NEAREST) |
| |
| overlay = Image.new("RGBA", img.size, (255, 0, 0, 100)) |
| base = img.convert("RGBA") |
| mask_alpha = combined_mask_img.point(lambda p: 80 if p > 0 else 0) |
| overlay.putalpha(mask_alpha) |
|
|
| seg_overlay = Image.alpha_composite(base, overlay) |
| else: |
| seg_overlay = img.copy() |
| else: |
| seg_overlay = img.copy() |
|
|
| |
| final_img = seg_overlay.copy() |
| draw = ImageDraw.Draw(final_img) |
| for pt in st.session_state.coords_list: |
| cx, cy = int(pt["x"]), int(pt["y"]) |
| draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill="red") |
|
|
| |
| |
| original_width = st.session_state.img.width |
|
|
| |
| scale_factor = original_width / page_width |
| |
| new_coord = streamlit_image_coordinates(final_img, key="click_img", use_column_width="always") |
|
|
| |
| if new_coord: |
| new_coord = { |
| "x": new_coord["x"] * scale_factor, |
| "y": new_coord["y"] * scale_factor |
| } |
|
|
| |
| if new_coord and new_coord not in st.session_state.coords_list and not st.session_state.is_removing_dot: |
| is_close = False |
| for coord in st.session_state.coords_list: |
| existing = np.array([coord["x"], coord["y"]]) |
| new = np.array([new_coord["x"], new_coord["y"]]) |
| if np.linalg.norm(existing - new) < 10: |
| is_close = True |
| break |
| if is_close: |
| st.session_state.coords_list.remove(coord) |
| st.session_state.is_removing_dot = True |
| else: |
| st.session_state.coords_list.append(new_coord) |
| st.rerun() |
| else: |
| st.session_state.is_removing_dot = False |
|
|
| st.write("Stored coordinates:", st.session_state.coords_list) |
|
|
|
|
| |
| |
| prompt = st.text_input("Prompt for inpainting (describe what should replace the selected area):") |
|
|
| |
| if prompt and combined_mask_img is not None: |
|
|
| combined_mask_img = combined_mask_img.convert("L") |
|
|
| |
| dilated_mask = combined_mask_img.filter(ImageFilter.MaxFilter(5)) |
|
|
| |
| blurred_mask = dilated_mask.filter(ImageFilter.GaussianBlur(radius=3)) |
| if st.button("Run Inpainting"): |
| with st.spinner("Inpainting..."): |
| |
| inpainted_img = sd_pipe( |
| prompt=prompt, |
| image=img, |
| mask_image=combined_mask_img, |
| width=img.width, |
| height=img.height, |
| guidance_scale=8, |
| num_inference_steps=50 |
| ).images[0] |
|
|
| |
| st.session_state.img = inpainted_img |
| |
| st.session_state.coords_list = [] |
| st.rerun() |