moebius / model_lib /nets /utils.py
Mike0021's picture
Implement Moebius Gradio Space
166ab04 verified
Raw
History Blame Contribute Delete
2.48 kB
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
@dataclass
class CustomOutput(UNet2DConditionOutput):
"""
The output of [`UNet2DConditionModel`].
Args:
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.Tensor = None
block_outputs: List[torch.Tensor] = None
cross_attention_maps: List[torch.Tensor] = None
common_args = {
"sample_size": None,
"in_channels": 4,
"out_channels": 4,
"center_input_sample": False,
"flip_sin_to_cos": True,
"freq_shift": 0,
"down_block_types": (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
"mid_block_type": "UNetMidBlock2DCrossAttn",
"up_block_types": (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
"only_cross_attention": False,
"block_out_channels": (320, 640, 1280, 1280),
"layers_per_block": 2,
"downsample_padding": 1,
"mid_block_scale_factor": 1,
"dropout": 0.0,
"act_fn": "silu",
"norm_num_groups": 32,
"norm_eps": 1e-5,
"cross_attention_dim": 1280,
"transformer_layers_per_block": 1,
"reverse_transformer_layers_per_block": None,
"encoder_hid_dim": None,
"encoder_hid_dim_type": None,
"attention_head_dim": 8,
"num_attention_heads": None,
"dual_cross_attention": False,
"use_linear_projection": False,
"class_embed_type": None,
"addition_embed_type": None,
"addition_time_embed_dim": None,
"num_class_embeds": None,
"upcast_attention": False,
"resnet_time_scale_shift": "default",
"resnet_skip_time_act": False,
"resnet_out_scale_factor": 1.0,
"time_embedding_type": "positional",
"time_embedding_dim": None,
"time_embedding_act_fn": None,
"timestep_post_act": None,
"time_cond_proj_dim": None,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"projection_class_embeddings_input_dim": None,
"attention_type": "default",
"class_embeddings_concat": False,
"mid_block_only_cross_attention": None,
"cross_attention_norm": None,
"addition_embed_type_num_heads": 64,
}