| from dataclasses import dataclass |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| import torch |
|
|
| from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
|
|
|
|
| @dataclass |
| class CustomOutput(UNet2DConditionOutput): |
| """ |
| The output of [`UNet2DConditionModel`]. |
| |
| Args: |
| sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. |
| """ |
|
|
| sample: torch.Tensor = None |
| block_outputs: List[torch.Tensor] = None |
| cross_attention_maps: List[torch.Tensor] = None |
|
|
|
|
|
|
|
|
| common_args = { |
| "sample_size": None, |
| "in_channels": 4, |
| "out_channels": 4, |
| "center_input_sample": False, |
| "flip_sin_to_cos": True, |
| "freq_shift": 0, |
| "down_block_types": ( |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "CrossAttnDownBlock2D", |
| "DownBlock2D", |
| ), |
| "mid_block_type": "UNetMidBlock2DCrossAttn", |
| "up_block_types": ( |
| "UpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D", |
| "CrossAttnUpBlock2D", |
| ), |
| "only_cross_attention": False, |
| "block_out_channels": (320, 640, 1280, 1280), |
| "layers_per_block": 2, |
| "downsample_padding": 1, |
| "mid_block_scale_factor": 1, |
| "dropout": 0.0, |
| "act_fn": "silu", |
| "norm_num_groups": 32, |
| "norm_eps": 1e-5, |
| "cross_attention_dim": 1280, |
| "transformer_layers_per_block": 1, |
| "reverse_transformer_layers_per_block": None, |
| "encoder_hid_dim": None, |
| "encoder_hid_dim_type": None, |
| "attention_head_dim": 8, |
| "num_attention_heads": None, |
| "dual_cross_attention": False, |
| "use_linear_projection": False, |
| "class_embed_type": None, |
| "addition_embed_type": None, |
| "addition_time_embed_dim": None, |
| "num_class_embeds": None, |
| "upcast_attention": False, |
| "resnet_time_scale_shift": "default", |
| "resnet_skip_time_act": False, |
| "resnet_out_scale_factor": 1.0, |
| "time_embedding_type": "positional", |
| "time_embedding_dim": None, |
| "time_embedding_act_fn": None, |
| "timestep_post_act": None, |
| "time_cond_proj_dim": None, |
| "conv_in_kernel": 3, |
| "conv_out_kernel": 3, |
| "projection_class_embeddings_input_dim": None, |
| "attention_type": "default", |
| "class_embeddings_concat": False, |
| "mid_block_only_cross_attention": None, |
| "cross_attention_norm": None, |
| "addition_embed_type_num_heads": 64, |
| } |
|
|