|
|
| from typing import Any, Dict, Optional
|
|
|
| import torch
|
| from einops import rearrange
|
|
|
| from .attention import BasicTransformerBlock
|
| from .attention import TemporalBasicTransformerBlock
|
|
|
|
|
| def torch_dfs(model: torch.nn.Module):
|
| result = [model]
|
| for child in model.children():
|
| result += torch_dfs(child)
|
| return result
|
|
|
|
|
| class ReferenceAttentionControl:
|
| def __init__(
|
| self,
|
| unet,
|
| mode="write",
|
| do_classifier_free_guidance=False,
|
| attention_auto_machine_weight=float("inf"),
|
| gn_auto_machine_weight=1.0,
|
| style_fidelity=1.0,
|
| reference_attn=True,
|
| reference_adain=False,
|
| fusion_blocks="midup",
|
| batch_size=1,
|
| reference_attention_weight=1.,
|
| audio_attention_weight=1.,
|
| ) -> None:
|
|
|
| self.unet = unet
|
| assert mode in ["read", "write"]
|
| assert fusion_blocks in ["midup", "full"]
|
| self.reference_attn = reference_attn
|
| self.reference_adain = reference_adain
|
| self.fusion_blocks = fusion_blocks
|
| self.reference_attention_weight = reference_attention_weight
|
| self.audio_attention_weight = audio_attention_weight
|
| self.register_reference_hooks(
|
| mode,
|
| do_classifier_free_guidance,
|
| attention_auto_machine_weight,
|
| gn_auto_machine_weight,
|
| style_fidelity,
|
| reference_attn,
|
| reference_adain,
|
| fusion_blocks,
|
| batch_size=batch_size,
|
| )
|
|
|
| def register_reference_hooks(
|
| self,
|
| mode,
|
| do_classifier_free_guidance,
|
| attention_auto_machine_weight,
|
| gn_auto_machine_weight,
|
| style_fidelity,
|
| reference_attn,
|
| reference_adain,
|
| dtype=torch.float16,
|
| batch_size=1,
|
| num_images_per_prompt=1,
|
| device=torch.device("cpu"),
|
| fusion_blocks="midup",
|
| ):
|
| MODE = mode
|
| do_classifier_free_guidance = do_classifier_free_guidance
|
| attention_auto_machine_weight = attention_auto_machine_weight
|
| gn_auto_machine_weight = gn_auto_machine_weight
|
| style_fidelity = style_fidelity
|
| reference_attn = reference_attn
|
| reference_adain = reference_adain
|
| fusion_blocks = fusion_blocks
|
| num_images_per_prompt = num_images_per_prompt
|
| reference_attention_weight = self.reference_attention_weight
|
| audio_attention_weight = self.audio_attention_weight
|
| dtype = dtype
|
| if do_classifier_free_guidance:
|
| uc_mask = (
|
| torch.Tensor(
|
| [1] * batch_size * num_images_per_prompt * 16
|
| + [0] * batch_size * num_images_per_prompt * 16
|
| )
|
| .to(device)
|
| .bool()
|
| )
|
| else:
|
| uc_mask = (
|
| torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
|
| .to(device)
|
| .bool()
|
| )
|
|
|
| def hacked_basic_transformer_inner_forward(
|
| self,
|
| hidden_states: torch.FloatTensor,
|
| attention_mask: Optional[torch.FloatTensor] = None,
|
| encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| timestep: Optional[torch.LongTensor] = None,
|
| cross_attention_kwargs: Dict[str, Any] = None,
|
| class_labels: Optional[torch.LongTensor] = None,
|
| video_length=None,
|
| ):
|
| if self.use_ada_layer_norm:
|
| norm_hidden_states = self.norm1(hidden_states, timestep)
|
| elif self.use_ada_layer_norm_zero:
|
| (
|
| norm_hidden_states,
|
| gate_msa,
|
| shift_mlp,
|
| scale_mlp,
|
| gate_mlp,
|
| ) = self.norm1(
|
| hidden_states,
|
| timestep,
|
| class_labels,
|
| hidden_dtype=hidden_states.dtype,
|
| )
|
| else:
|
| norm_hidden_states = self.norm1(hidden_states)
|
|
|
|
|
|
|
| cross_attention_kwargs = (
|
| cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
| )
|
| if self.only_cross_attention:
|
| attn_output = self.attn1(
|
| norm_hidden_states,
|
| encoder_hidden_states=encoder_hidden_states
|
| if self.only_cross_attention
|
| else None,
|
| attention_mask=attention_mask,
|
| **cross_attention_kwargs,
|
| )
|
| else:
|
| if MODE == "write":
|
| attn_output = self.attn1(
|
| norm_hidden_states,
|
| encoder_hidden_states=encoder_hidden_states
|
| if self.only_cross_attention
|
| else None,
|
| attention_mask=attention_mask,
|
| **cross_attention_kwargs,
|
| )
|
|
|
| if self.use_ada_layer_norm_zero:
|
| attn_output = gate_msa.unsqueeze(1) * attn_output
|
| hidden_states = attn_output + hidden_states
|
|
|
| if self.attn2 is not None:
|
| norm_hidden_states = (
|
| self.norm2(hidden_states, timestep)
|
| if self.use_ada_layer_norm
|
| else self.norm2(hidden_states)
|
| )
|
| self.bank.append(norm_hidden_states.clone())
|
|
|
|
|
| attn_output = self.attn2(
|
| norm_hidden_states,
|
| encoder_hidden_states=encoder_hidden_states,
|
| attention_mask=encoder_attention_mask,
|
| **cross_attention_kwargs,
|
| )
|
| hidden_states = attn_output + hidden_states
|
|
|
| if MODE == "read":
|
| hidden_states = (
|
| self.attn1(
|
| norm_hidden_states,
|
| encoder_hidden_states=norm_hidden_states,
|
| attention_mask=attention_mask,
|
| )
|
| + hidden_states
|
| )
|
|
|
| if self.use_ada_layer_norm:
|
| norm_hidden_states = self.norm1_5(hidden_states, timestep)
|
| elif self.use_ada_layer_norm_zero:
|
| (
|
| norm_hidden_states,
|
| gate_msa,
|
| shift_mlp,
|
| scale_mlp,
|
| gate_mlp,
|
| ) = self.norm1_5(
|
| hidden_states,
|
| timestep,
|
| class_labels,
|
| hidden_dtype=hidden_states.dtype,
|
| )
|
| else:
|
| norm_hidden_states = self.norm1_5(hidden_states)
|
|
|
| bank_fea = []
|
| for d in self.bank:
|
| if len(d.shape) == 3:
|
| d = d.unsqueeze(1).repeat(1, video_length, 1, 1)
|
| bank_fea.append(rearrange(d, "b t l c -> (b t) l c"))
|
|
|
| attn_hidden_states = self.attn1_5(
|
| norm_hidden_states,
|
| encoder_hidden_states=bank_fea[0],
|
| attention_mask=attention_mask,
|
| )
|
|
|
| if reference_attention_weight != 1.:
|
| attn_hidden_states *= reference_attention_weight
|
|
|
| hidden_states = (attn_hidden_states + hidden_states)
|
|
|
|
|
| if self.attn2 is not None:
|
|
|
| norm_hidden_states = (
|
| self.norm2(hidden_states, timestep)
|
| if self.use_ada_layer_norm
|
| else self.norm2(hidden_states)
|
| )
|
|
|
| attn_hidden_states = self.attn2(
|
| norm_hidden_states,
|
| encoder_hidden_states=encoder_hidden_states,
|
| attention_mask=attention_mask,
|
| )
|
|
|
| if audio_attention_weight != 1.:
|
| attn_hidden_states *= audio_attention_weight
|
|
|
| hidden_states = (attn_hidden_states + hidden_states)
|
|
|
|
|
| hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
|
|
|
|
| if self.unet_use_temporal_attention:
|
| d = hidden_states.shape[1]
|
| hidden_states = rearrange(
|
| hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
| )
|
| norm_hidden_states = (
|
| self.norm_temp(hidden_states, timestep)
|
| if self.use_ada_layer_norm
|
| else self.norm_temp(hidden_states)
|
| )
|
| hidden_states = (
|
| self.attn_temp(norm_hidden_states) + hidden_states
|
| )
|
| hidden_states = rearrange(
|
| hidden_states, "(b d) f c -> (b f) d c", d=d
|
| )
|
|
|
| return hidden_states
|
|
|
|
|
| norm_hidden_states = self.norm3(hidden_states)
|
|
|
| if self.use_ada_layer_norm_zero:
|
| norm_hidden_states = (
|
| norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
| )
|
|
|
| ff_output = self.ff(norm_hidden_states)
|
|
|
| if self.use_ada_layer_norm_zero:
|
| ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
|
| hidden_states = ff_output + hidden_states
|
|
|
| return hidden_states
|
|
|
| if self.reference_attn:
|
| if self.fusion_blocks == "midup":
|
| attn_modules = [
|
| module
|
| for module in (
|
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
|
| )
|
| if isinstance(module, BasicTransformerBlock)
|
| or isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| elif self.fusion_blocks == "full":
|
| attn_modules = [
|
| module
|
| for module in torch_dfs(self.unet)
|
| if isinstance(module, BasicTransformerBlock)
|
| or isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| attn_modules = sorted(
|
| attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
| )
|
|
|
| for i, module in enumerate(attn_modules):
|
| module._original_inner_forward = module.forward
|
| if isinstance(module, BasicTransformerBlock):
|
| module.forward = hacked_basic_transformer_inner_forward.__get__(
|
| module, BasicTransformerBlock
|
| )
|
| if isinstance(module, TemporalBasicTransformerBlock):
|
| module.forward = hacked_basic_transformer_inner_forward.__get__(
|
| module, TemporalBasicTransformerBlock
|
| )
|
|
|
| module.bank = []
|
| module.attn_weight = float(i) / float(len(attn_modules))
|
|
|
| def update(
|
| self,
|
| writer,
|
| do_classifier_free_guidance=True,
|
| dtype=torch.float16,
|
| ):
|
| if self.reference_attn:
|
| if self.fusion_blocks == "midup":
|
| reader_attn_modules = [
|
| module
|
| for module in (torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks))
|
| if isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| writer_attn_modules = [
|
| module
|
| for module in (torch_dfs(writer.unet.mid_block) + torch_dfs(writer.unet.up_blocks))
|
| if isinstance(module, BasicTransformerBlock)
|
| ]
|
| elif self.fusion_blocks == "full":
|
| reader_attn_modules = [
|
| module
|
| for module in torch_dfs(self.unet)
|
| if isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| writer_attn_modules = [
|
| module
|
| for module in torch_dfs(writer.unet)
|
| if isinstance(module, BasicTransformerBlock)
|
| ]
|
| reader_attn_modules = sorted(
|
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
| )
|
| writer_attn_modules = sorted(
|
| writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
| )
|
| for r, w in zip(reader_attn_modules, writer_attn_modules):
|
| if do_classifier_free_guidance:
|
| r.bank = [torch.cat([torch.zeros_like(v), v]).to(dtype) for v in w.bank]
|
| else:
|
| r.bank = [v.clone().to(dtype) for v in w.bank]
|
|
|
| def clear(self):
|
| if self.reference_attn:
|
| if self.fusion_blocks == "midup":
|
| reader_attn_modules = [
|
| module
|
| for module in (
|
| torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
|
| )
|
| if isinstance(module, BasicTransformerBlock)
|
| or isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| elif self.fusion_blocks == "full":
|
| reader_attn_modules = [
|
| module
|
| for module in torch_dfs(self.unet)
|
| if isinstance(module, BasicTransformerBlock)
|
| or isinstance(module, TemporalBasicTransformerBlock)
|
| ]
|
| reader_attn_modules = sorted(
|
| reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
|
| )
|
| for r in reader_attn_modules:
|
| r.bank.clear()
|
|
|