tellurion's picture
initialize huggingface space demo
d066167
import torch
import torch.nn as nn
from refnet.modules.layers import zero_module
from refnet.modules.attention import MemoryEfficientAttention
from refnet.modules.transformer import BasicTransformerBlock
from refnet.util import checkpoint_wrapper, exists
from refnet.util import load_weights
class NormalizedLinear(nn.Module):
def __init__(self, dim, output_dim, checkpoint=True):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(dim, output_dim),
nn.LayerNorm(output_dim)
)
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x):
return self.layers(x)
class GlobalProjection(nn.Module):
def __init__(self, input_dim, output_dim, heads, dim_head=128, checkpoint=True):
super().__init__()
self.c_dim = output_dim
self.dim_head = dim_head
self.head = (heads[0], heads[0] * heads[1])
self.proj1 = nn.Linear(input_dim, dim_head * heads[0])
self.proj2 = nn.Sequential(
nn.SiLU(),
zero_module(nn.Linear(dim_head, output_dim * heads[1])),
)
self.norm = nn.LayerNorm(output_dim)
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x):
x = self.proj1(x).reshape(-1, self.head[0], self.dim_head).contiguous()
x = self.proj2(x).reshape(-1, self.head[1], self.c_dim).contiguous()
return self.norm(x)
class ClusterConcat(nn.Module):
def __init__(self, input_dim, c_dim, output_dim, dim_head=64, token_length=196, checkpoint=True):
super().__init__()
self.attn = MemoryEfficientAttention(input_dim, dim_head=dim_head)
self.norm = nn.LayerNorm(input_dim)
self.proj = nn.Sequential(
nn.Linear(input_dim + c_dim, output_dim),
nn.SiLU(),
nn.Linear(output_dim, output_dim),
nn.LayerNorm(output_dim)
)
self.token_length = token_length
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x, emb, fgbg=False, *args, **kwargs):
x = self.attn(x)[:, :self.token_length]
x = self.norm(x)
x = torch.cat([x, emb], 2)
x = self.proj(x)
if fgbg:
x = torch.cat(torch.chunk(x, 2), 1)
return x
class RecoveryClusterConcat(ClusterConcat):
def __init__(self, input_dim, c_dim, output_dim, dim_head=64, *args, **kwargs):
super().__init__(input_dim, c_dim, output_dim, dim_head=dim_head, *args, **kwargs)
self.transformer = BasicTransformerBlock(
output_dim, output_dim//dim_head, dim_head,
disable_cross_attn=True, checkpoint=False
)
@checkpoint_wrapper
def forward(self, x, emb, bg=False):
x = self.attn(x)[:, :self.token_length]
x = self.norm(x)
x = torch.cat([x, emb], 2)
x = self.proj(x)
if bg:
x = self.transformer(x)
return x
class LogitClusterConcat(ClusterConcat):
def __init__(self, c_dim, mlp_in_dim, mlp_ckpt_path=None, *args, **kwargs):
super().__init__(c_dim=c_dim, *args, **kwargs)
self.mlp = AdaptiveMLP(c_dim, mlp_in_dim)
if exists(mlp_ckpt_path):
self.mlp.load_state_dict(load_weights(mlp_ckpt_path), strict=True)
@checkpoint_wrapper
def forward(self, x, emb, bg=False):
with torch.no_grad():
emb = self.mlp(emb).detach()
return super().forward(x, emb, bg)
class AdaptiveMLP(nn.Module):
def __init__(self, dim, in_dim, layers=4, checkpoint=True):
super().__init__()
model = [nn.Sequential(nn.Linear(in_dim, dim))]
for i in range(1, layers):
model += [nn.Sequential(
nn.SiLU(),
nn.LayerNorm(dim),
nn.Linear(dim, dim)
)]
self.mlp = nn.Sequential(*model)
self.fusion_layer = nn.Linear(dim * layers, dim, bias=False)
self.norm = nn.LayerNorm(dim)
self.checkpoint = checkpoint
@checkpoint_wrapper
def forward(self, x):
fx = []
for layer in self.mlp:
x = layer(x)
fx.append(x)
x = torch.cat(fx, dim=2)
out = self.fusion_layer(x)
out = self.norm(out)
return out
class Concat(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, x, y, *args, **kwargs):
return torch.cat([x, y], dim=-1)