| | import typing as tp |
| |
|
| | import torch |
| |
|
| | from einops import rearrange |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from x_transformers import ContinuousTransformerWrapper, Encoder |
| |
|
| | from .blocks import FourierFeatures |
| | from .transformer import ContinuousTransformer |
| |
|
| | class DiffusionTransformer(nn.Module): |
| | def __init__(self, |
| | io_channels=32, |
| | patch_size=1, |
| | embed_dim=768, |
| | cond_token_dim=0, |
| | project_cond_tokens=True, |
| | global_cond_dim=0, |
| | project_global_cond=True, |
| | input_concat_dim=0, |
| | prepend_cond_dim=0, |
| | depth=12, |
| | num_heads=8, |
| | transformer_type: tp.Literal["x-transformers", "continuous_transformer"] = "x-transformers", |
| | global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", |
| | **kwargs): |
| |
|
| | super().__init__() |
| |
|
| | self.cond_token_dim = cond_token_dim |
| |
|
| | |
| | timestep_features_dim = 256 |
| |
|
| | self.timestep_features = FourierFeatures(1, timestep_features_dim) |
| |
|
| | self.to_timestep_embed = nn.Sequential( |
| | nn.Linear(timestep_features_dim, embed_dim, bias=True), |
| | nn.SiLU(), |
| | nn.Linear(embed_dim, embed_dim, bias=True), |
| | ) |
| |
|
| | if cond_token_dim > 0: |
| | |
| | cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim |
| | self.to_cond_embed = nn.Sequential( |
| | nn.Linear(cond_token_dim, cond_embed_dim, bias=False), |
| | nn.SiLU(), |
| | nn.Linear(cond_embed_dim, cond_embed_dim, bias=False) |
| | ) |
| | else: |
| | cond_embed_dim = 0 |
| |
|
| | if global_cond_dim > 0: |
| | |
| | global_embed_dim = global_cond_dim if not project_global_cond else embed_dim |
| | self.to_global_embed = nn.Sequential( |
| | nn.Linear(global_cond_dim, global_embed_dim, bias=False), |
| | nn.SiLU(), |
| | nn.Linear(global_embed_dim, global_embed_dim, bias=False) |
| | ) |
| |
|
| | if prepend_cond_dim > 0: |
| | |
| | self.to_prepend_embed = nn.Sequential( |
| | nn.Linear(prepend_cond_dim, embed_dim, bias=False), |
| | nn.SiLU(), |
| | nn.Linear(embed_dim, embed_dim, bias=False) |
| | ) |
| |
|
| | self.input_concat_dim = input_concat_dim |
| |
|
| | dim_in = io_channels + self.input_concat_dim |
| |
|
| | self.patch_size = patch_size |
| |
|
| | |
| |
|
| | self.transformer_type = transformer_type |
| | |
| | self.global_cond_type = global_cond_type |
| |
|
| | if self.transformer_type == "x-transformers": |
| | self.transformer = ContinuousTransformerWrapper( |
| | dim_in=dim_in * patch_size, |
| | dim_out=io_channels * patch_size, |
| | max_seq_len=0, |
| | attn_layers = Encoder( |
| | dim=embed_dim, |
| | depth=depth, |
| | heads=num_heads, |
| | attn_flash = True, |
| | cross_attend = cond_token_dim > 0, |
| | dim_context=None if cond_embed_dim == 0 else cond_embed_dim, |
| | zero_init_branch_output=True, |
| | use_abs_pos_emb = False, |
| | rotary_pos_emb=True, |
| | ff_swish = True, |
| | ff_glu = True, |
| | **kwargs |
| | ) |
| | ) |
| |
|
| | elif self.transformer_type == "continuous_transformer": |
| |
|
| | global_dim = None |
| |
|
| | if self.global_cond_type == "adaLN": |
| | |
| | global_dim = embed_dim |
| |
|
| | self.transformer = ContinuousTransformer( |
| | dim=embed_dim, |
| | depth=depth, |
| | dim_heads=embed_dim // num_heads, |
| | dim_in=dim_in * patch_size, |
| | dim_out=io_channels * patch_size, |
| | cross_attend = cond_token_dim > 0, |
| | cond_token_dim = cond_embed_dim, |
| | global_cond_dim=global_dim, |
| | **kwargs |
| | ) |
| | |
| | else: |
| | raise ValueError(f"Unknown transformer type: {self.transformer_type}") |
| |
|
| | self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False) |
| | nn.init.zeros_(self.preprocess_conv.weight) |
| | self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False) |
| | nn.init.zeros_(self.postprocess_conv.weight) |
| |
|
| | def _forward( |
| | self, |
| | x, |
| | t, |
| | mask=None, |
| | cross_attn_cond=None, |
| | cross_attn_cond_mask=None, |
| | input_concat_cond=None, |
| | global_embed=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | return_info=False, |
| | **kwargs): |
| |
|
| | if cross_attn_cond is not None: |
| | cross_attn_cond = self.to_cond_embed(cross_attn_cond) |
| |
|
| | if global_embed is not None: |
| | |
| | global_embed = self.to_global_embed(global_embed) |
| |
|
| | prepend_inputs = None |
| | prepend_mask = None |
| | prepend_length = 0 |
| | if prepend_cond is not None: |
| | |
| | prepend_cond = self.to_prepend_embed(prepend_cond) |
| | |
| | prepend_inputs = prepend_cond |
| | if prepend_cond_mask is not None: |
| | prepend_mask = prepend_cond_mask |
| |
|
| | if input_concat_cond is not None: |
| |
|
| | |
| | if input_concat_cond.shape[2] != x.shape[2]: |
| | input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') |
| |
|
| | x = torch.cat([x, input_concat_cond], dim=1) |
| |
|
| | |
| | timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) |
| |
|
| | |
| | if global_embed is not None: |
| | global_embed = global_embed + timestep_embed |
| | else: |
| | global_embed = timestep_embed |
| |
|
| | |
| | if self.global_cond_type == "prepend": |
| | if prepend_inputs is None: |
| | |
| | prepend_inputs = global_embed.unsqueeze(1) |
| | prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) |
| | else: |
| | |
| | prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) |
| | prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) |
| |
|
| | prepend_length = prepend_inputs.shape[1] |
| |
|
| | x = self.preprocess_conv(x) + x |
| |
|
| | x = rearrange(x, "b c t -> b t c") |
| |
|
| | extra_args = {} |
| |
|
| | if self.global_cond_type == "adaLN": |
| | extra_args["global_cond"] = global_embed |
| |
|
| | if self.patch_size > 1: |
| | x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) |
| |
|
| | if self.transformer_type == "x-transformers": |
| | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) |
| | elif self.transformer_type == "continuous_transformer": |
| |
|
| | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) |
| |
|
| | if return_info: |
| | output, info = output |
| | elif self.transformer_type == "mm_transformer": |
| | output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) |
| |
|
| | output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] |
| |
|
| | if self.patch_size > 1: |
| | output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) |
| |
|
| | output = self.postprocess_conv(output) + output |
| |
|
| | if return_info: |
| | return output, info |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | x, |
| | t, |
| | cross_attn_cond=None, |
| | cross_attn_cond_mask=None, |
| | negative_cross_attn_cond=None, |
| | negative_cross_attn_mask=None, |
| | input_concat_cond=None, |
| | global_embed=None, |
| | negative_global_embed=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | cfg_scale=1.0, |
| | cfg_dropout_prob=0.0, |
| | causal=False, |
| | scale_phi=0.0, |
| | mask=None, |
| | return_info=False, |
| | **kwargs): |
| |
|
| | assert causal == False, "Causal mode is not supported for DiffusionTransformer" |
| |
|
| | if cross_attn_cond_mask is not None: |
| | cross_attn_cond_mask = cross_attn_cond_mask.bool() |
| |
|
| | cross_attn_cond_mask = None |
| |
|
| | if prepend_cond_mask is not None: |
| | prepend_cond_mask = prepend_cond_mask.bool() |
| |
|
| | |
| | if cfg_dropout_prob > 0.0: |
| | if cross_attn_cond is not None: |
| | null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
| | dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) |
| | cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) |
| |
|
| | if prepend_cond is not None: |
| | null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
| | dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) |
| | prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) |
| |
|
| |
|
| | if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None): |
| | |
| | |
| | batch_inputs = torch.cat([x, x], dim=0) |
| | batch_timestep = torch.cat([t, t], dim=0) |
| |
|
| | if global_embed is not None: |
| | batch_global_cond = torch.cat([global_embed, global_embed], dim=0) |
| | else: |
| | batch_global_cond = None |
| |
|
| | if input_concat_cond is not None: |
| | batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0) |
| | else: |
| | batch_input_concat_cond = None |
| |
|
| | batch_cond = None |
| | batch_cond_masks = None |
| | |
| | |
| | if cross_attn_cond is not None: |
| |
|
| | null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
| |
|
| | |
| | if negative_cross_attn_cond is not None: |
| |
|
| | |
| | if negative_cross_attn_mask is not None: |
| | negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2) |
| |
|
| | negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed) |
| | |
| | batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0) |
| |
|
| | else: |
| | batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0) |
| |
|
| | if cross_attn_cond_mask is not None: |
| | batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0) |
| | |
| | batch_prepend_cond = None |
| | batch_prepend_cond_mask = None |
| |
|
| | if prepend_cond is not None: |
| |
|
| | null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
| |
|
| | batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) |
| | |
| | if prepend_cond_mask is not None: |
| | batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) |
| | |
| |
|
| | if mask is not None: |
| | batch_masks = torch.cat([mask, mask], dim=0) |
| | else: |
| | batch_masks = None |
| | |
| | batch_output = self._forward( |
| | batch_inputs, |
| | batch_timestep, |
| | cross_attn_cond=batch_cond, |
| | cross_attn_cond_mask=batch_cond_masks, |
| | mask = batch_masks, |
| | input_concat_cond=batch_input_concat_cond, |
| | global_embed = batch_global_cond, |
| | prepend_cond = batch_prepend_cond, |
| | prepend_cond_mask = batch_prepend_cond_mask, |
| | return_info = return_info, |
| | **kwargs) |
| |
|
| | if return_info: |
| | batch_output, info = batch_output |
| |
|
| | cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0) |
| | cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale |
| |
|
| | |
| | if scale_phi != 0.0: |
| | cond_out_std = cond_output.std(dim=1, keepdim=True) |
| | out_cfg_std = cfg_output.std(dim=1, keepdim=True) |
| | output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output |
| | else: |
| | output = cfg_output |
| | |
| | if return_info: |
| | return output, info |
| |
|
| | return output |
| | |
| | else: |
| | return self._forward( |
| | x, |
| | t, |
| | cross_attn_cond=cross_attn_cond, |
| | cross_attn_cond_mask=cross_attn_cond_mask, |
| | input_concat_cond=input_concat_cond, |
| | global_embed=global_embed, |
| | prepend_cond=prepend_cond, |
| | prepend_cond_mask=prepend_cond_mask, |
| | mask=mask, |
| | return_info=return_info, |
| | **kwargs |
| | ) |