| from dataclasses import dataclass |
| from copy import deepcopy |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.utils.checkpoint |
|
|
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers |
|
|
| |
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput, UNet2DConditionModel |
| |
| |
|
|
| from diffusers.models.activations import get_activation |
| from diffusers.models.embeddings import TimestepEmbedding |
|
|
|
|
| from .layers.λ.vanillaλ import MQSλ, MQCλ, DEFAULT_λ_CONFIG |
|
|
| from .unet_lambda_dwconv_blocks import ( |
| custom_get_down_block, |
| custom_get_mid_block, |
| custom_get_up_block |
| ) |
| from .layers.unet_blocks.custom_down_blocks import DWMixTFDownBlock2D |
| from .layers.unet_blocks.custom_mid_blocks import DWMixTFMidBlock2D |
| from .layers.unet_blocks.custom_up_blocks import DWMixTFUpBlock2D |
| from .utils import CustomOutput |
|
|
| from .layers._efficientnet_blocks import DepthwiseSeparableConv as DWConv2d |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
|
|
| class UNet2DLambdaDWConvMixFFNConditionModel_prune_down_mid_up_block_8x8(UNet2DConditionModel): |
| _supports_gradient_checkpointing = True |
| _no_split_modules = [ |
| "BasicTransformerBlock", "MixTransformerBlock", |
| "ResnetBlock2D", "DWResnetBlock2D", |
| "CrossAttnUpBlock2D", "DWTFUpBlock2D", "DWMixTFUpBlock2D"] |
| @register_to_config |
| def __init__( |
| self, |
| sample_size: Optional[int] = None, |
| in_channels: int = 4, |
| out_channels: int = 4, |
| center_input_sample: bool = False, |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| down_block_types: Tuple[str] = ( |
| "DWMixTFDownBlock2D", |
| "DWMixTFDownBlock2D", |
| "DWMixTFDownBlock2D", |
| "DWDownBlock2D", |
| ), |
| mid_block_type: Optional[str] = "DWTFMidBlock2D", |
| up_block_types: Tuple[str] = ( |
| "DWUpBlock2D", |
| "DWMixTFUpBlock2D", |
| "DWMixTFUpBlock2D", |
| "DWMixTFUpBlock2D" |
| ), |
| only_cross_attention: Union[bool, Tuple[bool]] = False, |
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
| layers_per_block: Union[int, Tuple[int]] = 2, |
| downsample_padding: int = 1, |
| mid_block_scale_factor: float = 1, |
| dropout: float = 0.0, |
| act_fn: str = "silu", |
| norm_num_groups: Optional[int] = 32, |
| norm_eps: float = 1e-5, |
| cross_attention_dim: Union[int, Tuple[int]] = 1280, |
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, |
| reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, |
| encoder_hid_dim: Optional[int] = None, |
| encoder_hid_dim_type: Optional[str] = None, |
| attention_head_dim: Union[int, Tuple[int]] = 8, |
| num_attention_heads: Optional[Union[int, Tuple[int]]] = None, |
| dual_cross_attention: bool = False, |
| use_linear_projection: bool = False, |
| class_embed_type: Optional[str] = None, |
| addition_embed_type: Optional[str] = None, |
| addition_time_embed_dim: Optional[int] = None, |
| num_class_embeds: Optional[int] = None, |
| upcast_attention: bool = False, |
| resnet_time_scale_shift: str = "default", |
| resnet_skip_time_act: bool = False, |
| resnet_out_scale_factor: float = 1.0, |
| time_embedding_type: str = "positional", |
| time_embedding_dim: Optional[int] = None, |
| time_embedding_act_fn: Optional[str] = None, |
| timestep_post_act: Optional[str] = None, |
| time_cond_proj_dim: Optional[int] = None, |
| conv_in_kernel: int = 3, |
| conv_out_kernel: int = 3, |
| projection_class_embeddings_input_dim: Optional[int] = None, |
| attention_type: str = "default", |
| class_embeddings_concat: bool = False, |
| mid_block_only_cross_attention: Optional[bool] = None, |
| cross_attention_norm: Optional[str] = None, |
| addition_embed_type_num_heads: int = 64, |
| use_lambda_cross_attn=False, |
| use_local_self_attn=True, |
| num_embeddings=20, |
| mix_mlp_ratio=2.5, |
| ): |
| self._init_unet2dcondmodel_blocks( |
| sample_size, |
| in_channels, |
| out_channels, |
| center_input_sample, |
| flip_sin_to_cos, |
| freq_shift, |
| down_block_types, |
| mid_block_type, |
| up_block_types, |
| only_cross_attention, |
| block_out_channels, |
| layers_per_block, |
| downsample_padding, |
| mid_block_scale_factor, |
| dropout, |
| act_fn, |
| norm_num_groups, |
| norm_eps, |
| cross_attention_dim, |
| transformer_layers_per_block, |
| reverse_transformer_layers_per_block, |
| encoder_hid_dim, |
| encoder_hid_dim_type, |
| attention_head_dim, |
| num_attention_heads, |
| dual_cross_attention, |
| use_linear_projection, |
| class_embed_type, |
| addition_embed_type, |
| addition_time_embed_dim, |
| num_class_embeds, |
| upcast_attention, |
| resnet_time_scale_shift, |
| resnet_skip_time_act, |
| resnet_out_scale_factor, |
| time_embedding_type, |
| time_embedding_dim, |
| time_embedding_act_fn, |
| timestep_post_act, |
| time_cond_proj_dim, |
| conv_in_kernel, |
| conv_out_kernel, |
| projection_class_embeddings_input_dim, |
| attention_type, |
| class_embeddings_concat, |
| mid_block_only_cross_attention, |
| cross_attention_norm, |
| addition_embed_type_num_heads, |
| ) |
| |
|
|
| |
| cur_hw_size = self.sample_size |
| for i, cur_block in enumerate(self.down_blocks): |
| print(f"{type(cur_block)}") |
| if isinstance(cur_block, DWMixTFDownBlock2D): |
| for j, tfmodel in enumerate(cur_block.attentions): |
| self.inject_lambda_into_tf2dmodel( |
| tfmodel, |
| use_lambda_cross_attn=use_lambda_cross_attn, |
| use_local_self_attn=use_local_self_attn, |
| sample_size=cur_hw_size) |
|
|
| if i != len(block_out_channels)-1: |
| cur_hw_size //= 2 |
|
|
| if isinstance(self.mid_block, DWMixTFMidBlock2D): |
| cur_block = self.mid_block |
| for j, tfmodel in enumerate(cur_block.attentions): |
| self.inject_lambda_into_tf2dmodel( |
| tfmodel, |
| use_lambda_cross_attn=use_lambda_cross_attn, |
| use_local_self_attn=use_local_self_attn, |
| sample_size=cur_hw_size) |
|
|
| for i, cur_block in enumerate(self.up_blocks): |
| if isinstance(cur_block, DWMixTFUpBlock2D): |
| for j, tfmodel in enumerate(cur_block.attentions): |
| self.inject_lambda_into_tf2dmodel( |
| tfmodel, |
| use_lambda_cross_attn=use_lambda_cross_attn, |
| use_local_self_attn=use_local_self_attn, |
| sample_size=cur_hw_size) |
| if i != len(block_out_channels)-1: |
| cur_hw_size *= 2 |
|
|
| self.initialize_weights() |
|
|
|
|
| def _init_unet2dcondmodel_blocks( |
| self, |
| sample_size: Optional[int] = None, |
| in_channels: int = 4, |
| out_channels: int = 4, |
| center_input_sample: bool = False, |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| down_block_types: Tuple[str] = ( |
| "DWMixTFDownBlock2D", |
| "DWMixTFDownBlock2D", |
| "DWMixTFDownBlock2D", |
| "DWDownBlock2D", |
| ), |
| mid_block_type: Optional[str] = "DWMixTFMidBlock2D", |
| up_block_types: Tuple[str] = ( |
| "DWUpBlock2D", |
| "DWMixTFUpBlock2D", |
| "DWMixTFUpBlock2D", |
| "DWMixTFUpBlock2D" |
| ), |
| only_cross_attention: Union[bool, Tuple[bool]] = False, |
| block_out_channels: Tuple[int] = (320, 640, 1280, 1280), |
| layers_per_block: Union[int, Tuple[int]] = 2, |
| downsample_padding: int = 1, |
| mid_block_scale_factor: float = 1, |
| dropout: float = 0.0, |
| act_fn: str = "silu", |
| norm_num_groups: Optional[int] = 32, |
| norm_eps: float = 1e-5, |
| cross_attention_dim: Union[int, Tuple[int]] = 1280, |
| transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, |
| reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, |
| encoder_hid_dim: Optional[int] = None, |
| encoder_hid_dim_type: Optional[str] = None, |
| attention_head_dim: Union[int, Tuple[int]] = 8, |
| num_attention_heads: Optional[Union[int, Tuple[int]]] = None, |
| dual_cross_attention: bool = False, |
| use_linear_projection: bool = False, |
| class_embed_type: Optional[str] = None, |
| addition_embed_type: Optional[str] = None, |
| addition_time_embed_dim: Optional[int] = None, |
| num_class_embeds: Optional[int] = None, |
| upcast_attention: bool = False, |
| resnet_time_scale_shift: str = "default", |
| resnet_skip_time_act: bool = False, |
| resnet_out_scale_factor: float = 1.0, |
| time_embedding_type: str = "positional", |
| time_embedding_dim: Optional[int] = None, |
| time_embedding_act_fn: Optional[str] = None, |
| timestep_post_act: Optional[str] = None, |
| time_cond_proj_dim: Optional[int] = None, |
| conv_in_kernel: int = 3, |
| conv_out_kernel: int = 3, |
| projection_class_embeddings_input_dim: Optional[int] = None, |
| attention_type: str = "default", |
| class_embeddings_concat: bool = False, |
| mid_block_only_cross_attention: Optional[bool] = None, |
| cross_attention_norm: Optional[str] = None, |
| addition_embed_type_num_heads: int = 64, |
| mix_mlp_ratio: float = 2.5 |
| ): |
| super(UNet2DConditionModel, self).__init__() |
| self.sample_size = sample_size |
|
|
| if num_attention_heads is not None: |
| raise ValueError( |
| "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| num_attention_heads = num_attention_heads or attention_head_dim |
|
|
| |
| self._check_config( |
| down_block_types=down_block_types, |
| up_block_types=up_block_types, |
| only_cross_attention=only_cross_attention, |
| block_out_channels=block_out_channels, |
| layers_per_block=layers_per_block, |
| cross_attention_dim=cross_attention_dim, |
| transformer_layers_per_block=transformer_layers_per_block, |
| reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, |
| attention_head_dim=attention_head_dim, |
| num_attention_heads=num_attention_heads, |
| ) |
|
|
| |
| self.conv_in = DWConv2d(in_channels, block_out_channels[0], dw_kernel_size=conv_in_kernel) |
|
|
| |
| time_embed_dim, timestep_input_dim = self._set_time_proj( |
| time_embedding_type, |
| block_out_channels=block_out_channels, |
| flip_sin_to_cos=flip_sin_to_cos, |
| freq_shift=freq_shift, |
| time_embedding_dim=time_embedding_dim, |
| ) |
|
|
| self.time_embedding = TimestepEmbedding( |
| timestep_input_dim, |
| time_embed_dim, |
| act_fn=act_fn, |
| post_act_fn=timestep_post_act, |
| cond_proj_dim=time_cond_proj_dim, |
| ) |
|
|
| self._set_encoder_hid_proj( |
| encoder_hid_dim_type, |
| cross_attention_dim=cross_attention_dim, |
| encoder_hid_dim=encoder_hid_dim, |
| ) |
|
|
| |
| self._set_class_embedding( |
| class_embed_type, |
| act_fn=act_fn, |
| num_class_embeds=num_class_embeds, |
| projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, |
| time_embed_dim=time_embed_dim, |
| timestep_input_dim=timestep_input_dim, |
| ) |
|
|
| self._set_add_embedding( |
| addition_embed_type, |
| addition_embed_type_num_heads=addition_embed_type_num_heads, |
| addition_time_embed_dim=addition_time_embed_dim, |
| cross_attention_dim=cross_attention_dim, |
| encoder_hid_dim=encoder_hid_dim, |
| flip_sin_to_cos=flip_sin_to_cos, |
| freq_shift=freq_shift, |
| projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, |
| time_embed_dim=time_embed_dim, |
| ) |
|
|
| if time_embedding_act_fn is None: |
| self.time_embed_act = None |
| else: |
| self.time_embed_act = get_activation(time_embedding_act_fn) |
|
|
| self.down_blocks = nn.ModuleList([]) |
| self.up_blocks = nn.ModuleList([]) |
|
|
| if isinstance(only_cross_attention, bool): |
| if mid_block_only_cross_attention is None: |
| mid_block_only_cross_attention = only_cross_attention |
|
|
| only_cross_attention = [only_cross_attention] * len(down_block_types) |
|
|
| if mid_block_only_cross_attention is None: |
| mid_block_only_cross_attention = False |
|
|
| if isinstance(num_attention_heads, int): |
| num_attention_heads = (num_attention_heads,) * len(down_block_types) |
|
|
| if isinstance(attention_head_dim, int): |
| attention_head_dim = (attention_head_dim,) * len(down_block_types) |
|
|
| if isinstance(cross_attention_dim, int): |
| cross_attention_dim = (cross_attention_dim,) * len(down_block_types) |
|
|
| if isinstance(layers_per_block, int): |
| layers_per_block = [layers_per_block] * len(down_block_types) |
|
|
| if isinstance(transformer_layers_per_block, int): |
| transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) |
|
|
| if class_embeddings_concat: |
| |
| |
| |
| blocks_time_embed_dim = time_embed_dim * 2 |
| else: |
| blocks_time_embed_dim = time_embed_dim |
|
|
| |
| output_channel = block_out_channels[0] |
| for i, down_block_type in enumerate(down_block_types): |
| input_channel = output_channel |
| output_channel = block_out_channels[i] |
| is_final_block = i == len(block_out_channels) - 1 |
|
|
| down_block = custom_get_down_block( |
| down_block_type, |
| num_layers=layers_per_block[i], |
| transformer_layers_per_block=transformer_layers_per_block[i], |
| in_channels=input_channel, |
| out_channels=output_channel, |
| temb_channels=blocks_time_embed_dim, |
| add_downsample=not is_final_block, |
| resnet_eps=norm_eps, |
| resnet_act_fn=act_fn, |
| resnet_groups=norm_num_groups, |
| cross_attention_dim=cross_attention_dim[i], |
| num_attention_heads=num_attention_heads[i], |
| downsample_padding=downsample_padding, |
| dual_cross_attention=dual_cross_attention, |
| use_linear_projection=use_linear_projection, |
| only_cross_attention=only_cross_attention[i], |
| upcast_attention=upcast_attention, |
| resnet_time_scale_shift=resnet_time_scale_shift, |
| attention_type=attention_type, |
| resnet_skip_time_act=resnet_skip_time_act, |
| resnet_out_scale_factor=resnet_out_scale_factor, |
| cross_attention_norm=cross_attention_norm, |
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
| dropout=dropout, |
| mlp_ratio=mix_mlp_ratio |
| ) |
| self.down_blocks.append(down_block) |
|
|
| |
| self.mid_block = custom_get_mid_block( |
| mid_block_type, |
| temb_channels=blocks_time_embed_dim, |
| in_channels=block_out_channels[-1], |
| resnet_eps=norm_eps, |
| resnet_act_fn=act_fn, |
| resnet_groups=norm_num_groups, |
| output_scale_factor=mid_block_scale_factor, |
| transformer_layers_per_block=transformer_layers_per_block[-1], |
| num_attention_heads=num_attention_heads[-1], |
| cross_attention_dim=cross_attention_dim[-1], |
| dual_cross_attention=dual_cross_attention, |
| use_linear_projection=use_linear_projection, |
| mid_block_only_cross_attention=mid_block_only_cross_attention, |
| upcast_attention=upcast_attention, |
| resnet_time_scale_shift=resnet_time_scale_shift, |
| attention_type=attention_type, |
| resnet_skip_time_act=resnet_skip_time_act, |
| cross_attention_norm=cross_attention_norm, |
| attention_head_dim=attention_head_dim[-1], |
| dropout=dropout, |
| mlp_ratio=mix_mlp_ratio |
| ) |
|
|
| |
| self.num_upsamplers = 0 |
|
|
| |
| reversed_block_out_channels = list(reversed(block_out_channels)) |
| reversed_num_attention_heads = list(reversed(num_attention_heads)) |
| reversed_layers_per_block = list(reversed(layers_per_block)) |
| reversed_cross_attention_dim = list(reversed(cross_attention_dim)) |
| reversed_transformer_layers_per_block = ( |
| list(reversed(transformer_layers_per_block)) |
| if reverse_transformer_layers_per_block is None |
| else reverse_transformer_layers_per_block |
| ) |
| only_cross_attention = list(reversed(only_cross_attention)) |
|
|
| output_channel = reversed_block_out_channels[0] |
| for i, up_block_type in enumerate(up_block_types): |
| is_final_block = i == len(block_out_channels) - 1 |
|
|
| prev_output_channel = output_channel |
| output_channel = reversed_block_out_channels[i] |
| input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] |
|
|
| |
| if not is_final_block: |
| add_upsample = True |
| self.num_upsamplers += 1 |
| else: |
| add_upsample = False |
|
|
| up_block = custom_get_up_block( |
| up_block_type, |
| num_layers=reversed_layers_per_block[i] + 1, |
| transformer_layers_per_block=reversed_transformer_layers_per_block[i], |
| in_channels=input_channel, |
| out_channels=output_channel, |
| prev_output_channel=prev_output_channel, |
| temb_channels=blocks_time_embed_dim, |
| add_upsample=add_upsample, |
| resnet_eps=norm_eps, |
| resnet_act_fn=act_fn, |
| resolution_idx=i, |
| resnet_groups=norm_num_groups, |
| cross_attention_dim=reversed_cross_attention_dim[i], |
| num_attention_heads=reversed_num_attention_heads[i], |
| dual_cross_attention=dual_cross_attention, |
| use_linear_projection=use_linear_projection, |
| only_cross_attention=only_cross_attention[i], |
| upcast_attention=upcast_attention, |
| resnet_time_scale_shift=resnet_time_scale_shift, |
| attention_type=attention_type, |
| resnet_skip_time_act=resnet_skip_time_act, |
| resnet_out_scale_factor=resnet_out_scale_factor, |
| cross_attention_norm=cross_attention_norm, |
| attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, |
| dropout=dropout, |
| mlp_ratio=mix_mlp_ratio |
| ) |
| self.up_blocks.append(up_block) |
|
|
| |
| if norm_num_groups is not None: |
| self.conv_norm_out = nn.GroupNorm( |
| num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps |
| ) |
|
|
| self.conv_act = get_activation(act_fn) |
|
|
| else: |
| self.conv_norm_out = None |
| self.conv_act = None |
|
|
| self.conv_out = DWConv2d(block_out_channels[0], out_channels, dw_kernel_size=conv_out_kernel) |
|
|
| def initialize_weights(self): |
| |
| def _basic_init(module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.constant_(module.bias, 0) |
| self.apply(_basic_init) |
|
|
|
|
| def forward( |
| self, |
| sample: torch.Tensor, |
| timestep: Union[torch.Tensor, float, int], |
| encoder_hidden_states: torch.Tensor, |
| class_labels: Optional[torch.Tensor] = None, |
| timestep_cond: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
| added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
| down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| mid_block_additional_residual: Optional[torch.Tensor] = None, |
| down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
| encoder_attention_mask: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ) -> Union[UNet2DConditionOutput, Tuple]: |
| |
| |
| |
| |
| default_overall_up_factor = 2**self.num_upsamplers |
|
|
| |
| forward_upsample_size = False |
| upsample_size = None |
|
|
| |
| intermidiate_samples = [] |
|
|
| for dim in sample.shape[-2:]: |
| if dim % default_overall_up_factor != 0: |
| |
| forward_upsample_size = True |
| break |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if attention_mask is not None: |
| |
| |
| |
| |
| attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
| attention_mask = attention_mask.unsqueeze(1) |
|
|
| |
| if encoder_attention_mask is not None: |
| encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 |
| encoder_attention_mask = encoder_attention_mask.unsqueeze(1) |
|
|
| |
| if self.config.center_input_sample: |
| sample = 2 * sample - 1.0 |
|
|
| |
| t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
| emb = self.time_embedding(t_emb, timestep_cond) |
| aug_emb = None |
|
|
| class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) |
| if class_emb is not None: |
| if self.config.class_embeddings_concat: |
| emb = torch.cat([emb, class_emb], dim=-1) |
| else: |
| emb = emb + class_emb |
|
|
| aug_emb = self.get_aug_embed( |
| emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
| ) |
| if self.config.addition_embed_type == "image_hint": |
| aug_emb, hint = aug_emb |
| sample = torch.cat([sample, hint], dim=1) |
|
|
| emb = emb + aug_emb if aug_emb is not None else emb |
|
|
| if self.time_embed_act is not None: |
| emb = self.time_embed_act(emb) |
|
|
| encoder_hidden_states = self.process_encoder_hidden_states( |
| encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
| ) |
|
|
| |
| sample = self.conv_in(sample) |
|
|
| |
| if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: |
| cross_attention_kwargs = cross_attention_kwargs.copy() |
| gligen_args = cross_attention_kwargs.pop("gligen") |
| cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} |
|
|
| |
| |
| |
| if cross_attention_kwargs is not None: |
| cross_attention_kwargs = cross_attention_kwargs.copy() |
| lora_scale = cross_attention_kwargs.pop("scale", 1.0) |
| else: |
| lora_scale = 1.0 |
|
|
| if USE_PEFT_BACKEND: |
| |
| scale_lora_layers(self, lora_scale) |
|
|
| is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None |
| |
| is_adapter = down_intrablock_additional_residuals is not None |
| |
| |
| |
| if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: |
| deprecate( |
| "T2I should not use down_block_additional_residuals", |
| "1.3.0", |
| "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ |
| and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ |
| for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", |
| standard_warn=False, |
| ) |
| down_intrablock_additional_residuals = down_block_additional_residuals |
| is_adapter = True |
|
|
| down_block_res_samples = (sample,) |
| |
| for i, downsample_block in enumerate(self.down_blocks): |
| if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
| |
| additional_residuals = {} |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: |
| additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) |
|
|
| sample, res_samples = downsample_block( |
| hidden_states=sample, |
| temb=emb, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=attention_mask, |
| cross_attention_kwargs=cross_attention_kwargs, |
| encoder_attention_mask=encoder_attention_mask, |
| **additional_residuals, |
| ) |
| else: |
| sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
| if is_adapter and len(down_intrablock_additional_residuals) > 0: |
| sample += down_intrablock_additional_residuals.pop(0) |
|
|
| down_block_res_samples += res_samples |
| intermidiate_samples += [sample] |
|
|
| if is_controlnet: |
| new_down_block_res_samples = () |
|
|
| for down_block_res_sample, down_block_additional_residual in zip( |
| down_block_res_samples, down_block_additional_residuals |
| ): |
| down_block_res_sample = down_block_res_sample + down_block_additional_residual |
| new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) |
|
|
| down_block_res_samples = new_down_block_res_samples |
|
|
| down_block_sample = sample |
|
|
| |
| if self.mid_block is not None: |
| if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: |
| sample = self.mid_block( |
| sample, |
| emb, |
| encoder_hidden_states=encoder_hidden_states, |
| attention_mask=attention_mask, |
| cross_attention_kwargs=cross_attention_kwargs, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| else: |
| sample = self.mid_block(sample, emb) |
|
|
| |
| if ( |
| is_adapter |
| and len(down_intrablock_additional_residuals) > 0 |
| and sample.shape == down_intrablock_additional_residuals[0].shape |
| ): |
| sample += down_intrablock_additional_residuals.pop(0) |
| intermidiate_samples += [sample] |
| else: |
| pass |
|
|
| if is_controlnet: |
| sample = sample + mid_block_additional_residual |
|
|
| mid_block_sample = sample |
|
|
| |
| for i, upsample_block in enumerate(self.up_blocks): |
| is_final_block = i == len(self.up_blocks) - 1 |
|
|
| res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
| down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
| |
| |
| if not is_final_block and forward_upsample_size: |
| upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
| if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_tuple=res_samples, |
| encoder_hidden_states=encoder_hidden_states, |
| cross_attention_kwargs=cross_attention_kwargs, |
| upsample_size=upsample_size, |
| attention_mask=attention_mask, |
| encoder_attention_mask=encoder_attention_mask, |
| ) |
| else: |
| sample = upsample_block( |
| hidden_states=sample, |
| temb=emb, |
| res_hidden_states_tuple=res_samples, |
| upsample_size=upsample_size, |
| ) |
|
|
| intermidiate_samples += [sample] |
|
|
| up_block_sample = sample |
|
|
| |
| if self.conv_norm_out: |
| sample = self.conv_norm_out(sample) |
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample) |
|
|
| if USE_PEFT_BACKEND: |
| |
| unscale_lora_layers(self, lora_scale) |
|
|
| if not return_dict: |
| return (sample,) |
|
|
|
|
| return CustomOutput(sample=sample, block_outputs=intermidiate_samples) |
|
|
|
|
| def inject_lambda_into_tf2dmodel(self, tf2dmodel, |
| use_lambda_cross_attn=False, use_local_self_attn=True, sample_size=None): |
| vanilla_tf_cfg = tf2dmodel.config |
|
|
| new_config = DEFAULT_λ_CONFIG | dict( |
| n = sample_size, |
| dim = vanilla_tf_cfg.in_channels, |
| dim_k = vanilla_tf_cfg.attention_head_dim, |
| heads = vanilla_tf_cfg.num_attention_heads, |
| dim_out = vanilla_tf_cfg.out_channels |
| ) |
|
|
| [ |
| ('num_attention_heads', 8), |
| ('attention_head_dim', 40), |
| ('in_channels', 320), |
| ('out_channels', None), |
| ('num_layers', 1), |
| ('dropout', 0.0), |
| ('norm_num_groups', 32), |
| ('cross_attention_dim', 768), |
| ('attention_bias', False), |
| ('sample_size', None), |
| ('num_vector_embeds', None), |
| ('patch_size', None), |
| ('activation_fn', 'geglu'), |
| ('num_embeds_ada_norm', None), |
| ('use_linear_projection', False), |
| ('only_cross_attention', False), |
| ('double_self_attention', False), |
| ('upcast_attention', False), |
| ('norm_type', 'layer_norm'), |
| ('norm_elementwise_affine', True), |
| ('norm_eps', 1e-05), |
| ('attention_type', 'default'), |
| ('caption_channels', None), |
| ('interpolation_scale', None), |
| ('use_additional_conditions', None), |
| ('_use_default_values', |
| [ |
| 'num_attention_heads', |
| 'activation_fn', |
| 'caption_channels', |
| 'norm_eps', |
| 'patch_size', |
| 'attention_bias', |
| 'num_vector_embeds', |
| 'attention_head_dim', |
| 'use_additional_conditions', |
| 'norm_elementwise_affine', |
| 'num_embeds_ada_norm', |
| 'dropout', |
| 'norm_type', |
| 'double_self_attention', |
| 'sample_size', |
| 'out_channels', |
| 'interpolation_scale' |
| ] |
| ) |
| ] |
|
|
|
|
| for tfblock in tf2dmodel.transformer_blocks: |
| new_SA_config = deepcopy(new_config) |
| if use_local_self_attn: |
| new_SA_config |= dict(r = 15) |
| lambda_sattn = MQSλ(**new_SA_config) |
| tfblock.attn1 = lambda_sattn |
|
|
| if use_lambda_cross_attn: |
| new_CA_config = new_config | dict( |
| m = self.config.num_embeddings // 2, |
| dim_cross = vanilla_tf_cfg.cross_attention_dim,) |
| lambda_xattn = MQCλ(**new_CA_config) |
| tfblock.attn2 = lambda_xattn |
|
|
|
|