| import torch |
| from ..models import SDUNet, SDMotionModel |
| from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock |
| from ..models.tiler import TileWorker |
| from ..controlnets import MultiControlNetManager |
|
|
|
|
| def lets_dance( |
| unet: SDUNet, |
| motion_modules: SDMotionModel = None, |
| controlnet: MultiControlNetManager = None, |
| sample = None, |
| timestep = None, |
| encoder_hidden_states = None, |
| controlnet_frames = None, |
| unet_batch_size = 1, |
| controlnet_batch_size = 1, |
| cross_frame_attention = False, |
| tiled=False, |
| tile_size=64, |
| tile_stride=32, |
| device = "cuda", |
| vram_limit_level = 0, |
| ): |
| |
| |
| |
| controlnet_insert_block_id = 30 |
| if controlnet is not None and controlnet_frames is not None: |
| res_stacks = [] |
| |
| for batch_id in range(0, sample.shape[0], controlnet_batch_size): |
| batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) |
| res_stack = controlnet( |
| sample[batch_id: batch_id_], |
| timestep, |
| encoder_hidden_states[batch_id: batch_id_], |
| controlnet_frames[:, batch_id: batch_id_], |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| ) |
| if vram_limit_level >= 1: |
| res_stack = [res.cpu() for res in res_stack] |
| res_stacks.append(res_stack) |
| |
| additional_res_stack = [] |
| for i in range(len(res_stacks[0])): |
| res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) |
| additional_res_stack.append(res) |
| else: |
| additional_res_stack = None |
|
|
| |
| time_emb = unet.time_proj(timestep[None]).to(sample.dtype) |
| time_emb = unet.time_embedding(time_emb) |
|
|
| |
| height, width = sample.shape[2], sample.shape[3] |
| hidden_states = unet.conv_in(sample) |
| text_emb = encoder_hidden_states |
| res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] |
|
|
| |
| for block_id, block in enumerate(unet.blocks): |
| |
| if isinstance(block, PushBlock): |
| hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
| if vram_limit_level>=1: |
| res_stack[-1] = res_stack[-1].cpu() |
| elif isinstance(block, PopBlock): |
| if vram_limit_level>=1: |
| res_stack[-1] = res_stack[-1].to(device) |
| hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
| else: |
| hidden_states_input = hidden_states |
| hidden_states_output = [] |
| for batch_id in range(0, sample.shape[0], unet_batch_size): |
| batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) |
| hidden_states, _, _, _ = block( |
| hidden_states_input[batch_id: batch_id_], |
| time_emb, |
| text_emb[batch_id: batch_id_], |
| res_stack, |
| cross_frame_attention=cross_frame_attention, |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| ) |
| hidden_states_output.append(hidden_states) |
| hidden_states = torch.concat(hidden_states_output, dim=0) |
| |
| if motion_modules is not None: |
| if block_id in motion_modules.call_block_id: |
| motion_module_id = motion_modules.call_block_id[block_id] |
| hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( |
| hidden_states, time_emb, text_emb, res_stack, |
| batch_size=1 |
| ) |
| |
| if block_id == controlnet_insert_block_id and additional_res_stack is not None: |
| hidden_states += additional_res_stack.pop().to(device) |
| if vram_limit_level>=1: |
| res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] |
| else: |
| res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] |
| |
| |
| hidden_states = unet.conv_norm_out(hidden_states) |
| hidden_states = unet.conv_act(hidden_states) |
| hidden_states = unet.conv_out(hidden_states) |
|
|
| return hidden_states |
|
|