| | from numpy import sqrt |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | import numpy as np |
| | from typing import Tuple, Literal |
| | from functools import partial |
| |
|
| | from pdb import set_trace as st |
| |
|
| | |
| | from vit.vision_transformer import MemEffAttention |
| |
|
| |
|
| | class MVAttention(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int = 8, |
| | qkv_bias: bool = False, |
| | proj_bias: bool = True, |
| | attn_drop: float = 0.0, |
| | proj_drop: float = 0.0, |
| | groups: int = 32, |
| | eps: float = 1e-5, |
| | residual: bool = True, |
| | skip_scale: float = 1, |
| | num_frames: int = 4, |
| | ): |
| | super().__init__() |
| |
|
| | self.residual = residual |
| | self.skip_scale = skip_scale |
| | self.num_frames = num_frames |
| |
|
| | self.norm = nn.GroupNorm(num_groups=groups, |
| | num_channels=dim, |
| | eps=eps, |
| | affine=True) |
| | self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, |
| | attn_drop, proj_drop) |
| |
|
| | def forward(self, x): |
| | |
| | BV, C, H, W = x.shape |
| | B = BV // self.num_frames |
| |
|
| | res = x |
| | x = self.norm(x) |
| |
|
| | x = x.reshape(B, self.num_frames, C, H, |
| | W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) |
| | x = self.attn(x) |
| | x = x.reshape(B, self.num_frames, H, W, |
| | C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) |
| |
|
| | if self.residual: |
| | x = (x + res) * self.skip_scale |
| | return x |
| |
|
| |
|
| | class ResnetBlock(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | resample: Literal['default', 'up', 'down'] = 'default', |
| | groups: int = 32, |
| | eps: float = 1e-5, |
| | skip_scale: float = 1, |
| | ): |
| | super().__init__() |
| |
|
| | self.in_channels = in_channels |
| | self.out_channels = out_channels |
| | self.skip_scale = skip_scale |
| |
|
| | self.norm1 = nn.GroupNorm(num_groups=groups, |
| | num_channels=in_channels, |
| | eps=eps, |
| | affine=True) |
| | self.conv1 = nn.Conv2d(in_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | self.norm2 = nn.GroupNorm(num_groups=groups, |
| | num_channels=out_channels, |
| | eps=eps, |
| | affine=True) |
| | self.conv2 = nn.Conv2d(out_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | self.act = F.silu |
| |
|
| | self.resample = None |
| | if resample == 'up': |
| | self.resample = partial(F.interpolate, |
| | scale_factor=2.0, |
| | mode="nearest") |
| | elif resample == 'down': |
| | self.resample = nn.AvgPool2d(kernel_size=2, stride=2) |
| |
|
| | self.shortcut = nn.Identity() |
| | if self.in_channels != self.out_channels: |
| | self.shortcut = nn.Conv2d(in_channels, |
| | out_channels, |
| | kernel_size=1, |
| | bias=True) |
| |
|
| | def forward(self, x): |
| | res = x |
| |
|
| | x = self.norm1(x) |
| | x = self.act(x) |
| |
|
| | if self.resample: |
| | res = self.resample(res) |
| | x = self.resample(x) |
| |
|
| | x = self.conv1(x) |
| | x = self.norm2(x) |
| | x = self.act(x) |
| | x = self.conv2(x) |
| |
|
| | x = (x + self.shortcut(res)) * self.skip_scale |
| |
|
| | return x |
| |
|
| |
|
| | class DownBlock(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | num_layers: int = 1, |
| | downsample: bool = True, |
| | attention: bool = True, |
| | attention_heads: int = 16, |
| | skip_scale: float = 1, |
| | ): |
| | super().__init__() |
| |
|
| | nets = [] |
| | attns = [] |
| | for i in range(num_layers): |
| | in_channels = in_channels if i == 0 else out_channels |
| | nets.append( |
| | ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) |
| | if attention: |
| | attns.append( |
| | MVAttention(out_channels, |
| | attention_heads, |
| | skip_scale=skip_scale)) |
| | else: |
| | attns.append(None) |
| | self.nets = nn.ModuleList(nets) |
| | self.attns = nn.ModuleList(attns) |
| |
|
| | self.downsample = None |
| | if downsample: |
| | self.downsample = nn.Conv2d(out_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=2, |
| | padding=1) |
| |
|
| | def forward(self, x): |
| | xs = [] |
| |
|
| | for attn, net in zip(self.attns, self.nets): |
| | x = net(x) |
| | if attn: |
| | x = attn(x) |
| | xs.append(x) |
| |
|
| | if self.downsample: |
| | x = self.downsample(x) |
| | xs.append(x) |
| |
|
| | return x, xs |
| |
|
| |
|
| | class MidBlock(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | num_layers: int = 1, |
| | attention: bool = True, |
| | attention_heads: int = 16, |
| | skip_scale: float = 1, |
| | ): |
| | super().__init__() |
| |
|
| | nets = [] |
| | attns = [] |
| | |
| | nets.append( |
| | ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) |
| | |
| | for i in range(num_layers): |
| | nets.append( |
| | ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) |
| | if attention: |
| | attns.append( |
| | MVAttention(in_channels, |
| | attention_heads, |
| | skip_scale=skip_scale)) |
| | else: |
| | attns.append(None) |
| | self.nets = nn.ModuleList(nets) |
| | self.attns = nn.ModuleList(attns) |
| |
|
| | def forward(self, x): |
| | x = self.nets[0](x) |
| | for attn, net in zip(self.attns, self.nets[1:]): |
| | if attn: |
| | x = attn(x) |
| | x = net(x) |
| | return x |
| |
|
| |
|
| | class UpBlock(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int, |
| | prev_out_channels: int, |
| | out_channels: int, |
| | num_layers: int = 1, |
| | upsample: bool = True, |
| | attention: bool = True, |
| | attention_heads: int = 16, |
| | skip_scale: float = 1, |
| | ): |
| | super().__init__() |
| |
|
| | nets = [] |
| | attns = [] |
| | for i in range(num_layers): |
| | cin = in_channels if i == 0 else out_channels |
| | cskip = prev_out_channels if (i == num_layers - |
| | 1) else out_channels |
| |
|
| | nets.append( |
| | ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) |
| | if attention: |
| | attns.append( |
| | MVAttention(out_channels, |
| | attention_heads, |
| | skip_scale=skip_scale)) |
| | else: |
| | attns.append(None) |
| | self.nets = nn.ModuleList(nets) |
| | self.attns = nn.ModuleList(attns) |
| |
|
| | self.upsample = None |
| | if upsample: |
| | self.upsample = nn.Conv2d(out_channels, |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | def forward(self, x, xs): |
| |
|
| | for attn, net in zip(self.attns, self.nets): |
| | res_x = xs[-1] |
| | xs = xs[:-1] |
| | x = torch.cat([x, res_x], dim=1) |
| | x = net(x) |
| | if attn: |
| | x = attn(x) |
| |
|
| | if self.upsample: |
| | x = F.interpolate(x, scale_factor=2.0, mode='nearest') |
| | x = self.upsample(x) |
| |
|
| | return x |
| |
|
| |
|
| | |
| | class MVUNet(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), |
| | down_attention: Tuple[bool, |
| | ...] = (False, False, False, True, True), |
| | mid_attention: bool = True, |
| | up_channels: Tuple[int, ...] = (1024, 512, 256), |
| | up_attention: Tuple[bool, ...] = (True, True, False), |
| | layers_per_block: int = 2, |
| | skip_scale: float = np.sqrt(0.5), |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.conv_in = nn.Conv2d(in_channels, |
| | down_channels[0], |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | |
| | down_blocks = [] |
| | cout = down_channels[0] |
| | for i in range(len(down_channels)): |
| | cin = cout |
| | cout = down_channels[i] |
| |
|
| | down_blocks.append( |
| | DownBlock( |
| | cin, |
| | cout, |
| | num_layers=layers_per_block, |
| | downsample=(i |
| | != len(down_channels) - 1), |
| | attention=down_attention[i], |
| | skip_scale=skip_scale, |
| | )) |
| | self.down_blocks = nn.ModuleList(down_blocks) |
| |
|
| | |
| | self.mid_block = MidBlock(down_channels[-1], |
| | attention=mid_attention, |
| | skip_scale=skip_scale) |
| |
|
| | |
| | up_blocks = [] |
| | cout = up_channels[0] |
| | for i in range(len(up_channels)): |
| | cin = cout |
| | cout = up_channels[i] |
| | cskip = down_channels[max(-2 - i, |
| | -len(down_channels))] |
| |
|
| | up_blocks.append( |
| | UpBlock( |
| | cin, |
| | cskip, |
| | cout, |
| | num_layers=layers_per_block + 1, |
| | upsample=(i != len(up_channels) - 1), |
| | attention=up_attention[i], |
| | skip_scale=skip_scale, |
| | )) |
| | self.up_blocks = nn.ModuleList(up_blocks) |
| |
|
| | |
| | self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], |
| | num_groups=32, |
| | eps=1e-5) |
| | self.conv_out = nn.Conv2d(up_channels[-1], |
| | out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | def forward(self, x): |
| | |
| |
|
| | |
| | x = self.conv_in(x) |
| |
|
| | |
| | xss = [x] |
| | for block in self.down_blocks: |
| | x, xs = block(x) |
| | xss.extend(xs) |
| |
|
| | |
| | x = self.mid_block(x) |
| |
|
| | |
| | for block in self.up_blocks: |
| | xs = xss[-len(block.nets):] |
| | xss = xss[:-len(block.nets)] |
| | x = block(x, xs) |
| |
|
| | |
| | x = self.norm_out(x) |
| | x = F.silu(x) |
| | x = self.conv_out(x) |
| |
|
| | return x |
| |
|
| |
|
| | class LGM_MVEncoder(MVUNet): |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | down_channels: Tuple[int] = (64, 128, 256, 512, 1024), |
| | down_attention: Tuple[bool] = (False, False, False, True, True), |
| | mid_attention: bool = True, |
| | up_channels: Tuple[int] = (1024, 512, 256), |
| | up_attention: Tuple[bool] = (True, True, False), |
| | layers_per_block: int = 2, |
| | skip_scale: float = np.sqrt(0.5), |
| | z_channels=4, |
| | double_z=True, |
| | add_fusion_layer=True, |
| | ): |
| | super().__init__(in_channels, out_channels, down_channels, |
| | down_attention, mid_attention, up_channels, |
| | up_attention, layers_per_block, skip_scale) |
| | del self.up_blocks |
| |
|
| | self.conv_out = torch.nn.Conv2d(up_channels[0], |
| | 2 * |
| | z_channels if double_z else z_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| | if add_fusion_layer: |
| | self.fusion_layer = torch.nn.Conv2d( |
| | 2 * z_channels * 4 if double_z else z_channels * 4, |
| | 2 * z_channels if double_z else z_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1) |
| |
|
| | self.num_frames = 4 |
| | |
| | def forward(self, x): |
| | |
| | x = self.conv_in(x) |
| |
|
| | |
| | xss = [x] |
| | for block in self.down_blocks: |
| | x, xs = block(x) |
| | xss.extend(xs) |
| |
|
| | |
| | x = self.mid_block(x) |
| |
|
| | |
| | x = x.chunk(x.shape[0] // self.num_frames) |
| | |
| | x = [self.fusion_layer(torch.cat(feat.chunk(feat.shape[0]), dim=1)) for feat in x] |
| | st() |
| | return torch.cat(x, dim=0) |