| | import torch |
| | from ..models import SDUNet, SDMotionModel |
| | from ..models.sd_unet import PushBlock, PopBlock, ResnetBlock, AttentionBlock |
| | from ..models.tiler import TileWorker |
| | from ..controlnets import MultiControlNetManager |
| |
|
| |
|
| | def lets_dance( |
| | unet: SDUNet, |
| | motion_modules: SDMotionModel = None, |
| | controlnet: MultiControlNetManager = None, |
| | sample = None, |
| | timestep = None, |
| | encoder_hidden_states = None, |
| | controlnet_frames = None, |
| | unet_batch_size = 1, |
| | controlnet_batch_size = 1, |
| | cross_frame_attention = False, |
| | tiled=False, |
| | tile_size=64, |
| | tile_stride=32, |
| | device = "cuda", |
| | vram_limit_level = 0, |
| | ): |
| | |
| | |
| | |
| | controlnet_insert_block_id = 30 |
| | if controlnet is not None and controlnet_frames is not None: |
| | res_stacks = [] |
| | |
| | for batch_id in range(0, sample.shape[0], controlnet_batch_size): |
| | batch_id_ = min(batch_id + controlnet_batch_size, sample.shape[0]) |
| | res_stack = controlnet( |
| | sample[batch_id: batch_id_], |
| | timestep, |
| | encoder_hidden_states[batch_id: batch_id_], |
| | controlnet_frames[:, batch_id: batch_id_], |
| | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| | ) |
| | if vram_limit_level >= 1: |
| | res_stack = [res.cpu() for res in res_stack] |
| | res_stacks.append(res_stack) |
| | |
| | additional_res_stack = [] |
| | for i in range(len(res_stacks[0])): |
| | res = torch.concat([res_stack[i] for res_stack in res_stacks], dim=0) |
| | additional_res_stack.append(res) |
| | else: |
| | additional_res_stack = None |
| |
|
| | |
| | time_emb = unet.time_proj(timestep[None]).to(sample.dtype) |
| | time_emb = unet.time_embedding(time_emb) |
| |
|
| | |
| | height, width = sample.shape[2], sample.shape[3] |
| | hidden_states = unet.conv_in(sample) |
| | text_emb = encoder_hidden_states |
| | res_stack = [hidden_states.cpu() if vram_limit_level>=1 else hidden_states] |
| |
|
| | |
| | for block_id, block in enumerate(unet.blocks): |
| | |
| | if isinstance(block, PushBlock): |
| | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
| | if vram_limit_level>=1: |
| | res_stack[-1] = res_stack[-1].cpu() |
| | elif isinstance(block, PopBlock): |
| | if vram_limit_level>=1: |
| | res_stack[-1] = res_stack[-1].to(device) |
| | hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack) |
| | else: |
| | hidden_states_input = hidden_states |
| | hidden_states_output = [] |
| | for batch_id in range(0, sample.shape[0], unet_batch_size): |
| | batch_id_ = min(batch_id + unet_batch_size, sample.shape[0]) |
| | hidden_states, _, _, _ = block( |
| | hidden_states_input[batch_id: batch_id_], |
| | time_emb, |
| | text_emb[batch_id: batch_id_], |
| | res_stack, |
| | cross_frame_attention=cross_frame_attention, |
| | tiled=tiled, tile_size=tile_size, tile_stride=tile_stride |
| | ) |
| | hidden_states_output.append(hidden_states) |
| | hidden_states = torch.concat(hidden_states_output, dim=0) |
| | |
| | if motion_modules is not None: |
| | if block_id in motion_modules.call_block_id: |
| | motion_module_id = motion_modules.call_block_id[block_id] |
| | hidden_states, time_emb, text_emb, res_stack = motion_modules.motion_modules[motion_module_id]( |
| | hidden_states, time_emb, text_emb, res_stack, |
| | batch_size=1 |
| | ) |
| | |
| | if block_id == controlnet_insert_block_id and additional_res_stack is not None: |
| | hidden_states += additional_res_stack.pop().to(device) |
| | if vram_limit_level>=1: |
| | res_stack = [(res.to(device) + additional_res.to(device)).cpu() for res, additional_res in zip(res_stack, additional_res_stack)] |
| | else: |
| | res_stack = [res + additional_res for res, additional_res in zip(res_stack, additional_res_stack)] |
| | |
| | |
| | hidden_states = unet.conv_norm_out(hidden_states) |
| | hidden_states = unet.conv_act(hidden_states) |
| | hidden_states = unet.conv_out(hidden_states) |
| |
|
| | return hidden_states |
| |
|