| | |
| | |
| | |
| |
|
| | |
| |
|
| | import math |
| | import copy |
| | from random import random |
| | from typing import List, Union |
| | from tqdm.auto import tqdm |
| | from functools import partial, wraps |
| | from contextlib import contextmanager, nullcontext |
| | from collections import namedtuple |
| | from pathlib import Path |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch.nn.parallel import DistributedDataParallel |
| | from torch import nn, einsum |
| | from torch.cuda.amp import autocast |
| | from torch.special import expm1 |
| | import torchvision.transforms as T |
| |
|
| | import kornia.augmentation as K |
| |
|
| | from einops import rearrange, repeat, reduce |
| | from einops.layers.torch import Rearrange, Reduce |
| |
|
| | from einops_exts import rearrange_many, repeat_many, check_shape |
| | from einops_exts.torch import EinopsToAndFrom |
| |
|
| | |
| | from tensorflow.keras.preprocessing import text, sequence |
| | from tensorflow.keras.preprocessing.text import Tokenizer |
| | |
| | |
| | |
| | from PD_pLMProbXDiff.UtilityPack import ( |
| | prepare_UNet_keys, modify_keys, params |
| | ) |
| |
|
| | |
| | from torchinfo import summary |
| | import json |
| | |
| | device = torch.device( |
| | "cuda:0" if torch.cuda.is_available() else "cpu" |
| | ) |
| | print('identify the device independently', device) |
| |
|
| | |
| | |
| | |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| | def identity(t, *args, **kwargs): |
| | return t |
| |
|
| | def first(arr, d = None): |
| | if len(arr) == 0: |
| | return d |
| | return arr[0] |
| |
|
| | def maybe(fn): |
| | @wraps(fn) |
| | def inner(x): |
| | if not exists(x): |
| | return x |
| | return fn(x) |
| | return inner |
| |
|
| | def once(fn): |
| | called = False |
| | @wraps(fn) |
| | def inner(x): |
| | nonlocal called |
| | if called: |
| | return |
| | called = True |
| | return fn(x) |
| | return inner |
| |
|
| | print_once = once(print) |
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if callable(d) else d |
| |
|
| | def cast_tuple(val, length = None): |
| | if isinstance(val, list): |
| | val = tuple(val) |
| |
|
| | output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) |
| |
|
| | if exists(length): |
| | assert len(output) == length |
| |
|
| | return output |
| |
|
| | def is_float_dtype(dtype): |
| | return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) |
| |
|
| | def cast_uint8_images_to_float(images): |
| | if not images.dtype == torch.uint8: |
| | return images |
| | return images / 255 |
| |
|
| | def module_device(module): |
| | return next(module.parameters()).device |
| |
|
| | def zero_init_(m): |
| | nn.init.zeros_(m.weight) |
| | if exists(m.bias): |
| | nn.init.zeros_(m.bias) |
| |
|
| | def eval_decorator(fn): |
| | def inner(model, *args, **kwargs): |
| | was_training = model.training |
| | model.eval() |
| | out = fn(model, *args, **kwargs) |
| | model.train(was_training) |
| | return out |
| | return inner |
| |
|
| | def pad_tuple_to_length(t, length, fillvalue = None): |
| | remain_length = length - len(t) |
| | if remain_length <= 0: |
| | return t |
| | return (*t, *((fillvalue,) * remain_length)) |
| |
|
| | |
| | |
| | |
| |
|
| | class Identity(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__() |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | |
| | |
| | |
| |
|
| | def log(t, eps: float = 1e-12): |
| | return torch.log(t.clamp(min = eps)) |
| |
|
| | def l2norm(t): |
| | return F.normalize(t, dim = -1) |
| |
|
| | def right_pad_dims_to(x, t): |
| | padding_dims = x.ndim - t.ndim |
| | if padding_dims <= 0: |
| | return t |
| | return t.view(*t.shape, *((1,) * padding_dims)) |
| |
|
| | def masked_mean(t, *, dim, mask = None): |
| | if not exists(mask): |
| | return t.mean(dim = dim) |
| |
|
| | denom = mask.sum(dim = dim, keepdim = True) |
| | mask = rearrange(mask, 'b n -> b n 1') |
| | masked_t = t.masked_fill(~mask, 0.) |
| |
|
| | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) |
| |
|
| | def resize_image_to( |
| | image, |
| | target_image_size, |
| | clamp_range = None |
| | ): |
| | orig_image_size = image.shape[-1] |
| |
|
| | if orig_image_size == target_image_size: |
| | return image |
| |
|
| | out = F.interpolate(image.float(), target_image_size, mode = 'linear', align_corners = True) |
| |
|
| | return out |
| |
|
| | |
| | |
| | |
| | def normalize_neg_one_to_one(img): |
| | return img * 2 - 1 |
| |
|
| | def unnormalize_zero_to_one(normed_img): |
| | return (normed_img + 1) * 0.5 |
| |
|
| | |
| | def prob_mask_like(shape, prob, device): |
| | if prob == 1: |
| | return torch.ones(shape, device = device, dtype = torch.bool) |
| | elif prob == 0: |
| | return torch.zeros(shape, device = device, dtype = torch.bool) |
| | else: |
| | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | @torch.jit.script |
| | def beta_linear_log_snr(t): |
| | return -torch.log(expm1(1e-4 + 10 * (t ** 2))) |
| |
|
| | @torch.jit.script |
| | def alpha_cosine_log_snr(t, s: float = 0.008): |
| | return -log((torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** -2) - 1, eps = 1e-5) |
| |
|
| | def log_snr_to_alpha_sigma(log_snr): |
| | return torch.sqrt(torch.sigmoid(log_snr)), torch.sqrt(torch.sigmoid(-log_snr)) |
| |
|
| | class GaussianDiffusionContinuousTimes(nn.Module): |
| | def __init__(self, *, noise_schedule, timesteps = 1000): |
| | super().__init__() |
| |
|
| | if noise_schedule == "linear": |
| | self.log_snr = beta_linear_log_snr |
| | elif noise_schedule == "cosine": |
| | self.log_snr = alpha_cosine_log_snr |
| | else: |
| | raise ValueError(f'invalid noise schedule {noise_schedule}') |
| |
|
| | self.num_timesteps = timesteps |
| |
|
| | def get_times(self, batch_size, noise_level, *, device): |
| | return torch.full((batch_size,), noise_level, device = device, dtype = torch.float32) |
| |
|
| | def sample_random_times(self, batch_size, max_thres = 0.999, *, device): |
| | return torch.zeros((batch_size,), device = device).float().uniform_(0, max_thres) |
| |
|
| | def get_condition(self, times): |
| | return maybe(self.log_snr)(times) |
| |
|
| | def get_sampling_timesteps(self, batch, *, device): |
| | times = torch.linspace(1., 0., self.num_timesteps + 1, device = device) |
| | times = repeat(times, 't -> b t', b = batch) |
| | times = torch.stack((times[:, :-1], times[:, 1:]), dim = 0) |
| | times = times.unbind(dim = -1) |
| | return times |
| |
|
| | def q_posterior(self, x_start, x_t, t, *, t_next = None): |
| | t_next = default(t_next, lambda: (t - 1. / self.num_timesteps).clamp(min = 0.)) |
| |
|
| | """ https://openreview.net/attachment?id=2LdBqxc1Yv&name=supplementary_material """ |
| | log_snr = self.log_snr(t) |
| | log_snr_next = self.log_snr(t_next) |
| | log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next)) |
| |
|
| | alpha, sigma = log_snr_to_alpha_sigma(log_snr) |
| | alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next) |
| |
|
| | |
| | c = -expm1(log_snr - log_snr_next) |
| | posterior_mean = alpha_next * (x_t * (1 - c) / alpha + c * x_start) |
| |
|
| | |
| | posterior_variance = (sigma_next ** 2) * c |
| | posterior_log_variance_clipped = log(posterior_variance, eps = 1e-20) |
| | return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| |
|
| | def q_sample(self, x_start, t, noise = None): |
| | dtype = x_start.dtype |
| |
|
| | if isinstance(t, float): |
| | batch = x_start.shape[0] |
| | t = torch.full((batch,), t, device = x_start.device, dtype = dtype) |
| |
|
| | noise = default(noise, lambda: torch.randn_like(x_start)) |
| | log_snr = self.log_snr(t).type(dtype) |
| | log_snr_padded_dim = right_pad_dims_to(x_start, log_snr) |
| | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) |
| |
|
| | return alpha * x_start + sigma * noise, log_snr |
| |
|
| | def q_sample_from_to(self, x_from, from_t, to_t, noise = None): |
| | shape, device, dtype = x_from.shape, x_from.device, x_from.dtype |
| | batch = shape[0] |
| |
|
| | if isinstance(from_t, float): |
| | from_t = torch.full((batch,), from_t, device = device, dtype = dtype) |
| |
|
| | if isinstance(to_t, float): |
| | to_t = torch.full((batch,), to_t, device = device, dtype = dtype) |
| |
|
| | noise = default(noise, lambda: torch.randn_like(x_from)) |
| |
|
| | log_snr = self.log_snr(from_t) |
| | log_snr_padded_dim = right_pad_dims_to(x_from, log_snr) |
| | alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim) |
| |
|
| | log_snr_to = self.log_snr(to_t) |
| | log_snr_padded_dim_to = right_pad_dims_to(x_from, log_snr_to) |
| | alpha_to, sigma_to = log_snr_to_alpha_sigma(log_snr_padded_dim_to) |
| |
|
| | return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha |
| |
|
| | def predict_start_from_noise(self, x_t, t, noise): |
| | log_snr = self.log_snr(t) |
| | log_snr = right_pad_dims_to(x_t, log_snr) |
| | alpha, sigma = log_snr_to_alpha_sigma(log_snr) |
| | return (x_t - sigma * noise) / alpha.clamp(min = 1e-8) |
| |
|
| | |
| | |
| | |
| |
|
| | class LayerNorm(nn.Module): |
| | def __init__(self, feats, stable = False, dim = -1): |
| | super().__init__() |
| | self.stable = stable |
| | self.dim = dim |
| |
|
| | self.g = nn.Parameter(torch.ones(feats, *((1,) * (-dim - 1)))) |
| |
|
| | def forward(self, x): |
| | dtype, dim = x.dtype, self.dim |
| |
|
| | if self.stable: |
| | x = x / x.amax(dim = dim, keepdim = True).detach() |
| |
|
| | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 |
| | var = torch.var(x, dim = dim, unbiased = False, keepdim = True) |
| | mean = torch.mean(x, dim = dim, keepdim = True) |
| |
|
| | return (x - mean) * (var + eps).rsqrt().type(dtype) * self.g.type(dtype) |
| |
|
| | ChanLayerNorm = partial(LayerNorm, dim = -2) |
| |
|
| | class Always(): |
| | def __init__(self, val): |
| | self.val = val |
| |
|
| | def __call__(self, *args, **kwargs): |
| | return self.val |
| |
|
| | class Residual(nn.Module): |
| | def __init__(self, fn): |
| | super().__init__() |
| | self.fn = fn |
| |
|
| | def forward(self, x, **kwargs): |
| | return self.fn(x, **kwargs) + x |
| |
|
| | class Parallel(nn.Module): |
| | def __init__(self, *fns): |
| | super().__init__() |
| | self.fns = nn.ModuleList(fns) |
| |
|
| | def forward(self, x): |
| | outputs = [fn(x) for fn in self.fns] |
| | return sum(outputs) |
| | |
| | |
| | |
| | |
| | |
| | class PerceiverAttention(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | dim_head = 64, |
| | heads = 8, |
| | cosine_sim_attn = False |
| | ): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1 |
| | self.cosine_sim_attn = cosine_sim_attn |
| | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 |
| |
|
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm = nn.LayerNorm(dim) |
| | self.norm_latents = nn.LayerNorm(dim) |
| |
|
| | self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, dim, bias = False), |
| | nn.LayerNorm(dim) |
| | ) |
| |
|
| | def forward(self, x, latents, mask = None): |
| | x = self.norm(x) |
| | latents = self.norm_latents(latents) |
| |
|
| | b, h = x.shape[0], self.heads |
| |
|
| | q = self.to_q(latents) |
| |
|
| | |
| | |
| | kv_input = torch.cat((x, latents), dim = -2) |
| | k, v = self.to_kv(kv_input).chunk(2, dim = -1) |
| |
|
| | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) |
| |
|
| | q = q * self.scale |
| |
|
| | |
| |
|
| | if self.cosine_sim_attn: |
| | q, k = map(l2norm, (q, k)) |
| |
|
| | |
| |
|
| | sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale |
| |
|
| | if exists(mask): |
| | max_neg_value = -torch.finfo(sim.dtype).max |
| | mask = F.pad(mask, (0, latents.shape[-2]), value = True) |
| | |
| | mask = rearrange(mask, 'b j -> b 1 1 j') |
| | sim = sim.masked_fill(~mask, max_neg_value) |
| |
|
| | |
| |
|
| | attn = sim.softmax(dim = -1, dtype = torch.float32) |
| | attn = attn.to(sim.dtype) |
| |
|
| | out = einsum('... i j, ... j d -> ... i d', attn, v) |
| | out = rearrange(out, 'b h n d -> b n (h d)', h = h) |
| | return self.to_out(out) |
| |
|
| | class PerceiverResampler(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | depth, |
| | dim_head = 64, |
| | heads = 8, |
| | num_latents = 64, |
| | num_latents_mean_pooled = 4, |
| | max_seq_len = 512, |
| | ff_mult = 4, |
| | cosine_sim_attn = False |
| | ): |
| | super().__init__() |
| | self.pos_emb = nn.Embedding(max_seq_len, dim) |
| |
|
| | self.latents = nn.Parameter(torch.randn(num_latents, dim)) |
| |
|
| | self.to_latents_from_mean_pooled_seq = None |
| |
|
| | if num_latents_mean_pooled > 0: |
| | self.to_latents_from_mean_pooled_seq = nn.Sequential( |
| | LayerNorm(dim), |
| | nn.Linear(dim, dim * num_latents_mean_pooled), |
| | Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) |
| | ) |
| |
|
| | self.layers = nn.ModuleList([]) |
| | for _ in range(depth): |
| | self.layers.append(nn.ModuleList([ |
| | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn), |
| | FeedForward(dim = dim, mult = ff_mult) |
| | ])) |
| |
|
| | def forward(self, x, mask = None): |
| | n, device = x.shape[1], x.device |
| | pos_emb = self.pos_emb(torch.arange(n, device = device)) |
| |
|
| | x_with_pos = x + pos_emb |
| |
|
| | latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) |
| |
|
| | if exists(self.to_latents_from_mean_pooled_seq): |
| | meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) |
| | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) |
| | latents = torch.cat((meanpooled_latents, latents), dim = -2) |
| |
|
| | for attn, ff in self.layers: |
| | latents = attn(x_with_pos, latents, mask = mask) + latents |
| | latents = ff(latents) + latents |
| |
|
| | return latents |
| |
|
| | |
| | |
| | |
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | dim_head = 64, |
| | heads = 8, |
| | context_dim = None, |
| | cosine_sim_attn = False |
| | ): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. |
| | self.cosine_sim_attn = cosine_sim_attn |
| | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 |
| |
|
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | self.norm = LayerNorm(dim) |
| |
|
| | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) |
| | self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) |
| |
|
| | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, dim, bias = False), |
| | LayerNorm(dim) |
| | ) |
| |
|
| | def forward(self, x, context = None, mask = None, attn_bias = None): |
| | b, n, device = *x.shape[:2], x.device |
| |
|
| | x = self.norm(x) |
| |
|
| | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) |
| |
|
| | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) |
| | q = q * self.scale |
| |
|
| | |
| |
|
| | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) |
| | k = torch.cat((nk, k), dim = -2) |
| | v = torch.cat((nv, v), dim = -2) |
| |
|
| | |
| |
|
| | if exists(context): |
| | assert exists(self.to_context) |
| | ck, cv = self.to_context(context).chunk(2, dim = -1) |
| | k = torch.cat((ck, k), dim = -2) |
| | v = torch.cat((cv, v), dim = -2) |
| |
|
| | |
| |
|
| | if self.cosine_sim_attn: |
| | q, k = map(l2norm, (q, k)) |
| |
|
| | |
| |
|
| | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale |
| |
|
| | |
| |
|
| | if exists(attn_bias): |
| | sim = sim + attn_bias |
| |
|
| | |
| |
|
| | max_neg_value = -torch.finfo(sim.dtype).max |
| |
|
| | if exists(mask): |
| | mask = F.pad(mask, (1, 0), value = True) |
| | |
| | mask = rearrange(mask, 'b j -> b 1 j') |
| | sim = sim.masked_fill(~mask, max_neg_value) |
| |
|
| | |
| |
|
| | attn = sim.softmax(dim = -1, dtype = torch.float32) |
| | attn = attn.to(sim.dtype) |
| |
|
| | |
| |
|
| | out = einsum('b h i j, b j d -> b h i d', attn, v) |
| |
|
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | return self.to_out(out) |
| |
|
| | |
| | |
| | |
| | def Upsample(dim, dim_out = None): |
| | dim_out = default(dim_out, dim) |
| |
|
| | return nn.Sequential( |
| | nn.Upsample(scale_factor = 2, mode = 'nearest'), |
| | nn.Conv1d(dim, dim_out, 3, padding = 1) |
| | ) |
| |
|
| | class PixelShuffleUpsample(nn.Module): |
| | """ |
| | code shared by @MalumaDev at DALLE2-pytorch for addressing checkboard artifacts |
| | https://arxiv.org/ftp/arxiv/papers/1707/1707.02937.pdf |
| | """ |
| | def __init__(self, dim, dim_out = None): |
| | super().__init__() |
| | dim_out = default(dim_out, dim) |
| | conv = nn.Conv1d(dim, dim_out * 4, 1) |
| |
|
| | self.net = nn.Sequential( |
| | conv, |
| | nn.SiLU(), |
| | nn.PixelShuffle(2) |
| | ) |
| |
|
| | self.init_conv_(conv) |
| |
|
| | def init_conv_(self, conv): |
| | |
| | o, i, h = conv.weight.shape |
| | conv_weight = torch.empty(o // 4, i, h ) |
| | nn.init.kaiming_uniform_(conv_weight) |
| | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') |
| |
|
| | conv.weight.data.copy_(conv_weight) |
| | nn.init.zeros_(conv.bias.data) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | def Downsample(dim, dim_out = None): |
| | |
| | |
| | dim_out = default(dim_out, dim) |
| | |
| | return nn.Sequential( |
| | |
| | Rearrange('b c (h s1) -> b (c s1) h', s1 = 2), |
| | nn.Conv1d(dim * 2, dim_out, 1) |
| | ) |
| |
|
| | class SinusoidalPosEmb(nn.Module): |
| | def __init__(self, dim): |
| | super().__init__() |
| | self.dim = dim |
| |
|
| | def forward(self, x): |
| | half_dim = self.dim // 2 |
| | emb = math.log(10000) / (half_dim - 1) |
| | emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) |
| | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') |
| | return torch.cat((emb.sin(), emb.cos()), dim = -1) |
| |
|
| | class LearnedSinusoidalPosEmb(nn.Module): |
| | """ following @crowsonkb 's lead with learned sinusoidal pos emb """ |
| | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ |
| |
|
| | def __init__(self, dim): |
| | super().__init__() |
| | assert (dim % 2) == 0 |
| | half_dim = dim // 2 |
| | self.weights = nn.Parameter(torch.randn(half_dim)) |
| |
|
| | def forward(self, x): |
| | x = rearrange(x, 'b -> b 1') |
| | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi |
| | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) |
| | fouriered = torch.cat((x, fouriered), dim = -1) |
| | return fouriered |
| |
|
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_out, |
| | groups = 8, |
| | norm = True |
| | ): |
| | super().__init__() |
| | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() |
| | self.activation = nn.SiLU() |
| | self.project = nn.Conv1d(dim, dim_out, 3, padding = 1) |
| |
|
| | def forward(self, x, scale_shift = None): |
| | x = self.groupnorm(x) |
| |
|
| | if exists(scale_shift): |
| | scale, shift = scale_shift |
| | x = x * (scale + 1) + shift |
| |
|
| | x = self.activation(x) |
| | return self.project(x) |
| |
|
| | class ResnetBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_out, |
| | *, |
| | cond_dim = None, |
| | time_cond_dim = None, |
| | groups = 8, |
| | linear_attn = False, |
| | use_gca = False, |
| | squeeze_excite = False, |
| | **attn_kwargs |
| | ): |
| | super().__init__() |
| |
|
| | self.time_mlp = None |
| |
|
| | if exists(time_cond_dim): |
| | self.time_mlp = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(time_cond_dim, dim_out * 2) |
| | ) |
| |
|
| | self.cross_attn = None |
| |
|
| | if exists(cond_dim): |
| | attn_klass = CrossAttention if not linear_attn else LinearCrossAttention |
| |
|
| | self.cross_attn = EinopsToAndFrom( |
| | |
| | 'b c h ', |
| | 'b h c', |
| | attn_klass( |
| | dim = dim_out, |
| | context_dim = cond_dim, |
| | **attn_kwargs |
| | ) |
| | ) |
| |
|
| | self.block1 = Block(dim, dim_out, groups = groups) |
| | self.block2 = Block(dim_out, dim_out, groups = groups) |
| |
|
| | self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) |
| |
|
| | self.res_conv = nn.Conv1d(dim, dim_out, 1) if dim != dim_out else Identity() |
| |
|
| |
|
| | def forward(self, x, time_emb = None, cond = None): |
| |
|
| | scale_shift = None |
| | if exists(self.time_mlp) and exists(time_emb): |
| | time_emb = self.time_mlp(time_emb) |
| | |
| | time_emb = rearrange(time_emb, 'b c -> b c 1') |
| | scale_shift = time_emb.chunk(2, dim = 1) |
| |
|
| | h = self.block1(x) |
| |
|
| | if exists(self.cross_attn): |
| | assert exists(cond) |
| | h = self.cross_attn(h, context = cond) + h |
| |
|
| | h = self.block2(h, scale_shift = scale_shift) |
| |
|
| | h = h * self.gca(h) |
| |
|
| | return h + self.res_conv(x) |
| |
|
| | class CrossAttention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | context_dim = None, |
| | dim_head = 64, |
| | heads = 8, |
| | norm_context = False, |
| | cosine_sim_attn = False |
| | ): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. |
| | self.cosine_sim_attn = cosine_sim_attn |
| | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 |
| |
|
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| |
|
| | context_dim = default(context_dim, dim) |
| |
|
| | self.norm = LayerNorm(dim) |
| | self.norm_context = LayerNorm(context_dim) if norm_context else Identity() |
| |
|
| | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) |
| | self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Linear(inner_dim, dim, bias = False), |
| | LayerNorm(dim) |
| | ) |
| |
|
| | def forward(self, x, context, mask = None): |
| | b, n, device = *x.shape[:2], x.device |
| |
|
| | x = self.norm(x) |
| | context = self.norm_context(context) |
| |
|
| | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
| |
|
| | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) |
| |
|
| | |
| |
|
| | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) |
| |
|
| | k = torch.cat((nk, k), dim = -2) |
| | v = torch.cat((nv, v), dim = -2) |
| |
|
| | q = q * self.scale |
| |
|
| | |
| |
|
| | if self.cosine_sim_attn: |
| | q, k = map(l2norm, (q, k)) |
| |
|
| | |
| |
|
| | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale |
| |
|
| | |
| |
|
| | max_neg_value = -torch.finfo(sim.dtype).max |
| |
|
| | if exists(mask): |
| | mask = F.pad(mask, (1, 0), value = True) |
| | |
| | mask = rearrange(mask, 'b j -> b 1 j') |
| | sim = sim.masked_fill(~mask, max_neg_value) |
| |
|
| | attn = sim.softmax(dim = -1, dtype = torch.float32) |
| | attn = attn.to(sim.dtype) |
| |
|
| | out = einsum('b h i j, b h j d -> b h i d', attn, v) |
| | out = rearrange(out, 'b h n d -> b n (h d)') |
| | return self.to_out(out) |
| |
|
| | class LinearCrossAttention(CrossAttention): |
| | def forward(self, x, context, mask = None): |
| | b, n, device = *x.shape[:2], x.device |
| |
|
| | x = self.norm(x) |
| | context = self.norm_context(context) |
| |
|
| | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) |
| |
|
| | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads) |
| |
|
| | |
| |
|
| | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b) |
| |
|
| | k = torch.cat((nk, k), dim = -2) |
| | v = torch.cat((nv, v), dim = -2) |
| |
|
| | |
| |
|
| | max_neg_value = -torch.finfo(x.dtype).max |
| |
|
| | if exists(mask): |
| | mask = F.pad(mask, (1, 0), value = True) |
| | mask = rearrange(mask, 'b n -> b n 1') |
| | k = k.masked_fill(~mask, max_neg_value) |
| | v = v.masked_fill(~mask, 0.) |
| |
|
| | |
| |
|
| | q = q.softmax(dim = -1) |
| | k = k.softmax(dim = -2) |
| |
|
| | q = q * self.scale |
| |
|
| | context = einsum('b n d, b n e -> b d e', k, v) |
| | out = einsum('b n d, b d e -> b n e', q, context) |
| | out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) |
| | return self.to_out(out) |
| |
|
| | class LinearAttention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_head = 32, |
| | heads = 8, |
| | dropout = 0.05, |
| | context_dim = None, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.scale = dim_head ** -0.5 |
| | self.heads = heads |
| | inner_dim = dim_head * heads |
| | self.norm = ChanLayerNorm(dim) |
| |
|
| | self.nonlin = nn.SiLU() |
| |
|
| | self.to_q = nn.Sequential( |
| | nn.Dropout(dropout), |
| | nn.Conv1d(dim, inner_dim, 1, bias = False), |
| | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
| | ) |
| |
|
| | self.to_k = nn.Sequential( |
| | nn.Dropout(dropout), |
| | nn.Conv1d(dim, inner_dim, 1, bias = False), |
| | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
| | ) |
| |
|
| | self.to_v = nn.Sequential( |
| | nn.Dropout(dropout), |
| | nn.Conv1d(dim, inner_dim, 1, bias = False), |
| | nn.Conv1d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) |
| | ) |
| |
|
| | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None |
| |
|
| | self.to_out = nn.Sequential( |
| | nn.Conv1d(inner_dim, dim, 1, bias = False), |
| | ChanLayerNorm(dim) |
| | ) |
| |
|
| | def forward(self, fmap, context = None): |
| | h, x, y = self.heads, *fmap.shape[-2:] |
| |
|
| | fmap = self.norm(fmap) |
| | q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) |
| | q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) |
| |
|
| | if exists(context): |
| | assert exists(self.to_context) |
| | ck, cv = self.to_context(context).chunk(2, dim = -1) |
| | ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h) |
| | k = torch.cat((k, ck), dim = -2) |
| | v = torch.cat((v, cv), dim = -2) |
| |
|
| | q = q.softmax(dim = -1) |
| | k = k.softmax(dim = -2) |
| |
|
| | q = q * self.scale |
| |
|
| | context = einsum('b n d, b n e -> b d e', k, v) |
| | out = einsum('b n d, b d e -> b n e', q, context) |
| | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) |
| |
|
| | out = self.nonlin(out) |
| | return self.to_out(out) |
| |
|
| | class GlobalContext(nn.Module): |
| | """ basically a superior form of squeeze-excitation that is attention-esque """ |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | dim_in, |
| | dim_out |
| | ): |
| | super().__init__() |
| | self.to_k = nn.Conv1d(dim_in, 1, 1) |
| | hidden_dim = max(3, dim_out // 2) |
| |
|
| | self.net = nn.Sequential( |
| | nn.Conv1d(dim_in, hidden_dim, 1), |
| | nn.SiLU(), |
| | nn.Conv1d(hidden_dim, dim_out, 1), |
| | nn.Sigmoid() |
| | ) |
| |
|
| | def forward(self, x): |
| | context = self.to_k(x) |
| | x, context = rearrange_many((x, context), 'b n ... -> b n (...)') |
| | out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) |
| | |
| | return self.net(out) |
| |
|
| | def FeedForward(dim, mult = 2): |
| | hidden_dim = int(dim * mult) |
| | return nn.Sequential( |
| | LayerNorm(dim), |
| | nn.Linear(dim, hidden_dim, bias = False), |
| | nn.GELU(), |
| | LayerNorm(hidden_dim), |
| | nn.Linear(hidden_dim, dim, bias = False) |
| | ) |
| |
|
| | def ChanFeedForward(dim, mult = 2): |
| | hidden_dim = int(dim * mult) |
| | return nn.Sequential( |
| | ChanLayerNorm(dim), |
| | nn.Conv1d(dim, hidden_dim, 1, bias = False), |
| | nn.GELU(), |
| | ChanLayerNorm(hidden_dim), |
| | nn.Conv1d(hidden_dim, dim, 1, bias = False) |
| | ) |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | depth = 1, |
| | heads = 8, |
| | dim_head = 32, |
| | ff_mult = 2, |
| | context_dim = None, |
| | cosine_sim_attn = False |
| | ): |
| | super().__init__() |
| | self.layers = nn.ModuleList([]) |
| |
|
| | for _ in range(depth): |
| | self.layers.append(nn.ModuleList([ |
| | EinopsToAndFrom('b c h', 'b h c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)), |
| | ChanFeedForward(dim = dim, mult = ff_mult) |
| | ])) |
| |
|
| | def forward(self, x, context = None): |
| | for attn, ff in self.layers: |
| | x = attn(x, context = context) + x |
| | x = ff(x) + x |
| | return x |
| |
|
| | class LinearAttentionTransformerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | depth = 1, |
| | heads = 8, |
| | dim_head = 32, |
| | ff_mult = 2, |
| | context_dim = None, |
| | **kwargs |
| | ): |
| | super().__init__() |
| | self.layers = nn.ModuleList([]) |
| |
|
| | for _ in range(depth): |
| | self.layers.append(nn.ModuleList([ |
| | LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), |
| | ChanFeedForward(dim = dim, mult = ff_mult) |
| | ])) |
| |
|
| | def forward(self, x, context = None): |
| | for attn, ff in self.layers: |
| | x = attn(x, context = context) + x |
| | x = ff(x) + x |
| | return x |
| |
|
| | class CrossEmbedLayer(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in, |
| | kernel_sizes, |
| | dim_out = None, |
| | stride = 2 |
| | ): |
| | super().__init__() |
| | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) |
| | dim_out = default(dim_out, dim_in) |
| |
|
| | kernel_sizes = sorted(kernel_sizes) |
| | num_scales = len(kernel_sizes) |
| |
|
| | |
| | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] |
| | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] |
| |
|
| | self.convs = nn.ModuleList([]) |
| | for kernel, dim_scale in zip(kernel_sizes, dim_scales): |
| | self.convs.append(nn.Conv1d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) |
| |
|
| | def forward(self, x): |
| | fmaps = tuple(map(lambda conv: conv(x), self.convs)) |
| | return torch.cat(fmaps, dim = 1) |
| |
|
| | class UpsampleCombiner(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | *, |
| | enabled = False, |
| | dim_ins = tuple(), |
| | dim_outs = tuple() |
| | ): |
| | super().__init__() |
| | dim_outs = cast_tuple(dim_outs, len(dim_ins)) |
| | assert len(dim_ins) == len(dim_outs) |
| |
|
| | self.enabled = enabled |
| |
|
| | if not self.enabled: |
| | self.dim_out = dim |
| | return |
| |
|
| | self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) |
| | self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) |
| |
|
| | def forward(self, x, fmaps = None): |
| | target_size = x.shape[-1] |
| |
|
| | fmaps = default(fmaps, tuple()) |
| |
|
| | if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: |
| | return x |
| |
|
| | fmaps = [resize_image_to(fmap, target_size) for fmap in fmaps] |
| | outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] |
| | return torch.cat((x, *outs), dim = 1) |
| |
|
| | |
| | |
| | |
| | class OneD_Unet(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | super().__init__() |
| | |
| | |
| |
|
| | self._locals = locals() |
| | self._locals.pop('self', None) |
| | self._locals.pop('__class__', None) |
| | |
| | |
| | |
| | |
| | if CKeys['Debug_ModelPack']==1: |
| | print(json.dumps(PKeys, indent=4)) |
| | |
| | dim = PKeys['dim'] |
| | |
| | text_embed_dim = default(PKeys['text_embed_dim'], 768) |
| | num_resnet_blocks = default(PKeys['num_resnet_blocks'], 1) |
| | cond_dim = default(PKeys['cond_dim'], None) |
| | num_image_tokens = default(PKeys['num_image_tokens'], 4) |
| | num_time_tokens = default(PKeys['num_time_tokens'], 2) |
| | learned_sinu_pos_emb_dim = default(PKeys['learned_sinu_pos_emb_dim'], 16) |
| | out_dim = default(PKeys['out_dim'], None) |
| | dim_mults = default(PKeys['dim_mults'], (1, 2, 4, 8)) |
| | |
| | cond_images_channels = default(PKeys['cond_images_channels'], 0) |
| | channels = default(PKeys['channels'], 3) |
| | channels_out = default(PKeys['channels_out'], None) |
| | |
| | attn_dim_head = default(PKeys['attn_dim_head'], 64) |
| | attn_heads = default(PKeys['attn_heads'], 8) |
| | ff_mult = default(PKeys['ff_mult'], 2.) |
| | lowres_cond = default(PKeys['lowres_cond'], False) |
| | layer_attns = default(PKeys['layer_attns'], True) |
| | layer_attns_depth = default(PKeys['layer_attns_depth'], 1) |
| | layer_attns_add_text_cond = default(PKeys['layer_attns_add_text_cond'], True) |
| | attend_at_middle = default(PKeys['attend_at_middle'], True) |
| | layer_cross_attns = default(PKeys['layer_cross_attns'], True) |
| | use_linear_attn = default(PKeys['use_linear_attn'], False) |
| | use_linear_cross_attn = default(PKeys['use_linear_cross_attn'], False) |
| | |
| | cond_on_text = default(PKeys['cond_on_text'], True) |
| | max_text_len = default(PKeys['max_text_len'], 256) |
| | init_dim = default(PKeys['init_dim'], None) |
| | resnet_groups = default(PKeys['resnet_groups'], 8) |
| | init_conv_kernel_size = default(PKeys['init_conv_kernel_size'], 7) |
| | init_cross_embed = default(PKeys['init_cross_embed'], False) |
| | init_cross_embed_kernel_sizes = default(PKeys['init_cross_embed_kernel_sizes'], (3, 7, 15)) |
| | cross_embed_downsample = default(PKeys['cross_embed_downsample'], False) |
| | cross_embed_downsample_kernel_sizes = default(PKeys['cross_embed_downsample_kernel_sizes'], (2,4)) |
| | |
| | |
| | attn_pool_text = default(PKeys['attn_pool_text'], True) |
| | attn_pool_num_latents = default(PKeys['attn_pool_num_latents'], 32) |
| | dropout = default(PKeys['dropout'], 0.) |
| | memory_efficient = default(PKeys['memory_efficient'], False) |
| | init_conv_to_final_conv_residual = default(PKeys['init_conv_to_final_conv_residual'], False) |
| | |
| | |
| | use_global_context_attn = default(PKeys['use_global_context_attn'], True) |
| | scale_skip_connection = default(PKeys['scale_skip_connection'], True) |
| | final_resnet_block = default(PKeys['final_resnet_block'], True) |
| | final_conv_kernel_size = default(PKeys['final_conv_kernel_size'], 3) |
| | |
| | |
| | cosine_sim_attn = default(PKeys['cosine_sim_attn'], False) |
| | self_cond = default(PKeys['self_cond'], False) |
| | combine_upsample_fmaps = default(PKeys['combine_upsample_fmaps'], False) |
| | pixel_shuffle_upsample = default(PKeys['pixel_shuffle_upsample'], False) |
| | beginning_and_final_conv_present = default(PKeys['beginning_and_final_conv_present'], True) |
| | |
| | |
| | self.CKeys=CKeys |
| | |
| | |
| | if CKeys['Debug_ModelPack']==1: |
| | print("Check the inputs:") |
| | print(json.dumps(PKeys, indent=4)) |
| | |
| | |
| |
|
| | assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' |
| | |
| | if dim < 128: |
| | print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | self.channels = channels |
| | self.channels_out = default(channels_out, channels) |
| |
|
| | |
| | |
| | init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) |
| | init_dim = default(init_dim, dim) |
| |
|
| | self.self_cond = self_cond |
| |
|
| | |
| |
|
| | self.has_cond_image = cond_images_channels > 0 |
| | self.cond_images_channels = cond_images_channels |
| |
|
| | init_channels += cond_images_channels |
| | |
| |
|
| | self.beginning_and_final_conv_present=beginning_and_final_conv_present |
| | |
| | |
| | if self.beginning_and_final_conv_present: |
| | self.init_conv = CrossEmbedLayer( |
| | init_channels, dim_out = init_dim, |
| | kernel_sizes = init_cross_embed_kernel_sizes, |
| | stride = 1 |
| | ) if init_cross_embed else nn.Conv1d( |
| | init_channels, init_dim, |
| | init_conv_kernel_size, |
| | padding = init_conv_kernel_size // 2) |
| | |
| | if self.CKeys['Debug_ModelPack']==1 and self.beginning_and_final_conv_present: |
| | print("On self.init_conv:") |
| | print(f"init_channels: {str(init_channels)}\n init_dim: {str(init_dim)}") |
| | print("On self.init_conv, batch#=1, init_ch=2*#seq_channel, seq_len=128") |
| | summary(self.init_conv, (1, init_channels, 128), verbose=1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] |
| | in_out = list(zip(dims[:-1], dims[1:])) |
| |
|
| | |
| |
|
| | cond_dim = default(cond_dim, dim) |
| | time_cond_dim = dim * 4 * (2 if lowres_cond else 1) |
| |
|
| | |
| |
|
| | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) |
| | sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 |
| | |
| | |
| | self.to_time_hiddens = nn.Sequential( |
| | sinu_pos_emb, |
| | nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), |
| | nn.SiLU() |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.to_time_cond = nn.Sequential( |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | |
| |
|
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print(" to_time_tokens") |
| | print(" in dim: ", time_cond_dim) |
| | print(" ou dim: ", num_time_tokens) |
| | self.to_time_tokens = nn.Sequential( |
| | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
| | Rearrange('b (r d) -> b r d', r = num_time_tokens) |
| | ) |
| |
|
| | |
| |
|
| | self.lowres_cond = lowres_cond |
| |
|
| | if lowres_cond: |
| | self.to_lowres_time_hiddens = nn.Sequential( |
| | LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), |
| | nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), |
| | nn.SiLU() |
| | ) |
| |
|
| | self.to_lowres_time_cond = nn.Sequential( |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | self.to_lowres_time_tokens = nn.Sequential( |
| | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
| | Rearrange('b (r d) -> b r d', r = num_time_tokens) |
| | ) |
| |
|
| | |
| |
|
| | self.norm_cond = nn.LayerNorm(cond_dim) |
| |
|
| | |
| |
|
| | self.text_to_cond = None |
| |
|
| | if cond_on_text: |
| | assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' |
| | if text_embed_dim != cond_dim: |
| | self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) |
| | self.text_cond_linear=True |
| | |
| | else: |
| | print ("Text conditioning is equatl to cond_dim - no linear layer used") |
| | self.text_cond_linear=False |
| | |
| | |
| |
|
| | self.cond_on_text = cond_on_text |
| |
|
| | |
| |
|
| | self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, |
| | dim_head = attn_dim_head, heads = attn_heads, |
| | num_latents = attn_pool_num_latents, |
| | cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None |
| |
|
| | |
| |
|
| | self.max_text_len = max_text_len |
| |
|
| | self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) |
| | self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) |
| |
|
| | |
| |
|
| | self.to_text_non_attn_cond = None |
| |
|
| | if cond_on_text: |
| | self.to_text_non_attn_cond = nn.Sequential( |
| | nn.LayerNorm(cond_dim), |
| | nn.Linear(cond_dim, time_cond_dim), |
| | nn.SiLU(), |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | |
| |
|
| | attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) |
| |
|
| | num_layers = len(in_out) |
| |
|
| | |
| |
|
| | num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) |
| | resnet_groups = cast_tuple(resnet_groups, num_layers) |
| |
|
| | resnet_klass = partial(ResnetBlock, **attn_kwargs) |
| |
|
| | layer_attns = cast_tuple(layer_attns, num_layers) |
| | layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) |
| | layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) |
| |
|
| | use_linear_attn = cast_tuple(use_linear_attn, num_layers) |
| | use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) |
| |
|
| | assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) |
| |
|
| | |
| |
|
| | downsample_klass = Downsample |
| |
|
| | if cross_embed_downsample: |
| | downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) |
| |
|
| | |
| |
|
| | self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None |
| |
|
| | |
| |
|
| | self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) |
| |
|
| | |
| |
|
| | self.downs = nn.ModuleList([]) |
| | self.ups = nn.ModuleList([]) |
| | num_resolutions = len(in_out) |
| |
|
| | layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] |
| | reversed_layer_params = list(map(reversed, layer_params)) |
| |
|
| | |
| |
|
| | skip_connect_dims = [] |
| |
|
| | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): |
| | is_last = ind >= (num_resolutions - 1) |
| |
|
| | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
| |
|
| | if layer_attn: |
| | transformer_block_klass = TransformerBlock |
| | elif layer_use_linear_attn: |
| | transformer_block_klass = LinearAttentionTransformerBlock |
| | else: |
| | transformer_block_klass = Identity |
| |
|
| | current_dim = dim_in |
| |
|
| | |
| |
|
| | pre_downsample = None |
| |
|
| | if memory_efficient: |
| | pre_downsample = downsample_klass(dim_in, dim_out) |
| | current_dim = dim_out |
| |
|
| | skip_connect_dims.append(current_dim) |
| |
|
| | |
| |
|
| | post_downsample = None |
| | if not memory_efficient: |
| | post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) |
| |
|
| | self.downs.append( |
| | nn.ModuleList([ |
| | pre_downsample, |
| | resnet_klass( |
| | current_dim, current_dim, |
| | cond_dim = layer_cond_dim, |
| | linear_attn = layer_use_linear_cross_attn, |
| | time_cond_dim = time_cond_dim, groups = groups |
| | ), |
| | nn.ModuleList([ |
| | ResnetBlock( |
| | current_dim, current_dim, |
| | time_cond_dim = time_cond_dim, |
| | groups = groups, |
| | use_gca = use_global_context_attn |
| | ) for _ in range(layer_num_resnet_blocks)]), |
| | transformer_block_klass( |
| | dim = current_dim, |
| | depth = layer_attn_depth, |
| | ff_mult = ff_mult, |
| | context_dim = cond_dim, |
| | **attn_kwargs), |
| | post_downsample |
| | ]) |
| | ) |
| |
|
| | |
| |
|
| | mid_dim = dims[-1] |
| |
|
| | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
| | self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None |
| | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
| |
|
| | |
| |
|
| | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample |
| |
|
| | |
| |
|
| | upsample_fmap_dims = [] |
| |
|
| | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): |
| | is_last = ind == (len(in_out) - 1) |
| |
|
| | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
| |
|
| | if layer_attn: |
| | transformer_block_klass = TransformerBlock |
| | elif layer_use_linear_attn: |
| | transformer_block_klass = LinearAttentionTransformerBlock |
| | else: |
| | transformer_block_klass = Identity |
| |
|
| | skip_connect_dim = skip_connect_dims.pop() |
| |
|
| | upsample_fmap_dims.append(dim_out) |
| |
|
| | self.ups.append(nn.ModuleList([ |
| | resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), |
| | nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), |
| | transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), |
| | upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() |
| | ])) |
| |
|
| | |
| |
|
| | self.upsample_combiner = UpsampleCombiner( |
| | dim = dim, |
| | enabled = combine_upsample_fmaps, |
| | dim_ins = upsample_fmap_dims, |
| | dim_outs = dim |
| | ) |
| |
|
| | |
| |
|
| | self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual |
| | final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) |
| |
|
| | |
| |
|
| | self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None |
| |
|
| | final_conv_dim_in = dim if final_resnet_block else final_conv_dim |
| | final_conv_dim_in += (channels if lowres_cond else 0) |
| |
|
| | if self.beginning_and_final_conv_present: |
| | print (final_conv_dim_in, self.channels_out) |
| | self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) |
| |
|
| | if self.beginning_and_final_conv_present: |
| | zero_init_(self.final_conv) |
| |
|
| | |
| | |
| | def cast_model_parameters( |
| | self, |
| | *, |
| | lowres_cond, |
| | text_embed_dim, |
| | channels, |
| | channels_out, |
| | cond_on_text |
| | ): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if lowres_cond == self.lowres_cond and \ |
| | channels == self.channels and \ |
| | cond_on_text == self.cond_on_text and \ |
| | text_embed_dim == self._locals['PKeys']['text_embed_dim'] and \ |
| | channels_out == self.channels_out: |
| | return self |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | write_key=dict( |
| | lowres_cond = lowres_cond, |
| | text_embed_dim = text_embed_dim, |
| | channels = channels, |
| | channels_out = channels_out, |
| | cond_on_text = cond_on_text |
| | ) |
| | |
| | old_PKeys=self._locals['PKeys'] |
| | this_PKeys=modify_keys(old_PKeys, write_key) |
| | |
| | updated_kwargs = dict( |
| | CKeys=self.CKeys, |
| | PKeys=this_PKeys, |
| | ) |
| |
|
| | return self.__class__(**{**self._locals, **updated_kwargs}) |
| |
|
| | |
| |
|
| | def to_config_and_state_dict(self): |
| | return self._locals, self.state_dict() |
| |
|
| | |
| |
|
| | @classmethod |
| | def from_config_and_state_dict(klass, config, state_dict): |
| | unet = klass(**config) |
| | unet.load_state_dict(state_dict) |
| | return unet |
| |
|
| | |
| |
|
| | def persist_to_file(self, path): |
| | path = Path(path) |
| | path.parents[0].mkdir(exist_ok = True, parents = True) |
| |
|
| | config, state_dict = self.to_config_and_state_dict() |
| | pkg = dict(config = config, state_dict = state_dict) |
| | torch.save(pkg, str(path)) |
| |
|
| | |
| |
|
| | @classmethod |
| | def hydrate_from_file(klass, path): |
| | path = Path(path) |
| | assert path.exists() |
| | pkg = torch.load(str(path)) |
| |
|
| | assert 'config' in pkg and 'state_dict' in pkg |
| | config, state_dict = pkg['config'], pkg['state_dict'] |
| |
|
| | return Unet.from_config_and_state_dict(config, state_dict) |
| |
|
| | |
| |
|
| | def forward_with_cond_scale( |
| | self, |
| | *args, |
| | cond_scale = 1., |
| | **kwargs |
| | ): |
| | logits = self.forward(*args, **kwargs) |
| |
|
| | if cond_scale == 1: |
| | return logits |
| |
|
| | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) |
| | return null_logits + (logits - null_logits) * cond_scale |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def forward( |
| | self, |
| | x, |
| | time, |
| | *, |
| | lowres_cond_img = None, |
| | lowres_noise_times = None, |
| | text_embeds = None, |
| | text_mask = None, |
| | cond_images = None, |
| | self_cond = None, |
| | cond_drop_prob = 0. |
| | ): |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("========================================") |
| | print("Here are OneD_Unet:forward") |
| | ii=0 |
| | |
| | batch_size, device = x.shape[0], x.device |
| | |
| |
|
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("Check inputs: ") |
| | print(" x-0 .dim: ", x.shape, f"[batch, {str(self.channels)}, seq_len]") |
| | print(" time .dim: ", time.shape, "batch") |
| | if lowres_cond_img==None: |
| | print(" lowres_cond_img: None") |
| | else: |
| | print(" lowres_cond_img.dim: ", lowres_cond_img.shape) |
| | if lowres_noise_times==None: |
| | print(" lowres_noise_times: None") |
| | else: |
| | print(" lowres_noise_times.dim: ", lowres_noise_times.shape) |
| | if cond_images==None: |
| | print(" cond_images dim: None") |
| | else: |
| | |
| | print(" cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, seq_len]") |
| | if text_embeds==None: |
| | print(" text_embeds.dim: None") |
| | else: |
| | |
| | print(" text_embeds.dim: ", text_embeds.shape) |
| | if self_cond==None: |
| | print(" self_cond: None") |
| | else: |
| | print(" self_cond.dim: ", self_cond.shape) |
| | print("\n\n") |
| | |
| | if self.self_cond: |
| | self_cond = default(self_cond, lambda: torch.zeros_like(x)) |
| | x = torch.cat((x, self_cond), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("self_cond dim: ", self_cond.shape) |
| | print("After cat(x, self_cond)-> x. dim: ", x.shape) |
| |
|
| | |
| |
|
| | assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' |
| | assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' |
| |
|
| | if exists(lowres_cond_img): |
| | x = torch.cat((x, lowres_cond_img), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("lowres_cond_img dim: ", lowres_cond_img.shape) |
| | print("After cat(x, lowres_cond_img) dim: ", x.shape) |
| |
|
| | |
| |
|
| | assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' |
| |
|
| | if exists(cond_images): |
| | assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' |
| | cond_images = resize_image_to(cond_images, x.shape[-1]) |
| | |
| | |
| | |
| | x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("cond_images dim: ", cond_images.shape, f"[batch, {str(self.cond_images_channels)}, max_seq_len]") |
| | print("After cat(cond_images, x), x dim: ", x.shape, "[batch, cond_images_channels+images_channels, max_seq_len]") |
| |
|
| | |
| | |
| | if self.beginning_and_final_conv_present: |
| | x = self.init_conv(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("After self.init_conv(x)-> x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") |
| |
|
| | |
| |
|
| | if self.init_conv_to_final_conv_residual: |
| | init_conv_residual = x.clone() |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("x.clone()->init_conv_resi, dim: ", init_conv_residual.shape) |
| |
|
| | |
| |
|
| | time_hiddens = self.to_time_hiddens(time) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("time dim: ", time.shape, "[batch]") |
| | print("self.to_time_hiddens(time)-> time_hiddens .dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") |
| |
|
| | |
| |
|
| | time_tokens = self.to_time_tokens(time_hiddens) |
| | t = self.to_time_cond(time_hiddens) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("self.to_time_tokens(time_hiddens)-> time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") |
| | print("self.to_time_cond(time_hiddens)-> t dim: ", t.shape, "[batch, 4xUNet:dim]") |
| |
|
| | |
| | |
| |
|
| | if self.lowres_cond: |
| | lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) |
| | lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) |
| | lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("self.to_lowres_time_hiddens(lowres_noise_times)-> lowres_time_hiddens .dim: ", lowres_time_hiddens.shape) |
| | print("self.to_lowres_time_tokens(lowres_time_hiddens)-> lowres_time_tokens .dim: ", lowres_time_tokens.shape) |
| | print("self.to_lowres_time_cond(lowres_time_hiddens)-> lowres_t.dim: ", lowres_t.shape) |
| |
|
| | t = t + lowres_t |
| | |
| | time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii +=1 |
| | print(ii) |
| | print("After cat(time_tokens, lowres_time_tokens)-> time_tokens dim: ", time_tokens.shape) |
| |
|
| | |
| |
|
| | text_tokens = None |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("UNet.cond_on_text: ", self.cond_on_text) |
| |
|
| | if exists(text_embeds) and self.cond_on_text: |
| |
|
| | |
| |
|
| | text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) |
| | |
| | text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') |
| | text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') |
| |
|
| | |
| | |
| | if self.text_cond_linear: |
| | text_tokens = self.text_to_cond(text_embeds) |
| | else: |
| | text_tokens=text_embeds |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("On text conditioning part...") |
| | ii +=1 |
| | print(ii) |
| | print("text_embeds->text_tokens.dim: ", text_tokens.shape) |
| |
|
| | text_tokens = text_tokens[:, :self.max_text_len] |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii +=1 |
| | print(ii) |
| | print("text_tokens[:,:max_text_len]-> text_tokens.dim: ", text_tokens.shape) |
| | print("do the same for text_mask") |
| | |
| | if exists(text_mask): |
| | text_mask = text_mask[:, :self.max_text_len] |
| |
|
| | text_tokens_len = text_tokens.shape[1] |
| | remainder = self.max_text_len - text_tokens_len |
| | |
| | if remainder > 0: |
| | |
| | text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) |
| |
|
| | if exists(text_mask): |
| | if remainder > 0: |
| | text_mask = F.pad(text_mask, (0, remainder), value = False) |
| |
|
| | |
| | text_mask = rearrange(text_mask, 'b n -> b n 1') |
| | text_keep_mask_embed = text_mask & text_keep_mask_embed |
| | |
| | null_text_embed = self.null_text_embed.to(text_tokens.dtype) |
| | text_tokens = torch.where( |
| | text_keep_mask_embed, |
| | text_tokens, |
| | null_text_embed |
| | ) |
| | |
| | if exists(self.attn_pool): |
| | text_tokens = self.attn_pool(text_tokens) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii+=1 |
| | print(ii) |
| | print("self.attn_pool(text_tokens)->text_tokens.dim: ", text_tokens.shape) |
| |
|
| | |
| | |
| | |
| | mean_pooled_text_tokens = text_tokens.mean(dim = -2) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii +=1 |
| | print(ii) |
| | print("text_tokens.mean(dim=-2)->mean_pooled_text_tokens.dim: ", mean_pooled_text_tokens.shape) |
| |
|
| | text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii +=1 |
| | print(ii) |
| | print("self.to_text_non_attn_cond(mean_pooled_text_tokens)->text_hiddens.dim: ",text_hiddens.shape) |
| |
|
| | null_text_hidden = self.null_text_hidden.to(t.dtype) |
| |
|
| | text_hiddens = torch.where( |
| | text_keep_mask_hidden, |
| | text_hiddens, |
| | null_text_hidden |
| | ) |
| |
|
| | t = t + text_hiddens |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii +=1 |
| | print(ii) |
| | print("t+text_hiddens.dim: ", t.shape) |
| | |
| |
|
| | |
| | |
| | c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("cat(time_tokens, text_tokens)-> c dim: ", c.shape) |
| | |
| | |
| |
|
| | c = self.norm_cond(c) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("self.norm_cond(c)->c dim:", c.shape) |
| | |
| | |
| |
|
| | if exists(self.init_resnet_block): |
| | x = self.init_resnet_block(x, t) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(ii) |
| | print("self.init_resnet_block(x,t)-> x dim: ", x.shape) |
| | |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("Before unet, down and up, ") |
| | print(" x dim: ", x.shape) |
| | print(" t dim: ", t.shape) |
| | print(" c dim: ", c.shape) |
| | |
| |
|
| | hiddens = [] |
| |
|
| | for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: |
| | if exists(pre_downsample): |
| | x = pre_downsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, pre_downsample(x)=>x dim: ", x.shape) |
| | |
| | x = init_block(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, init_block(x,t,c)=>x dim: ", x.shape) |
| | |
| | for resnet_block in resnet_blocks: |
| | x = resnet_block(x, t) |
| | hiddens.append(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) |
| | |
| | |
| | x = attn_block(x, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) |
| | |
| | hiddens.append(x) |
| |
|
| | if exists(post_downsample): |
| | |
| | x = post_downsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, post_downsample(x)=> x dim: ", x.shape) |
| |
|
| | x = self.mid_block1(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, mid_block_1(x,t,c)=> x dim: ", x.shape) |
| | |
| | if exists(self.mid_attn): |
| | x = self.mid_attn(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, mid_attn(x)=> x dim: ", x.shape) |
| |
|
| | x = self.mid_block2(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, mid_block_2(x,t,c)=> x dim: ", x.shape) |
| | |
| | add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) |
| |
|
| | up_hiddens = [] |
| | |
| | for init_block, resnet_blocks, attn_block, upsample in self.ups: |
| | x = add_skip_connection(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) |
| | |
| | x = init_block(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, init_block(x,t,c)=> x dim: ", x.shape) |
| | |
| | for resnet_block in resnet_blocks: |
| | x = add_skip_connection(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, add_skip_connection(x)=> x dim: ", x.shape) |
| | x = resnet_block(x, t) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, resnet_block(x,t)=> x dim: ", x.shape) |
| |
|
| | x = attn_block(x, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, attn_block(x,c)=> x dim: ", x.shape) |
| | |
| | up_hiddens.append(x.contiguous()) |
| | x = upsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, upsample(x)=> x dim: ", x.shape) |
| |
|
| | |
| |
|
| | x = self.upsample_combiner(x, up_hiddens) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, upsample_combiner(x,up_hiddens)=> x dim: ", x.shape) |
| |
|
| | |
| |
|
| | if self.init_conv_to_final_conv_residual: |
| | x = torch.cat((x, init_conv_residual), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, cat(x,init_conv_residual)=> x dim: ", x.shape) |
| |
|
| | if exists(self.final_res_block): |
| | x = self.final_res_block(x, t) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, final_res_block(x,t)=> x dim: ", x.shape) |
| |
|
| | if exists(lowres_cond_img): |
| | x = torch.cat((x, lowres_cond_img), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, cat(x,lowres_cond_img)=> x dim: ", x.shape) |
| | |
| | if self.beginning_and_final_conv_present: |
| | x=self.final_conv(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, final_conv(x)=> x dim: ", x.shape) |
| | |
| | return x |
| |
|
| | |
| | |
| | |
| | class OneD_Unet_Old(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | |
| | text_embed_dim = 768, |
| | num_resnet_blocks = 1, |
| | cond_dim = None, |
| | num_image_tokens = 4, |
| | num_time_tokens = 2, |
| | learned_sinu_pos_emb_dim = 16, |
| | out_dim = None, |
| | dim_mults=(1, 2, 4, 8), |
| | cond_images_channels = 0, |
| | channels = 3, |
| | channels_out = None, |
| | attn_dim_head = 64, |
| | attn_heads = 8, |
| | ff_mult = 2., |
| | lowres_cond = False, |
| | layer_attns = True, |
| | layer_attns_depth = 1, |
| | layer_attns_add_text_cond = True, |
| | attend_at_middle = True, |
| | layer_cross_attns = True, |
| | use_linear_attn = False, |
| | use_linear_cross_attn = False, |
| | cond_on_text = True, |
| | max_text_len = 256, |
| | init_dim = None, |
| | resnet_groups = 8, |
| | init_conv_kernel_size = 7, |
| | init_cross_embed = False, |
| | init_cross_embed_kernel_sizes = (3, 7, 15), |
| | cross_embed_downsample = False, |
| | cross_embed_downsample_kernel_sizes = (2, 4), |
| | attn_pool_text = True, |
| | attn_pool_num_latents = 32, |
| | dropout = 0., |
| | memory_efficient = False, |
| | init_conv_to_final_conv_residual = False, |
| | use_global_context_attn = True, |
| | scale_skip_connection = True, |
| | final_resnet_block = True, |
| | final_conv_kernel_size = 3, |
| | cosine_sim_attn = False, |
| | self_cond = False, |
| | combine_upsample_fmaps = False, |
| | pixel_shuffle_upsample = False , |
| | beginning_and_final_conv_present = True , |
| | |
| | CKeys=None, |
| | |
| | ): |
| | super().__init__() |
| | |
| | |
| | self.CKeys=CKeys |
| |
|
| | assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' |
| | |
| | if dim < 128: |
| | print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') |
| |
|
| |
|
| | |
| |
|
| | self._locals = locals() |
| | self._locals.pop('self', None) |
| | self._locals.pop('__class__', None) |
| | |
| | if CKeys['Debug_ModelPack']==1: |
| | print("Showing the input:") |
| | print(json.dumps(self._locals, indent=4)) |
| |
|
| | |
| |
|
| | self.channels = channels |
| | self.channels_out = default(channels_out, channels) |
| |
|
| | |
| | |
| | init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) |
| | init_dim = default(init_dim, dim) |
| |
|
| | self.self_cond = self_cond |
| |
|
| | |
| |
|
| | self.has_cond_image = cond_images_channels > 0 |
| | self.cond_images_channels = cond_images_channels |
| |
|
| | init_channels += cond_images_channels |
| | |
| |
|
| | self.beginning_and_final_conv_present=beginning_and_final_conv_present |
| | |
| | |
| | if self.beginning_and_final_conv_present: |
| | self.init_conv = CrossEmbedLayer( |
| | init_channels, dim_out = init_dim, |
| | kernel_sizes = init_cross_embed_kernel_sizes, |
| | stride = 1 |
| | ) if init_cross_embed else nn.Conv1d( |
| | init_channels, init_dim, |
| | init_conv_kernel_size, |
| | padding = init_conv_kernel_size // 2) |
| | |
| | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] |
| | in_out = list(zip(dims[:-1], dims[1:])) |
| |
|
| | |
| |
|
| | cond_dim = default(cond_dim, dim) |
| | time_cond_dim = dim * 4 * (2 if lowres_cond else 1) |
| |
|
| | |
| |
|
| | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) |
| | sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print(" to_time_hiddens") |
| | print(" ou dim: ", time_cond_dim) |
| | self.to_time_hiddens = nn.Sequential( |
| | sinu_pos_emb, |
| | nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), |
| | nn.SiLU() |
| | ) |
| |
|
| | self.to_time_cond = nn.Sequential( |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | |
| |
|
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print(" to_time_tokens") |
| | print(" in dim: ", time_cond_dim) |
| | print(" ou dim: ", num_time_tokens) |
| | self.to_time_tokens = nn.Sequential( |
| | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
| | Rearrange('b (r d) -> b r d', r = num_time_tokens) |
| | ) |
| |
|
| | |
| |
|
| | self.lowres_cond = lowres_cond |
| |
|
| | if lowres_cond: |
| | self.to_lowres_time_hiddens = nn.Sequential( |
| | LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), |
| | nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), |
| | nn.SiLU() |
| | ) |
| |
|
| | self.to_lowres_time_cond = nn.Sequential( |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | self.to_lowres_time_tokens = nn.Sequential( |
| | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), |
| | Rearrange('b (r d) -> b r d', r = num_time_tokens) |
| | ) |
| |
|
| | |
| |
|
| | self.norm_cond = nn.LayerNorm(cond_dim) |
| |
|
| | |
| |
|
| | self.text_to_cond = None |
| |
|
| | if cond_on_text: |
| | assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' |
| | if text_embed_dim != cond_dim: |
| | self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) |
| | self.text_cond_linear=True |
| | |
| | else: |
| | print ("Text conditioning is equatl to cond_dim - no linear layer used") |
| | self.text_cond_linear=False |
| | |
| | |
| |
|
| | self.cond_on_text = cond_on_text |
| |
|
| | |
| |
|
| | self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, |
| | dim_head = attn_dim_head, heads = attn_heads, |
| | num_latents = attn_pool_num_latents, |
| | cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None |
| |
|
| | |
| |
|
| | self.max_text_len = max_text_len |
| |
|
| | self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) |
| | self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) |
| |
|
| | |
| |
|
| | self.to_text_non_attn_cond = None |
| |
|
| | if cond_on_text: |
| | self.to_text_non_attn_cond = nn.Sequential( |
| | nn.LayerNorm(cond_dim), |
| | nn.Linear(cond_dim, time_cond_dim), |
| | nn.SiLU(), |
| | nn.Linear(time_cond_dim, time_cond_dim) |
| | ) |
| |
|
| | |
| |
|
| | attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) |
| |
|
| | num_layers = len(in_out) |
| |
|
| | |
| |
|
| | num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) |
| | resnet_groups = cast_tuple(resnet_groups, num_layers) |
| |
|
| | resnet_klass = partial(ResnetBlock, **attn_kwargs) |
| |
|
| | layer_attns = cast_tuple(layer_attns, num_layers) |
| | layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) |
| | layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) |
| |
|
| | use_linear_attn = cast_tuple(use_linear_attn, num_layers) |
| | use_linear_cross_attn = cast_tuple(use_linear_cross_attn, num_layers) |
| |
|
| | assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) |
| |
|
| | |
| |
|
| | downsample_klass = Downsample |
| |
|
| | if cross_embed_downsample: |
| | downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) |
| |
|
| | |
| |
|
| | self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None |
| |
|
| | |
| |
|
| | self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) |
| |
|
| | |
| |
|
| | self.downs = nn.ModuleList([]) |
| | self.ups = nn.ModuleList([]) |
| | num_resolutions = len(in_out) |
| |
|
| | layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns, use_linear_attn, use_linear_cross_attn] |
| | reversed_layer_params = list(map(reversed, layer_params)) |
| |
|
| | |
| |
|
| | skip_connect_dims = [] |
| |
|
| | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(in_out, *layer_params)): |
| | is_last = ind >= (num_resolutions - 1) |
| |
|
| | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
| |
|
| | if layer_attn: |
| | transformer_block_klass = TransformerBlock |
| | elif layer_use_linear_attn: |
| | transformer_block_klass = LinearAttentionTransformerBlock |
| | else: |
| | transformer_block_klass = Identity |
| |
|
| | current_dim = dim_in |
| |
|
| | |
| |
|
| | pre_downsample = None |
| |
|
| | if memory_efficient: |
| | pre_downsample = downsample_klass(dim_in, dim_out) |
| | current_dim = dim_out |
| |
|
| | skip_connect_dims.append(current_dim) |
| |
|
| | |
| |
|
| | post_downsample = None |
| | if not memory_efficient: |
| | post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(nn.Conv1d(dim_in, dim_out, 3, padding = 1), nn.Conv1d(dim_in, dim_out, 1)) |
| |
|
| | self.downs.append( |
| | nn.ModuleList([ |
| | pre_downsample, |
| | resnet_klass( |
| | current_dim, current_dim, |
| | cond_dim = layer_cond_dim, |
| | linear_attn = layer_use_linear_cross_attn, |
| | time_cond_dim = time_cond_dim, groups = groups |
| | ), |
| | nn.ModuleList([ |
| | ResnetBlock( |
| | current_dim, current_dim, |
| | time_cond_dim = time_cond_dim, |
| | groups = groups, |
| | use_gca = use_global_context_attn |
| | ) for _ in range(layer_num_resnet_blocks)]), |
| | transformer_block_klass( |
| | dim = current_dim, |
| | depth = layer_attn_depth, |
| | ff_mult = ff_mult, |
| | context_dim = cond_dim, |
| | **attn_kwargs), |
| | post_downsample |
| | ]) |
| | ) |
| |
|
| | |
| |
|
| | mid_dim = dims[-1] |
| |
|
| | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
| | self.mid_attn = EinopsToAndFrom('b c h', 'b h c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None |
| | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) |
| |
|
| | |
| |
|
| | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample |
| |
|
| | |
| |
|
| | upsample_fmap_dims = [] |
| |
|
| | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn, layer_use_linear_attn, layer_use_linear_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): |
| | is_last = ind == (len(in_out) - 1) |
| |
|
| | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None |
| |
|
| | if layer_attn: |
| | transformer_block_klass = TransformerBlock |
| | elif layer_use_linear_attn: |
| | transformer_block_klass = LinearAttentionTransformerBlock |
| | else: |
| | transformer_block_klass = Identity |
| |
|
| | skip_connect_dim = skip_connect_dims.pop() |
| |
|
| | upsample_fmap_dims.append(dim_out) |
| |
|
| | self.ups.append(nn.ModuleList([ |
| | resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), |
| | nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), |
| | transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), |
| | upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() |
| | ])) |
| |
|
| | |
| |
|
| | self.upsample_combiner = UpsampleCombiner( |
| | dim = dim, |
| | enabled = combine_upsample_fmaps, |
| | dim_ins = upsample_fmap_dims, |
| | dim_outs = dim |
| | ) |
| |
|
| | |
| |
|
| | self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual |
| | final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) |
| |
|
| | |
| |
|
| | self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None |
| |
|
| | final_conv_dim_in = dim if final_resnet_block else final_conv_dim |
| | final_conv_dim_in += (channels if lowres_cond else 0) |
| |
|
| | if self.beginning_and_final_conv_present: |
| | print (final_conv_dim_in, self.channels_out) |
| | self.final_conv = nn.Conv1d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) |
| |
|
| | if self.beginning_and_final_conv_present: |
| | zero_init_(self.final_conv) |
| |
|
| | |
| | |
| | def cast_model_parameters( |
| | self, |
| | *, |
| | lowres_cond, |
| | text_embed_dim, |
| | channels, |
| | channels_out, |
| | cond_on_text |
| | ): |
| | if lowres_cond == self.lowres_cond and \ |
| | channels == self.channels and \ |
| | cond_on_text == self.cond_on_text and \ |
| | text_embed_dim == self._locals['text_embed_dim'] and \ |
| | channels_out == self.channels_out: |
| | return self |
| |
|
| | updated_kwargs = dict( |
| | lowres_cond = lowres_cond, |
| | text_embed_dim = text_embed_dim, |
| | channels = channels, |
| | channels_out = channels_out, |
| | cond_on_text = cond_on_text |
| | ) |
| |
|
| | return self.__class__(**{**self._locals, **updated_kwargs}) |
| |
|
| | |
| |
|
| | def to_config_and_state_dict(self): |
| | return self._locals, self.state_dict() |
| |
|
| | |
| |
|
| | @classmethod |
| | def from_config_and_state_dict(klass, config, state_dict): |
| | unet = klass(**config) |
| | unet.load_state_dict(state_dict) |
| | return unet |
| |
|
| | |
| |
|
| | def persist_to_file(self, path): |
| | path = Path(path) |
| | path.parents[0].mkdir(exist_ok = True, parents = True) |
| |
|
| | config, state_dict = self.to_config_and_state_dict() |
| | pkg = dict(config = config, state_dict = state_dict) |
| | torch.save(pkg, str(path)) |
| |
|
| | |
| |
|
| | @classmethod |
| | def hydrate_from_file(klass, path): |
| | path = Path(path) |
| | assert path.exists() |
| | pkg = torch.load(str(path)) |
| |
|
| | assert 'config' in pkg and 'state_dict' in pkg |
| | config, state_dict = pkg['config'], pkg['state_dict'] |
| |
|
| | return Unet.from_config_and_state_dict(config, state_dict) |
| |
|
| | |
| |
|
| | def forward_with_cond_scale( |
| | self, |
| | *args, |
| | cond_scale = 1., |
| | **kwargs |
| | ): |
| | logits = self.forward(*args, **kwargs) |
| |
|
| | if cond_scale == 1: |
| | return logits |
| |
|
| | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) |
| | return null_logits + (logits - null_logits) * cond_scale |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def forward( |
| | self, |
| | x, |
| | time, |
| | *, |
| | lowres_cond_img = None, |
| | lowres_noise_times = None, |
| | text_embeds = None, |
| | text_mask = None, |
| | cond_images = None, |
| | self_cond = None, |
| | cond_drop_prob = 0. |
| | ): |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("========================================") |
| | print("Here are OneD_Unet:forward") |
| | |
| | batch_size, device = x.shape[0], x.device |
| | |
| |
|
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("Check inputs: ") |
| | print(" x-0 dim: ", x.shape, "[batch, 1, seq_len]") |
| | print(" time dim: ", time.shape, "") |
| | print(" cond_images dim: ", cond_images.shape, "") |
| |
|
| | |
| | if self.self_cond: |
| | self_cond = default(self_cond, lambda: torch.zeros_like(x)) |
| | x = torch.cat((x, self_cond), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After self_cond x dim: ", x.shape) |
| |
|
| | |
| |
|
| | assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' |
| | assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' |
| |
|
| | if exists(lowres_cond_img): |
| | x = torch.cat((x, lowres_cond_img), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After lowres_cond_img x dim: ", x.shape) |
| |
|
| | |
| |
|
| | assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' |
| |
|
| | if exists(cond_images): |
| | assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' |
| | cond_images = resize_image_to(cond_images, x.shape[-1]) |
| | |
| | |
| | |
| | x = torch.cat((cond_images.to(device), x.to(device)), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("cond_images dim: ", cond_images.shape, "[batch, 1, max_seq_len]") |
| | print("After cond_images, x dim: ", x.shape, "[batch, 2, max_seq_len]") |
| |
|
| | |
| | |
| | if self.beginning_and_final_conv_present: |
| | x = self.init_conv(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After init_conv, x dim: ", x.shape, "[batch, UNet:dim, max_seq_len]") |
| |
|
| | |
| |
|
| | if self.init_conv_to_final_conv_residual: |
| | init_conv_residual = x.clone() |
| |
|
| | |
| |
|
| | time_hiddens = self.to_time_hiddens(time) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("time dim: ", time.shape, "[batch]") |
| | print("after, time_hiddens dim: ", time_hiddens.shape, "[batch, 4xUNet:dim]") |
| |
|
| | |
| |
|
| | time_tokens = self.to_time_tokens(time_hiddens) |
| | t = self.to_time_cond(time_hiddens) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("time_tokens dim: ", time_tokens.shape, "[batch, num_time_tokens,4xdim/num_time_tokens]") |
| | print("after to_time_cond t dim: ", t.shape, "[batch, 4xUNet:dim]") |
| |
|
| | |
| | |
| |
|
| | if self.lowres_cond: |
| | lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) |
| | lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) |
| | lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) |
| |
|
| | t = t + lowres_t |
| | |
| | time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After lowres_cond, time_tokens dim: ", time_tokens.shape) |
| |
|
| | |
| |
|
| | text_tokens = None |
| |
|
| | if exists(text_embeds) and self.cond_on_text: |
| |
|
| | |
| |
|
| | text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) |
| | |
| | text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') |
| | text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') |
| |
|
| | |
| | |
| | if self.text_cond_linear: |
| | text_tokens = self.text_to_cond(text_embeds) |
| | else: |
| | text_tokens=text_embeds |
| |
|
| | text_tokens = text_tokens[:, :self.max_text_len] |
| | |
| | if exists(text_mask): |
| | text_mask = text_mask[:, :self.max_text_len] |
| |
|
| | text_tokens_len = text_tokens.shape[1] |
| | remainder = self.max_text_len - text_tokens_len |
| | |
| | if remainder > 0: |
| | |
| | text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) |
| |
|
| | if exists(text_mask): |
| | if remainder > 0: |
| | text_mask = F.pad(text_mask, (0, remainder), value = False) |
| |
|
| | |
| | text_mask = rearrange(text_mask, 'b n -> b n 1') |
| | text_keep_mask_embed = text_mask & text_keep_mask_embed |
| | |
| | null_text_embed = self.null_text_embed.to(text_tokens.dtype) |
| | text_tokens = torch.where( |
| | text_keep_mask_embed, |
| | text_tokens, |
| | null_text_embed |
| | ) |
| | |
| | if exists(self.attn_pool): |
| | text_tokens = self.attn_pool(text_tokens) |
| |
|
| | |
| | |
| | |
| | mean_pooled_text_tokens = text_tokens.mean(dim = -2) |
| |
|
| | text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) |
| |
|
| | null_text_hidden = self.null_text_hidden.to(t.dtype) |
| |
|
| | text_hiddens = torch.where( |
| | text_keep_mask_hidden, |
| | text_hiddens, |
| | null_text_hidden |
| | ) |
| |
|
| | t = t + text_hiddens |
| | |
| |
|
| | |
| | |
| | c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("Merge time and text tokens, c dim: ", c.shape) |
| | |
| | |
| |
|
| | c = self.norm_cond(c) |
| | |
| | |
| |
|
| | if exists(self.init_resnet_block): |
| | x = self.init_resnet_block(x, t) |
| | |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("Before unet, down and up, ") |
| | print("x dim: ", x.shape) |
| | print("t dim: ", t.shape) |
| | print("c dim: ", c.shape) |
| | ii=0 |
| |
|
| | hiddens = [] |
| |
|
| | for pre_downsample, init_block, resnet_blocks, attn_block, post_downsample in self.downs: |
| | if exists(pre_downsample): |
| | x = pre_downsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after pre_downsample x dim: ", x.shape) |
| | |
| | x = init_block(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) |
| | |
| | for resnet_block in resnet_blocks: |
| | x = resnet_block(x, t) |
| | hiddens.append(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after resnet_block x dim: ", x.shape) |
| | |
| | |
| | x = attn_block(x, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after attn_block x dim: ", x.shape) |
| | |
| | hiddens.append(x) |
| |
|
| | if exists(post_downsample): |
| | |
| | x = post_downsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after post_downsample x dim: ", x.shape) |
| |
|
| | x = self.mid_block1(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after mid_block_1 x dim: ", x.shape) |
| | |
| | if exists(self.mid_attn): |
| | x = self.mid_attn(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after mid_attn x dim: ", x.shape) |
| |
|
| | x = self.mid_block2(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after mid_block_2 x dim: ", x.shape) |
| | |
| | add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) |
| |
|
| | up_hiddens = [] |
| | |
| | for init_block, resnet_blocks, attn_block, upsample in self.ups: |
| | x = add_skip_connection(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) |
| | |
| | x = init_block(x, t, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after init_block(x,t,c) x dim: ", x.shape) |
| | |
| | for resnet_block in resnet_blocks: |
| | x = add_skip_connection(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after add_skip_connection x dim: ", x.shape) |
| | x = resnet_block(x, t) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after resnet_block(x,t) x dim: ", x.shape) |
| |
|
| | x = attn_block(x, c) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after attn_block(x,c) x dim: ", x.shape) |
| | up_hiddens.append(x.contiguous()) |
| | x = upsample(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after upsample(x) x dim: ", x.shape) |
| |
|
| | |
| |
|
| | x = self.upsample_combiner(x, up_hiddens) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after upsample_combiner(x,..) x dim: ", x.shape) |
| |
|
| | |
| |
|
| | if self.init_conv_to_final_conv_residual: |
| | x = torch.cat((x, init_conv_residual), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after cat_init_conv_resi x dim: ", x.shape) |
| |
|
| | if exists(self.final_res_block): |
| | x = self.final_res_block(x, t) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after final_res_block(x,t) x dim: ", x.shape) |
| |
|
| | if exists(lowres_cond_img): |
| | x = torch.cat((x, lowres_cond_img), dim = 1) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after cat_x_lowres_cond_img x dim: ", x.shape) |
| | |
| | if self.beginning_and_final_conv_present: |
| | x=self.final_conv(x) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | ii += 1 |
| | print(F" {str(ii)}, after final_conv(x) x dim: ", x.shape) |
| | |
| | return x |
| | |
| | |
| | |
| | |
| |
|
| | class NullUnet(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__() |
| | self.lowres_cond = False |
| | self.dummy_paramcast_model_parameterseter = nn.Parameter(torch.tensor([0.])) |
| |
|
| | def cast_model_parameters(self, *args, **kwargs): |
| | return self |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| | |
| | class Unet(nn.Module): |
| | def __init__(self, *args, **kwargs): |
| | super().__init__() |
| | self.lowres_cond = False |
| | self.dummy_parameter = nn.Parameter(torch.tensor([0.])) |
| |
|
| | def cast_model_parameters(self, *args, **kwargs): |
| | return self |
| |
|
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from math import sqrt |
| |
|
| | Hparams_fields = [ |
| | 'num_sample_steps', |
| | 'sigma_min', |
| | 'sigma_max', |
| | 'sigma_data', |
| | 'rho', |
| | 'P_mean', |
| | 'P_std', |
| | 'S_churn', |
| | 'S_tmin', |
| | 'S_tmax', |
| | 'S_noise' |
| | ] |
| |
|
| | Hparams = namedtuple('Hparams', Hparams_fields) |
| |
|
| | |
| |
|
| | def log(t, eps = 1e-20): |
| | return torch.log(t.clamp(min = eps)) |
| |
|
| | |
| |
|
| | class ElucidatedImagen(nn.Module): |
| | def __init__( |
| | self, |
| | unets, |
| | *, |
| | image_sizes, |
| | text_encoder_name = '', |
| | text_embed_dim = None, |
| | channels = 3, |
| | channels_out=3, |
| | cond_drop_prob = 0.1, |
| | random_crop_sizes = None, |
| | lowres_sample_noise_level = 0.2, |
| | per_sample_random_aug_noise_level = False, |
| | condition_on_text = True, |
| | auto_normalize_img = True, |
| | dynamic_thresholding = True, |
| | dynamic_thresholding_percentile = 0.95, |
| | only_train_unet_number = None, |
| | lowres_noise_schedule = 'linear', |
| | num_sample_steps = 32, |
| | sigma_min = 0.002, |
| | sigma_max = 80, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 80, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | loss_type=0, |
| | categorical_loss_ignore=None, |
| | |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | super().__init__() |
| | |
| | |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| |
|
| | self.only_train_unet_number = only_train_unet_number |
| |
|
| | self.condition_on_text = condition_on_text |
| | self.unconditional = not condition_on_text |
| | self.loss_type=loss_type |
| | if self.loss_type>0: |
| | self.categorical_loss=True |
| | self.m = nn.LogSoftmax(dim=1) |
| | else: |
| | self.categorical_loss=False |
| | |
| | print("Loss type: ", self.loss_type) |
| | self.categorical_loss_ignore=categorical_loss_ignore |
| | |
| | |
| |
|
| | self.channels = channels |
| | self.channels_out = channels_out |
| |
|
| | unets = cast_tuple(unets) |
| | num_unets = len(unets) |
| |
|
| | |
| |
|
| | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) |
| | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' |
| |
|
| | |
| |
|
| | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) |
| |
|
| | |
| |
|
| | self.text_embed_dim =text_embed_dim |
| | |
| | |
| |
|
| | self.unets = nn.ModuleList([]) |
| | self.unet_being_trained_index = -1 |
| |
|
| | print (f"Channels in={self.channels}, channels out={self.channels_out}") |
| | for ind, one_unet in enumerate(unets): |
| | |
| | |
| | |
| | |
| | |
| | is_first = ind == 0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | print("Test on cast_model_parameters...") |
| | print(not is_first) |
| | print(self.condition_on_text) |
| | print(self.text_embed_dim if self.condition_on_text else None) |
| | print(self.channels) |
| | print(self.channels_out) |
| | one_unet = one_unet.cast_model_parameters( |
| | lowres_cond = not is_first, |
| | cond_on_text = self.condition_on_text, |
| | text_embed_dim = self.text_embed_dim if self.condition_on_text else None, |
| | channels = self.channels, |
| | |
| | channels_out = self.channels_out |
| | ) |
| |
|
| | self.unets.append(one_unet) |
| |
|
| | |
| |
|
| | is_video = False |
| | self.is_video = is_video |
| |
|
| | self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1' if not is_video else 'b -> b 1 1 1')) |
| | self.resize_to = resize_video_to if is_video else resize_image_to |
| |
|
| | |
| | self.image_sizes = image_sizes |
| | assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' |
| |
|
| | self.sample_channels = cast_tuple(self.channels, num_unets) |
| |
|
| | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) |
| | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' |
| |
|
| | self.lowres_sample_noise_level = lowres_sample_noise_level |
| | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level |
| |
|
| | |
| |
|
| | self.cond_drop_prob = cond_drop_prob |
| | self.can_classifier_guidance = cond_drop_prob > 0. |
| |
|
| | |
| |
|
| | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity |
| | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity |
| | self.input_image_range = (0. if auto_normalize_img else -1., 1.) |
| |
|
| | |
| |
|
| | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) |
| | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | hparams = [ |
| | num_sample_steps, |
| | sigma_min, |
| | sigma_max, |
| | sigma_data, |
| | rho, |
| | P_mean, |
| | P_std, |
| | S_churn, |
| | S_tmin, |
| | S_tmax, |
| | S_noise, |
| | ] |
| |
|
| | hparams = [cast_tuple(hp, num_unets) for hp in hparams] |
| | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] |
| |
|
| | |
| |
|
| | |
| | self.register_buffer('_temp', torch.tensor([0.]), persistent = False) |
| |
|
| | |
| |
|
| | self.to(next(self.unets.parameters()).device) |
| | |
| | print(next(self.unets.parameters()).device) |
| | |
| | |
| |
|
| | def force_unconditional_(self): |
| | self.condition_on_text = False |
| | self.unconditional = True |
| |
|
| | for unet in self.unets: |
| | unet.cond_on_text = False |
| |
|
| | @property |
| | def device(self): |
| | return self._temp.device |
| | |
| |
|
| | def get_unet(self, unet_number): |
| | assert 0 < unet_number <= len(self.unets) |
| | index = unet_number - 1 |
| |
|
| | if isinstance(self.unets, nn.ModuleList): |
| | unets_list = [unet for unet in self.unets] |
| | delattr(self, 'unets') |
| | self.unets = unets_list |
| |
|
| | if index != self.unet_being_trained_index: |
| | for unet_index, unet in enumerate(self.unets): |
| | unet.to(self.device if unet_index == index else 'cpu') |
| |
|
| | self.unet_being_trained_index = index |
| | return self.unets[index] |
| |
|
| | def reset_unets_all_one_device(self, device = None): |
| | device = default(device, self.device) |
| | |
| | self.unets = nn.ModuleList([*self.unets]) |
| | self.unets.to(device) |
| |
|
| | self.unet_being_trained_index = -1 |
| |
|
| | @contextmanager |
| | def one_unet_in_gpu(self, unet_number = None, unet = None): |
| | assert exists(unet_number) ^ exists(unet) |
| |
|
| | if exists(unet_number): |
| | unet = self.unets[unet_number - 1] |
| |
|
| | devices = [module_device(unet) for unet in self.unets] |
| | self.unets.cpu() |
| | unet.to(self.device) |
| | |
| | yield |
| |
|
| | for unet, device in zip(self.unets, devices): |
| | unet.to(device) |
| |
|
| | |
| |
|
| | def state_dict(self, *args, **kwargs): |
| | self.reset_unets_all_one_device() |
| | return super().state_dict(*args, **kwargs) |
| |
|
| | def load_state_dict(self, *args, **kwargs): |
| | self.reset_unets_all_one_device() |
| | return super().load_state_dict(*args, **kwargs) |
| |
|
| | |
| |
|
| | def threshold_x_start(self, x_start, dynamic_threshold = True): |
| | if not dynamic_threshold: |
| | return x_start.clamp(-1., 1.) |
| |
|
| | s = torch.quantile( |
| | rearrange(x_start, 'b ... -> b (...)').abs(), |
| | self.dynamic_thresholding_percentile, |
| | dim = -1 |
| | ) |
| |
|
| | s.clamp_(min = 1.) |
| | s = right_pad_dims_to(x_start, s) |
| | return x_start.clamp(-s, s) / s |
| |
|
| | |
| |
|
| | def c_skip(self, sigma_data, sigma): |
| | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) |
| |
|
| | def c_out(self, sigma_data, sigma): |
| | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 |
| |
|
| | def c_in(self, sigma_data, sigma): |
| | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 |
| |
|
| | def c_noise(self, sigma): |
| | return log(sigma) * 0.25 |
| |
|
| | |
| | |
| |
|
| | def preconditioned_network_forward( |
| | self, |
| | unet_forward, |
| | noised_images, |
| | sigma, |
| | *, |
| | sigma_data, |
| | clamp = False, |
| | dynamic_threshold = True, |
| | **kwargs |
| | ): |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("========================================") |
| | print("ElucidatedImagen: preconditioned_network_forward") |
| | |
| | batch, device = noised_images.shape[0], noised_images.device |
| | |
| | if isinstance(sigma, float): |
| | sigma = torch.full((batch,), sigma, device = device) |
| |
|
| | padded_sigma = self.right_pad_dims_to_datatype(sigma) |
| |
|
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("unet_forward: ") |
| | print("arg_0: ", noised_images.shape) |
| | print("arg_1: ", sigma.shape) |
| | print("other arg: ", kwargs) |
| | |
| | net_out = unet_forward( |
| | self.c_in(sigma_data, padded_sigma) * noised_images, |
| | self.c_noise(sigma), |
| | |
| | **kwargs |
| | ) |
| |
|
| | out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out |
| |
|
| | if not clamp: |
| | return out |
| |
|
| | return self.threshold_x_start(out, dynamic_threshold) |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | def sample_schedule( |
| | self, |
| | num_sample_steps, |
| | rho, |
| | sigma_min, |
| | sigma_max |
| | ): |
| | N = num_sample_steps |
| | inv_rho = 1 / rho |
| |
|
| | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) |
| | sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho |
| |
|
| | sigmas = F.pad(sigmas, (0, 1), value = 0.) |
| | return sigmas |
| |
|
| | @torch.no_grad() |
| | def one_unet_sample( |
| | self, |
| | unet, |
| | shape, |
| | *, |
| | unet_number, |
| | clamp = True, |
| | dynamic_threshold = True, |
| | cond_scale = 1., |
| | use_tqdm = True, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | skip_steps = None, |
| | sigma_min = None, |
| | sigma_max = None, |
| | **kwargs |
| | ): |
| | |
| |
|
| | hp = self.hparams[unet_number - 1] |
| |
|
| | sigma_min = default(sigma_min, hp.sigma_min) |
| | sigma_max = default(sigma_max, hp.sigma_max) |
| |
|
| | |
| |
|
| | sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max) |
| |
|
| | gammas = torch.where( |
| | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), |
| | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), |
| | 0. |
| | ) |
| |
|
| | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) |
| |
|
| | |
| |
|
| | init_sigma = sigmas[0] |
| |
|
| | images = init_sigma * torch.randn(shape, device = self.device) |
| |
|
| | |
| |
|
| | if exists(init_images): |
| | images += init_images |
| |
|
| | |
| |
|
| | x_start = None |
| |
|
| | |
| |
|
| | has_inpainting = exists(inpaint_images) and exists(inpaint_masks) |
| | resample_times = inpaint_resample_times if has_inpainting else 1 |
| |
|
| | if has_inpainting: |
| | inpaint_images = self.normalize_img(inpaint_images) |
| | inpaint_images = self.resize_to(inpaint_images, shape[-1]) |
| | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool() |
| |
|
| | |
| |
|
| | unet_kwargs = dict( |
| | sigma_data = hp.sigma_data, |
| | clamp = clamp, |
| | dynamic_threshold = dynamic_threshold, |
| | cond_scale = cond_scale, |
| | **kwargs |
| | ) |
| |
|
| | |
| |
|
| | initial_step = default(skip_steps, 0) |
| | sigmas_and_gammas = sigmas_and_gammas[initial_step:] |
| |
|
| | total_steps = len(sigmas_and_gammas) |
| |
|
| | for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm): |
| | is_last_timestep = ind == (total_steps - 1) |
| |
|
| | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) |
| |
|
| | for r in reversed(range(resample_times)): |
| | is_last_resample_step = r == 0 |
| |
|
| | eps = hp.S_noise * torch.randn(shape, device = self.device) |
| |
|
| | sigma_hat = sigma + gamma * sigma |
| | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps |
| |
|
| | images_hat = images + added_noise |
| |
|
| | self_cond = x_start if unet.self_cond else None |
| |
|
| | if has_inpainting: |
| | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks |
| |
|
| | model_output = self.preconditioned_network_forward( |
| | unet.forward_with_cond_scale, |
| | images_hat, |
| | sigma_hat, |
| | self_cond = self_cond, |
| | **unet_kwargs |
| | ) |
| | |
| | denoised_over_sigma = (images_hat - model_output) / sigma_hat |
| |
|
| | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma |
| |
|
| | |
| |
|
| | if sigma_next != 0: |
| | self_cond = model_output if unet.self_cond else None |
| |
|
| | model_output_next = self.preconditioned_network_forward( |
| | unet.forward_with_cond_scale, |
| | images_next, |
| | sigma_next, |
| | self_cond = self_cond, |
| | **unet_kwargs |
| | ) |
| | |
| | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next |
| | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) |
| |
|
| | images = images_next |
| |
|
| | if has_inpainting and not (is_last_resample_step or is_last_timestep): |
| | |
| | repaint_noise = torch.randn(shape, device = self.device) |
| | images = images + (sigma - sigma_next) * repaint_noise |
| |
|
| | x_start = model_output |
| | |
| |
|
| | if has_inpainting: |
| | images = images * ~inpaint_masks + inpaint_images * inpaint_masks |
| |
|
| | return images |
| |
|
| | @torch.no_grad() |
| | @eval_decorator |
| | def sample( |
| | self, |
| | texts: List[str] = None, |
| | text_masks = None, |
| | text_embeds = None, |
| | cond_images = None, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | skip_steps = None, |
| | sigma_min = None, |
| | sigma_max = None, |
| | video_frames = None, |
| | batch_size = 1, |
| | cond_scale = 1., |
| | lowres_sample_noise_level = None, |
| | start_at_unet_number = 1, |
| | start_image_or_video = None, |
| | stop_at_unet_number = None, |
| | return_all_unet_outputs = False, |
| | return_pil_images = False, |
| | use_tqdm = True, |
| | device = None, |
| | |
| | ): |
| |
|
| | |
| | device = default(device, self.device) |
| | self.reset_unets_all_one_device(device = device) |
| |
|
| | cond_images = maybe(cast_uint8_images_to_float)(cond_images) |
| |
|
| | if exists(texts) and not exists(text_embeds) and not self.unconditional: |
| | assert all([*map(len, texts)]), 'text cannot be empty' |
| |
|
| | with autocast(enabled = False): |
| | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) |
| |
|
| | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) |
| |
|
| | if not self.unconditional: |
| | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' |
| |
|
| | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) |
| | batch_size = text_embeds.shape[0] |
| |
|
| | if exists(inpaint_images): |
| | if self.unconditional: |
| | if batch_size == 1: |
| | batch_size = inpaint_images.shape[0] |
| |
|
| | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=<int>)``' |
| | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' |
| |
|
| | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' |
| | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' |
| | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' |
| |
|
| | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' |
| |
|
| | outputs = [] |
| |
|
| | is_cuda = next(self.parameters()).is_cuda |
| | device = next(self.parameters()).device |
| |
|
| | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) |
| |
|
| | num_unets = len(self.unets) |
| | cond_scale = cast_tuple(cond_scale, num_unets) |
| |
|
| | |
| |
|
| | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' |
| |
|
| | frame_dims = (video_frames,) if self.is_video else tuple() |
| |
|
| | |
| |
|
| | init_images = cast_tuple(init_images, num_unets) |
| | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] |
| |
|
| | skip_steps = cast_tuple(skip_steps, num_unets) |
| |
|
| | sigma_min = cast_tuple(sigma_min, num_unets) |
| | sigma_max = cast_tuple(sigma_max, num_unets) |
| |
|
| | |
| |
|
| | if start_at_unet_number > 1: |
| | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' |
| | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number |
| | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' |
| |
|
| | prev_image_size = self.image_sizes[start_at_unet_number - 2] |
| | img = self.resize_to(start_image_or_video, prev_image_size) |
| |
|
| | |
| | for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): |
| | if unet_number < start_at_unet_number: |
| | continue |
| |
|
| | assert not isinstance(unet, NullUnet), 'cannot sample from null unet' |
| |
|
| | context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() |
| |
|
| | with context: |
| | lowres_cond_img = lowres_noise_times = None |
| |
|
| | shape = (batch_size, channel, *frame_dims, image_size ) |
| | |
| | |
| |
|
| | if unet.lowres_cond: |
| | lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) |
| |
|
| | lowres_cond_img = self.resize_to(img, image_size) |
| | lowres_cond_img = self.normalize_img(lowres_cond_img.float()) |
| |
|
| | lowres_cond_img, _ = self.lowres_noise_schedule.q_sample( |
| | x_start = lowres_cond_img.float(), |
| | t = lowres_noise_times, |
| | noise = torch.randn_like(lowres_cond_img.float()) |
| | ) |
| |
|
| | if exists(unet_init_images): |
| | unet_init_images = self.resize_to(unet_init_images, image_size) |
| |
|
| | |
| | shape = (batch_size, self.channels, *frame_dims, image_size) |
| |
|
| | img = self.one_unet_sample( |
| | unet, |
| | shape, |
| | unet_number = unet_number, |
| | text_embeds = text_embeds, |
| | text_mask =text_masks, |
| | cond_images = cond_images, |
| | inpaint_images = inpaint_images, |
| | inpaint_masks = inpaint_masks, |
| | inpaint_resample_times = inpaint_resample_times, |
| | init_images = unet_init_images, |
| | skip_steps = unet_skip_steps, |
| | sigma_min = unet_sigma_min, |
| | sigma_max = unet_sigma_max, |
| | cond_scale = unet_cond_scale, |
| | lowres_cond_img = lowres_cond_img, |
| | lowres_noise_times = lowres_noise_times, |
| | dynamic_threshold = dynamic_threshold, |
| | use_tqdm = use_tqdm |
| | ) |
| | |
| | if self.categorical_loss: |
| | img=self.m(img) |
| | |
| | outputs.append(img) |
| |
|
| | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: |
| | break |
| |
|
| | output_index = -1 if not return_all_unet_outputs else slice(None) |
| |
|
| |
|
| | if not return_all_unet_outputs: |
| | outputs = outputs[-1:] |
| |
|
| | assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' |
| | |
| | if self.categorical_loss: |
| | return torch.argmax(outputs[output_index], dim=1).unsqueeze (1) |
| | else: |
| | return outputs[output_index] |
| |
|
| | |
| |
|
| | def loss_weight(self, sigma_data, sigma): |
| | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 |
| |
|
| | def noise_distribution(self, P_mean, P_std, batch_size): |
| | |
| | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() |
| |
|
| | def forward( |
| | self, |
| | images, |
| | unet: Union[ NullUnet, DistributedDataParallel] = None, |
| | texts: List[str] = None, |
| | text_embeds = None, |
| | text_masks = None, |
| | unet_number = None, |
| | cond_images = None, |
| | |
| | ): |
| | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' |
| | unet_number = default(unet_number, 1) |
| | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | if cond_images!=None: |
| | print("cond_images type: ", cond_images.dtype) |
| | else: |
| | print("cond_images type: None") |
| | |
| | cond_images = maybe(cast_uint8_images_to_float)(cond_images) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | if cond_images!=None: |
| | print("cond_images type: ", cond_images.dtype) |
| | else: |
| | print("cond_images type: None") |
| | |
| | |
| | if self.categorical_loss==False: |
| | assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' |
| |
|
| | unet_index = unet_number - 1 |
| | |
| | unet = default(unet, lambda: self.get_unet(unet_number)) |
| |
|
| | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' |
| |
|
| | target_image_size = self.image_sizes[unet_index] |
| | random_crop_size = self.random_crop_sizes[unet_index] |
| | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None |
| | hp = self.hparams[unet_index] |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("target_image_size: ", target_image_size) |
| | print("prev_image_size: ", prev_image_size) |
| | print("random_crop_size: ", random_crop_size) |
| |
|
| | |
| | batch_size, c, *_, h, device, is_video = *images.shape, images.device, (images.ndim == 4) |
| |
|
| | frames = images.shape[2] if is_video else None |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("frames: ", frames) |
| |
|
| | check_shape(images, 'b c ...', c = self.channels) |
| |
|
| | |
| | assert h >= target_image_size |
| |
|
| | if exists(texts) and not exists(text_embeds) and not self.unconditional: |
| | assert all([*map(len, texts)]), 'text cannot be empty' |
| | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' |
| |
|
| | with autocast(enabled = False): |
| | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) |
| |
|
| | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) |
| |
|
| | if not self.unconditional: |
| | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) |
| |
|
| | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' |
| | assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' |
| |
|
| | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' |
| |
|
| | lowres_cond_img = lowres_aug_times = None |
| | if exists(prev_image_size): |
| | lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) |
| | lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) |
| |
|
| | if self.per_sample_random_aug_noise_level: |
| | lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) |
| | else: |
| | |
| | lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) |
| | |
| | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) |
| |
|
| | if exists(random_crop_size): |
| | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) |
| |
|
| | if is_video: |
| | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h -> (b f) c h') |
| |
|
| | images = aug(images) |
| | lowres_cond_img = aug(lowres_cond_img, params = aug._params) |
| |
|
| | if is_video: |
| | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h -> b c f h', f = frames) |
| |
|
| | |
| | lowres_cond_img_noisy = None |
| | if exists(lowres_cond_img): |
| | lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample( |
| | x_start = lowres_cond_img, |
| | t = lowres_aug_times, |
| | noise = torch.randn_like(lowres_cond_img.float()) |
| | ) |
| |
|
| | |
| |
|
| | sigmas = self.noise_distribution( |
| | hp.P_mean, |
| | hp.P_std, |
| | batch_size |
| | ).to(device) |
| | padded_sigmas = self.right_pad_dims_to_datatype(sigmas).to(device) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print('sigmas dim: ', sigmas.shape, 'should = batch_size') |
| | print('sigmas[0..3]: ', sigmas[:3]) |
| | print('padded_sigmas dim: ', padded_sigmas.shape) |
| |
|
| | |
| |
|
| | noise = torch.randn_like(images.float()).to(device) |
| | |
| | |
| | noised_images = images + padded_sigmas * noise |
| | |
| | |
| |
|
| | unet_kwargs = dict( |
| | sigma_data = hp.sigma_data, |
| | text_embeds = text_embeds, |
| | text_mask =text_masks, |
| | cond_images = cond_images, |
| | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), |
| | lowres_cond_img = lowres_cond_img_noisy, |
| | cond_drop_prob = self.cond_drop_prob, |
| | ) |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("lowres_noise_times: ", unet_kwargs['lowres_noise_times']) |
| | print("lowres_cond_img_noisy: ", unet_kwargs['lowres_cond_img']) |
| | print("unet_kwargs: \n", unet_kwargs) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet |
| | |
| | |
| | |
| | |
| |
|
| | if self_cond and random() < 0.5: |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("=========added self_cond.============") |
| | with torch.no_grad(): |
| | pred_x0 = self.preconditioned_network_forward( |
| | unet.forward, |
| | noised_images, |
| | sigmas, |
| | **unet_kwargs |
| | ).detach() |
| |
|
| | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("after self-condition, random < 0.5 .......") |
| | print("unet_kwargs: \n", unet_kwargs) |
| |
|
| | |
| |
|
| | denoised_images = self.preconditioned_network_forward( |
| | unet.forward, |
| | noised_images, |
| | sigmas, |
| | **unet_kwargs |
| | ) |
| |
|
| | |
| | |
| | if self.loss_type==0: |
| |
|
| | losses = F.mse_loss(denoised_images, images, reduction = 'none') |
| | losses = reduce(losses, 'b ... -> b', 'mean') |
| |
|
| | |
| |
|
| | losses = losses * self.loss_weight(hp.sigma_data, sigmas) |
| | losses=losses.mean() |
| |
|
| | return losses |
| |
|
| | |
| | |
| | |
| | class ProteinDesigner_B(nn.Module): |
| | def __init__(self, |
| | unet, |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | super(ProteinDesigner_B, self).__init__() |
| | |
| | |
| | timesteps =default(PKeys['timesteps'], 10) |
| | dim =default(PKeys['dim'], 32) |
| | pred_dim =default(PKeys['pred_dim'], 25) |
| | loss_type =default(PKeys['loss_type'], 0) |
| | elucidated =default(PKeys['elucidated'], True) |
| | padding_idx =default(PKeys['padding_idx'], 0) |
| | cond_dim =default(PKeys['cond_dim'], 512) |
| | text_embed_dim =default(PKeys['text_embed_dim'], 512) |
| | input_tokens =default(PKeys['input_tokens'], 25) |
| | sequence_embed =default(PKeys['sequence_embed'], False) |
| | embed_dim_position =default(PKeys['embed_dim_position'], 32) |
| | max_text_len =default(PKeys['max_text_len'], 16) |
| | cond_images_channels=default(PKeys['cond_images_channels'], 0) |
| | |
| | max_length =default(PKeys['max_length'], 64) |
| | device =default(PKeys['device'], None) |
| | |
| | |
| |
|
| | print ("Model B: Generative protein diffusion model, residue-based") |
| | print ("Using condition as the initial sequence") |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | self.max_length = max_length |
| | |
| | assert loss_type == 0, "Losses other than MSE not implemented" |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| |
|
| | condition_on_text=True |
| | self.cond_images_channels=cond_images_channels |
| | |
| | if self.cond_images_channels>0: |
| | condition_on_text = False |
| | |
| | if self.cond_images_channels>0: |
| | print ("Use conditioning image during training....") |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | condition_on_text = condition_on_text, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = ( [max_length ]), |
| | cond_drop_prob = 0.1, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to(device) |
| | else: |
| | print ("Not implemented.") |
| | |
| | def forward(self, |
| | output, |
| | x=None, |
| | cond_images = None, |
| | unet_number=1, |
| | ): |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("on Model B:forward") |
| | print("inputs:") |
| | print(' output: ', output.shape) |
| | print(' x: ', x) |
| | print(' cond_img: ', cond_images.shape) |
| | print(' unet_num: ', unet_number) |
| | |
| | |
| | if x != None: |
| | x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| |
|
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | if cond_images!=None: |
| | this_cond_images=cond_images.to(device) |
| | else: |
| | this_cond_images=cond_images |
| |
|
| | if self.CKeys['Debug_ModelPack']==1: |
| | print('x with pos_emb_x: ', x) |
| | print("into self.imagen...") |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | |
| | cond_images=this_cond_images, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample ( |
| | self, |
| | x=None, |
| | stop_at_unet_number=1 , |
| | cond_scale=7.5, |
| | x_data=None, |
| | skip_steps=None, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | x_data_tokenized=None, |
| | device=None, |
| | |
| | tokenizer_X=None, |
| | Xnormfac=1., |
| | max_length=1., |
| | ): |
| | |
| | batch_size=1 |
| | |
| | if x_data != None: |
| | print ("Conditioning target sequence provided via ori x_data ...", x_data) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | x_data = tokenizer_X.texts_to_sequences(x_data) |
| | |
| | x_data_0 = [] |
| | for this_x_data in x_data: |
| | x_data_0.append([0]+this_x_data) |
| | x_data = sequence.pad_sequences( |
| | x_data_0, maxlen=max_length, |
| | padding='post', truncating='post', |
| | ) |
| | |
| | x_data = torch.from_numpy(x_data).float().to(device) |
| | x_data = x_data/Xnormfac |
| | |
| | |
| | x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) |
| | |
| | |
| | print ("After channel expansion, x_data from target sequence=", x_data, x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | if x_data_tokenized != None: |
| | |
| | |
| | |
| | |
| | print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) |
| | |
| | print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | if init_images != None: |
| | print ("Init sequence provided...", init_images) |
| | init_images = tokenizer_y.texts_to_sequences(init_images) |
| | init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') |
| | init_images= torch.from_numpy(init_images).float().to(device)/ynormfac |
| | print ("init_images=", init_images) |
| | |
| | if inpaint_images != None: |
| | print ("Inpaint sequence provided...", inpaint_images) |
| | print ("Mask: ", inpaint_masks) |
| | inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) |
| | inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') |
| | inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac |
| | print ("in_paint images=", inpaint_images) |
| | |
| | if x !=None: |
| | x_in=torch.zeros( (x.shape[0],max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | batch_size=x.shape[0] |
| | |
| | output = self.imagen.sample( |
| | text_embeds= x, |
| | cond_scale = cond_scale, |
| | stop_at_unet_number=stop_at_unet_number, |
| | cond_images=x_data, |
| | skip_steps=skip_steps, |
| | inpaint_images = inpaint_images, |
| | inpaint_masks = inpaint_masks, |
| | inpaint_resample_times = inpaint_resample_times, |
| | init_images = init_images, |
| | batch_size=batch_size, |
| | device=device, |
| | ) |
| | |
| | return output |
| | |
| | |
| | |
| | |
| | class ProteinPredictor_B(nn.Module): |
| | def __init__(self, |
| | unet, |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | super(ProteinPredictor_B, self).__init__() |
| | |
| | |
| | timesteps =default(PKeys['timesteps'], 10) |
| | dim =default(PKeys['dim'], 32) |
| | pred_dim =default(PKeys['pred_dim'], 25) |
| | loss_type =default(PKeys['loss_type'], 0) |
| | elucidated =default(PKeys['elucidated'], True) |
| | padding_idx =default(PKeys['padding_idx'], 0) |
| | cond_dim =default(PKeys['cond_dim'], 512) |
| | text_embed_dim =default(PKeys['text_embed_dim'], 512) |
| | input_tokens =default(PKeys['input_tokens'], 25) |
| | sequence_embed =default(PKeys['sequence_embed'], False) |
| | embed_dim_position =default(PKeys['embed_dim_position'], 32) |
| | max_text_len =default(PKeys['max_text_len'], 16) |
| | cond_images_channels=default(PKeys['cond_images_channels'], 0) |
| | |
| | max_length =default(PKeys['max_length'], 64) |
| | device =default(PKeys['device'], None) |
| | |
| | |
| |
|
| | print ("Model B: Predictive protein diffusion model, residue-based") |
| | print ("Using condition as the initial sequence") |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | self.max_length = max_length |
| | |
| | assert loss_type == 0, "Losses other than MSE not implemented" |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| |
|
| | condition_on_text=True |
| | self.cond_images_channels=cond_images_channels |
| | |
| | if self.cond_images_channels>0: |
| | condition_on_text = False |
| | |
| | if self.cond_images_channels>0: |
| | print ("Use conditioning image during training....") |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | condition_on_text = condition_on_text, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = ( [max_length ]), |
| | cond_drop_prob = 0.1, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to(device) |
| | else: |
| | print ("Not implemented.") |
| | |
| | def forward(self, |
| | output, |
| | x=None, |
| | cond_images = None, |
| | unet_number=1, |
| | ): |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("on Model B:forward") |
| | print("inputs:") |
| | print(' output: ', output.shape) |
| | print(' x: ', x) |
| | print(' cond_img: ', cond_images.shape) |
| | print(' unet_num: ', unet_number) |
| | |
| | |
| | if x != None: |
| | x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| |
|
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | if cond_images!=None: |
| | this_cond_images=cond_images.to(device) |
| | else: |
| | this_cond_images=cond_images |
| |
|
| | if self.CKeys['Debug_ModelPack']==1: |
| | print('x with pos_emb_x: ', x) |
| | print("into self.imagen...") |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | |
| | cond_images=this_cond_images, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample ( |
| | self, |
| | x=None, |
| | stop_at_unet_number=1 , |
| | cond_scale=7.5, |
| | x_data=None, |
| | skip_steps=None, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | x_data_tokenized=None, |
| | device=None, |
| | |
| | tokenizer_X=None, |
| | Xnormfac=1., |
| | max_length=1., |
| | |
| | pLM_Model_Name=None, |
| | pLM_Model=None, |
| | pLM_alphabet=None, |
| | esm_layer=None, |
| | ): |
| | |
| | batch_size=1 |
| | |
| | if x_data != None: |
| | print ("Conditioning target sequence provided via ori x_data ...", x_data) |
| | print(f"use pLM model {pLM_Model_Name}") |
| | |
| | |
| | |
| | |
| | |
| | if pLM_Model_Name=='trivial': |
| | |
| | |
| | x_data = tokenizer_X.texts_to_sequences(x_data) |
| | |
| | x_data = sequence.pad_sequences( |
| | x_data, maxlen=max_length-1, |
| | padding='post', truncating='post', |
| | value=0.0, |
| | ) |
| | |
| | x_data = sequence.pad_sequences( |
| | x_data, maxlen=max_length, |
| | padding='pre', truncating='pre', |
| | value=0.0, |
| | ) |
| | x_data=x_data.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) |
| | |
| | x_data = x_data/Xnormfac |
| | |
| | else: |
| | |
| | print("pLM Model: ", pLM_Model_Name) |
| | |
| | |
| | esm_batch_converter = pLM_alphabet.get_batch_converter( |
| | truncation_seq_length=max_length-2 |
| | ) |
| | |
| | |
| | seqs_ext=[] |
| | for i in range(len(x_data)): |
| | seqs_ext.append( |
| | (" ", x_data[i]) |
| | ) |
| | |
| | _, x_strs, x_data = esm_batch_converter(seqs_ext) |
| | x_strs_lens = (x_data != pLM_alphabet.padding_idx).sum(1) |
| | |
| | |
| | |
| | |
| | current_seq_len = x_data.shape[1] |
| | print("current seq batch len: ", current_seq_len) |
| | missing_num_pad = max_length-current_seq_len |
| | if missing_num_pad>0: |
| | print("extra padding is added to match the target seq input length...") |
| | |
| | x_data = F.pad( |
| | x_data, |
| | (0, missing_num_pad), |
| | "constant", pLM_alphabet.padding_idx |
| | ) |
| | else: |
| | print("No extra padding is needed") |
| | x_data = x_data.to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | results = pLM_Model( |
| | x_data, |
| | repr_layers=[esm_layer], |
| | return_contacts=False, |
| | ) |
| | x_data=results["representations"][esm_layer] |
| | x_data=rearrange( |
| | x_data, |
| | 'b l c -> b c l' |
| | ) |
| | |
| | |
| | |
| | print ("x_data.dim: ", x_data.shape) |
| | print ("x_data.type: ", x_data.type) |
| | batch_size=x_data.shape[0] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if x_data_tokenized != None: |
| | |
| | |
| | |
| | print ( |
| | "Conditioning target output via provided AA tokens sequence...", |
| | x_data_tokenized, |
| | x_data_tokenized.shape, |
| | ) |
| | |
| | if pLM_Model_Name=='trivial': |
| | |
| | x_data=x_data_tokenized.unsqueeze(1).repeat(1,self.pred_dim,1) .to(device) |
| | else: |
| | with torch.no_grad(): |
| | results = pLM_Model( |
| | x_data_tokenized, |
| | repr_layers=[esm_layer], |
| | return_contacts=False, |
| | ) |
| | x_data=results["representations"][esm_layer] |
| | x_data=rearrange( |
| | x_data, |
| | 'b l c -> b c l' |
| | ) |
| | x_data = x_data.to(device) |
| | |
| | batch_size=x_data.shape[0] |
| | |
| | |
| | print ("x_data.dim provided from x_data_tokenized: ", x_data.shape) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | if init_images != None: |
| | print ("Init sequence provided...", init_images) |
| | init_images = tokenizer_y.texts_to_sequences(init_images) |
| | init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') |
| | init_images= torch.from_numpy(init_images).float().to(device)/ynormfac |
| | print ("init_images=", init_images) |
| | |
| | if inpaint_images != None: |
| | print ("Inpaint sequence provided...", inpaint_images) |
| | print ("Mask: ", inpaint_masks) |
| | inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) |
| | inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') |
| | inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac |
| | print ("in_paint images=", inpaint_images) |
| | |
| | if x !=None: |
| | x_in=torch.zeros( (x.shape[0],max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | batch_size=x.shape[0] |
| | |
| | output = self.imagen.sample( |
| | text_embeds= x, |
| | cond_scale = cond_scale, |
| | stop_at_unet_number=stop_at_unet_number, |
| | cond_images=x_data, |
| | skip_steps=skip_steps, |
| | inpaint_images = inpaint_images, |
| | inpaint_masks = inpaint_masks, |
| | inpaint_resample_times = inpaint_resample_times, |
| | init_images = init_images, |
| | batch_size=batch_size, |
| | device=device, |
| | ) |
| | |
| | return output |
| | |
| | |
| | |
| | |
| | |
| | class ProteinDesigner_B_Old(nn.Module): |
| | def __init__(self, |
| | unet, |
| | timesteps=10 , |
| | dim=32, |
| | pred_dim=25, |
| | loss_type=0, |
| | elucidated=True, |
| | padding_idx=0, |
| | cond_dim = 512, |
| | text_embed_dim = 512, |
| | input_tokens=25, |
| | sequence_embed=False, |
| | embed_dim_position=32, |
| | max_text_len=16, |
| | cond_images_channels=0, |
| | |
| | max_length=1, |
| | device=None, |
| | CKeys=None, |
| | PKeys=None, |
| | |
| | ): |
| | super(ProteinDesigner_B_Old, self).__init__() |
| |
|
| | print ("Model B: Generative protein diffusion model, residue-based") |
| | print ("Using condition as the initial sequence") |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | self.max_length = max_length |
| | |
| | assert loss_type == 0, "Losses other than MSE not implemented" |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| |
|
| | condition_on_text=True |
| | self.cond_images_channels=cond_images_channels |
| | |
| | if self.cond_images_channels>0: |
| | condition_on_text = False |
| | |
| | if self.cond_images_channels>0: |
| | print ("Use conditioning image during training....") |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | condition_on_text = condition_on_text, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = ( [max_length ]), |
| | cond_drop_prob = 0.1, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to(device) |
| | else: |
| | print ("Not implemented.") |
| | |
| | def forward(self, |
| | output, |
| | x=None, |
| | cond_images = None, |
| | unet_number=1, |
| | ): |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print('output: ', output.shape) |
| | print('x: ', x) |
| | print('cond_img: ', cond_images) |
| | print('unet_num: ', unet_number) |
| | |
| | |
| | if x != None: |
| | x_in=torch.zeros( (x.shape[0], self.max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| |
|
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | if cond_images!=None: |
| | this_cond_images=cond_images.to(device) |
| | else: |
| | this_cond_images=cond_images |
| |
|
| | if self.CKeys['Debug_ModelPack']==1: |
| | print('x with pos_emb_x: ', x) |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | |
| | cond_images=this_cond_images, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample ( |
| | self, |
| | x=None, |
| | stop_at_unet_number=1 , |
| | cond_scale=7.5, |
| | x_data=None, |
| | skip_steps=None, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | x_data_tokenized=None, |
| | device=None, |
| | |
| | tokenizer_X=None, |
| | Xnormfac=1., |
| | ynormfac=1., |
| | max_length=1., |
| | ): |
| | |
| | batch_size=1 |
| | |
| | if x_data != None: |
| | print ("Conditioning target sequence provided via x_data ...", x_data) |
| | x_data = tokenizer_X.texts_to_sequences(x_data) |
| | x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') |
| | |
| | x_data= torch.from_numpy(x_data).float().to(device) |
| | x_data = x_data/Xnormfac |
| | x_data=x_data.unsqueeze (2) |
| | x_data=torch.permute(x_data, (0,2,1) ) |
| | |
| | print ("x_data from target sequence=", x_data, x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | if x_data_tokenized != None: |
| | print ("Conditioning target sequence provided via x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) |
| | |
| | x_data=x_data_tokenized.unsqueeze (2) |
| | x_data=torch.permute(x_data, (0,2,1) ).to(device) |
| | print ("Data provided from x_data_tokenized: ", x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | if init_images != None: |
| | print ("Init sequence provided...", init_images) |
| | init_images = tokenizer_y.texts_to_sequences(init_images) |
| | init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') |
| | init_images= torch.from_numpy(init_images).float().to(device)/ynormfac |
| | print ("init_images=", init_images) |
| | |
| | if inpaint_images != None: |
| | print ("Inpaint sequence provided...", inpaint_images) |
| | print ("Mask: ", inpaint_masks) |
| | inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) |
| | inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') |
| | inpaint_images= torch.from_numpy(inpaint_images).float().to(device)/ynormfac |
| | print ("in_paint images=", inpaint_images) |
| | |
| | if x !=None: |
| | x_in=torch.zeros( (x.shape[0],max_length) ).to(device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | batch_size=x.shape[0] |
| | |
| | output = self.imagen.sample( |
| | text_embeds= x, |
| | cond_scale = cond_scale, |
| | stop_at_unet_number=stop_at_unet_number, |
| | cond_images=x_data, |
| | skip_steps=skip_steps, |
| | inpaint_images = inpaint_images, |
| | inpaint_masks = inpaint_masks, |
| | inpaint_resample_times = inpaint_resample_times, |
| | init_images = init_images, |
| | batch_size=batch_size, |
| | device=device, |
| | ) |
| | |
| | return output |
| | |
| | |
| | |
| | |
| | |
| | class ProteinDesigner_A_II(nn.Module): |
| | def __init__( |
| | self, |
| | unet1, |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | |
| | super(ProteinDesigner_A_II, self).__init__() |
| | |
| | |
| | |
| | timesteps =default(PKeys['timesteps'], 10) |
| | dim =default(PKeys['dim'], 32) |
| | pred_dim =default(PKeys['pred_dim'], 25) |
| | loss_type =default(PKeys['loss_type'], 0) |
| | elucidated =default(PKeys['elucidated'], True) |
| | padding_idx =default(PKeys['padding_idx'], 0) |
| | cond_dim =default(PKeys['cond_dim'], 512) |
| | text_embed_dim =default(PKeys['text_embed_dim'], 512) |
| | input_tokens =default(PKeys['input_tokens'], 25) |
| | sequence_embed =default(PKeys['sequence_embed'], False) |
| | embed_dim_position =default(PKeys['embed_dim_position'], 32) |
| | max_text_len =default(PKeys['max_text_len'], 16) |
| | cond_images_channels=default(PKeys['cond_images_channels'], 0) |
| | |
| | max_length =default(PKeys['max_length'], 64) |
| | device =default(PKeys['device'], 'cuda:0') |
| | |
| | |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | |
| | self.device=device |
| | |
| | assert (loss_type==0), "Loss other than MSE not implemented" |
| | |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("ModelA.pos_matrix_i: ", self.pos_matrix_i) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet1), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = [max_length], |
| | cond_drop_prob = 0.2, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to (self.device) |
| | |
| | if CKeys['Debug_ModelPack']==1: |
| | print("Check on EImagen:") |
| | print("channels: ", self.pred_dim) |
| | print("loss_type: ", loss_type) |
| | print("text_embed_dim: ",text_embed_dim) |
| | print("image_sizes: ", max_length) |
| | print("num_sample_steps: ", timesteps) |
| | print("Measure imagen:") |
| | params( self.imagen) |
| | print("Measure fc_embed2") |
| | params( self.fc_embed2) |
| | print("Measure pos_emb_x") |
| | params( self.pos_emb_x) |
| | |
| | else: |
| | print ("Not implemented.") |
| | |
| | |
| | def forward( |
| | self, |
| | |
| | |
| | |
| | |
| | output, |
| | x=None, |
| | |
| | cond_images=None, |
| | unet_number=1, |
| | ): |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("on Model A:forward") |
| | print("inputs:") |
| | print(' output.dim : ', output.shape) |
| | print(' x: ', x) |
| | print(' x.dim: ', x.shape) |
| | if cond_images==None: |
| | print(' cond_img: None ') |
| | else: |
| | print(' cond_img: ', cond_images.shape) |
| | print(' unet_num: ', unet_number) |
| | |
| | x=x.unsqueeze (2) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After x.unsqueeze(2), x.dim: ", x.shape) |
| | |
| | x= self.fc_embed2(x) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After fc_embed2(x), x.dim: ", x.shape) |
| | print() |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("pos_matrix_i_.dim: ", pos_matrix_i_.shape) |
| | print("pos_matrix_i_: ", pos_matrix_i_) |
| | print() |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("After pos_emb_x(pos_matrix_i_), pos_emb_x.dim: ", pos_emb_x.shape) |
| | print("pos_emb_x: ", pos_emb_x) |
| | print() |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("after operations, pos_emb_x.dim: ", pos_emb_x.shape) |
| | print("pos_emb_x: ", pos_emb_x) |
| | print() |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("after cat((x,pos_emb_x),2)=>x dim: ", x.shape, "Batch x max_text_len x (text_embed_dim+embed_dim_position)") |
| | print() |
| | print("Now, get into self.imagen part...") |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample ( |
| | self, |
| | x=None, |
| | stop_at_unet_number=1, |
| | cond_scale=7.5, |
| | |
| | x_data=None, |
| | skip_steps=None, |
| | inpaint_images = None, |
| | inpaint_masks = None, |
| | inpaint_resample_times = 5, |
| | init_images = None, |
| | x_data_tokenized=None, |
| | tokenizer_X=None, |
| | Xnormfac=None, |
| | |
| | device=None, |
| | max_length=None, |
| | max_text_len=None, |
| | ): |
| | |
| | if x_data != None: |
| | print ("Conditioning target sequence provided via sequence/image in x_data ...", x_data) |
| | |
| | x_data = tokenizer_X.texts_to_sequences(x_data) |
| | x_data= sequence.pad_sequences(x_data, maxlen=max_length, padding='post', truncating='post') |
| | |
| | x_data= torch.from_numpy(x_data).float().to(self.device) |
| | x_data = x_data/Xnormfac |
| | x_data=x_data.unsqueeze (2) |
| | x_data=torch.permute(x_data, (0,2,1) ) |
| | |
| | print ("x_data from target sequence=", x_data, x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | |
| | if x_data_tokenized != None: |
| | print ("Conditioning target sequence provided via processed sequence/image in x_data_tokenized ...", x_data_tokenized, x_data_tokenized.shape) |
| | |
| | x_data=x_data_tokenized.unsqueeze (2) |
| | x_data=torch.permute(x_data, (0,2,1) ).to(self.device) |
| | print ("Data provided from x_data_tokenized: ", x_data.shape) |
| | batch_size=x_data.shape[0] |
| | |
| | |
| | if init_images != None: |
| | |
| | print ("Init sequence provided...", init_images) |
| | init_images = tokenizer_y.texts_to_sequences(init_images) |
| | init_images= sequence.pad_sequences(init_images, maxlen=max_length, padding='post', truncating='post') |
| | init_images= torch.from_numpy(init_images).float().to(self.device)/ynormfac |
| | print ("init_images=", init_images) |
| | |
| | if inpaint_images != None: |
| | |
| | print ("Inpaint sequence provided...", inpaint_images) |
| | print ("Mask: ", inpaint_masks) |
| | inpaint_images = tokenizer_y.texts_to_sequences(inpaint_images) |
| | inpaint_images= sequence.pad_sequences(inpaint_images, maxlen=max_length, padding='post', truncating='post') |
| | inpaint_images= torch.from_numpy(inpaint_images).float().to(self.device)/ynormfac |
| | print ("in_paint images=", inpaint_images) |
| | |
| | |
| | if x !=None: |
| | print ("Conditioning target sequence via tokenized text in x ...", x) |
| | |
| | |
| | |
| | if self.CKeys['Debug_ModelPack']==1: |
| | print("x.dim: ", x.shape) |
| | print("max_text_len: ", max_text_len) |
| | print("self.device: ",self.device) |
| | |
| | x_in=torch.zeros( (x.shape[0],max_text_len) ).to(self.device) |
| | x_in[:,:x.shape[1]]=x |
| | x=x_in |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to(self.device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | batch_size=x.shape[0] |
| | |
| | output = self.imagen.sample( |
| | text_embeds= x, |
| | cond_scale= cond_scale, |
| | stop_at_unet_number=stop_at_unet_number, |
| | |
| | cond_images=x_data, |
| | skip_steps=skip_steps, |
| | inpaint_images = inpaint_images, |
| | inpaint_masks = inpaint_masks, |
| | inpaint_resample_times = inpaint_resample_times, |
| | init_images = init_images, |
| | batch_size=batch_size, |
| | device=self.device, |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | return output |
| | |
| | |
| | |
| | |
| | |
| | class ProteinDesigner_A_I(nn.Module): |
| | def __init__( |
| | self, |
| | |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | |
| | super(ProteinDesigner_A_I, self).__init__() |
| | |
| | |
| | |
| | timesteps =default(PKeys['timesteps'], 10) |
| | dim =default(PKeys['dim'], 32) |
| | pred_dim =default(PKeys['pred_dim'], 25) |
| | loss_type =default(PKeys['loss_type'], 0) |
| | elucidated =default(PKeys['elucidated'], True) |
| | padding_idx =default(PKeys['padding_idx'], 0) |
| | cond_dim =default(PKeys['cond_dim'], 512) |
| | text_embed_dim =default(PKeys['text_embed_dim'], 512) |
| | input_tokens =default(PKeys['input_tokens'], 25) |
| | sequence_embed =default(PKeys['sequence_embed'], False) |
| | embed_dim_position =default(PKeys['embed_dim_position'], 32) |
| | max_text_len =default(PKeys['max_text_len'], 16) |
| | cond_images_channels=default(PKeys['cond_images_channels'], 0) |
| | |
| | max_length =default(PKeys['max_length'], 64) |
| | device =default(PKeys['device'], 'cuda:0') |
| | |
| | |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | |
| | self.device=device |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| |
|
| | assert (loss_type==0), "Loss other than MSE not implemented" |
| | |
| | |
| | |
| | |
| | write_PK_UNet=dict() |
| | write_PK_UNet['dim']=dim |
| | write_PK_UNet['text_embed_dim']=text_embed_dim |
| | write_PK_UNet['cond_dim']=cond_dim |
| | write_PK_UNet['dim_mults']=(1, 2, 4, 8) |
| | write_PK_UNet['num_resnet_blocks']=1 |
| | write_PK_UNet['layer_attns']=(False, True, True, False) |
| | write_PK_UNet['layer_cross_attns']=(False, True, True, False) |
| | write_PK_UNet['channels']=pred_dim |
| | write_PK_UNet['channels_out']=pred_dim |
| | write_PK_UNet['attn_dim_head']=64 |
| | write_PK_UNet['attn_heads']=8 |
| | write_PK_UNet['ff_mult']=2. |
| | write_PK_UNet['lowres_cond']=False |
| | write_PK_UNet['layer_attns_depth']=1 |
| | write_PK_UNet['layer_attns_add_text_cond']=True |
| | write_PK_UNet['attend_at_middle']=True |
| | write_PK_UNet['use_linear_attn']=False |
| | write_PK_UNet['use_linear_cross_attn']=False |
| | write_PK_UNet['cond_on_text'] = True |
| | write_PK_UNet['max_text_len'] = max_length |
| | write_PK_UNet['init_dim'] = None |
| | write_PK_UNet['resnet_groups'] = 8 |
| | write_PK_UNet['init_conv_kernel_size'] =7 |
| | write_PK_UNet['init_cross_embed'] = False |
| | write_PK_UNet['init_cross_embed_kernel_sizes'] = (3, 7, 15) |
| | write_PK_UNet['cross_embed_downsample'] = False |
| | write_PK_UNet['cross_embed_downsample_kernel_sizes'] = (2, 4) |
| | write_PK_UNet['attn_pool_text'] = True |
| | write_PK_UNet['attn_pool_num_latents'] = 32 |
| | write_PK_UNet['dropout'] = 0. |
| | write_PK_UNet['memory_efficient'] = False |
| | write_PK_UNet['init_conv_to_final_conv_residual'] = False |
| | write_PK_UNet['use_global_context_attn'] = True |
| | write_PK_UNet['scale_skip_connection'] = True |
| | write_PK_UNet['final_resnet_block'] = True |
| | write_PK_UNet['final_conv_kernel_size'] = 3 |
| | write_PK_UNet['cosine_sim_attn'] = True |
| | write_PK_UNet['self_cond'] = False |
| | write_PK_UNet['combine_upsample_fmaps'] = True |
| | write_PK_UNet['pixel_shuffle_upsample'] = False |
| | |
| | |
| | Unet_PKeys=prepare_UNet_keys(write_PK_UNet) |
| | unet1 = OneD_Unet( |
| | CKeys=CKeys, |
| | PKeys=Unet_PKeys, |
| | ).to (self.device) |
| | if CKeys['Debug_ModelPack']==1: |
| | print('Check unet generated...') |
| | params(unet1) |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet1), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = [max_length], |
| | cond_drop_prob = 0.2, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to (self.device) |
| | else: |
| | print ("Not implemented.") |
| | |
| | def forward(self, x, output, unet_number=1): |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) |
| | |
| | return output |
| | |
| | |
| | |
| | |
| | class ProteinDesigner_A_Old(nn.Module): |
| | def __init__( |
| | self, |
| | timesteps=10 , |
| | dim=32, |
| | pred_dim=25, |
| | loss_type=0, |
| | elucidated=False, |
| | padding_idx=0, |
| | cond_dim = 512, |
| | text_embed_dim = 512, |
| | input_tokens=25, |
| | sequence_embed=False, |
| | embed_dim_position=32, |
| | max_text_len=16, |
| | device='cuda:0', |
| | |
| | max_length=64, |
| | CKeys=None, |
| | PKeys=None, |
| | ): |
| | |
| | super(ProteinDesigner_A_Old, self).__init__() |
| | |
| | |
| | self.CKeys=CKeys |
| | self.PKeys=PKeys |
| | |
| | self.device=device |
| | self.pred_dim=pred_dim |
| | self.loss_type=loss_type |
| | |
| | self.fc_embed1 = nn.Linear( 8, max_length) |
| | self.fc_embed2 = nn.Linear( 1, text_embed_dim) |
| | self.max_text_len=max_text_len |
| | |
| | self.pos_emb_x = nn.Embedding(max_text_len+1, embed_dim_position) |
| | text_embed_dim=text_embed_dim+embed_dim_position |
| | self.pos_matrix_i = torch.zeros (max_text_len, dtype=torch.long) |
| | for i in range (max_text_len): |
| | self.pos_matrix_i [i]=i +1 |
| |
|
| | assert (loss_type==0), "Loss other than MSE not implemented" |
| |
|
| | unet1 = OneD_Unet_Old( |
| | dim = dim, |
| | text_embed_dim = text_embed_dim, |
| | cond_dim = cond_dim, |
| | dim_mults = (1, 2, 4, 8), |
| | |
| | num_resnet_blocks = 1, |
| | layer_attns = (False, True, True, False), |
| | layer_cross_attns = (False, True, True, False), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | |
| | attn_dim_head = 64, |
| | attn_heads = 8, |
| | ff_mult = 2., |
| | lowres_cond = False, |
| |
|
| | layer_attns_depth =1, |
| | layer_attns_add_text_cond = True, |
| | attend_at_middle = True, |
| | use_linear_attn = False, |
| | use_linear_cross_attn = False, |
| | cond_on_text = True, |
| | max_text_len = max_length, |
| | init_dim = None, |
| | resnet_groups = 8, |
| | init_conv_kernel_size =7, |
| | init_cross_embed = False, |
| | init_cross_embed_kernel_sizes = (3, 7, 15), |
| | cross_embed_downsample = False, |
| | cross_embed_downsample_kernel_sizes = (2, 4), |
| | attn_pool_text = True, |
| | attn_pool_num_latents = 32, |
| | dropout = 0., |
| | memory_efficient = False, |
| | init_conv_to_final_conv_residual = False, |
| | use_global_context_attn = True, |
| | scale_skip_connection = True, |
| | final_resnet_block = True, |
| | final_conv_kernel_size = 3, |
| | cosine_sim_attn = True, |
| | self_cond = False, |
| | combine_upsample_fmaps = True, |
| | pixel_shuffle_upsample = False , |
| | |
| | CKeys=CKeys, |
| |
|
| | ).to (self.device) |
| | |
| | if CKeys['Debug_ModelPack']==1: |
| | print("Check NUnet...") |
| | params( unet1) |
| | |
| | assert elucidated , "Only elucidated model implemented...." |
| | self.is_elucidated=elucidated |
| | if elucidated: |
| | self.imagen = ElucidatedImagen( |
| | unets = (unet1), |
| | channels=self.pred_dim, |
| | channels_out=self.pred_dim , |
| | loss_type=loss_type, |
| | text_embed_dim = text_embed_dim, |
| | image_sizes = [max_length], |
| | cond_drop_prob = 0.2, |
| | auto_normalize_img = False, |
| | num_sample_steps = timesteps, |
| | sigma_min = 0.002, |
| | sigma_max = 160, |
| | sigma_data = 0.5, |
| | rho = 7, |
| | P_mean = -1.2, |
| | P_std = 1.2, |
| | S_churn = 40, |
| | S_tmin = 0.05, |
| | S_tmax = 50, |
| | S_noise = 1.003, |
| | |
| | CKeys=self.CKeys, |
| | PKeys=self.PKeys, |
| | |
| | ).to (self.device) |
| | if CKeys['Debug_ModelPack']==1: |
| | print("Check on EImagen:") |
| | print("channels: ", self.pred_dim) |
| | print("loss_type: ", loss_type) |
| | print("text_embed_dim: ",text_embed_dim) |
| | print("image_sizes: ", max_length) |
| | print("num_sample_steps: ", timesteps) |
| | print("Measure imagen:") |
| | params( self.imagen) |
| | print("Measure fc_embed2") |
| | params( self.fc_embed2) |
| | print("Measure pos_emb_x") |
| | params( self.pos_emb_x) |
| | else: |
| | print ("Not implemented.") |
| | |
| | def forward(self, x, output, unet_number=1): |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | loss = self.imagen( |
| | output, |
| | text_embeds = x, |
| | unet_number = unet_number, |
| | ) |
| | |
| | return loss |
| | |
| | def sample (self, x, stop_at_unet_number=1 ,cond_scale=7.5,): |
| | |
| | x=x.unsqueeze (2) |
| | |
| | x= self.fc_embed2(x) |
| | |
| | pos_matrix_i_=self.pos_matrix_i.repeat(x.shape[0], 1).to (self.device) |
| | pos_emb_x = self.pos_emb_x( pos_matrix_i_) |
| | pos_emb_x = torch.squeeze(pos_emb_x, 1) |
| | pos_emb_x[:,x.shape[1]:,:]=0 |
| | pos_emb_x=pos_emb_x[:,:x.shape[1],:] |
| | x= torch.cat( (x, pos_emb_x ), 2) |
| | |
| | output = self.imagen.sample(text_embeds= x, cond_scale = cond_scale, stop_at_unet_number=stop_at_unet_number) |
| | |
| | return output |