| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gc |
| | from time import time |
| | import math |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | import torch.version |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | import os |
| | import sys |
| | sys.path.append(os.getcwd()) |
| | import my_utils.devices as devices |
| |
|
| | try: |
| | import xformers |
| | import xformers.ops |
| | except ImportError: |
| | pass |
| |
|
| | sd_flag = False |
| |
|
| | def get_recommend_encoder_tile_size(): |
| | if torch.cuda.is_available(): |
| | total_memory = torch.cuda.get_device_properties( |
| | devices.device).total_memory // 2**20 |
| | if total_memory > 16*1000: |
| | ENCODER_TILE_SIZE = 3072 |
| | elif total_memory > 12*1000: |
| | ENCODER_TILE_SIZE = 2048 |
| | elif total_memory > 8*1000: |
| | ENCODER_TILE_SIZE = 1536 |
| | else: |
| | ENCODER_TILE_SIZE = 960 |
| | else: |
| | ENCODER_TILE_SIZE = 512 |
| | return ENCODER_TILE_SIZE |
| |
|
| |
|
| | def get_recommend_decoder_tile_size(): |
| | if torch.cuda.is_available(): |
| | total_memory = torch.cuda.get_device_properties( |
| | devices.device).total_memory // 2**20 |
| | if total_memory > 30*1000: |
| | DECODER_TILE_SIZE = 256 |
| | elif total_memory > 16*1000: |
| | DECODER_TILE_SIZE = 192 |
| | elif total_memory > 12*1000: |
| | DECODER_TILE_SIZE = 128 |
| | elif total_memory > 8*1000: |
| | DECODER_TILE_SIZE = 96 |
| | else: |
| | DECODER_TILE_SIZE = 64 |
| | else: |
| | DECODER_TILE_SIZE = 64 |
| | return DECODER_TILE_SIZE |
| |
|
| |
|
| | if 'global const': |
| | DEFAULT_ENABLED = False |
| | DEFAULT_MOVE_TO_GPU = False |
| | DEFAULT_FAST_ENCODER = True |
| | DEFAULT_FAST_DECODER = True |
| | DEFAULT_COLOR_FIX = 0 |
| | DEFAULT_ENCODER_TILE_SIZE = get_recommend_encoder_tile_size() |
| | DEFAULT_DECODER_TILE_SIZE = get_recommend_decoder_tile_size() |
| |
|
| |
|
| | |
| | def inplace_nonlinearity(x): |
| | |
| | return F.silu(x, inplace=True) |
| |
|
| | |
| |
|
| | |
| | def attn_forward_new(self, h_): |
| | batch_size, channel, height, width = h_.shape |
| | hidden_states = h_.view(batch_size, channel, height * width).transpose(1, 2) |
| |
|
| | attention_mask = None |
| | encoder_hidden_states = None |
| | batch_size, sequence_length, _ = hidden_states.shape |
| | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
| |
|
| | query = self.to_q(hidden_states) |
| |
|
| | if encoder_hidden_states is None: |
| | encoder_hidden_states = hidden_states |
| | elif self.norm_cross: |
| | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) |
| |
|
| | key = self.to_k(encoder_hidden_states) |
| | value = self.to_v(encoder_hidden_states) |
| |
|
| | query = self.head_to_batch_dim(query) |
| | key = self.head_to_batch_dim(key) |
| | value = self.head_to_batch_dim(value) |
| |
|
| | attention_probs = self.get_attention_scores(query, key, attention_mask) |
| | hidden_states = torch.bmm(attention_probs, value) |
| | hidden_states = self.batch_to_head_dim(hidden_states) |
| |
|
| | |
| | hidden_states = self.to_out[0](hidden_states) |
| | |
| | hidden_states = self.to_out[1](hidden_states) |
| |
|
| | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) |
| |
|
| | return hidden_states |
| |
|
| | def attn_forward(self, h_): |
| | q = self.q(h_) |
| | k = self.k(h_) |
| | v = self.v(h_) |
| |
|
| | |
| | b, c, h, w = q.shape |
| | q = q.reshape(b, c, h*w) |
| | q = q.permute(0, 2, 1) |
| | k = k.reshape(b, c, h*w) |
| | w_ = torch.bmm(q, k) |
| | w_ = w_ * (int(c)**(-0.5)) |
| | w_ = torch.nn.functional.softmax(w_, dim=2) |
| |
|
| | |
| | v = v.reshape(b, c, h*w) |
| | w_ = w_.permute(0, 2, 1) |
| | |
| | h_ = torch.bmm(v, w_) |
| | h_ = h_.reshape(b, c, h, w) |
| |
|
| | h_ = self.proj_out(h_) |
| |
|
| | return h_ |
| |
|
| |
|
| | def xformer_attn_forward(self, h_): |
| | q = self.q(h_) |
| | k = self.k(h_) |
| | v = self.v(h_) |
| |
|
| | |
| | B, C, H, W = q.shape |
| | q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) |
| |
|
| | q, k, v = map( |
| | lambda t: t.unsqueeze(3) |
| | .reshape(B, t.shape[1], 1, C) |
| | .permute(0, 2, 1, 3) |
| | .reshape(B * 1, t.shape[1], C) |
| | .contiguous(), |
| | (q, k, v), |
| | ) |
| | out = xformers.ops.memory_efficient_attention( |
| | q, k, v, attn_bias=None, op=self.attention_op) |
| |
|
| | out = ( |
| | out.unsqueeze(0) |
| | .reshape(B, 1, out.shape[1], C) |
| | .permute(0, 2, 1, 3) |
| | .reshape(B, out.shape[1], C) |
| | ) |
| | out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) |
| | out = self.proj_out(out) |
| | return out |
| |
|
| |
|
| | def attn2task(task_queue, net): |
| | if False: |
| | task_queue.append(('store_res', lambda x: x)) |
| | task_queue.append(('pre_norm', net.norm)) |
| | task_queue.append(('attn', lambda x, net=net: attn_forward(net, x))) |
| | task_queue.append(['add_res', None]) |
| | elif False: |
| | task_queue.append(('store_res', lambda x: x)) |
| | task_queue.append(('pre_norm', net.norm)) |
| | task_queue.append( |
| | ('attn', lambda x, net=net: xformer_attn_forward(net, x))) |
| | task_queue.append(['add_res', None]) |
| | else: |
| | task_queue.append(('store_res', lambda x: x)) |
| | task_queue.append(('pre_norm', net.group_norm)) |
| | task_queue.append(('attn', lambda x, net=net: attn_forward_new(net, x))) |
| | task_queue.append(['add_res', None]) |
| |
|
| | def resblock2task(queue, block): |
| | """ |
| | Turn a ResNetBlock into a sequence of tasks and append to the task queue |
| | |
| | @param queue: the target task queue |
| | @param block: ResNetBlock |
| | |
| | """ |
| | if block.in_channels != block.out_channels: |
| | if sd_flag: |
| | if block.use_conv_shortcut: |
| | queue.append(('store_res', block.conv_shortcut)) |
| | else: |
| | queue.append(('store_res', block.nin_shortcut)) |
| | else: |
| | if block.use_in_shortcut: |
| | queue.append(('store_res', block.conv_shortcut)) |
| | else: |
| | queue.append(('store_res', block.nin_shortcut)) |
| |
|
| | else: |
| | queue.append(('store_res', lambda x: x)) |
| | queue.append(('pre_norm', block.norm1)) |
| | queue.append(('silu', inplace_nonlinearity)) |
| | queue.append(('conv1', block.conv1)) |
| | queue.append(('pre_norm', block.norm2)) |
| | queue.append(('silu', inplace_nonlinearity)) |
| | queue.append(('conv2', block.conv2)) |
| | queue.append(['add_res', None]) |
| |
|
| |
|
| |
|
| | def build_sampling(task_queue, net, is_decoder): |
| | """ |
| | Build the sampling part of a task queue |
| | @param task_queue: the target task queue |
| | @param net: the network |
| | @param is_decoder: currently building decoder or encoder |
| | """ |
| | if is_decoder: |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | resblock2task(task_queue, net.mid_block.resnets[0]) |
| | attn2task(task_queue, net.mid_block.attentions[0]) |
| | resblock2task(task_queue, net.mid_block.resnets[1]) |
| | resolution_iter = (range(len(net.up_blocks))) |
| | block_ids = 2 + 1 |
| | condition = len(net.up_blocks) - 1 |
| | module = net.up_blocks |
| | func_name = 'upsamplers' |
| | else: |
| | |
| | |
| | |
| | |
| | |
| | resolution_iter = (range(len(net.down_blocks))) |
| | block_ids = 2 |
| | condition = len(net.down_blocks) - 1 |
| | module = net.down_blocks |
| | func_name = 'downsamplers' |
| |
|
| |
|
| | for i_level in resolution_iter: |
| | for i_block in range(block_ids): |
| | resblock2task(task_queue, module[i_level].resnets[i_block]) |
| | if i_level != condition: |
| | if is_decoder: |
| | task_queue.append((func_name, module[i_level].upsamplers[0])) |
| | else: |
| | task_queue.append((func_name, module[i_level].downsamplers[0])) |
| |
|
| | if not is_decoder: |
| | resblock2task(task_queue, net.mid_block.resnets[0]) |
| | attn2task(task_queue, net.mid_block.attentions[0]) |
| | resblock2task(task_queue, net.mid_block.resnets[1]) |
| |
|
| |
|
| | def build_task_queue(net, is_decoder): |
| | """ |
| | Build a single task queue for the encoder or decoder |
| | @param net: the VAE decoder or encoder network |
| | @param is_decoder: currently building decoder or encoder |
| | @return: the task queue |
| | """ |
| | task_queue = [] |
| | task_queue.append(('conv_in', net.conv_in)) |
| |
|
| | |
| | |
| | build_sampling(task_queue, net, is_decoder) |
| | if is_decoder and not sd_flag: |
| | net.give_pre_end = False |
| | net.tanh_out = False |
| |
|
| | if not is_decoder or not net.give_pre_end: |
| | if sd_flag: |
| | task_queue.append(('pre_norm', net.norm_out)) |
| | else: |
| | task_queue.append(('pre_norm', net.conv_norm_out)) |
| | task_queue.append(('silu', inplace_nonlinearity)) |
| | task_queue.append(('conv_out', net.conv_out)) |
| | if is_decoder and net.tanh_out: |
| | task_queue.append(('tanh', torch.tanh)) |
| |
|
| | return task_queue |
| |
|
| |
|
| | def clone_task_queue(task_queue): |
| | """ |
| | Clone a task queue |
| | @param task_queue: the task queue to be cloned |
| | @return: the cloned task queue |
| | """ |
| | return [[item for item in task] for task in task_queue] |
| |
|
| |
|
| | def get_var_mean(input, num_groups, eps=1e-6): |
| | """ |
| | Get mean and var for group norm |
| | """ |
| | b, c = input.size(0), input.size(1) |
| | channel_in_group = int(c/num_groups) |
| | input_reshaped = input.contiguous().view( |
| | 1, int(b * num_groups), channel_in_group, *input.size()[2:]) |
| | var, mean = torch.var_mean( |
| | input_reshaped, dim=[0, 2, 3, 4], unbiased=False) |
| | return var, mean |
| |
|
| |
|
| | def custom_group_norm(input, num_groups, mean, var, weight=None, bias=None, eps=1e-6): |
| | """ |
| | Custom group norm with fixed mean and var |
| | |
| | @param input: input tensor |
| | @param num_groups: number of groups. by default, num_groups = 32 |
| | @param mean: mean, must be pre-calculated by get_var_mean |
| | @param var: var, must be pre-calculated by get_var_mean |
| | @param weight: weight, should be fetched from the original group norm |
| | @param bias: bias, should be fetched from the original group norm |
| | @param eps: epsilon, by default, eps = 1e-6 to match the original group norm |
| | |
| | @return: normalized tensor |
| | """ |
| | b, c = input.size(0), input.size(1) |
| | channel_in_group = int(c/num_groups) |
| | input_reshaped = input.contiguous().view( |
| | 1, int(b * num_groups), channel_in_group, *input.size()[2:]) |
| |
|
| | out = F.batch_norm(input_reshaped, mean, var, weight=None, bias=None, |
| | training=False, momentum=0, eps=eps) |
| |
|
| | out = out.view(b, c, *input.size()[2:]) |
| |
|
| | |
| | if weight is not None: |
| | out *= weight.view(1, -1, 1, 1) |
| | if bias is not None: |
| | out += bias.view(1, -1, 1, 1) |
| | return out |
| |
|
| |
|
| | def crop_valid_region(x, input_bbox, target_bbox, is_decoder): |
| | """ |
| | Crop the valid region from the tile |
| | @param x: input tile |
| | @param input_bbox: original input bounding box |
| | @param target_bbox: output bounding box |
| | @param scale: scale factor |
| | @return: cropped tile |
| | """ |
| | padded_bbox = [i * 8 if is_decoder else i//8 for i in input_bbox] |
| | margin = [target_bbox[i] - padded_bbox[i] for i in range(4)] |
| | return x[:, :, margin[2]:x.size(2)+margin[3], margin[0]:x.size(3)+margin[1]] |
| |
|
| | |
| |
|
| |
|
| | def perfcount(fn): |
| | def wrapper(*args, **kwargs): |
| | ts = time() |
| |
|
| | if torch.cuda.is_available(): |
| | torch.cuda.reset_peak_memory_stats(devices.device) |
| | devices.torch_gc() |
| | gc.collect() |
| |
|
| | ret = fn(*args, **kwargs) |
| |
|
| | devices.torch_gc() |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | vram = torch.cuda.max_memory_allocated(devices.device) / 2**20 |
| | torch.cuda.reset_peak_memory_stats(devices.device) |
| | print( |
| | f'[Tiled VAE]: Done in {time() - ts:.3f}s, max VRAM alloc {vram:.3f} MB') |
| | else: |
| | print(f'[Tiled VAE]: Done in {time() - ts:.3f}s') |
| |
|
| | return ret |
| | return wrapper |
| |
|
| | |
| |
|
| |
|
| | class GroupNormParam: |
| | def __init__(self): |
| | self.var_list = [] |
| | self.mean_list = [] |
| | self.pixel_list = [] |
| | self.weight = None |
| | self.bias = None |
| |
|
| | def add_tile(self, tile, layer): |
| | var, mean = get_var_mean(tile, 32) |
| | |
| | |
| | if var.dtype == torch.float16 and var.isinf().any(): |
| | fp32_tile = tile.float() |
| | var, mean = get_var_mean(fp32_tile, 32) |
| | |
| | |
| | |
| | |
| | self.var_list.append(var) |
| | self.mean_list.append(mean) |
| | self.pixel_list.append( |
| | tile.shape[2]*tile.shape[3]) |
| | if hasattr(layer, 'weight'): |
| | self.weight = layer.weight |
| | self.bias = layer.bias |
| | else: |
| | self.weight = None |
| | self.bias = None |
| |
|
| | def summary(self): |
| | """ |
| | summarize the mean and var and return a function |
| | that apply group norm on each tile |
| | """ |
| | if len(self.var_list) == 0: |
| | return None |
| | var = torch.vstack(self.var_list) |
| | mean = torch.vstack(self.mean_list) |
| | max_value = max(self.pixel_list) |
| | pixels = torch.tensor( |
| | self.pixel_list, dtype=torch.float32, device=devices.device) / max_value |
| | sum_pixels = torch.sum(pixels) |
| | pixels = pixels.unsqueeze( |
| | 1) / sum_pixels |
| | var = torch.sum( |
| | var * pixels, dim=0) |
| | mean = torch.sum( |
| | mean * pixels, dim=0) |
| | return lambda x: custom_group_norm(x, 32, mean, var, self.weight, self.bias) |
| |
|
| | @staticmethod |
| | def from_tile(tile, norm): |
| | """ |
| | create a function from a single tile without summary |
| | """ |
| | var, mean = get_var_mean(tile, 32) |
| | if var.dtype == torch.float16 and var.isinf().any(): |
| | fp32_tile = tile.float() |
| | var, mean = get_var_mean(fp32_tile, 32) |
| | |
| | if var.device.type == 'mps': |
| | |
| | var = torch.clamp(var, 0, 60000) |
| | var = var.half() |
| | mean = mean.half() |
| | if hasattr(norm, 'weight'): |
| | weight = norm.weight |
| | bias = norm.bias |
| | else: |
| | weight = None |
| | bias = None |
| |
|
| | def group_norm_func(x, mean=mean, var=var, weight=weight, bias=bias): |
| | return custom_group_norm(x, 32, mean, var, weight, bias, 1e-6) |
| | return group_norm_func |
| |
|
| |
|
| | class VAEHook: |
| | def __init__(self, net, tile_size, is_decoder, fast_decoder, fast_encoder, color_fix, to_gpu=False): |
| | self.net = net |
| | self.tile_size = tile_size |
| | self.is_decoder = is_decoder |
| | self.fast_mode = (fast_encoder and not is_decoder) or ( |
| | fast_decoder and is_decoder) |
| | self.color_fix = color_fix and not is_decoder |
| | self.to_gpu = to_gpu |
| | self.pad = 11 if is_decoder else 32 |
| |
|
| | def __call__(self, x): |
| | B, C, H, W = x.shape |
| | original_device = next(self.net.parameters()).device |
| | try: |
| | if self.to_gpu: |
| | self.net.to(devices.get_optimal_device()) |
| | if max(H, W) <= self.pad * 2 + self.tile_size: |
| | print("[Tiled VAE]: the input size is tiny and unnecessary to tile.") |
| | return self.net.original_forward(x) |
| | else: |
| | return self.vae_tile_forward(x) |
| | finally: |
| | self.net.to(original_device) |
| |
|
| | def get_best_tile_size(self, lowerbound, upperbound): |
| | """ |
| | Get the best tile size for GPU memory |
| | """ |
| | divider = 32 |
| | while divider >= 2: |
| | remainer = lowerbound % divider |
| | if remainer == 0: |
| | return lowerbound |
| | candidate = lowerbound - remainer + divider |
| | if candidate <= upperbound: |
| | return candidate |
| | divider //= 2 |
| | return lowerbound |
| |
|
| | def split_tiles(self, h, w): |
| | """ |
| | Tool function to split the image into tiles |
| | @param h: height of the image |
| | @param w: width of the image |
| | @return: tile_input_bboxes, tile_output_bboxes |
| | """ |
| | tile_input_bboxes, tile_output_bboxes = [], [] |
| | tile_size = self.tile_size |
| | pad = self.pad |
| | num_height_tiles = math.ceil((h - 2 * pad) / tile_size) |
| | num_width_tiles = math.ceil((w - 2 * pad) / tile_size) |
| | |
| | |
| | num_height_tiles = max(num_height_tiles, 1) |
| | num_width_tiles = max(num_width_tiles, 1) |
| |
|
| | |
| | real_tile_height = math.ceil((h - 2 * pad) / num_height_tiles) |
| | real_tile_width = math.ceil((w - 2 * pad) / num_width_tiles) |
| | real_tile_height = self.get_best_tile_size(real_tile_height, tile_size) |
| | real_tile_width = self.get_best_tile_size(real_tile_width, tile_size) |
| |
|
| | print(f'[Tiled VAE]: split to {num_height_tiles}x{num_width_tiles} = {num_height_tiles*num_width_tiles} tiles. ' + |
| | f'Optimal tile size {real_tile_width}x{real_tile_height}, original tile size {tile_size}x{tile_size}') |
| |
|
| | for i in range(num_height_tiles): |
| | for j in range(num_width_tiles): |
| | |
| | |
| | input_bbox = [ |
| | pad + j * real_tile_width, |
| | min(pad + (j + 1) * real_tile_width, w), |
| | pad + i * real_tile_height, |
| | min(pad + (i + 1) * real_tile_height, h), |
| | ] |
| |
|
| | |
| | output_bbox = [ |
| | input_bbox[0] if input_bbox[0] > pad else 0, |
| | input_bbox[1] if input_bbox[1] < w - pad else w, |
| | input_bbox[2] if input_bbox[2] > pad else 0, |
| | input_bbox[3] if input_bbox[3] < h - pad else h, |
| | ] |
| |
|
| | |
| | output_bbox = [x * 8 if self.is_decoder else x // 8 for x in output_bbox] |
| | tile_output_bboxes.append(output_bbox) |
| |
|
| | |
| | tile_input_bboxes.append([ |
| | max(0, input_bbox[0] - pad), |
| | min(w, input_bbox[1] + pad), |
| | max(0, input_bbox[2] - pad), |
| | min(h, input_bbox[3] + pad), |
| | ]) |
| |
|
| | return tile_input_bboxes, tile_output_bboxes |
| |
|
| | @torch.no_grad() |
| | def estimate_group_norm(self, z, task_queue, color_fix): |
| | device = z.device |
| | tile = z |
| | last_id = len(task_queue) - 1 |
| | while last_id >= 0 and task_queue[last_id][0] != 'pre_norm': |
| | last_id -= 1 |
| | if last_id <= 0 or task_queue[last_id][0] != 'pre_norm': |
| | raise ValueError('No group norm found in the task queue') |
| | |
| | for i in range(last_id + 1): |
| | task = task_queue[i] |
| | if task[0] == 'pre_norm': |
| | group_norm_func = GroupNormParam.from_tile(tile, task[1]) |
| | task_queue[i] = ('apply_norm', group_norm_func) |
| | if i == last_id: |
| | return True |
| | tile = group_norm_func(tile) |
| | elif task[0] == 'store_res': |
| | task_id = i + 1 |
| | while task_id < last_id and task_queue[task_id][0] != 'add_res': |
| | task_id += 1 |
| | if task_id >= last_id: |
| | continue |
| | task_queue[task_id][1] = task[1](tile) |
| | elif task[0] == 'add_res': |
| | tile += task[1].to(device) |
| | task[1] = None |
| | elif color_fix and task[0] == 'downsample': |
| | for j in range(i, last_id + 1): |
| | if task_queue[j][0] == 'store_res': |
| | task_queue[j] = ('store_res_cpu', task_queue[j][1]) |
| | return True |
| | else: |
| | tile = task[1](tile) |
| | try: |
| | devices.test_for_nans(tile, "vae") |
| | except: |
| | print(f'Nan detected in fast mode estimation. Fast mode disabled.') |
| | return False |
| |
|
| | raise IndexError('Should not reach here') |
| |
|
| | @perfcount |
| | @torch.no_grad() |
| | def vae_tile_forward(self, z): |
| | """ |
| | Decode a latent vector z into an image in a tiled manner. |
| | @param z: latent vector |
| | @return: image |
| | """ |
| | device = next(self.net.parameters()).device |
| | net = self.net |
| | tile_size = self.tile_size |
| | is_decoder = self.is_decoder |
| |
|
| | z = z.detach() |
| |
|
| | N, height, width = z.shape[0], z.shape[2], z.shape[3] |
| | net.last_z_shape = z.shape |
| |
|
| | |
| | print(f'[Tiled VAE]: input_size: {z.shape}, tile_size: {tile_size}, padding: {self.pad}') |
| |
|
| | in_bboxes, out_bboxes = self.split_tiles(height, width) |
| |
|
| | |
| | tiles = [] |
| | for input_bbox in in_bboxes: |
| | tile = z[:, :, input_bbox[2]:input_bbox[3], input_bbox[0]:input_bbox[1]].cpu() |
| | tiles.append(tile) |
| |
|
| | num_tiles = len(tiles) |
| | num_completed = 0 |
| |
|
| | |
| | single_task_queue = build_task_queue(net, is_decoder) |
| | |
| | if self.fast_mode: |
| | |
| | |
| | scale_factor = tile_size / max(height, width) |
| | z = z.to(device) |
| | downsampled_z = F.interpolate(z, scale_factor=scale_factor, mode='nearest-exact') |
| | |
| | print(f'[Tiled VAE]: Fast mode enabled, estimating group norm parameters on {downsampled_z.shape[3]} x {downsampled_z.shape[2]} image') |
| |
|
| | |
| | |
| | std_old, mean_old = torch.std_mean(z, dim=[0, 2, 3], keepdim=True) |
| | std_new, mean_new = torch.std_mean(downsampled_z, dim=[0, 2, 3], keepdim=True) |
| | downsampled_z = (downsampled_z - mean_new) / std_new * std_old + mean_old |
| | del std_old, mean_old, std_new, mean_new |
| | |
| | |
| | downsampled_z = torch.clamp_(downsampled_z, min=z.min(), max=z.max()) |
| | estimate_task_queue = clone_task_queue(single_task_queue) |
| | if self.estimate_group_norm(downsampled_z, estimate_task_queue, color_fix=self.color_fix): |
| | single_task_queue = estimate_task_queue |
| | del downsampled_z |
| |
|
| | task_queues = [clone_task_queue(single_task_queue) for _ in range(num_tiles)] |
| |
|
| | |
| | result = None |
| | result_approx = None |
| | |
| | |
| | |
| | |
| | |
| | del z |
| |
|
| | |
| | pbar = tqdm(total=num_tiles * len(task_queues[0]), desc=f"[Tiled VAE]: Executing {'Decoder' if is_decoder else 'Encoder'} Task Queue: ") |
| |
|
| | |
| | |
| | forward = True |
| | interrupted = False |
| | |
| | while True: |
| | |
| |
|
| | group_norm_param = GroupNormParam() |
| | for i in range(num_tiles) if forward else reversed(range(num_tiles)): |
| | |
| |
|
| | tile = tiles[i].to(device) |
| | input_bbox = in_bboxes[i] |
| | task_queue = task_queues[i] |
| |
|
| | interrupted = False |
| | while len(task_queue) > 0: |
| | |
| |
|
| | |
| | |
| | task = task_queue.pop(0) |
| | if task[0] == 'pre_norm': |
| | group_norm_param.add_tile(tile, task[1]) |
| | break |
| | elif task[0] == 'store_res' or task[0] == 'store_res_cpu': |
| | task_id = 0 |
| | res = task[1](tile) |
| | if not self.fast_mode or task[0] == 'store_res_cpu': |
| | res = res.cpu() |
| | while task_queue[task_id][0] != 'add_res': |
| | task_id += 1 |
| | task_queue[task_id][1] = res |
| | elif task[0] == 'add_res': |
| | tile += task[1].to(device) |
| | task[1] = None |
| | else: |
| | tile = task[1](tile) |
| | pbar.update(1) |
| |
|
| | if interrupted: break |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | if len(task_queue) == 0: |
| | tiles[i] = None |
| | num_completed += 1 |
| | if result is None: |
| | result = torch.zeros((N, tile.shape[1], height * 8 if is_decoder else height // 8, width * 8 if is_decoder else width // 8), device=device, requires_grad=False) |
| | result[:, :, out_bboxes[i][2]:out_bboxes[i][3], out_bboxes[i][0]:out_bboxes[i][1]] = crop_valid_region(tile, in_bboxes[i], out_bboxes[i], is_decoder) |
| | del tile |
| | elif i == num_tiles - 1 and forward: |
| | forward = False |
| | tiles[i] = tile |
| | elif i == 0 and not forward: |
| | forward = True |
| | tiles[i] = tile |
| | else: |
| | tiles[i] = tile.cpu() |
| | del tile |
| |
|
| | if interrupted: break |
| | if num_completed == num_tiles: break |
| |
|
| | |
| | group_norm_func = group_norm_param.summary() |
| | if group_norm_func is not None: |
| | for i in range(num_tiles): |
| | task_queue = task_queues[i] |
| | task_queue.insert(0, ('apply_norm', group_norm_func)) |
| |
|
| | |
| | pbar.close() |
| | return result if result is not None else result_approx.to(device) |