| import math
|
| import torch
|
| import torch.nn.functional as F
|
| from math import log2
|
| from torch import nn, einsum
|
| from kornia.filters import filter2d
|
| from einops import reduce, rearrange, repeat
|
|
|
|
|
| def exists(val):
|
| return val is not None
|
|
|
|
|
| def is_power_of_two(val):
|
| return log2(val).is_integer()
|
|
|
|
|
| def default(val, d):
|
| return val if exists(val) else d
|
|
|
|
|
| def get_1d_dct(i, freq, L):
|
| result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
|
| return result * (1 if freq == 0 else math.sqrt(2))
|
|
|
|
|
| def get_dct_weights(width, channel, fidx_u, fidx_v):
|
| dct_weights = torch.zeros(1, channel, width, width)
|
| c_part = channel // len(fidx_u)
|
|
|
| for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
|
| for x in range(width):
|
| for y in range(width):
|
| coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
|
| dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
|
|
|
| return dct_weights
|
|
|
|
|
| class Blur(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| f = torch.Tensor([1, 2, 1])
|
| self.register_buffer("f", f)
|
|
|
| def forward(self, x):
|
| f = self.f
|
| f = f[None, None, :] * f[None, :, None]
|
| return filter2d(x, f, normalized=True)
|
|
|
|
|
| class ChanNorm(nn.Module):
|
| def __init__(self, dim, eps=1e-5):
|
| super().__init__()
|
| self.eps = eps
|
| self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
| self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
|
|
| def forward(self, x):
|
| var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
| mean = torch.mean(x, dim=1, keepdim=True)
|
| return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
|
|
|
|
| def Conv2dSame(dim_in, dim_out, kernel_size, bias=True):
|
| pad_left = kernel_size // 2
|
| pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
|
|
|
| return nn.Sequential(
|
| nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
|
| nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias),
|
| )
|
|
|
|
|
| class DepthWiseConv2d(nn.Module):
|
| def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
|
| super().__init__()
|
| self.net = nn.Sequential(
|
| nn.Conv2d(
|
| dim_in,
|
| dim_in,
|
| kernel_size=kernel_size,
|
| padding=padding,
|
| groups=dim_in,
|
| stride=stride,
|
| bias=bias,
|
| ),
|
| nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias),
|
| )
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class FCANet(nn.Module):
|
| def __init__(self, *, chan_in, chan_out, reduction=4, width):
|
| super().__init__()
|
|
|
| freq_w, freq_h = ([0] * 8), list(
|
| range(8)
|
| )
|
| dct_weights = get_dct_weights(
|
| width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
|
| )
|
| self.register_buffer("dct_weights", dct_weights)
|
|
|
| chan_intermediate = max(3, chan_out // reduction)
|
|
|
| self.net = nn.Sequential(
|
| nn.Conv2d(chan_in, chan_intermediate, 1),
|
| nn.LeakyReLU(0.1),
|
| nn.Conv2d(chan_intermediate, chan_out, 1),
|
| nn.Sigmoid(),
|
| )
|
|
|
| def forward(self, x):
|
| x = reduce(
|
| x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
|
| )
|
| return self.net(x)
|
|
|
|
|
| class Generator(nn.Module):
|
| def __init__(
|
| self,
|
| *,
|
| image_size,
|
| latent_dim=256,
|
| fmap_max=512,
|
| fmap_inverse_coef=12,
|
| transparent=False,
|
| greyscale=False,
|
| attn_res_layers=[],
|
| freq_chan_attn=False,
|
| syncbatchnorm=False,
|
| antialias=False,
|
| ):
|
| super().__init__()
|
| resolution = log2(image_size)
|
| assert is_power_of_two(image_size), "image size must be a power of 2"
|
|
|
|
|
| norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
|
| Blur = nn.Identity if not antialias else Blur
|
|
|
| if transparent:
|
| init_channel = 4
|
| elif greyscale:
|
| init_channel = 1
|
| else:
|
| init_channel = 3
|
|
|
| self.latent_dim = latent_dim
|
|
|
| fmap_max = default(fmap_max, latent_dim)
|
|
|
| self.initial_conv = nn.Sequential(
|
| nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
|
| norm_class(latent_dim * 2),
|
| nn.GLU(dim=1),
|
| )
|
|
|
| num_layers = int(resolution) - 2
|
| features = list(
|
| map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
|
| )
|
| features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
|
| features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
|
| features = [latent_dim, *features]
|
|
|
| in_out_features = list(zip(features[:-1], features[1:]))
|
|
|
| self.res_layers = range(2, num_layers + 2)
|
| self.layers = nn.ModuleList([])
|
| self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
|
|
|
| self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
|
| self.sle_map = list(
|
| filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
|
| )
|
| self.sle_map = dict(self.sle_map)
|
|
|
| self.num_layers_spatial_res = 1
|
|
|
| for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
|
| image_width = 2**res
|
|
|
| attn = None
|
| if image_width in attn_res_layers:
|
| attn = PreNorm(chan_in, LinearAttention(chan_in))
|
|
|
| sle = None
|
| if res in self.sle_map:
|
| residual_layer = self.sle_map[res]
|
| sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
|
|
|
| if freq_chan_attn:
|
| sle = FCANet(
|
| chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
|
| )
|
| else:
|
| sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
|
|
|
| layer = nn.ModuleList(
|
| [
|
| nn.Sequential(
|
| PixelShuffleUpsample(chan_in),
|
| Blur(),
|
| Conv2dSame(chan_in, chan_out * 2, 4),
|
| Noise(),
|
| norm_class(chan_out * 2),
|
| nn.GLU(dim=1),
|
| ),
|
| sle,
|
| attn,
|
| ]
|
| )
|
| self.layers.append(layer)
|
|
|
| self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
|
|
|
| def forward(self, x):
|
| x = rearrange(x, "b c -> b c () ()")
|
| x = self.initial_conv(x)
|
| x = F.normalize(x, dim=1)
|
|
|
| residuals = dict()
|
|
|
| for res, (up, sle, attn) in zip(self.res_layers, self.layers):
|
| if exists(attn):
|
| x = attn(x) + x
|
|
|
| x = up(x)
|
|
|
| if exists(sle):
|
| out_res = self.sle_map[res]
|
| residual = sle(x)
|
| residuals[out_res] = residual
|
|
|
| next_res = res + 1
|
| if next_res in residuals:
|
| x = x * residuals[next_res]
|
|
|
| return self.out_conv(x)
|
|
|
|
|
| class GlobalContext(nn.Module):
|
| def __init__(self, *, chan_in, chan_out):
|
| super().__init__()
|
| self.to_k = nn.Conv2d(chan_in, 1, 1)
|
| chan_intermediate = max(3, chan_out // 2)
|
|
|
| self.net = nn.Sequential(
|
| nn.Conv2d(chan_in, chan_intermediate, 1),
|
| nn.LeakyReLU(0.1),
|
| nn.Conv2d(chan_intermediate, chan_out, 1),
|
| nn.Sigmoid(),
|
| )
|
|
|
| def forward(self, x):
|
| context = self.to_k(x)
|
| context = context.flatten(2).softmax(dim=-1)
|
| out = einsum("b i n, b c n -> b c i", context, x.flatten(2))
|
| out = out.unsqueeze(-1)
|
| return self.net(out)
|
|
|
|
|
| class LinearAttention(nn.Module):
|
| def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
|
| super().__init__()
|
| self.scale = dim_head**-0.5
|
| self.heads = heads
|
| self.dim_head = dim_head
|
| inner_dim = dim_head * heads
|
|
|
| self.kernel_size = kernel_size
|
| self.nonlin = nn.GELU()
|
|
|
| self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
| self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
|
|
|
| self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
|
| self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
|
|
|
| self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
|
|
|
| def forward(self, fmap):
|
| h, x, y = self.heads, *fmap.shape[-2:]
|
|
|
|
|
|
|
| lin_q, lin_k, lin_v = (
|
| self.to_lin_q(fmap),
|
| *self.to_lin_kv(fmap).chunk(2, dim=1),
|
| )
|
| lin_q, lin_k, lin_v = map(
|
| lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
|
| (lin_q, lin_k, lin_v),
|
| )
|
|
|
| lin_q = lin_q.softmax(dim=-1)
|
| lin_k = lin_k.softmax(dim=-2)
|
|
|
| lin_q = lin_q * self.scale
|
|
|
| context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
|
| lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
|
| lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
| q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
|
| q, k, v = map(
|
| lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
|
| )
|
|
|
| k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
| v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
|
|
|
| k, v = map(
|
| lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
|
| )
|
|
|
| q = rearrange(q, "b c ... -> b (...) c") * self.scale
|
|
|
| sim = einsum("b i d, b i j d -> b i j", q, k)
|
| sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
|
| attn = sim.softmax(dim=-1)
|
|
|
| full_out = einsum("b i j, b i j d -> b i d", attn, v)
|
| full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
|
|
|
|
|
|
|
| lin_out = self.nonlin(lin_out)
|
| out = torch.cat((lin_out, full_out), dim=1)
|
| return self.to_out(out)
|
|
|
|
|
| class Noise(nn.Module):
|
| def __init__(self):
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.zeros(1))
|
|
|
| def forward(self, x, noise=None):
|
| b, _, h, w, device = *x.shape, x.device
|
|
|
| if not exists(noise):
|
| noise = torch.randn(b, 1, h, w, device=device)
|
|
|
| return x + self.weight * noise
|
|
|
|
|
| class PixelShuffleUpsample(nn.Module):
|
| def __init__(self, dim, dim_out=None):
|
| super().__init__()
|
| dim_out = default(dim_out, dim)
|
| conv = nn.Conv2d(dim, dim_out * 4, 1)
|
|
|
| self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
|
|
|
| self.init_conv_(conv)
|
|
|
| def init_conv_(self, conv):
|
| o, i, h, w = conv.weight.shape
|
| conv_weight = torch.empty(o // 4, i, h, w)
|
| nn.init.kaiming_uniform_(conv_weight)
|
| conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
|
|
|
| conv.weight.data.copy_(conv_weight)
|
| nn.init.zeros_(conv.bias.data)
|
|
|
| def forward(self, x):
|
| return self.net(x)
|
|
|
|
|
| class PreNorm(nn.Module):
|
| def __init__(self, dim, fn):
|
| super().__init__()
|
| self.fn = fn
|
| self.norm = ChanNorm(dim)
|
|
|
| def forward(self, x):
|
| return self.fn(self.norm(x))
|
|
|