import torch import torch.nn as nn from refnet.modules.transformer import BasicTransformerBlock, SelfInjectedTransformerBlock from refnet.util import checkpoint_wrapper """ This implementation refers to Multi-ControlNet, thanks for the authors Paper: Adding Conditional Control to Text-to-Image Diffusion Models Link: https://github.com/Mikubill/sd-webui-controlnet """ def exists(v): return v is not None def torch_dfs(model: nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result class AutoMachine(): Read = "read" Write = "write" """ This class controls the attentions of reference unet and denoising unet """ class ReferenceAttentionControl: writer_modules = [] reader_modules = [] def __init__( self, reader_module, writer_module, time_embed_ch = 0, only_decoder = True, *args, **kwargs ): self.time_embed_ch = time_embed_ch self.trainable_layers = [] self.only_decoder = only_decoder self.hooked = False self.register("read", reader_module) self.register("write", writer_module) if time_embed_ch > 0: self.insert_time_emb_proj(reader_module) def insert_time_emb_proj(self, unet): for module in torch_dfs(unet.output_blocks if self.only_decoder else unet): if isinstance(module, BasicTransformerBlock): module.time_proj = nn.Linear(self.time_embed_ch, module.dim) self.trainable_layers.append(module.time_proj) def register(self, mode, unet): @checkpoint_wrapper def transformer_forward_write(self, x, context=None, mask=None, emb=None, **kwargs): x_in = self.norm1(x) x = self.attn1(x_in) + x if not self.disable_cross_attn: x = self.attn2(self.norm2(x), context) + x x = self.ff(self.norm3(x)) + x self.bank = x_in return x @checkpoint_wrapper def transformer_forward_read(self, x, context=None, mask=None, emb=None, **kwargs): if exists(self.bank): bank = self.bank if bank.shape[0] != x.shape[0]: bank = bank.repeat(x.shape[0], 1, 1) if hasattr(self, "time_proj"): bank = bank + self.time_proj(emb).unsqueeze(1) x_in = self.norm1(x) x = self.attn1( x = x_in, context = torch.cat([x_in, bank], 1), mask = mask, scale_factor = self.scale_factor, **kwargs ) + x x = self.attn2( x = self.norm2(x), context = context, mask = mask, scale = self.reference_scale, scale_factor = self.scale_factor ) + x x = self.ff(self.norm3(x)) + x else: x = self.original_forward(x, context, mask, emb) return x assert mode in ["write", "read"] if mode == "read": self.hooked = True for module in torch_dfs(unet.output_blocks if self.only_decoder else unet): if isinstance(module, BasicTransformerBlock): if mode == "write": module.original_forward = module.forward module.forward = transformer_forward_write.__get__(module, BasicTransformerBlock) self.writer_modules.append(module) else: if not isinstance(module, SelfInjectedTransformerBlock): print(f"Hooking transformer block {module.__class__.__name__} for read mode") module.original_forward = module.forward module.forward = transformer_forward_read.__get__(module, BasicTransformerBlock) self.reader_modules.append(module) def update(self): for idx in range(len(self.writer_modules)): self.reader_modules[idx].bank = self.writer_modules[idx].bank def restore(self): for idx in range(len(self.writer_modules)): self.writer_modules[idx].forward = self.writer_modules[idx].original_forward self.reader_modules[idx].forward = self.reader_modules[idx].original_forward self.reader_modules[idx].bank = None self.hooked = False def clean(self): for idx in range(len(self.reader_modules)): self.reader_modules[idx].bank = None for idx in range(len(self.writer_modules)): self.writer_modules[idx].bank = None self.hooked = False def reader_restore(self): for idx in range(len(self.reader_modules)): self.reader_modules[idx].forward = self.reader_modules[idx].original_forward self.reader_modules[idx].bank = None self.hooked = False def get_trainable_layers(self): return self.trainable_layers """ This class is for self-injection inside the denoising unet """ class UnetHook: def __init__(self): super().__init__() self.attention_auto_machine = AutoMachine.Read def enhance_reference( self, model, ldm, bs, s, r, style_cfg=0.5, control_cfg=0, gr_indice=None, injection=False, start_step=0, ): def forward(self, x, t, control, context, **kwargs): if 1 - t[0] / (ldm.num_timesteps - 1) >= outer.start_step: # Write outer.attention_auto_machine = AutoMachine.Write rx = ldm.add_noise(outer.r.cpu(), torch.round(t.float()).long().cpu()).cuda().to(x.dtype) self.original_forward(rx, t, control=outer.s, context=context, **kwargs) # Read outer.attention_auto_machine = AutoMachine.Read return self.original_forward(x, t, control=control, context=context, **kwargs) def hacked_basic_transformer_inner_forward(self, x, context=None, mask=None, emb=None, **kwargs): x_norm1 = self.norm1(x) self_attn1 = None if self.disable_self_attn: # Do not use self-attention self_attn1 = self.attn1(x_norm1, context=context, **kwargs) else: # Use self-attention self_attention_context = x_norm1 if outer.attention_auto_machine == AutoMachine.Write: self.bank.append(self_attention_context.detach().clone()) self.style_cfgs.append(outer.current_style_fidelity) if outer.attention_auto_machine == AutoMachine.Read: if len(self.bank) > 0: style_cfg = sum(self.style_cfgs) / float(len(self.style_cfgs)) self_attn1_uc = self.attn1( x_norm1, context=torch.cat([self_attention_context] + self.bank, dim=1), **kwargs ) self_attn1_c = self_attn1_uc.clone() if len(outer.current_uc_indices) > 0 and style_cfg > 1e-5: self_attn1_c[outer.current_uc_indices] = self.attn1( x_norm1[outer.current_uc_indices], context=self_attention_context[outer.current_uc_indices], **kwargs ) self_attn1 = style_cfg * self_attn1_c + (1.0 - style_cfg) * self_attn1_uc self.bank = [] self.style_cfgs = [] if self_attn1 is None: self_attn1 = self.attn1(x_norm1, context=self_attention_context) x = self_attn1.to(x.dtype) + x x = self.attn2(self.norm2(x), context, mask, self.reference_scale, self.scale_factor, **kwargs) + x x = self.ff(self.norm3(x)) + x return x self.s = [s.repeat(bs, 1, 1, 1) * control_cfg for s in ldm.control_encoder(s)] self.r = r self.injection = injection self.start_step = start_step self.current_uc_indices = gr_indice self.current_style_fidelity = style_cfg outer = self model = model.diffusion_model model.original_forward = model.forward # TODO: change the class name to target model.forward = forward.__get__(model, model.__class__) all_modules = torch_dfs(model) for module in all_modules: if isinstance(module, BasicTransformerBlock): module._unet_hook_original_forward = module.forward module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) module.bank = [] module.style_cfgs = [] def restore(self, model): model = model.diffusion_model if hasattr(model, "original_forward"): model.forward = model.original_forward del model.original_forward all_modules = torch_dfs(model) for module in all_modules: if isinstance(module, BasicTransformerBlock): if hasattr(module, "_unet_hook_original_forward"): module.forward = module._unet_hook_original_forward del module._unet_hook_original_forward if hasattr(module, "bank"): module.bank = None if hasattr(module, "style_cfgs"): del module.style_cfgs