| | from base64 import b64encode |
| | import numpy |
| | import torch |
| | from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel |
| | from huggingface_hub import notebook_login |
| | import gradio as gr |
| |
|
| | |
| | from matplotlib import pyplot as plt |
| | from pathlib import Path |
| | from PIL import Image |
| | from torch import autocast |
| | from torchvision import transforms as tfms |
| | from tqdm.auto import tqdm |
| | from transformers import CLIPTextModel, CLIPTokenizer, logging |
| | import os |
| | import numpy as np |
| |
|
| |
|
| |
|
| | |
| | logging.set_verbosity_error() |
| |
|
| | |
| | torch_device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| |
|
| | |
| | vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae") |
| |
|
| | |
| | tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") |
| | text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") |
| |
|
| | |
| | unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet") |
| |
|
| | |
| | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) |
| |
|
| | |
| | vae = vae.to(torch_device) |
| | text_encoder = text_encoder.to(torch_device) |
| | unet = unet.to(torch_device) |
| | token_emb_layer = text_encoder.text_model.embeddings.token_embedding |
| | pos_emb_layer = text_encoder.text_model.embeddings.position_embedding |
| |
|
| | position_ids = text_encoder.text_model.embeddings.position_ids[:, :77] |
| | position_embeddings = pos_emb_layer(position_ids) |
| |
|
| |
|
| | def get_output_embeds(input_embeddings): |
| | |
| | bsz, seq_len = input_embeddings.shape[:2] |
| | causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype) |
| |
|
| | |
| | |
| | encoder_outputs = text_encoder.text_model.encoder( |
| | inputs_embeds=input_embeddings, |
| | attention_mask=None, |
| | causal_attention_mask=causal_attention_mask.to(torch_device), |
| | output_attentions=None, |
| | output_hidden_states=True, |
| | return_dict=None, |
| | ) |
| |
|
| | |
| | output = encoder_outputs[0] |
| |
|
| | |
| | output = text_encoder.text_model.final_layer_norm(output) |
| |
|
| | |
| | return output |
| |
|
| |
|
| | def set_timesteps(scheduler, num_inference_steps): |
| | scheduler.set_timesteps(num_inference_steps) |
| | scheduler.timesteps = scheduler.timesteps.to(torch.float32) |
| |
|
| | def pil_to_latent(input_im): |
| | |
| | with torch.no_grad(): |
| | latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) |
| | return 0.18215 * latent.latent_dist.sample() |
| |
|
| | def latents_to_pil(latents): |
| | |
| | latents = (1 / 0.18215) * latents |
| | with torch.no_grad(): |
| | image = vae.decode(latents).sample |
| | image = (image / 2 + 0.5).clamp(0, 1) |
| | image = image.detach().cpu().permute(0, 2, 3, 1).numpy() |
| | images = (image * 255).round().astype("uint8") |
| | pil_images = [Image.fromarray(image) for image in images] |
| | return pil_images |
| |
|
| |
|
| | def generate_with_embs(text_embeddings, text_input, seed,num_inference_steps,guidance_scale): |
| |
|
| | height = 512 |
| | width = 512 |
| | num_inference_steps = num_inference_steps |
| | guidance_scale = 8 |
| | generator = torch.manual_seed(seed) |
| | batch_size = 1 |
| |
|
| | max_length = text_input.input_ids.shape[-1] |
| | uncond_input = tokenizer( |
| | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
| | ) |
| | with torch.no_grad(): |
| | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] |
| | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) |
| |
|
| | |
| | set_timesteps(scheduler, num_inference_steps) |
| |
|
| | |
| | latents = torch.randn( |
| | (batch_size, unet.in_channels, height // 8, width // 8), |
| | generator=generator, |
| | ) |
| | latents = latents.to(torch_device) |
| | latents = latents * scheduler.init_noise_sigma |
| |
|
| | |
| | for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): |
| | |
| | latent_model_input = torch.cat([latents] * 2) |
| | sigma = scheduler.sigmas[i] |
| | latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | with torch.no_grad(): |
| | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] |
| |
|
| | |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | latents = scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| | return latents_to_pil(latents)[0] |
| |
|
| |
|
| | def generate_with_prompt_style(prompt, style, seed): |
| |
|
| | prompt = prompt + ' in style of s' |
| | embed = torch.load(style) |
| |
|
| | text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
| | |
| | |
| | input_ids = text_input.input_ids.to(torch_device) |
| |
|
| | token_embeddings = token_emb_layer(input_ids) |
| | |
| | replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device) |
| |
|
| | |
| | token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device) |
| |
|
| | |
| | input_embeddings = token_embeddings + position_embeddings |
| |
|
| | |
| | modified_output_embeddings = get_output_embeds(input_embeddings) |
| |
|
| | |
| | return generate_with_embs(modified_output_embeddings, text_input, seed) |
| |
|
| | def contrast_loss(images): |
| | variance = torch.var(images) |
| | return -variance |
| |
|
| |
|
| | def blue_loss(images): |
| | """ |
| | Computes the blue loss for a batch of images. |
| | |
| | The blue loss is defined as the negative variance of the blue channel's pixel values. |
| | |
| | Parameters: |
| | images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where |
| | N is the batch size, C is the number of channels (3 for RGB), |
| | H is the height, and W is the width. |
| | |
| | Returns: |
| | torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values. |
| | """ |
| | |
| | if images.shape[1] != 3: |
| | raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape)) |
| | |
| | |
| | blue_channel = images[:, 2, :, :] |
| | |
| | |
| | variance = torch.var(blue_channel) |
| | |
| | return -variance |
| |
|
| |
|
| | def ymca_loss(images, weights=(1.0, 1.0, 1.0, 1.0)): |
| | """ |
| | Computes the YMCA loss for a batch of images. |
| | |
| | The YMCA loss is a custom loss function combining the mean value of the Y (luminance) channel, |
| | the mean value of the M (magenta) channel, the variance of the C (cyan) channel, and the |
| | absolute sum of the A (alpha) channel if present. |
| | |
| | Parameters: |
| | images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where |
| | N is the batch size, C is the number of channels (3 for RGB or 4 for RGBA), |
| | H is the height, and W is the width. |
| | weights (tuple): A tuple of four floats representing the weights for each component of the loss |
| | (default is (1.0, 1.0, 1.0, 1.0)). |
| | |
| | Returns: |
| | torch.Tensor: The YMCA loss, combining the specified components. |
| | """ |
| | num_channels = images.shape[1] |
| | |
| | if num_channels not in [3, 4]: |
| | raise ValueError("Expected images with 3 (RGB) or 4 (RGBA) channels, but got shape {}".format(images.shape)) |
| | |
| | |
| | R = images[:, 0, :, :] |
| | G = images[:, 1, :, :] |
| | B = images[:, 2, :, :] |
| | |
| | |
| | Y = 0.299 * R + 0.587 * G + 0.114 * B |
| | |
| | |
| | M = 1 - G |
| | |
| | |
| | C = 1 - R |
| | |
| | |
| | mean_Y = torch.mean(Y) |
| | |
| | |
| | mean_M = torch.mean(M) |
| | |
| | |
| | variance_C = torch.var(C) |
| | |
| | loss = weights[0] * mean_Y + weights[1] * mean_M - weights[2] * variance_C |
| | |
| | if num_channels == 4: |
| | |
| | A = images[:, 3, :, :] |
| | |
| | abs_sum_A = torch.sum(torch.abs(A)) |
| | |
| | loss += weights[3] * abs_sum_A |
| | |
| | return loss |
| |
|
| |
|
| |
|
| | def rgb_to_cmyk(images): |
| | """ |
| | Converts an RGB image tensor to CMYK. |
| | |
| | Parameters: |
| | images (torch.Tensor): A batch of images in RGB format. Expected shape is (N, 3, H, W). |
| | |
| | Returns: |
| | torch.Tensor: A tensor containing the CMYK channels. |
| | """ |
| | R = images[:, 0, :, :] |
| | G = images[:, 1, :, :] |
| | B = images[:, 2, :, :] |
| |
|
| | |
| | C = 1 - R |
| | M = 1 - G |
| | Y = 1 - B |
| |
|
| | |
| | K = torch.min(torch.min(C, M), Y) |
| | C = (C - K) / (1 - K + 1e-8) |
| | M = (M - K) / (1 - K + 1e-8) |
| | Y = (Y - K) / (1 - K + 1e-8) |
| |
|
| | CMYK = torch.stack([C, M, Y, K], dim=1) |
| | return CMYK |
| |
|
| | def cymk_loss(images, weights=(1.0, 1.0, 1.0, 1.0)): |
| | """ |
| | Computes the CYMK loss for a batch of images. |
| | |
| | The CYMK loss is a custom loss function combining the variance of the Cyan channel, |
| | the mean value of the Yellow channel, the variance of the Magenta channel, and the |
| | absolute sum of the Black channel. |
| | |
| | Parameters: |
| | images (torch.Tensor): A batch of images. Expected shape is (N, 3, H, W) for RGB input. |
| | weights (tuple): A tuple of four floats representing the weights for each component of the loss |
| | (default is (1.0, 1.0, 1.0, 1.0)). |
| | |
| | Returns: |
| | torch.Tensor: The CYMK loss, combining the specified components. |
| | """ |
| | |
| | if images.shape[1] != 3: |
| | raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape)) |
| | |
| | |
| | cmyk_images = rgb_to_cmyk(images) |
| | |
| | |
| | C = cmyk_images[:, 0, :, :] |
| | M = cmyk_images[:, 1, :, :] |
| | Y = cmyk_images[:, 2, :, :] |
| | K = cmyk_images[:, 3, :, :] |
| | |
| | |
| | variance_C = torch.var(C) |
| | |
| | |
| | mean_Y = torch.mean(Y) |
| | |
| | |
| | variance_M = torch.var(M) |
| | |
| | |
| | abs_sum_K = torch.sum(torch.abs(K)) |
| | |
| | |
| | loss = (weights[0] * variance_C) + (weights[1] * mean_Y) + (weights[2] * variance_M) + (weights[3] * abs_sum_K) |
| | |
| | return loss |
| |
|
| |
|
| | def blue_loss_variant(images, use_mean=False, alpha=1.0): |
| | """ |
| | Computes the blue loss for a batch of images with an optional mean component. |
| | |
| | The blue loss is defined as the negative variance of the blue channel's pixel values. |
| | Optionally, it can also include the mean value of the blue channel. |
| | |
| | Parameters: |
| | images (torch.Tensor): A batch of images. Expected shape is (N, C, H, W) where |
| | N is the batch size, C is the number of channels (3 for RGB), |
| | H is the height, and W is the width. |
| | use_mean (bool): If True, includes the mean of the blue channel in the loss calculation. |
| | alpha (float): Weighting factor for the mean component when use_mean is True. |
| | |
| | Returns: |
| | torch.Tensor: The blue loss, which is the negative variance of the blue channel's pixel values, |
| | optionally combined with the mean value of the blue channel. |
| | """ |
| | |
| | if images.shape[1] != 3: |
| | raise ValueError("Expected images with 3 channels (RGB), but got shape {}".format(images.shape)) |
| | |
| | |
| | blue_channel = images[:, 2, :, :] |
| | |
| | |
| | variance = torch.var(blue_channel) |
| | |
| | if use_mean: |
| | |
| | mean = torch.mean(blue_channel) |
| | |
| | loss = -variance + alpha * mean |
| | else: |
| | loss = -variance |
| | |
| | return loss |
| |
|
| | def generate_with_prompt_style_guidance(prompt, style, seed,num_inference_steps,loss_function): |
| |
|
| | prompt = prompt + ' in style of s' |
| | |
| | embed = torch.load(style) |
| |
|
| | height = 512 |
| | width = 512 |
| | num_inference_steps = num_inference_steps |
| | guidance_scale = 8 |
| | generator = torch.manual_seed(seed) |
| | batch_size = 1 |
| | |
| |
|
| | |
| | text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") |
| | with torch.no_grad(): |
| | text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0] |
| |
|
| | input_ids = text_input.input_ids.to(torch_device) |
| |
|
| | |
| | token_embeddings = token_emb_layer(input_ids) |
| |
|
| | |
| | replacement_token_embedding = embed[list(embed.keys())[0]].to(torch_device) |
| |
|
| | |
| | token_embeddings[0, torch.where(input_ids[0]==338)] = replacement_token_embedding.to(torch_device) |
| |
|
| | |
| | input_embeddings = token_embeddings + position_embeddings |
| |
|
| | |
| | modified_output_embeddings = get_output_embeds(input_embeddings) |
| |
|
| | |
| | max_length = text_input.input_ids.shape[-1] |
| | uncond_input = tokenizer( |
| | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
| | ) |
| | with torch.no_grad(): |
| | uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] |
| |
|
| | text_embeddings = torch.cat([uncond_embeddings, modified_output_embeddings]) |
| |
|
| | |
| | scheduler.set_timesteps(num_inference_steps) |
| |
|
| | |
| | latents = torch.randn( |
| | (batch_size, unet.config.in_channels, height // 8, width // 8), |
| | generator=generator, |
| | ) |
| | latents = latents.to(torch_device) |
| | latents = latents * scheduler.init_noise_sigma |
| |
|
| | |
| | for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): |
| | |
| | latent_model_input = torch.cat([latents] * 2) |
| | sigma = scheduler.sigmas[i] |
| | latent_model_input = scheduler.scale_model_input(latent_model_input, t) |
| |
|
| | |
| | with torch.no_grad(): |
| | noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] |
| |
|
| | |
| | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
| | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| |
|
| | |
| | if i%5 == 0: |
| | |
| | latents = latents.detach().requires_grad_() |
| |
|
| | |
| | latents_x0 = latents - sigma * noise_pred |
| | |
| |
|
| | |
| | denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 |
| |
|
| | |
| | |
| | if loss_function == "contrast": |
| | loss_scale = 200 |
| | loss = contrast_loss(denoised_images) * loss_scale |
| | elif loss_function == "blue_original": |
| | loss_scale = 200 |
| | loss = blue_loss(denoised_images) * loss_scale |
| | elif loss_function == "blue_modified": |
| | loss_scale = 200 |
| | loss = blue_loss_variant(denoised_images) * loss_scale |
| | elif loss_function == "ymca": |
| | loss_scale = 200 |
| | loss = ymca_loss(denoised_images) * loss_scale |
| | elif loss_function == "cmyk": |
| | loss_scale = 1 |
| | loss = cymk_loss(denoised_images) * loss_scale |
| | else : |
| | loss_scale = 200 |
| | loss = ymca_loss(denoised_images) * loss_scale |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | cond_grad = torch.autograd.grad(loss, latents)[0] |
| |
|
| | |
| | latents = latents.detach() - cond_grad * sigma**2 |
| |
|
| | |
| | latents = scheduler.step(noise_pred, t, latents).prev_sample |
| |
|
| |
|
| | return latents_to_pil(latents)[0] |
| |
|
| |
|
| |
|
| |
|
| | dict_styles = { |
| | 'Dr Strange': 'styles/learned_embeds_dr_strange.bin', |
| | 'GTA-5':'styles/learned_embeds_gta5.bin', |
| | 'Manga':'styles/learned_embeds_manga.bin', |
| | 'Pokemon':'styles/learned_embeds_pokemon.bin', |
| | 'Illustration': 'styles/learned_embeds_illustration.bin', |
| | 'Matrix':'styles/learned_embeds_matrix.bin', |
| | 'Oil Painting':'styles/learned_embeds_oil.bin', |
| | } |
| |
|
| | def inference(prompt, seed, style,num_inference_steps,loss_function): |
| | print(prompt, seed, style,num_inference_steps,loss_function) |
| | |
| | if prompt is not None and style is not None and seed is not None: |
| | print(loss_function) |
| | style = dict_styles[style] |
| | torch.manual_seed(seed) |
| | result = generate_with_prompt_style_guidance(prompt, style,seed,num_inference_steps, loss_function) |
| | return np.array(result) |
| | else: |
| | return None |
| |
|
| | title = "Stable Diffusion with text input" |
| | description = "Apply various stable diffusion styles with text prompt as input" |
| | examples = [["a hyper-realistic high definition render of bear sitting on a red rug, 4k", 24041975,"Manga", 10,'ymca'], ["A man dancing on bhutan costume",24041975, "GTA-5",10, 'contrast']] |
| |
|
| | demo = gr.Interface(inference, |
| | inputs = [gr.Textbox(label='Prompt', value='A man dancing on bhutan costume'), gr.Textbox(label='Seed', value=24041975), |
| | gr.Dropdown(['Dr Strange', 'GTA-5', 'Manga', 'Pokemon','Illustration','Matrix','Oil Painting'], label='Style', value='Dr Strange'), |
| | gr.Slider( |
| | minimum=5, |
| | maximum=20, |
| | value=10, |
| | step=5, |
| | label="Number of Steps", |
| | interactive=True, |
| | ), |
| | gr.Radio(["contrast", "blue_original", "blue_modified","ymca","cmyk"], label="loss-function", info="loss-function" , value="ymca"), |
| | ], |
| | outputs = [ |
| | gr.Image(label="Stable Diffusion Output"), |
| | ], |
| | title = title, |
| | description = description, |
| | examples = examples, |
| | cache_examples=True |
| | ) |
| | demo.launch() |
| | |
| |
|