| import gradio as gr |
| import os |
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from insightface.app import FaceAnalysis |
| from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline |
| from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlus |
| import argparse |
| import random |
| from insightface.utils import face_align |
| from pyngrok import ngrok |
| import threading |
| import time |
| from ip_adapter.ip_adapter_faceid import IPAdapterFaceIDPlusXL |
| import hashlib |
| from datetime import datetime |
|
|
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--share", action="store_true", help="Enable Gradio share option") |
| parser.add_argument("--num_images", type=int, default=1, help="Number of images to generate") |
| parser.add_argument("--cache_limit", type=int, default=1, help="Limit for model cache") |
| parser.add_argument("--ngrok_token", type=str, default=None, help="ngrok authtoken for tunneling") |
|
|
| args = parser.parse_args() |
|
|
| |
| static_model_names = [ |
| "SG161222/Realistic_Vision_V6.0_B1_noVAE", |
| "stablediffusionapi/rev-animated-v122-eol", |
| "Lykon/DreamShaper", |
| "stablediffusionapi/toonyou", |
| "stablediffusionapi/real-cartoon-3d", |
| "KBlueLeaf/kohaku-v2.1", |
| "nitrosocke/Ghibli-Diffusion", |
| "Linaqruf/anything-v3.0", |
| "jinaai/flat-2d-animerge", |
| "stablediffusionapi/realcartoon3d", |
| "stablediffusionapi/disney-pixar-cartoon", |
| "stablediffusionapi/pastel-mix-stylized-anime", |
| "stablediffusionapi/anything-v5", |
| "SG161222/Realistic_Vision_V2.0", |
| "SG161222/Realistic_Vision_V4.0_noVAE", |
| "SG161222/Realistic_Vision_V5.1_noVAE", |
| "stablediffusionapi/anime-illust-diffusion-xl", |
| "stabilityai/stable-diffusion-xl-base-1.0", |
| |
| ] |
|
|
| |
| model_cache = {} |
| max_cache_size = args.cache_limit |
|
|
| embeddings_cache = {} |
|
|
| def get_image_hash(image): |
| image_bytes = image.tobytes() |
| return hashlib.sha256(image_bytes).hexdigest() |
|
|
| def convert_model(checkpoint_path, output_path, isSDXL): |
| try: |
| if isSDXL: |
| pipe = StableDiffusionXLPipeline.from_single_file(checkpoint_path) |
| pipe.save_pretrained(output_path) |
| else: |
| pipe = StableDiffusionPipeline.from_single_file(checkpoint_path) |
| pipe.save_pretrained(output_path) |
| return f"Model converted and saved to {output_path}" |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
|
|
| |
| def load_model(model_name, isSDXL): |
| if model_name in model_cache: |
| return model_cache[model_name] |
| print(f"loading model {model_name}") |
| |
| if len(model_cache) >= max_cache_size: |
| model_cache.pop(next(iter(model_cache))) |
|
|
| device = "cuda" |
| noise_scheduler = DDIMScheduler( |
| num_train_timesteps=1000, |
| beta_start=0.00085, |
| beta_end=0.012, |
| beta_schedule="scaled_linear", |
| clip_sample=False, |
| set_alpha_to_one=False, |
| steps_offset=1, |
| ) |
| vae_model_path = "stabilityai/sd-vae-ft-mse" |
| if isSDXL: |
| vae_model_path = "stabilityai/sdxl-vae" |
| vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) |
|
|
| if isSDXL: |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| vae=vae, |
| scheduler=noise_scheduler, |
| add_watermarker=False, |
| ).to(device) |
| else: |
| |
| pipe = StableDiffusionPipeline.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16, |
| scheduler=noise_scheduler, |
| vae=vae, |
| feature_extractor=None, |
| safety_checker=None |
| ).to(device) |
|
|
| if isSDXL: |
| image_encoder_path = "h94/IP-Adapter/models/image_encoder" |
| ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sdxl.bin" |
| ip_model = IPAdapterFaceIDPlusXL(pipe,image_encoder_path, ip_ckpt, device) |
| else: |
| image_encoder_path = "h94/IP-Adapter/models/image_encoder" |
| ip_ckpt = "adapters/ip-adapter-faceid-plusv2_sd15.bin" |
| ip_model = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_ckpt, device) |
|
|
| model_cache[model_name] = ip_model |
| return ip_model |
|
|
| |
| def generate_image(input_image, positive_prompt, negative_prompt, width, height, model_name, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg): |
| saved_images = [] |
| if custom_model_path: |
| model_name = custom_model_path |
| |
| ip_model = load_model(model_name, isSDXL) |
|
|
| |
| input_image = input_image.convert("RGB") |
| input_image_cv2 = cv2.cvtColor(np.array(input_image), cv2.COLOR_RGB2BGR) |
| image_hash = get_image_hash(input_image) |
|
|
| |
| if image_hash in embeddings_cache: |
| faceid_embeds, face_image = embeddings_cache[image_hash] |
| else: |
| app = FaceAnalysis( |
| name="buffalo_l", providers=["CUDAExecutionProvider", "CPUExecutionProvider"] |
| ) |
| app.prepare(ctx_id=0, det_size=(640, 640)) |
| faces = app.get(input_image_cv2) |
| if not faces: |
| raise ValueError("No faces found in the image.") |
|
|
| faceid_embeds = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) |
| face_image = face_align.norm_crop(input_image_cv2, landmark=faces[0].kps, image_size=224) |
| |
| embeddings_cache[image_hash] = (faceid_embeds, face_image) |
|
|
| for image_index in range(num_images): |
| if randomize_seed or image_index > 0: |
| seed = random.randint(0, 2**32 - 1) |
|
|
| |
| generated_images = ip_model.generate( |
| prompt=positive_prompt, |
| negative_prompt=negative_prompt, |
| faceid_embeds=faceid_embeds, |
| face_image=face_image, |
| num_samples=batch_size, |
| shortcut=enable_shortcut, |
| s_scale=s_scale, |
| width=width, |
| height=height, |
| guidance_scale=cfg, |
| num_inference_steps=num_inference_steps, |
| seed=seed, |
| ) |
|
|
| |
| outputs_dir = "outputs" |
| if not os.path.exists(outputs_dir): |
| os.makedirs(outputs_dir) |
| for i, img in enumerate(generated_images, start=1): |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| image_path = os.path.join(outputs_dir, f"{timestamp}_image_{len(os.listdir(outputs_dir)) + i}.png") |
| img.save(image_path) |
| saved_images.append(image_path) |
|
|
| return saved_images, f"Saved images: {', '.join(saved_images)}", seed |
|
|
| |
| with gr.Blocks() as demo: |
| gr.Markdown("Developed by SECourses - only distributed on https://www.patreon.com/posts/95759342") |
| with gr.Row(): |
| input_image = gr.Image(type="pil") |
| generate_btn = gr.Button("Generate") |
| with gr.Row(): |
| width = gr.Number(value=512, label="Width") |
| height = gr.Number(value=768, label="Height") |
| cfg = gr.Number(value=7.5, label="CFG") |
| with gr.Row(): |
| num_inference_steps = gr.Number(value=30, label="Number of Inference Steps", step=1, minimum=10, maximum=100) |
| seed = gr.Number(value=2023, label="Seed") |
| randomize_seed = gr.Checkbox(value=True, label="Randomize Seed") |
| with gr.Row(): |
| num_images = gr.Number(value=args.num_images, label="Number of Images to Generate", step=1, minimum=1) |
| batch_size = gr.Number(value=1, label="Batch Size", step=1) |
| with gr.Row(): |
| isSDXL = gr.Checkbox(value=False, label="Activate SDXL") |
| enable_shortcut = gr.Checkbox(value=True, label="Enable Shortcut") |
| s_scale = gr.Number(value=1.0, label="Scale Factor (s_scale)", step=0.1, minimum=0.5, maximum=4.0) |
| with gr.Row(): |
| positive_prompt = gr.Textbox(label="Positive Prompt") |
| negative_prompt = gr.Textbox(label="Negative Prompt") |
| with gr.Row(): |
| model_selector = gr.Dropdown(label="Select Model", choices=static_model_names, value=static_model_names[0]) |
| custom_model_path = gr.Textbox(label="Custom Model Path (Optional)") |
|
|
| with gr.Column(): |
| output_gallery = gr.Gallery(label="Generated Images") |
| output_text = gr.Textbox(label="Output Info") |
| display_seed = gr.Textbox(label="Used Seed", interactive=False) |
| |
| with gr.Row(): |
| checkpoint_path_input = gr.Textbox(label="Enter Checkpoint File Path .e.g G:\model\model.safetensors", ) |
| output_path_input = gr.Textbox(label="Enter Output Folder Path, e.g. G:\model\model_diffusers") |
| convert_btn = gr.Button("Convert Model") |
|
|
| generate_btn.click( |
| generate_image, |
| inputs=[input_image, positive_prompt, negative_prompt, width, height, model_selector, num_inference_steps, seed, randomize_seed, num_images, batch_size, enable_shortcut, s_scale, custom_model_path, isSDXL,cfg], |
| outputs=[output_gallery, output_text, display_seed] |
| ) |
| |
| convert_btn.click( |
| convert_model, |
| inputs=[checkpoint_path_input, output_path_input, isSDXL], |
| outputs=[gr.Text(label="Conversion Status")], |
| ) |
|
|
| |
| def start_ngrok(): |
| print("Starting ngrok...") |
| time.sleep(10) |
| ngrok.set_auth_token(args.ngrok_token) |
| public_url = ngrok.connect(port=7860) |
| print(f"ngrok tunnel started at {public_url}") |
|
|
| if __name__ == "__main__": |
| if args.ngrok_token: |
| |
| ngrok_thread = threading.Thread(target=start_ngrok, daemon=True) |
| ngrok_thread.start() |
|
|
| |
| demo.launch(share=args.share, inbrowser=True) |
|
|