backgroundbif / app.py
userhugginggit's picture
Create app.py
5213167 verified
import os
import torch
from PIL import Image
from typing import Union, Tuple
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import gradio as gr
from loadimg import load_img
# =========================================================================
# CONFIGURACIÓN DE DISPOSITIVO (CPU)
# =========================================================================
DEVICE = "cpu"
print(f"--- Cargando BiRefNet en {DEVICE.upper()} ---")
# Cargamos el modelo directamente del Hub de Hugging Face
birefnet = AutoModelForImageSegmentation.from_pretrained(
"merve/BiRefNet",
trust_remote_code=True,
torch_dtype=torch.float32
).to(DEVICE)
birefnet.eval()
print("Modelo cargado correctamente en CPU.")
# Transformaciones necesarias para el modelo
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# =========================================================================
# FUNCIONES DE PROCESAMIENTO
# =========================================================================
def process(image: Image.Image) -> Image.Image:
"""
Aplica BiRefNet para remover el fondo de la imagen usando CPU.
"""
image_size = image.size
# 1. Preparar el tensor para la red
input_tensor = transform_image(image).unsqueeze(0).to(DEVICE)
# 2. Inferencia (Paso por la red neuronal sin almacenar gradientes)
with torch.no_grad():
preds = birefnet(input_tensor)[-1].sigmoid().cpu()
# 3. Crear la máscara Alfa
mask = preds[0].squeeze()
mask_pil = transforms.ToPILImage()(mask)
# 4. Ajustar máscara al tamaño original con alta calidad (LANCZOS)
mask_final = mask_pil.resize(image_size, Image.LANCZOS)
# 5. Aplicar transparencia a la imagen original
output_image = image.copy()
output_image.putalpha(mask_final)
return output_image
def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
"""
Función para las pestañas de Gradio (Subida de Imagen y URL).
Devuelve la imagen original y la versión procesada para el ImageSlider.
"""
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
return (origin, processed_image)
def process_file(f: str) -> str:
"""
Función para la pestaña de archivos. Guarda y devuelve la ruta del PNG.
"""
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path, "PNG")
return name_path
# =========================================================================
# INTERFAZ GRADIO
# =========================================================================
slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# Ejemplos por defecto
example_image_path = "butterfly.jpg"
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
# Carga segura de la imagen de ejemplo local para evitar crasheos si no se ha subido aún
try:
chameleon = load_img(example_image_path, output_type="pil")
examples_img = [chameleon]
examples_file = [example_image_path]
except Exception:
examples_img = None
examples_file = None
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=examples_img, api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=examples_file, api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2, tab3],
["Image Upload", "URL Input", "File Output"],
title="Background Removal Tool (CPU Edition)"
)
if __name__ == "__main__":
demo.launch(show_error=True)