from dataclasses import dataclass from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn import torch.utils.checkpoint from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers # from diffusers.models.modeling_utils import ModelMixin from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel # from diffusers.models.unets.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, UNetMidBlock2DCrossAttn # from diffusers.models.resnet import ResnetBlock2D, Downsample2D from diffusers.models.activations import get_activation from diffusers.models.embeddings import TimestepEmbedding from .layers.λ.vanillaλ import MQSλ, MQCλ, DEFAULT_λ_CONFIG from .unet_lambda_dwconv_blocks import ( custom_get_down_block, custom_get_mid_block, custom_get_up_block ) from .layers.unet_blocks.custom_down_blocks import DWMixTFDownBlock2D from .layers.unet_blocks.custom_mid_blocks import DWMixTFMidBlock2D from .layers.unet_blocks.custom_up_blocks import DWMixTFUpBlock2D from .utils import CustomOutput from .layers._efficientnet_blocks import DepthwiseSeparableConv as DWConv2d logger = logging.get_logger(__name__) # pylint: disable=invalid-name class UNet2DLambdaDWConvMixFFNConditionModel_prune_down_mid_up_block_8x8(UNet2DConditionModel): _supports_gradient_checkpointing = True _no_split_modules = [ "BasicTransformerBlock", "MixTransformerBlock", "ResnetBlock2D", "DWResnetBlock2D", "CrossAttnUpBlock2D", "DWTFUpBlock2D", "DWMixTFUpBlock2D"] @register_to_config def __init__( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "DWMixTFDownBlock2D", "DWMixTFDownBlock2D", "DWMixTFDownBlock2D", "DWDownBlock2D", ), mid_block_type: Optional[str] = "DWTFMidBlock2D", up_block_types: Tuple[str] = ( "DWUpBlock2D", "DWMixTFUpBlock2D", "DWMixTFUpBlock2D", "DWMixTFUpBlock2D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, use_lambda_cross_attn=False, use_local_self_attn=True, num_embeddings=20, mix_mlp_ratio=2.5, ): self._init_unet2dcondmodel_blocks( sample_size, in_channels, out_channels, center_input_sample, flip_sin_to_cos, freq_shift, down_block_types, mid_block_type, up_block_types, only_cross_attention, block_out_channels, layers_per_block, downsample_padding, mid_block_scale_factor, dropout, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, reverse_transformer_layers_per_block, encoder_hid_dim, encoder_hid_dim_type, attention_head_dim, num_attention_heads, dual_cross_attention, use_linear_projection, class_embed_type, addition_embed_type, addition_time_embed_dim, num_class_embeds, upcast_attention, resnet_time_scale_shift, resnet_skip_time_act, resnet_out_scale_factor, time_embedding_type, time_embedding_dim, time_embedding_act_fn, timestep_post_act, time_cond_proj_dim, conv_in_kernel, conv_out_kernel, projection_class_embeddings_input_dim, attention_type, class_embeddings_concat, mid_block_only_cross_attention, cross_attention_norm, addition_embed_type_num_heads, ) # print(self.config) ## inject lambda into TD 2D Model cur_hw_size = self.sample_size for i, cur_block in enumerate(self.down_blocks): print(f"{type(cur_block)}") if isinstance(cur_block, DWMixTFDownBlock2D): for j, tfmodel in enumerate(cur_block.attentions): self.inject_lambda_into_tf2dmodel( tfmodel, use_lambda_cross_attn=use_lambda_cross_attn, use_local_self_attn=use_local_self_attn, sample_size=cur_hw_size) if i != len(block_out_channels)-1: cur_hw_size //= 2 if isinstance(self.mid_block, DWMixTFMidBlock2D): cur_block = self.mid_block for j, tfmodel in enumerate(cur_block.attentions): self.inject_lambda_into_tf2dmodel( tfmodel, use_lambda_cross_attn=use_lambda_cross_attn, use_local_self_attn=use_local_self_attn, sample_size=cur_hw_size) for i, cur_block in enumerate(self.up_blocks): if isinstance(cur_block, DWMixTFUpBlock2D): for j, tfmodel in enumerate(cur_block.attentions): self.inject_lambda_into_tf2dmodel( tfmodel, use_lambda_cross_attn=use_lambda_cross_attn, use_local_self_attn=use_local_self_attn, sample_size=cur_hw_size) if i != len(block_out_channels)-1: cur_hw_size *= 2 self.initialize_weights() def _init_unet2dcondmodel_blocks( self, sample_size: Optional[int] = None, in_channels: int = 4, out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, freq_shift: int = 0, down_block_types: Tuple[str] = ( "DWMixTFDownBlock2D", "DWMixTFDownBlock2D", "DWMixTFDownBlock2D", "DWDownBlock2D", ), mid_block_type: Optional[str] = "DWMixTFMidBlock2D", up_block_types: Tuple[str] = ( "DWUpBlock2D", "DWMixTFUpBlock2D", "DWMixTFUpBlock2D", "DWMixTFUpBlock2D" ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, dropout: float = 0.0, act_fn: str = "silu", norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, encoder_hid_dim: Optional[int] = None, encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, addition_embed_type: Optional[str] = None, addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, time_embedding_type: str = "positional", time_embedding_dim: Optional[int] = None, time_embedding_act_fn: Optional[str] = None, timestep_post_act: Optional[str] = None, time_cond_proj_dim: Optional[int] = None, conv_in_kernel: int = 3, conv_out_kernel: int = 3, projection_class_embeddings_input_dim: Optional[int] = None, attention_type: str = "default", class_embeddings_concat: bool = False, mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, mix_mlp_ratio: float = 2.5 ): super(UNet2DConditionModel, self).__init__() self.sample_size = sample_size if num_attention_heads is not None: raise ValueError( "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." ) # If `num_attention_heads` is not defined (which is the case for most models) # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. # The reason for this behavior is to correct for incorrectly named variables that were introduced # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking # which is why we correct for the naming here. num_attention_heads = num_attention_heads or attention_head_dim # Check inputs self._check_config( down_block_types=down_block_types, up_block_types=up_block_types, only_cross_attention=only_cross_attention, block_out_channels=block_out_channels, layers_per_block=layers_per_block, cross_attention_dim=cross_attention_dim, transformer_layers_per_block=transformer_layers_per_block, reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, attention_head_dim=attention_head_dim, num_attention_heads=num_attention_heads, ) # input self.conv_in = DWConv2d(in_channels, block_out_channels[0], dw_kernel_size=conv_in_kernel) # time time_embed_dim, timestep_input_dim = self._set_time_proj( time_embedding_type, block_out_channels=block_out_channels, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, time_embedding_dim=time_embedding_dim, ) self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, ) self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, ) # class embedding self._set_class_embedding( class_embed_type, act_fn=act_fn, num_class_embeds=num_class_embeds, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, timestep_input_dim=timestep_input_dim, ) self._set_add_embedding( addition_embed_type, addition_embed_type_num_heads=addition_embed_type_num_heads, addition_time_embed_dim=addition_time_embed_dim, cross_attention_dim=cross_attention_dim, encoder_hid_dim=encoder_hid_dim, flip_sin_to_cos=flip_sin_to_cos, freq_shift=freq_shift, projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, time_embed_dim=time_embed_dim, ) if time_embedding_act_fn is None: self.time_embed_act = None else: self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): if mid_block_only_cross_attention is None: mid_block_only_cross_attention = only_cross_attention only_cross_attention = [only_cross_attention] * len(down_block_types) if mid_block_only_cross_attention is None: mid_block_only_cross_attention = False if isinstance(num_attention_heads, int): num_attention_heads = (num_attention_heads,) * len(down_block_types) if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) if isinstance(cross_attention_dim, int): cross_attention_dim = (cross_attention_dim,) * len(down_block_types) if isinstance(layers_per_block, int): layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the # time embeddings passed to the down, middle, and up blocks is twice the dimension of the # regular time embeddings blocks_time_embed_dim = time_embed_dim * 2 else: blocks_time_embed_dim = time_embed_dim # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 down_block = custom_get_down_block( down_block_type, num_layers=layers_per_block[i], transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, cross_attention_dim=cross_attention_dim[i], num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, mlp_ratio=mix_mlp_ratio ) self.down_blocks.append(down_block) # mid self.mid_block = custom_get_mid_block( mid_block_type, temb_channels=blocks_time_embed_dim, in_channels=block_out_channels[-1], resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, output_scale_factor=mid_block_scale_factor, transformer_layers_per_block=transformer_layers_per_block[-1], num_attention_heads=num_attention_heads[-1], cross_attention_dim=cross_attention_dim[-1], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, mid_block_only_cross_attention=mid_block_only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[-1], dropout=dropout, mlp_ratio=mix_mlp_ratio ) # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_layers_per_block = list(reversed(layers_per_block)) reversed_cross_attention_dim = list(reversed(cross_attention_dim)) reversed_transformer_layers_per_block = ( list(reversed(transformer_layers_per_block)) if reverse_transformer_layers_per_block is None else reverse_transformer_layers_per_block ) only_cross_attention = list(reversed(only_cross_attention)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): is_final_block = i == len(block_out_channels) - 1 prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: add_upsample = True self.num_upsamplers += 1 else: add_upsample = False up_block = custom_get_up_block( up_block_type, num_layers=reversed_layers_per_block[i] + 1, transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resolution_idx=i, resnet_groups=norm_num_groups, cross_attention_dim=reversed_cross_attention_dim[i], num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, mlp_ratio=mix_mlp_ratio ) self.up_blocks.append(up_block) # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None self.conv_act = None self.conv_out = DWConv2d(block_out_channels[0], out_channels, dw_kernel_size=conv_out_kernel) def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) def forward( self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None # recording each block out samples, for REPA & featKD intermidiate_samples = [] for dim in sample.shape[-2:]: if dim % default_overall_up_factor != 0: # Forward upsample size to force interpolation output size. forward_upsample_size = True break # ensure attention_mask is a bias, and give it a singleton query_tokens dimension # expects mask of shape: # [batch, key_tokens] # adds singleton query_tokens dimension: # [batch, 1, key_tokens] # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: # assume that mask is expressed as: # (1 = keep, 0 = discard) # convert mask into a bias that can be added to attention scores: # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 # 1. time t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) if class_emb is not None: if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb aug_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) encoder_hidden_states = self.process_encoder_hidden_states( encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. if cross_attention_kwargs is not None: cross_attention_kwargs = cross_attention_kwargs.copy() lora_scale = cross_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", standard_warn=False, ) down_intrablock_additional_residuals = down_block_additional_residuals is_adapter = True down_block_res_samples = (sample,) # for downsample_block in self.down_blocks: for i, downsample_block in enumerate(self.down_blocks): if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, **additional_residuals, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) if is_adapter and len(down_intrablock_additional_residuals) > 0: sample += down_intrablock_additional_residuals.pop(0) down_block_res_samples += res_samples intermidiate_samples += [sample] if is_controlnet: new_down_block_res_samples = () for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples down_block_sample = sample # 4. mid if self.mid_block is not None: if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, encoder_attention_mask=encoder_attention_mask, ) else: sample = self.mid_block(sample, emb) # To support T2I-Adapter-XL if ( is_adapter and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape ): sample += down_intrablock_additional_residuals.pop(0) intermidiate_samples += [sample] else: pass if is_controlnet: sample = sample + mid_block_additional_residual mid_block_sample = sample # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, ) intermidiate_samples += [sample] up_block_sample = sample # 6. post-process if self.conv_norm_out: sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (sample,) return CustomOutput(sample=sample, block_outputs=intermidiate_samples) def inject_lambda_into_tf2dmodel(self, tf2dmodel, use_lambda_cross_attn=False, use_local_self_attn=True, sample_size=None): vanilla_tf_cfg = tf2dmodel.config new_config = DEFAULT_λ_CONFIG | dict( n = sample_size, # or vanilla_tf_cfg.sample_size or self.sample_size, dim = vanilla_tf_cfg.in_channels, dim_k = vanilla_tf_cfg.attention_head_dim, heads = vanilla_tf_cfg.num_attention_heads, dim_out = vanilla_tf_cfg.out_channels ) [ ('num_attention_heads', 8), ('attention_head_dim', 40), ('in_channels', 320), ('out_channels', None), ('num_layers', 1), ('dropout', 0.0), ('norm_num_groups', 32), ('cross_attention_dim', 768), ('attention_bias', False), ('sample_size', None), ('num_vector_embeds', None), ('patch_size', None), ('activation_fn', 'geglu'), ('num_embeds_ada_norm', None), ('use_linear_projection', False), ('only_cross_attention', False), ('double_self_attention', False), ('upcast_attention', False), ('norm_type', 'layer_norm'), ('norm_elementwise_affine', True), ('norm_eps', 1e-05), ('attention_type', 'default'), ('caption_channels', None), ('interpolation_scale', None), ('use_additional_conditions', None), ('_use_default_values', [ 'num_attention_heads', 'activation_fn', 'caption_channels', 'norm_eps', 'patch_size', 'attention_bias', 'num_vector_embeds', 'attention_head_dim', 'use_additional_conditions', 'norm_elementwise_affine', 'num_embeds_ada_norm', 'dropout', 'norm_type', 'double_self_attention', 'sample_size', 'out_channels', 'interpolation_scale' ] ) ] for tfblock in tf2dmodel.transformer_blocks: new_SA_config = deepcopy(new_config) if use_local_self_attn: new_SA_config |= dict(r = 15) lambda_sattn = MQSλ(**new_SA_config) tfblock.attn1 = lambda_sattn if use_lambda_cross_attn: new_CA_config = new_config | dict( m = self.config.num_embeddings // 2, dim_cross = vanilla_tf_cfg.cross_attention_dim,) lambda_xattn = MQCλ(**new_CA_config) tfblock.attn2 = lambda_xattn