| | import os |
| | import re |
| | import requests |
| | import sys |
| | import copy |
| | import numpy as np |
| | from tqdm import tqdm |
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoTokenizer, CLIPTextModel |
| | from diffusers import AutoencoderKL, UNet2DConditionModel |
| | from peft import LoraConfig, get_peft_model |
| | p = "src/" |
| | sys.path.append(p) |
| | from model import make_1step_sched, my_lora_fwd |
| | from basicsr.archs.arch_util import default_init_weights |
| | from my_utils.vaehook import VAEHook, perfcount |
| |
|
| |
|
| | def get_layer_number(module_name): |
| | base_layers = { |
| | 'down_blocks': 0, |
| | 'mid_block': 4, |
| | 'up_blocks': 5 |
| | } |
| |
|
| | if module_name == 'conv_out': |
| | return 9 |
| |
|
| | base_layer = None |
| | for key in base_layers: |
| | if key in module_name: |
| | base_layer = base_layers[key] |
| | break |
| |
|
| | if base_layer is None: |
| | return None |
| |
|
| | additional_layers = int(re.findall(r'\.(\d+)', module_name)[0]) |
| | final_layer = base_layer + additional_layers |
| | return final_layer |
| |
|
| |
|
| | class S3Diff(torch.nn.Module): |
| | def __init__(self, sd_path=None, pretrained_path=None, lora_rank_unet=32, lora_rank_vae=16, block_embedding_dim=64, args=None): |
| | super().__init__() |
| | self.args = args |
| | self.latent_tiled_size = args.latent_tiled_size |
| | self.latent_tiled_overlap = args.latent_tiled_overlap |
| |
|
| | self.tokenizer = AutoTokenizer.from_pretrained(sd_path, subfolder="tokenizer") |
| | self.text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder").cuda() |
| | self.sched = make_1step_sched(sd_path) |
| | self.guidance_scale = 1.07 |
| |
|
| | vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae") |
| | unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet") |
| |
|
| | target_modules_vae = r"^encoder\..*(conv1|conv2|conv_in|conv_shortcut|conv|conv_out|to_k|to_q|to_v|to_out\.0)$" |
| | target_modules_unet = [ |
| | "to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out", |
| | "proj_in", "proj_out", "ff.net.2", "ff.net.0.proj" |
| | ] |
| |
|
| | num_embeddings = 64 |
| | self.W = nn.Parameter(torch.randn(num_embeddings), requires_grad=False) |
| |
|
| | self.vae_de_mlp = nn.Sequential( |
| | nn.Linear(num_embeddings * 4, 256), |
| | nn.ReLU(True), |
| | ) |
| |
|
| | self.unet_de_mlp = nn.Sequential( |
| | nn.Linear(num_embeddings * 4, 256), |
| | nn.ReLU(True), |
| | ) |
| |
|
| | self.vae_block_mlp = nn.Sequential( |
| | nn.Linear(block_embedding_dim, 64), |
| | nn.ReLU(True), |
| | ) |
| |
|
| | self.unet_block_mlp = nn.Sequential( |
| | nn.Linear(block_embedding_dim, 64), |
| | nn.ReLU(True), |
| | ) |
| |
|
| | self.vae_fuse_mlp = nn.Linear(256 + 64, lora_rank_vae ** 2) |
| | self.unet_fuse_mlp = nn.Linear(256 + 64, lora_rank_unet ** 2) |
| |
|
| | default_init_weights([self.vae_de_mlp, self.unet_de_mlp, self.vae_block_mlp, self.unet_block_mlp, \ |
| | self.vae_fuse_mlp, self.unet_fuse_mlp], 1e-5) |
| |
|
| | |
| | self.vae_block_embeddings = nn.Embedding(6, block_embedding_dim) |
| | self.unet_block_embeddings = nn.Embedding(10, block_embedding_dim) |
| |
|
| | if pretrained_path is not None: |
| | sd = torch.load(pretrained_path, map_location="cpu") |
| | vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"]) |
| | vae.add_adapter(vae_lora_config, adapter_name="vae_skip") |
| | _sd_vae = vae.state_dict() |
| | for k in sd["state_dict_vae"]: |
| | _sd_vae[k] = sd["state_dict_vae"][k] |
| | vae.load_state_dict(_sd_vae) |
| |
|
| | unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"]) |
| | unet.add_adapter(unet_lora_config) |
| | _sd_unet = unet.state_dict() |
| | for k in sd["state_dict_unet"]: |
| | _sd_unet[k] = sd["state_dict_unet"][k] |
| | unet.load_state_dict(_sd_unet) |
| |
|
| | _vae_de_mlp = self.vae_de_mlp.state_dict() |
| | for k in sd["state_dict_vae_de_mlp"]: |
| | _vae_de_mlp[k] = sd["state_dict_vae_de_mlp"][k] |
| | self.vae_de_mlp.load_state_dict(_vae_de_mlp) |
| |
|
| | _unet_de_mlp = self.unet_de_mlp.state_dict() |
| | for k in sd["state_dict_unet_de_mlp"]: |
| | _unet_de_mlp[k] = sd["state_dict_unet_de_mlp"][k] |
| | self.unet_de_mlp.load_state_dict(_unet_de_mlp) |
| |
|
| | _vae_block_mlp = self.vae_block_mlp.state_dict() |
| | for k in sd["state_dict_vae_block_mlp"]: |
| | _vae_block_mlp[k] = sd["state_dict_vae_block_mlp"][k] |
| | self.vae_block_mlp.load_state_dict(_vae_block_mlp) |
| |
|
| | _unet_block_mlp = self.unet_block_mlp.state_dict() |
| | for k in sd["state_dict_unet_block_mlp"]: |
| | _unet_block_mlp[k] = sd["state_dict_unet_block_mlp"][k] |
| | self.unet_block_mlp.load_state_dict(_unet_block_mlp) |
| |
|
| | _vae_fuse_mlp = self.vae_fuse_mlp.state_dict() |
| | for k in sd["state_dict_vae_fuse_mlp"]: |
| | _vae_fuse_mlp[k] = sd["state_dict_vae_fuse_mlp"][k] |
| | self.vae_fuse_mlp.load_state_dict(_vae_fuse_mlp) |
| |
|
| | _unet_fuse_mlp = self.unet_fuse_mlp.state_dict() |
| | for k in sd["state_dict_unet_fuse_mlp"]: |
| | _unet_fuse_mlp[k] = sd["state_dict_unet_fuse_mlp"][k] |
| | self.unet_fuse_mlp.load_state_dict(_unet_fuse_mlp) |
| |
|
| | self.W = nn.Parameter(sd["w"], requires_grad=False) |
| |
|
| | embeddings_state_dict = sd["state_embeddings"] |
| | self.vae_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_vae_block']) |
| | self.unet_block_embeddings.load_state_dict(embeddings_state_dict['state_dict_unet_block']) |
| | else: |
| | print("Initializing model with random weights") |
| | vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian", |
| | target_modules=target_modules_vae) |
| | vae.add_adapter(vae_lora_config, adapter_name="vae_skip") |
| | unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian", |
| | target_modules=target_modules_unet |
| | ) |
| | unet.add_adapter(unet_lora_config) |
| |
|
| | self.lora_rank_unet = lora_rank_unet |
| | self.lora_rank_vae = lora_rank_vae |
| | self.target_modules_vae = target_modules_vae |
| | self.target_modules_unet = target_modules_unet |
| |
|
| | self.vae_lora_layers = [] |
| | for name, module in vae.named_modules(): |
| | if 'base_layer' in name: |
| | self.vae_lora_layers.append(name[:-len(".base_layer")]) |
| | |
| | for name, module in vae.named_modules(): |
| | if name in self.vae_lora_layers: |
| | module.forward = my_lora_fwd.__get__(module, module.__class__) |
| |
|
| | self.unet_lora_layers = [] |
| | for name, module in unet.named_modules(): |
| | if 'base_layer' in name: |
| | self.unet_lora_layers.append(name[:-len(".base_layer")]) |
| |
|
| | for name, module in unet.named_modules(): |
| | if name in self.unet_lora_layers: |
| | module.forward = my_lora_fwd.__get__(module, module.__class__) |
| |
|
| | self.unet_layer_dict = {name: get_layer_number(name) for name in self.unet_lora_layers} |
| |
|
| | unet.to("cuda") |
| | vae.to("cuda") |
| | self.unet, self.vae = unet, vae |
| | self.timesteps = torch.tensor([999], device="cuda").long() |
| | self.text_encoder.requires_grad_(False) |
| |
|
| | |
| | self._init_tiled_vae(encoder_tile_size=args.vae_encoder_tiled_size, decoder_tile_size=args.vae_decoder_tiled_size) |
| |
|
| | def set_eval(self): |
| | self.unet.eval() |
| | self.vae.eval() |
| | self.vae_de_mlp.eval() |
| | self.unet_de_mlp.eval() |
| | self.vae_block_mlp.eval() |
| | self.unet_block_mlp.eval() |
| | self.vae_fuse_mlp.eval() |
| | self.unet_fuse_mlp.eval() |
| |
|
| | self.vae_block_embeddings.requires_grad_(False) |
| | self.unet_block_embeddings.requires_grad_(False) |
| |
|
| | self.unet.requires_grad_(False) |
| | self.vae.requires_grad_(False) |
| |
|
| | def set_train(self): |
| | self.unet.train() |
| | self.vae.train() |
| | self.vae_de_mlp.train() |
| | self.unet_de_mlp.train() |
| | self.vae_block_mlp.train() |
| | self.unet_block_mlp.train() |
| | self.vae_fuse_mlp.train() |
| | self.unet_fuse_mlp.train() |
| |
|
| | self.vae_block_embeddings.requires_grad_(True) |
| | self.unet_block_embeddings.requires_grad_(True) |
| |
|
| | for n, _p in self.unet.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| | self.unet.conv_in.requires_grad_(True) |
| |
|
| | for n, _p in self.vae.named_parameters(): |
| | if "lora" in n: |
| | _p.requires_grad = True |
| |
|
| | @perfcount |
| | @torch.no_grad() |
| | def forward(self, c_t, deg_score, pos_prompt, neg_prompt): |
| | |
| | if pos_prompt is not None: |
| | |
| | pos_caption_tokens = self.tokenizer(pos_prompt, max_length=self.tokenizer.model_max_length, |
| | padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() |
| | pos_caption_enc = self.text_encoder(pos_caption_tokens)[0] |
| | else: |
| | pos_caption_enc = self.text_encoder(prompt_tokens)[0] |
| |
|
| | if neg_prompt is not None: |
| | |
| | neg_caption_tokens = self.tokenizer(neg_prompt, max_length=self.tokenizer.model_max_length, |
| | padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda() |
| | neg_caption_enc = self.text_encoder(neg_caption_tokens)[0] |
| | else: |
| | neg_caption_enc = self.text_encoder(neg_prompt_tokens)[0] |
| |
|
| | |
| | deg_proj = deg_score[..., None] * self.W[None, None, :] * 2 * np.pi |
| | deg_proj = torch.cat([torch.sin(deg_proj), torch.cos(deg_proj)], dim=-1) |
| | deg_proj = torch.cat([deg_proj[:, 0], deg_proj[:, 1]], dim=-1) |
| |
|
| | |
| | vae_de_c_embed = self.vae_de_mlp(deg_proj) |
| | unet_de_c_embed = self.unet_de_mlp(deg_proj) |
| |
|
| | |
| | vae_block_c_embeds = self.vae_block_mlp(self.vae_block_embeddings.weight) |
| | unet_block_c_embeds = self.unet_block_mlp(self.unet_block_embeddings.weight) |
| |
|
| | vae_embeds = self.vae_fuse_mlp(torch.cat([vae_de_c_embed.unsqueeze(1).repeat(1, vae_block_c_embeds.shape[0], 1), \ |
| | vae_block_c_embeds.unsqueeze(0).repeat(vae_de_c_embed.shape[0],1,1)], -1)) |
| | unet_embeds = self.unet_fuse_mlp(torch.cat([unet_de_c_embed.unsqueeze(1).repeat(1, unet_block_c_embeds.shape[0], 1), \ |
| | unet_block_c_embeds.unsqueeze(0).repeat(unet_de_c_embed.shape[0],1,1)], -1)) |
| |
|
| | for layer_name, module in self.vae.named_modules(): |
| | if layer_name in self.vae_lora_layers: |
| | split_name = layer_name.split(".") |
| | if split_name[1] == 'down_blocks': |
| | block_id = int(split_name[2]) |
| | vae_embed = vae_embeds[:, block_id] |
| | elif split_name[1] == 'mid_block': |
| | vae_embed = vae_embeds[:, -2] |
| | else: |
| | vae_embed = vae_embeds[:, -1] |
| | module.de_mod = vae_embed.reshape(-1, self.lora_rank_vae, self.lora_rank_vae) |
| |
|
| | for layer_name, module in self.unet.named_modules(): |
| | if layer_name in self.unet_lora_layers: |
| | split_name = layer_name.split(".") |
| | if split_name[0] == 'down_blocks': |
| | block_id = int(split_name[1]) |
| | unet_embed = unet_embeds[:, block_id] |
| | elif split_name[0] == 'mid_block': |
| | unet_embed = unet_embeds[:, 4] |
| | elif split_name[0] == 'up_blocks': |
| | block_id = int(split_name[1]) + 5 |
| | unet_embed = unet_embeds[:, block_id] |
| | else: |
| | unet_embed = unet_embeds[:, -1] |
| | module.de_mod = unet_embed.reshape(-1, self.lora_rank_unet, self.lora_rank_unet) |
| |
|
| | lq_latent = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor |
| |
|
| | |
| | _, _, h, w = lq_latent.size() |
| | tile_size, tile_overlap = (self.latent_tiled_size, self.latent_tiled_overlap) |
| | if h * w <= tile_size * tile_size: |
| | print(f"[Tiled Latent]: the input size is tiny and unnecessary to tile.") |
| | pos_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=pos_caption_enc).sample |
| | neg_model_pred = self.unet(lq_latent, self.timesteps, encoder_hidden_states=neg_caption_enc).sample |
| | model_pred = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred) |
| | else: |
| | print(f"[Tiled Latent]: the input size is {c_t.shape[-2]}x{c_t.shape[-1]}, need to tiled") |
| | |
| | tile_size = min(tile_size, min(h, w)) |
| | tile_weights = self._gaussian_weights(tile_size, tile_size, 1).to(c_t.device) |
| |
|
| | grid_rows = 0 |
| | cur_x = 0 |
| | while cur_x < lq_latent.size(-1): |
| | cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size |
| | grid_rows += 1 |
| |
|
| | grid_cols = 0 |
| | cur_y = 0 |
| | while cur_y < lq_latent.size(-2): |
| | cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size |
| | grid_cols += 1 |
| |
|
| | input_list = [] |
| | noise_preds = [] |
| | for row in range(grid_rows): |
| | noise_preds_row = [] |
| | for col in range(grid_cols): |
| | if col < grid_cols-1 or row < grid_rows-1: |
| | |
| | ofs_x = max(row * tile_size-tile_overlap * row, 0) |
| | ofs_y = max(col * tile_size-tile_overlap * col, 0) |
| | |
| | if row == grid_rows-1: |
| | ofs_x = w - tile_size |
| | if col == grid_cols-1: |
| | ofs_y = h - tile_size |
| |
|
| | input_start_x = ofs_x |
| | input_end_x = ofs_x + tile_size |
| | input_start_y = ofs_y |
| | input_end_y = ofs_y + tile_size |
| |
|
| | |
| | input_tile = lq_latent[:, :, input_start_y:input_end_y, input_start_x:input_end_x] |
| | input_list.append(input_tile) |
| |
|
| | if len(input_list) == 1 or col == grid_cols-1: |
| | input_list_t = torch.cat(input_list, dim=0) |
| | |
| | pos_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=pos_caption_enc).sample |
| | neg_model_pred = self.unet(input_list_t, self.timesteps, encoder_hidden_states=neg_caption_enc).sample |
| | model_out = neg_model_pred + self.guidance_scale * (pos_model_pred - neg_model_pred) |
| | input_list = [] |
| | noise_preds.append(model_out) |
| |
|
| | |
| | noise_pred = torch.zeros(lq_latent.shape, device=lq_latent.device) |
| | contributors = torch.zeros(lq_latent.shape, device=lq_latent.device) |
| | |
| | for row in range(grid_rows): |
| | for col in range(grid_cols): |
| | if col < grid_cols-1 or row < grid_rows-1: |
| | |
| | ofs_x = max(row * tile_size-tile_overlap * row, 0) |
| | ofs_y = max(col * tile_size-tile_overlap * col, 0) |
| | |
| | if row == grid_rows-1: |
| | ofs_x = w - tile_size |
| | if col == grid_cols-1: |
| | ofs_y = h - tile_size |
| |
|
| | input_start_x = ofs_x |
| | input_end_x = ofs_x + tile_size |
| | input_start_y = ofs_y |
| | input_end_y = ofs_y + tile_size |
| |
|
| | noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights |
| | contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights |
| | |
| | noise_pred /= contributors |
| | model_pred = noise_pred |
| |
|
| | x_denoised = self.sched.step(model_pred, self.timesteps, lq_latent, return_dict=True).prev_sample |
| | output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1) |
| |
|
| | return output_image |
| |
|
| | def save_model(self, outf): |
| | sd = {} |
| | sd["unet_lora_target_modules"] = self.target_modules_unet |
| | sd["vae_lora_target_modules"] = self.target_modules_vae |
| | sd["rank_unet"] = self.lora_rank_unet |
| | sd["rank_vae"] = self.lora_rank_vae |
| | sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k} |
| | sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip_conv" in k} |
| | sd["state_dict_vae_de_mlp"] = {k: v for k, v in self.vae_de_mlp.state_dict().items()} |
| | sd["state_dict_unet_de_mlp"] = {k: v for k, v in self.unet_de_mlp.state_dict().items()} |
| | sd["state_dict_vae_block_mlp"] = {k: v for k, v in self.vae_block_mlp.state_dict().items()} |
| | sd["state_dict_unet_block_mlp"] = {k: v for k, v in self.unet_block_mlp.state_dict().items()} |
| | sd["state_dict_vae_fuse_mlp"] = {k: v for k, v in self.vae_fuse_mlp.state_dict().items()} |
| | sd["state_dict_unet_fuse_mlp"] = {k: v for k, v in self.unet_fuse_mlp.state_dict().items()} |
| | sd["w"] = self.W |
| |
|
| | sd["state_embeddings"] = { |
| | "state_dict_vae_block": self.vae_block_embeddings.state_dict(), |
| | "state_dict_unet_block": self.unet_block_embeddings.state_dict(), |
| | } |
| |
|
| | torch.save(sd, outf) |
| |
|
| | def _set_latent_tile(self, |
| | latent_tiled_size = 96, |
| | latent_tiled_overlap = 32): |
| | self.latent_tiled_size = latent_tiled_size |
| | self.latent_tiled_overlap = latent_tiled_overlap |
| | |
| | def _init_tiled_vae(self, |
| | encoder_tile_size = 256, |
| | decoder_tile_size = 256, |
| | fast_decoder = False, |
| | fast_encoder = False, |
| | color_fix = False, |
| | vae_to_gpu = True): |
| | |
| | if not hasattr(self.vae.encoder, 'original_forward'): |
| | setattr(self.vae.encoder, 'original_forward', self.vae.encoder.forward) |
| | if not hasattr(self.vae.decoder, 'original_forward'): |
| | setattr(self.vae.decoder, 'original_forward', self.vae.decoder.forward) |
| |
|
| | encoder = self.vae.encoder |
| | decoder = self.vae.decoder |
| |
|
| | self.vae.encoder.forward = VAEHook( |
| | encoder, encoder_tile_size, is_decoder=False, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) |
| | self.vae.decoder.forward = VAEHook( |
| | decoder, decoder_tile_size, is_decoder=True, fast_decoder=fast_decoder, fast_encoder=fast_encoder, color_fix=color_fix, to_gpu=vae_to_gpu) |
| |
|
| | def _gaussian_weights(self, tile_width, tile_height, nbatches): |
| | """Generates a gaussian mask of weights for tile contributions""" |
| | from numpy import pi, exp, sqrt |
| | import numpy as np |
| |
|
| | latent_width = tile_width |
| | latent_height = tile_height |
| |
|
| | var = 0.01 |
| | midpoint = (latent_width - 1) / 2 |
| | x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)] |
| | midpoint = latent_height / 2 |
| | y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)] |
| |
|
| | weights = np.outer(y_probs, x_probs) |
| | return torch.tile(torch.tensor(weights), (nbatches, self.unet.config.in_channels, 1, 1)) |
| |
|
| |
|