Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| from refnet.modules.layers import zero_module | |
| from refnet.modules.attention import MemoryEfficientAttention | |
| from refnet.modules.transformer import BasicTransformerBlock | |
| from refnet.util import checkpoint_wrapper, exists | |
| from refnet.util import load_weights | |
| class NormalizedLinear(nn.Module): | |
| def __init__(self, dim, output_dim, checkpoint=True): | |
| super().__init__() | |
| self.layers = nn.Sequential( | |
| nn.Linear(dim, output_dim), | |
| nn.LayerNorm(output_dim) | |
| ) | |
| self.checkpoint = checkpoint | |
| def forward(self, x): | |
| return self.layers(x) | |
| class GlobalProjection(nn.Module): | |
| def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True): | |
| super().__init__() | |
| self.c_dim = output_dim | |
| self.dim_head = dim_head | |
| self.head = (heads[0], heads[0] * heads[1]) | |
| self.proj1 = nn.Linear(input_dim, dim_head * heads[0]) | |
| self.proj2 = nn.Sequential( | |
| nn.SiLU(), | |
| zero_module(nn.Linear(dim_head, output_dim * heads[1])), | |
| ) | |
| self.norm = nn.LayerNorm(output_dim) | |
| self.checkpoint = checkpoint | |
| def forward(self, x): | |
| x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous() | |
| x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous() | |
| return self.norm(x) | |
| class ClusterConcat(nn.Module): | |
| def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True): | |
| super().__init__() | |
| self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head) | |
| self.norm = nn.LayerNorm(input_dim) | |
| self.proj = nn.Sequential( | |
| nn.Linear(input_dim + c_dim, output_dim), | |
| nn.SiLU(), | |
| nn.Linear(output_dim, output_dim), | |
| nn.LayerNorm(output_dim) | |
| ) | |
| self.token_length = token_length | |
| self.checkpoint = checkpoint | |
| def forward(self, x, emb, fgbg=False, *args, **kwargs): | |
| x = self.attn(x)[:, :self.token_length] | |
| x = self.norm(x) | |
| x = torch.cat([x, emb], 2) | |
| x = self.proj(x) | |
| if fgbg: | |
| x = torch.cat(torch.chunk(x, 2), 1) | |
| return x | |
| class RecoveryClusterConcat(ClusterConcat): | |
| def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs): | |
| super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs) | |
| self.transformer = BasicTransformerBlock( | |
| output_dim, output_dim//dim_head, dim_head, | |
| disable_cross_attn=True, checkpoint=False | |
| ) | |
| def forward(self, x, emb, bg=False): | |
| x = self.attn(x)[:, :self.token_length] | |
| x = self.norm(x) | |
| x = torch.cat([x, emb], 2) | |
| x = self.proj(x) | |
| if bg: | |
| x = self.transformer(x) | |
| return x | |
| class LogitClusterConcat(ClusterConcat): | |
| def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs): | |
| super().__init__(c_dim=c_dim, *args, **kwargs) | |
| self.mlp = AdaptiveMLP(c_dim, mlp_in_dim) | |
| if exists(mlp_ckpt_path): | |
| self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True) | |
| def forward(self, x, emb, bg=False): | |
| with torch.no_grad(): | |
| emb = self.mlp(emb).detach() | |
| return super().forward(x, emb, bg) | |
| class AdaptiveMLP(nn.Module): | |
| def __init__(self, dim, in_dim, layers=4, checkpoint=True): | |
| super().__init__() | |
| model = [nn.Sequential(nn.Linear(in_dim, dim))] | |
| for i in range(1, layers): | |
| model += [nn.Sequential( | |
| nn.SiLU(), | |
| nn.LayerNorm(dim), | |
| nn.Linear(dim, dim) | |
| )] | |
| self.mlp = nn.Sequential(*model) | |
| self.fusion_layer = nn.Linear(dim * layers, dim, bias=False) | |
| self.norm = nn.LayerNorm(dim) | |
| self.checkpoint = checkpoint | |
| def forward(self, x): | |
| fx = [] | |
| for layer in self.mlp: | |
| x = layer(x) | |
| fx.append(x) | |
| x = torch.cat(fx, dim=2) | |
| out = self.fusion_layer(x) | |
| out = self.norm(out) | |
| return out | |
| class Concat(nn.Module): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__() | |
| def forward(self, x, y, *args, **kwargs): | |
| return torch.cat([x, y], dim=-1) |