| import gradio as gr |
| import torch |
| from PIL import Image |
| import numpy as np |
| from engine import SegmentAnythingModel, StableDiffusionInpaintingPipeline |
| from utils import show_anns, create_image_grid |
| import matplotlib.pyplot as plt |
| import PIL |
| import requests |
| import matplotlib |
| matplotlib.use('Agg') |
|
|
| |
| if not torch.cuda.is_available(): |
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Segment Anything + Stable Diffusion Inpainting") |
| |
| |
| gr.Markdown("**CUDA is not available.** Please run it on Google Colab. You can find the Colab here: [Colab Link](https://github.com/SanshruthR/Stable-Diffusion-Inpainting_with_SAM)") |
|
|
| |
| with gr.Tab("Step 1: Segment Image"): |
| with gr.Row(): |
| input_image = gr.Image(label="Input Image", interactive=False) |
| mask_output = gr.Plot(label="Available Masks") |
| segment_btn = gr.Button("Generate Masks", interactive=False) |
| |
| |
| with gr.Tab("Step 2: Inpaint"): |
| with gr.Row(): |
| with gr.Column(): |
| mask_index = gr.Slider(minimum=0, maximum=20, step=1, |
| label="Mask Index (select based on mask numbers from Step 1)", interactive=False) |
| prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt", interactive=False) |
| prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt", interactive=False) |
| prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt", interactive=False) |
| prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt", interactive=False) |
| inpaint_output = gr.Plot(label="Inpainting Results") |
| inpaint_btn = gr.Button("Generate Inpainting", interactive=False) |
|
|
| demo.launch(share=True, debug=True) |
| exit() |
|
|
| |
| url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
| response = requests.get(url) |
|
|
| with open("sam_vit_h_4b8939.pth", "wb") as file: |
| file.write(response.content) |
|
|
| |
| sam_checkpoint = "sam_vit_h_4b8939.pth" |
| model_type = "vit_h" |
| device = "cuda" |
| sam_model = SegmentAnythingModel(sam_checkpoint, model_type, device) |
|
|
| model_dir = "stabilityai/stable-diffusion-2-inpainting" |
| sd_pipeline = StableDiffusionInpaintingPipeline(model_dir) |
|
|
| |
| current_masks = None |
| current_image = None |
|
|
| def segment_image(image): |
| global current_masks, current_image |
| current_image = image |
| |
| |
| image_array = np.array(image) |
| |
| |
| current_masks = sam_model.generate_masks(image_array) |
| |
| |
| fig = plt.figure(figsize=(10, 10)) |
| ax = fig.add_subplot(1, 1, 1) |
| |
| |
| ax.imshow(sam_model.preprocess_image(image)) |
| |
| |
| show_anns(current_masks, ax) |
| |
| ax.axis('off') |
| plt.tight_layout() |
| |
| return fig |
|
|
| def inpaint_image(mask_index, prompt1, prompt2, prompt3, prompt4): |
| global current_masks, current_image |
| |
| if current_masks is None or current_image is None: |
| return None |
| |
| |
| segmentation_mask = current_masks[mask_index]['segmentation'] |
| stable_diffusion_mask = PIL.Image.fromarray((segmentation_mask * 255).astype(np.uint8)) |
|
|
| |
| prompts = [p for p in [prompt1, prompt2, prompt3, prompt4] if p.strip()] |
| generator = torch.Generator(device="cuda").manual_seed(42) |
| |
| encoded_images = [] |
| for prompt in prompts: |
| img = sd_pipeline.inpaint( |
| prompt=prompt, |
| image=Image.fromarray(np.array(current_image)), |
| mask_image=stable_diffusion_mask, |
| guidance_scale=7.5, |
| num_inference_steps=50, |
| generator=generator |
| ) |
| encoded_images.append(img) |
|
|
| |
| result_grid = create_image_grid(Image.fromarray(np.array(current_image)), |
| encoded_images, |
| prompts, |
| 2, 3) |
| |
| return result_grid |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("# Segment Anything + Stable Diffusion Inpainting") |
| |
| with gr.Tab("Step 1: Segment Image"): |
| with gr.Row(): |
| input_image = gr.Image(label="Input Image") |
| mask_output = gr.Plot(label="Available Masks") |
| segment_btn = gr.Button("Generate Masks") |
| segment_btn.click(fn=segment_image, inputs=[input_image], outputs=[mask_output]) |
| |
| with gr.Tab("Step 2: Inpaint"): |
| with gr.Row(): |
| with gr.Column(): |
| mask_index = gr.Slider(minimum=0, maximum=20, step=1, |
| label="Mask Index (select based on mask numbers from Step 1)") |
| prompt1 = gr.Textbox(label="Prompt 1", placeholder="Enter first inpainting prompt") |
| prompt2 = gr.Textbox(label="Prompt 2", placeholder="Enter second inpainting prompt") |
| prompt3 = gr.Textbox(label="Prompt 3", placeholder="Enter third inpainting prompt") |
| prompt4 = gr.Textbox(label="Prompt 4", placeholder="Enter fourth inpainting prompt") |
| inpaint_output = gr.Plot(label="Inpainting Results") |
| inpaint_btn = gr.Button("Generate Inpainting") |
| inpaint_btn.click(fn=inpaint_image, |
| inputs=[mask_index, prompt1, prompt2, prompt3, prompt4], |
| outputs=[inpaint_output]) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True, debug=True, ssr_mode=False) |
|
|