World_Model / URSA /diffnext /models /autoencoders /autoencoder_kl_ltx.py
BryanW's picture
Add files using upload-large-folder tool
d403233 verified
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
"""Simple implementation of AutoEncoderKL for LTX v0.95."""
from einops import rearrange
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffnext.models.autoencoders.modeling_utils import DiagonalGaussianDistribution
from diffnext.models.autoencoders.modeling_utils import DecoderOutput, TilingMixin
class Conv3d(nn.Conv3d):
"""3D convolution."""
def __init__(self, *args, **kwargs):
self.causal = kwargs.pop("causal", True)
super(Conv3d, self).__init__(*args, **kwargs)
self.padding = (0,) + tuple((_ // 2 for _ in self.kernel_size[1:]))
self.pad1 = nn.ReplicationPad3d((0,) * 4 + (self.kernel_size[0] - 1, 0))
self.pad2 = nn.ReplicationPad3d((0,) * 4 + (self.pad1.padding[-2] // 2,) * 2)
self.pad1 = nn.Identity() if self.kernel_size[0] == 1 else self.pad1
self.pad2 = nn.Identity() if self.kernel_size[0] == 1 else self.pad2
def forward(self, x):
return super().forward(self.pad1(x) if self.causal else self.pad2(x))
class RMSNorm(nn.Module):
"""RMS normalization."""
def forward(self, x):
# Enforce high precision RMS to avoid float16 underflow.
return x.mul(x.float().square().mean(-1, True).add_(1e-8).rsqrt().to(x.dtype))
class TimeEmbed(nn.Module):
"""Time embedding layer."""
def __init__(self, embed_dim, freq_dim=256):
super(TimeEmbed, self).__init__()
self.timestep_proj = nn.Module()
self.timestep_proj.fc1 = nn.Linear(freq_dim, embed_dim)
self.timestep_proj.fc2 = nn.Linear(embed_dim, embed_dim)
self.freq_dim, self.time_freq = freq_dim, None
def get_freq_embed(self, timestep) -> torch.Tensor:
if self.time_freq is None:
dim, log_theta = self.freq_dim // 2, 9.210340371976184 # math.log(10000)
freq = torch.arange(dim, dtype=torch.float32, device=timestep.device)
self.time_freq = freq.mul(-log_theta / dim).exp().unsqueeze_(0)
emb = timestep.unsqueeze(-1).float() * self.time_freq
return torch.cat([emb.cos(), emb.sin()], dim=-1).to(dtype=timestep.dtype)
def forward(self, temb) -> torch.Tensor:
x = self.get_freq_embed(temb) if temb.dim() == 1 else temb
return self.timestep_proj.fc2(nn.functional.silu(self.timestep_proj.fc1(x)))
class ResBlock(nn.Module):
"""Resnet block."""
def __init__(self, dim, out_dim, causal=True):
super(ResBlock, self).__init__()
self.norm1, self.norm2 = RMSNorm(), RMSNorm()
self.conv1 = Conv3d(dim, out_dim, 3, causal=causal)
self.conv2 = Conv3d(out_dim, out_dim, 3, causal=causal)
self.nonlinearity, self.dropout = nn.SiLU(), nn.Dropout(0, inplace=True)
self.scale_shift_table = None if causal else nn.Parameter(torch.randn(4, dim) / dim**0.5)
def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor:
shortcut, stats = x, []
if self.scale_shift_table is not None:
stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(4, 1)
x = self.norm1(x.movedim(1, -1)).movedim(-1, 1)
x = x.mul(1 + stats[1]).add_(stats[0]) if stats else x
x = self.conv1(self.nonlinearity(x))
x = self.norm2(x.movedim(1, -1)).movedim(-1, 1)
x = x.mul(1 + stats[3]).add_(stats[2]) if stats else x
return self.conv2(self.dropout(self.nonlinearity(x))).add_(shortcut)
class MidBlock(nn.Module):
"""UNet mid block."""
def __init__(self, dim, depth=1, causal=True):
super(MidBlock, self).__init__()
self.time_embed = None if causal else TimeEmbed(dim * 4)
self.resnets = nn.ModuleList(ResBlock(dim, dim, causal=causal) for _ in range(depth))
def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor:
temb = self.time_embed(temb) if self.time_embed else None
for resnet in self.resnets:
x = resnet(x, temb)
return x
class Downsample(nn.Module):
"""Residual downsample layer."""
def __init__(self, dim, out_dim, stride, causal=True):
super(Downsample, self).__init__()
self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3
self.group_size = (dim * torch.Size(stride).numel()) // out_dim
self.pad_t, conv_dim = stride[0] - 1, out_dim // torch.Size(stride).numel()
self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal)
self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]}
self.patch_args["pattern"] = "b c (t r) (h p) (w q) -> b (c r p q) t h w"
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = nn.functional.pad(x, (0,) * 4 + (self.pad_t, 0), "replicate") if self.pad_t else x
shortcut = rearrange(x, **self.patch_args).unflatten(1, (-1, self.group_size)).mean(dim=2)
return rearrange(self.conv(x), **self.patch_args).add_(shortcut)
class Upsample(nn.Module):
"""Residual upsample layer."""
def __init__(self, dim, out_dim, stride, causal=False):
super(Upsample, self).__init__()
self.stride = stride = stride if isinstance(stride, (tuple, list)) else (stride,) * 3
self.repeats = (out_dim * torch.Size(stride).numel()) // dim
self.slice_t, conv_dim = stride[0] - 1, out_dim * torch.Size(stride).numel()
self.conv = Conv3d(dim, conv_dim, 3, 1, causal=causal)
self.patch_args = {"r": stride[0], "p": stride[1], "q": stride[2]}
self.patch_args["pattern"] = "b (c r p q) t h w -> b c (t r) (h p) (w q)"
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = rearrange(x, **self.patch_args).repeat(1, self.repeats, 1, 1, 1)
x = rearrange(self.conv(x), **self.patch_args)
x = x[:, :, self.slice_t :] if self.slice_t else x
return x.add_(shortcut[:, :, self.slice_t :] if self.slice_t else shortcut)
class DownBlock(nn.Module):
"""Downsample block."""
def __init__(self, dim, out_dim, depth=1, causal=True, downsample=""):
super(DownBlock, self).__init__()
self.resnets, self.downsamplers = nn.ModuleList(), nn.ModuleList()
for _ in range(depth):
self.resnets.append(ResBlock(dim, dim, causal=causal))
for _ in range(1 if downsample else 0):
stride = {"spatial": (1, 2, 2), "temporal": (2, 1, 1), "spatiotemporal": 2}[downsample]
self.downsamplers.append(Downsample(dim, out_dim, stride, causal=causal))
def forward(self, x: torch.Tensor) -> torch.Tensor:
for resnet in self.resnets:
x = resnet(x)
for downsampler in self.downsamplers:
x = downsampler(x)
return x
class UpBlock(nn.Module):
"""Upsample block."""
def __init__(self, dim, out_dim, depth=1, causal=False, upscale_factor=2):
super(UpBlock, self).__init__()
self.time_embed = TimeEmbed(out_dim * 4)
self.resnets, self.upsamplers = nn.ModuleList(), nn.ModuleList()
for _ in range(1 if upscale_factor > 1 else 0):
self.upsamplers.append(Upsample(dim, out_dim, 2, causal=causal))
for _ in range(depth):
self.resnets.append(ResBlock(out_dim, out_dim, causal=causal))
def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor:
for upsampler in self.upsamplers:
x = upsampler(x)
temb = self.time_embed(temb)
for resnet in self.resnets:
x = resnet(x, temb)
return x
class Encoder(nn.Module):
"""VAE encoder."""
def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4):
super(Encoder, self).__init__()
self.patch_args = {"p": patch_size, "q": patch_size}
downsample_type = ["spatial", "temporal", "spatiotemporal", "spatiotemporal"]
self.conv_in = Conv3d(dim * patch_size**2, block_dims[0], 3, 1)
self.down_blocks = nn.ModuleList()
for i, (in_dim, depth, down) in enumerate(zip(block_dims, block_depths, downsample_type)):
blk = DownBlock(in_dim, block_dims[i + 1], depth, downsample=down)
self.down_blocks.append(blk)
self.mid_block = MidBlock(block_dims[-1], block_depths[-1])
self.norm_out, self.conv_act = RMSNorm(), nn.SiLU()
self.conv_out = Conv3d(block_dims[-1], out_dim + 1, 3, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "b c t (h p) (w q) -> b (c q p) t h w", **self.patch_args)
x = self.conv_in(x)
for down_block in self.down_blocks:
x = down_block(x)
x = self.mid_block(x)
x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1)
return self.conv_out(self.conv_act(x))
class Decoder(nn.Module):
"""VAE decoder."""
def __init__(self, dim, out_dim, block_dims, block_depths, patch_size=4):
super(Decoder, self).__init__()
block_dims = tuple(reversed(block_dims))
self.patch_args = {"p": patch_size, "q": patch_size}
self.conv_in = Conv3d(dim, block_dims[0], 3, 1, causal=False)
self.mid_block = MidBlock(block_dims[0], block_depths[-1], causal=False)
self.up_blocks = nn.ModuleList([])
for in_dim, depth in zip(block_dims, block_depths[:-1]):
self.up_blocks.append(UpBlock(in_dim, in_dim // 2, depth, upscale_factor=2))
self.norm_out, self.conv_act = RMSNorm(), nn.SiLU()
self.conv_out = Conv3d(block_dims[-1], out_dim * patch_size**2, 3, 1, causal=False)
self.time_embed = TimeEmbed(block_dims[-1] * 2)
self.scale_shift_table = nn.Parameter(torch.randn(2, block_dims[-1]))
self.timestep_scale = nn.Parameter(torch.tensor(1000, dtype=torch.float32))
def forward(self, x: torch.Tensor, temb: torch.Tensor = None) -> torch.Tensor:
x = self.conv_in(x)
temb = self.time_embed.get_freq_embed(temb * self.timestep_scale)
x = self.mid_block(x, temb)
for up_block in self.up_blocks:
x = up_block(x, temb)
x = self.norm_out(x.movedim(1, -1)).movedim(-1, 1)
temb = self.time_embed(temb)
stats = temb.add(self.scale_shift_table.view(1, -1))[..., None, None, None].chunk(2, 1)
x = x.mul(1 + stats[1]).add_(stats[0])
x = self.conv_out(self.conv_act(x))
return rearrange(x, "b (c q p) t h w -> b c t (h p) (w q)", **self.patch_args)
class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, TilingMixin):
"""AutoEncoder KL."""
@register_to_config
def __init__(
self,
in_channels=3,
out_channels=3,
down_block_types=("LTXVideoDownBlock3D",) * 4,
block_out_channels=(128, 256, 512, 1024, 2048),
layers_per_block=(4, 6, 6, 2, 2),
decoder_block_out_channels=(128, 256, 512, 1024),
decoder_layers_per_block=(5, 5, 5, 5),
act_fn="silu",
latent_channels=128,
sample_size=1024,
scaling_factor=1.0,
shift_factor=None,
latents_mean=None,
latents_std=None,
patch_size=4,
):
super(AutoencoderKLLTXVideo, self).__init__()
TilingMixin.__init__(self, sample_min_t=249, latent_min_t=32, sample_ovr_t=1)
channels, layers = block_out_channels, layers_per_block
self.encoder = Encoder(in_channels, latent_channels, channels, layers)
channels, layers = decoder_block_out_channels, decoder_layers_per_block
self.decoder = Decoder(latent_channels, out_channels, channels, layers)
self.register_buffer("shift_factors", torch.zeros(latents_mean) if latents_mean else None)
self.register_buffer("scaling_factors", torch.ones(latents_std) if latents_std else None)
self.latent_dist = DiagonalGaussianDistribution
def scale_(self, x) -> torch.Tensor:
"""Scale the input latents."""
if self.shift_factors is not None:
return x.sub_(self.shift_factors).mul_(self.scaling_factors)
x.add_(-self.config.shift_factor) if self.config.shift_factor else None
return x.mul_(self.config.scaling_factor)
def unscale_(self, x) -> torch.Tensor:
"""Unscale the input latents."""
if self.shift_factors is not None:
return x.div_(self.scaling_factors).add_(self.shift_factors)
x.mul_(1 / self.config.scaling_factor)
return x.add_(self.config.shift_factor) if self.config.shift_factor else x
def encode(self, x) -> AutoencoderKLOutput:
"""Encode the input samples."""
z = self.tiled_encoder(self.forward(x))
posterior = self.latent_dist(z)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z, temb: torch.Tensor = None) -> DecoderOutput:
"""Decode the input latents."""
if temb is None:
temb = torch.tensor([0] * z.size(0), dtype=z.dtype, device=z.device)
extra_dim = 2 if z.dim() == 4 else None
z = z.unsqueeze_(extra_dim) if extra_dim is not None else z
x = self.tiled_decoder(self.forward(z), temb=temb)
x = x.squeeze_(extra_dim) if extra_dim is not None else x
return DecoderOutput(sample=x)
def forward(self, x): # NOOP.
return x