Instructions to use michaelriedl/MonsterForge-medium with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use michaelriedl/MonsterForge-medium with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="michaelriedl/MonsterForge-medium", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("michaelriedl/MonsterForge-medium", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from math import log2 | |
| from torch import nn, einsum | |
| from kornia.filters import filter2d | |
| from einops import reduce, rearrange, repeat | |
| def exists(val): | |
| return val is not None | |
| def is_power_of_two(val): | |
| return log2(val).is_integer() | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def get_1d_dct(i, freq, L): | |
| result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L) | |
| return result * (1 if freq == 0 else math.sqrt(2)) | |
| def get_dct_weights(width, channel, fidx_u, fidx_v): | |
| dct_weights = torch.zeros(1, channel, width, width) | |
| c_part = channel // len(fidx_u) | |
| for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)): | |
| for x in range(width): | |
| for y in range(width): | |
| coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width) | |
| dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value | |
| return dct_weights | |
| class Blur(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| f = torch.Tensor([1, 2, 1]) | |
| self.register_buffer("f", f) | |
| def forward(self, x): | |
| f = self.f | |
| f = f[None, None, :] * f[None, :, None] | |
| return filter2d(x, f, normalized=True) | |
| class ChanNorm(nn.Module): | |
| def __init__(self, dim, eps=1e-5): | |
| super().__init__() | |
| self.eps = eps | |
| self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) | |
| self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) | |
| def forward(self, x): | |
| var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
| mean = torch.mean(x, dim=1, keepdim=True) | |
| return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |
| def Conv2dSame(dim_in, dim_out, kernel_size, bias=True): | |
| pad_left = kernel_size // 2 | |
| pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left | |
| return nn.Sequential( | |
| nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)), | |
| nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias), | |
| ) | |
| class DepthWiseConv2d(nn.Module): | |
| def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d( | |
| dim_in, | |
| dim_in, | |
| kernel_size=kernel_size, | |
| padding=padding, | |
| groups=dim_in, | |
| stride=stride, | |
| bias=bias, | |
| ), | |
| nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class FCANet(nn.Module): | |
| def __init__(self, *, chan_in, chan_out, reduction=4, width): | |
| super().__init__() | |
| freq_w, freq_h = ([0] * 8), list( | |
| range(8) | |
| ) # in paper, it seems 16 frequencies was ideal | |
| dct_weights = get_dct_weights( | |
| width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w] | |
| ) | |
| self.register_buffer("dct_weights", dct_weights) | |
| chan_intermediate = max(3, chan_out // reduction) | |
| self.net = nn.Sequential( | |
| nn.Conv2d(chan_in, chan_intermediate, 1), | |
| nn.LeakyReLU(0.1), | |
| nn.Conv2d(chan_intermediate, chan_out, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| x = reduce( | |
| x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1 | |
| ) | |
| return self.net(x) | |
| class Generator(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| image_size, | |
| latent_dim=256, | |
| fmap_max=512, | |
| fmap_inverse_coef=12, | |
| transparent=False, | |
| greyscale=False, | |
| attn_res_layers=[], | |
| freq_chan_attn=False, | |
| syncbatchnorm=False, | |
| antialias=False, | |
| ): | |
| super().__init__() | |
| resolution = log2(image_size) | |
| assert is_power_of_two(image_size), "image size must be a power of 2" | |
| # Set the normalization and blur | |
| norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d | |
| Blur = nn.Identity if not antialias else Blur | |
| if transparent: | |
| init_channel = 4 | |
| elif greyscale: | |
| init_channel = 1 | |
| else: | |
| init_channel = 3 | |
| self.latent_dim = latent_dim | |
| fmap_max = default(fmap_max, latent_dim) | |
| self.initial_conv = nn.Sequential( | |
| nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4), | |
| norm_class(latent_dim * 2), | |
| nn.GLU(dim=1), | |
| ) | |
| num_layers = int(resolution) - 2 | |
| features = list( | |
| map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2)) | |
| ) | |
| features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features)) | |
| features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features)) | |
| features = [latent_dim, *features] | |
| in_out_features = list(zip(features[:-1], features[1:])) | |
| self.res_layers = range(2, num_layers + 2) | |
| self.layers = nn.ModuleList([]) | |
| self.res_to_feature_map = dict(zip(self.res_layers, in_out_features)) | |
| self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10)) | |
| self.sle_map = list( | |
| filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map) | |
| ) | |
| self.sle_map = dict(self.sle_map) | |
| self.num_layers_spatial_res = 1 | |
| for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features): | |
| image_width = 2**res | |
| attn = None | |
| if image_width in attn_res_layers: | |
| attn = PreNorm(chan_in, LinearAttention(chan_in)) | |
| sle = None | |
| if res in self.sle_map: | |
| residual_layer = self.sle_map[res] | |
| sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1] | |
| if freq_chan_attn: | |
| sle = FCANet( | |
| chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1) | |
| ) | |
| else: | |
| sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out) | |
| layer = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| PixelShuffleUpsample(chan_in), | |
| Blur(), | |
| Conv2dSame(chan_in, chan_out * 2, 4), | |
| Noise(), | |
| norm_class(chan_out * 2), | |
| nn.GLU(dim=1), | |
| ), | |
| sle, | |
| attn, | |
| ] | |
| ) | |
| self.layers.append(layer) | |
| self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1) | |
| def forward(self, x): | |
| x = rearrange(x, "b c -> b c () ()") | |
| x = self.initial_conv(x) | |
| x = F.normalize(x, dim=1) | |
| residuals = dict() | |
| for res, (up, sle, attn) in zip(self.res_layers, self.layers): | |
| if exists(attn): | |
| x = attn(x) + x | |
| x = up(x) | |
| if exists(sle): | |
| out_res = self.sle_map[res] | |
| residual = sle(x) | |
| residuals[out_res] = residual | |
| next_res = res + 1 | |
| if next_res in residuals: | |
| x = x * residuals[next_res] | |
| return self.out_conv(x) | |
| class GlobalContext(nn.Module): | |
| def __init__(self, *, chan_in, chan_out): | |
| super().__init__() | |
| self.to_k = nn.Conv2d(chan_in, 1, 1) | |
| chan_intermediate = max(3, chan_out // 2) | |
| self.net = nn.Sequential( | |
| nn.Conv2d(chan_in, chan_intermediate, 1), | |
| nn.LeakyReLU(0.1), | |
| nn.Conv2d(chan_intermediate, chan_out, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, x): | |
| context = self.to_k(x) | |
| context = context.flatten(2).softmax(dim=-1) | |
| out = einsum("b i n, b c n -> b c i", context, x.flatten(2)) | |
| out = out.unsqueeze(-1) | |
| return self.net(out) | |
| class LinearAttention(nn.Module): | |
| def __init__(self, dim, dim_head=64, heads=8, kernel_size=3): | |
| super().__init__() | |
| self.scale = dim_head**-0.5 | |
| self.heads = heads | |
| self.dim_head = dim_head | |
| inner_dim = dim_head * heads | |
| self.kernel_size = kernel_size | |
| self.nonlin = nn.GELU() | |
| self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False) | |
| self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False) | |
| self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False) | |
| self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) | |
| self.to_out = nn.Conv2d(inner_dim * 2, dim, 1) | |
| def forward(self, fmap): | |
| h, x, y = self.heads, *fmap.shape[-2:] | |
| # linear attention | |
| lin_q, lin_k, lin_v = ( | |
| self.to_lin_q(fmap), | |
| *self.to_lin_kv(fmap).chunk(2, dim=1), | |
| ) | |
| lin_q, lin_k, lin_v = map( | |
| lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h), | |
| (lin_q, lin_k, lin_v), | |
| ) | |
| lin_q = lin_q.softmax(dim=-1) | |
| lin_k = lin_k.softmax(dim=-2) | |
| lin_q = lin_q * self.scale | |
| context = einsum("b n d, b n e -> b d e", lin_k, lin_v) | |
| lin_out = einsum("b n d, b d e -> b n e", lin_q, context) | |
| lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y) | |
| # conv-like full attention | |
| q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1)) | |
| q, k, v = map( | |
| lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v) | |
| ) | |
| k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2) | |
| v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2) | |
| k, v = map( | |
| lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v) | |
| ) | |
| q = rearrange(q, "b c ... -> b (...) c") * self.scale | |
| sim = einsum("b i d, b i j d -> b i j", q, k) | |
| sim = sim - sim.amax(dim=-1, keepdim=True).detach() | |
| attn = sim.softmax(dim=-1) | |
| full_out = einsum("b i j, b i j d -> b i d", attn, v) | |
| full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y) | |
| # add outputs of linear attention + conv like full attention | |
| lin_out = self.nonlin(lin_out) | |
| out = torch.cat((lin_out, full_out), dim=1) | |
| return self.to_out(out) | |
| class Noise(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x, noise=None): | |
| b, _, h, w, device = *x.shape, x.device | |
| if not exists(noise): | |
| noise = torch.randn(b, 1, h, w, device=device) | |
| return x + self.weight * noise | |
| class PixelShuffleUpsample(nn.Module): | |
| def __init__(self, dim, dim_out=None): | |
| super().__init__() | |
| dim_out = default(dim_out, dim) | |
| conv = nn.Conv2d(dim, dim_out * 4, 1) | |
| self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2)) | |
| self.init_conv_(conv) | |
| def init_conv_(self, conv): | |
| o, i, h, w = conv.weight.shape | |
| conv_weight = torch.empty(o // 4, i, h, w) | |
| nn.init.kaiming_uniform_(conv_weight) | |
| conv_weight = repeat(conv_weight, "o ... -> (o 4) ...") | |
| conv.weight.data.copy_(conv_weight) | |
| nn.init.zeros_(conv.bias.data) | |
| def forward(self, x): | |
| return self.net(x) | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = ChanNorm(dim) | |
| def forward(self, x): | |
| return self.fn(self.norm(x)) | |