| | import torch |
| | import types |
| | torch.cuda.get_device_capability = lambda *args, **kwargs: (8, 6) |
| | torch.cuda.get_device_properties = lambda *args, **kwargs: types.SimpleNamespace(name='NVIDIA A10G', major=8, minor=6, total_memory=23836033024, multi_processor_count=80) |
| |
|
| | import huggingface_hub |
| | huggingface_hub.snapshot_download( |
| | repo_id='camenduru/PASD', |
| | allow_patterns=[ |
| | 'pasd/**', |
| | 'pasd_light/**', |
| | 'pasd_light_rrdb/**', |
| | 'pasd_rrdb/**', |
| | ], |
| | local_dir='PASD/runs', |
| | local_dir_use_symlinks=False, |
| | ) |
| | huggingface_hub.hf_hub_download( |
| | repo_id='camenduru/PASD', |
| | filename='majicmixRealistic_v6.safetensors', |
| | local_dir='PASD/checkpoints/personalized_models', |
| | local_dir_use_symlinks=False, |
| | ) |
| | huggingface_hub.hf_hub_download( |
| | repo_id='akhaliq/RetinaFace-R50', |
| | filename='RetinaFace-R50.pth', |
| | local_dir='PASD/annotator/ckpts', |
| | local_dir_use_symlinks=False, |
| | ) |
| |
|
| | import sys; sys.path.append('./PASD') |
| | import spaces |
| | import os |
| | import datetime |
| | import einops |
| | import gradio as gr |
| | from gradio_imageslider import ImageSlider |
| | import numpy as np |
| | import torch |
| | import random |
| | from PIL import Image |
| | from pathlib import Path |
| | from torchvision import transforms |
| | import torch.nn.functional as F |
| | from torchvision.models import resnet50, ResNet50_Weights |
| |
|
| | from pytorch_lightning import seed_everything |
| | from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor |
| | from diffusers import AutoencoderKL, DDIMScheduler, PNDMScheduler, DPMSolverMultistepScheduler, UniPCMultistepScheduler |
| |
|
| | from pipelines.pipeline_pasd import StableDiffusionControlNetPipeline |
| | from myutils.misc import load_dreambooth_lora, rand_name |
| | from myutils.wavelet_color_fix import wavelet_color_fix |
| | from annotator.retinaface import RetinaFaceDetection |
| |
|
| | use_pasd_light = False |
| | face_detector = RetinaFaceDetection() |
| |
|
| | if use_pasd_light: |
| | from models.pasd_light.unet_2d_condition import UNet2DConditionModel |
| | from models.pasd_light.controlnet import ControlNetModel |
| | else: |
| | from models.pasd.unet_2d_condition import UNet2DConditionModel |
| | from models.pasd.controlnet import ControlNetModel |
| |
|
| | pretrained_model_path = "runwayml/stable-diffusion-v1-5" |
| | ckpt_path = "PASD/runs/pasd/checkpoint-100000" |
| | |
| | dreambooth_lora_path = "PASD/checkpoints/personalized_models/majicmixRealistic_v6.safetensors" |
| | |
| | weight_dtype = torch.float16 |
| | device = "cpu" |
| |
|
| | scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") |
| | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") |
| | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") |
| | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") |
| | feature_extractor = CLIPImageProcessor.from_pretrained(pretrained_model_path, subfolder="feature_extractor") |
| | unet = UNet2DConditionModel.from_pretrained(ckpt_path, subfolder="unet") |
| | controlnet = ControlNetModel.from_pretrained(ckpt_path, subfolder="controlnet") |
| | vae.requires_grad_(False) |
| | text_encoder.requires_grad_(False) |
| | unet.requires_grad_(False) |
| | controlnet.requires_grad_(False) |
| |
|
| | unet, vae, text_encoder = load_dreambooth_lora(unet, vae, text_encoder, dreambooth_lora_path) |
| |
|
| | text_encoder.to(device, dtype=weight_dtype) |
| | vae.to(device, dtype=weight_dtype) |
| | unet.to(device, dtype=weight_dtype) |
| | controlnet.to(device, dtype=weight_dtype) |
| |
|
| | validation_pipeline = StableDiffusionControlNetPipeline( |
| | vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, feature_extractor=feature_extractor, |
| | unet=unet, controlnet=controlnet, scheduler=scheduler, safety_checker=None, requires_safety_checker=False, |
| | ) |
| | |
| | validation_pipeline._init_tiled_vae(decoder_tile_size=224) |
| |
|
| | weights = ResNet50_Weights.DEFAULT |
| | preprocess = weights.transforms() |
| | resnet = resnet50(weights=weights) |
| | resnet.eval() |
| |
|
| | def resize_image(image_path, target_height): |
| | |
| | with Image.open(image_path) as img: |
| | |
| | ratio = target_height / float(img.size[1]) |
| | |
| | new_width = int(float(img.size[0]) * ratio) |
| | |
| | resized_img = img.resize((new_width, target_height), Image.LANCZOS) |
| | |
| | |
| | return resized_img |
| |
|
| | @spaces.GPU(enable_queue=True) |
| | def inference(input_image, prompt, a_prompt, n_prompt, denoise_steps, upscale, alpha, cfg, seed): |
| | |
| | |
| | if seed == -1: |
| | seed = 0 |
| | |
| | input_image = resize_image(input_image, 512) |
| | process_size = 768 |
| | resize_preproc = transforms.Compose([ |
| | transforms.Resize(process_size, interpolation=transforms.InterpolationMode.BILINEAR), |
| | ]) |
| | |
| | |
| | timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") |
| |
|
| | with torch.no_grad(): |
| | seed_everything(seed) |
| | generator = torch.Generator(device=device) |
| |
|
| | input_image = input_image.convert('RGB') |
| | batch = preprocess(input_image).unsqueeze(0) |
| | prediction = resnet(batch).squeeze(0).softmax(0) |
| | class_id = prediction.argmax().item() |
| | score = prediction[class_id].item() |
| | category_name = weights.meta["categories"][class_id] |
| | if score >= 0.1: |
| | prompt += f"{category_name}" if prompt=='' else f", {category_name}" |
| |
|
| | prompt = a_prompt if prompt=='' else f"{prompt}, {a_prompt}" |
| |
|
| | ori_width, ori_height = input_image.size |
| | resize_flag = False |
| |
|
| | rscale = upscale |
| | input_image = input_image.resize((input_image.size[0]*rscale, input_image.size[1]*rscale)) |
| | |
| | |
| | |
| |
|
| | input_image = input_image.resize((input_image.size[0]//8*8, input_image.size[1]//8*8)) |
| | width, height = input_image.size |
| | resize_flag = True |
| |
|
| | try: |
| | image = validation_pipeline( |
| | None, prompt, input_image, num_inference_steps=denoise_steps, generator=generator, height=height, width=width, guidance_scale=cfg, |
| | negative_prompt=n_prompt, conditioning_scale=alpha, eta=0.0, |
| | ).images[0] |
| | |
| | if True: |
| | image = wavelet_color_fix(image, input_image) |
| | |
| | if resize_flag: |
| | image = image.resize((ori_width*rscale, ori_height*rscale)) |
| | except Exception as e: |
| | print(e) |
| | image = Image.new(mode="RGB", size=(512, 512)) |
| | |
| | |
| | image.save(f'result_{timestamp}.jpg', 'JPEG') |
| |
|
| | |
| | input_image.save(f'input_{timestamp}.jpg', 'JPEG') |
| | |
| | return (f"input_{timestamp}.jpg", f"result_{timestamp}.jpg"), f"result_{timestamp}.jpg" |
| |
|
| | title = "Pixel-Aware Stable Diffusion for Real-ISR" |
| | description = "Gradio Demo for PASD Real-ISR. To use it, simply upload your image, or click one of the examples to load them." |
| | article = "<a href='https://github.com/yangxy/PASD' target='_blank'>Github Repo Pytorch</a>" |
| | |
| |
|
| | css = """ |
| | #col-container{ |
| | margin: 0 auto; |
| | max-width: 720px; |
| | } |
| | #project-links{ |
| | margin: 0 0 12px !important; |
| | column-gap: 8px; |
| | display: flex; |
| | justify-content: center; |
| | flex-wrap: nowrap; |
| | flex-direction: row; |
| | align-items: center; |
| | } |
| | """ |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | with gr.Column(elem_id="col-container"): |
| | gr.HTML(f""" |
| | <h2 style="text-align: center;"> |
| | PASD Magnify |
| | </h2> |
| | <p style="text-align: center;"> |
| | Pixel-Aware Stable Diffusion for Realistic Image Super-resolution and Personalized Stylization |
| | </p> |
| | <p id="project-links" align="center"> |
| | <a href='https://github.com/yangxy/PASD'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://huggingface.co/papers/2308.14469'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> |
| | </p> |
| | <p style="margin:12px auto;display: flex;justify-content: center;"> |
| | <a href="https://huggingface.co/spaces/fffiloni/PASD?duplicate=true"><img src="https://huggingface.co/datasets/huggingface/badges/resolve/main/duplicate-this-space-lg.svg" alt="Duplicate this Space"></a> |
| | </p> |
| | |
| | """) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image(type="filepath", sources=["upload"], value="PASD/samples/frog.png") |
| | prompt_in = gr.Textbox(label="Prompt", value="Frog") |
| | with gr.Accordion(label="Advanced settings", open=False): |
| | added_prompt = gr.Textbox(label="Added Prompt", value='clean, high-resolution, 8k, best quality, masterpiece') |
| | neg_prompt = gr.Textbox(label="Negative Prompt",value='dotted, noise, blur, lowres, oversmooth, longbody, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality') |
| | denoise_steps = gr.Slider(label="Denoise Steps", minimum=10, maximum=50, value=20, step=1) |
| | upsample_scale = gr.Slider(label="Upsample Scale", minimum=1, maximum=4, value=2, step=1) |
| | condition_scale = gr.Slider(label="Conditioning Scale", minimum=0.5, maximum=1.5, value=1.1, step=0.1) |
| | classifier_free_guidance = gr.Slider(label="Classier-free Guidance", minimum=0.1, maximum=10.0, value=7.5, step=0.1) |
| | seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True) |
| | submit_btn = gr.Button("Submit") |
| | with gr.Column(): |
| | b_a_slider = ImageSlider(label="B/A result", position=0.5) |
| | file_output = gr.File(label="Downloadable image result") |
| | |
| | submit_btn.click( |
| | fn = inference, |
| | inputs = [ |
| | input_image, prompt_in, |
| | added_prompt, neg_prompt, |
| | denoise_steps, |
| | upsample_scale, condition_scale, |
| | classifier_free_guidance, seed |
| | ], |
| | outputs = [ |
| | b_a_slider, |
| | file_output |
| | ] |
| | ) |
| | demo.queue(max_size=10).launch(show_api=False) |
| |
|