| import gradio as gr |
| import requests |
| import io |
| import re |
| import random |
| import os |
| from PIL import Image |
| from datasets import load_dataset |
| from huggingface_hub import login |
|
|
| login(token=os.getenv("HF_READ_TOKEN")) |
|
|
| API_URL = "https://api-inference.huggingface.co/models/openskyml/open-diffusion-v1" |
| API_TOKEN = os.getenv("HF_READ_TOKEN") |
| headers = {"Authorization": f"Bearer {API_TOKEN}"} |
|
|
| word_list_dataset = load_dataset("openskyml/bad-words-prompt-list", data_files="en.txt", use_auth_token=True) |
| word_list = word_list_dataset["train"]['text'] |
|
|
| def query(prompt, is_negative=False, steps=5, cfg_scale=7, seed=None, num_images=4): |
| for filter in word_list: |
| if re.search(rf"\b{filter}\b", prompt): |
| raise gr.Error("Unsafe content found. Please try again with different prompts.") |
| images = [] |
| |
| for _ in range(num_images): |
| payload = { |
| "inputs": prompt + ", 8k", |
| "is_negative": is_negative, |
| "steps": steps, |
| "cfg_scale": cfg_scale, |
| "seed": seed if seed is not None else random.randint(-1, 2147483647) |
| } |
| |
| image_bytes = requests.post(API_URL, headers=headers, json=payload).content |
| image = Image.open(io.BytesIO(image_bytes)) |
| |
| images.append(image) |
|
|
| return images |
|
|
|
|
| css = """ |
| .gradio-container { |
| font-family: 'IBM Plex Sans', sans-serif; |
| } |
| #gallery { |
| min-height: 22rem; |
| margin-bottom: 15px; |
| margin-left: auto; |
| margin-right: auto; |
| border-bottom-right-radius: .5rem !important; |
| border-bottom-left-radius: .5rem !important; |
| } |
| #gallery>div>.h-full { |
| min-height: 20rem; |
| |
| } |
| |
| #prompt-text-input, #negative-prompt-text-input{padding: .45rem 0.625rem} |
| #component-16{border-top-width: 1px!important;margin-top: 1em} |
| .image_duplication{position: absolute; width: 100px; left: 50px} |
| """ |
|
|
| with gr.Blocks(css=css) as demo: |
| gr.HTML( |
| """ |
| <div style="text-align: center; margin: 0 auto;"> |
| <div |
| style=" |
| display: inline-flex; |
| align-items: center; |
| gap: 0.8rem; |
| font-size: 1.75rem; |
| " |
| > |
| <h1 style="font-weight: 900; margin-bottom: 7px;margin-top:5px"> |
| Open Diffusion 1.0 Demo |
| </h1> |
| </div> |
| </div> |
| """ |
| ) |
| |
| with gr.Row(): |
| gallery_output = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery").style(grid=[2]) |
| |
| with gr.Row(): |
| with gr.Box(): |
| text_prompt = gr.Textbox(show_label=False, placeholder="Enter your prompt", max_lines=1) |
| negative_prompt = gr.Textbox(show_label=False, placeholder="Enter a negative", max_lines=1) |
| text_button = gr.Button("Generate", icon="https://www.gstatic.com/android/keyboard/emojikitchen/20210521/u1fa84/u1fa84_u1fa84.png") |
|
|
| |
| |
| text_button.click(query, inputs=[text_prompt, negative_prompt], outputs=gallery_output) |
|
|
| demo.launch(show_api=False, server_name="0.0.0.0", server_port=7860) |