| | import torch |
| | import torch.nn.functional as F |
| | from torchvision import transforms |
| |
|
| | def calc_mean_std(feat, eps=1e-5): |
| | |
| | size = feat.size() |
| |
|
| | N, C = size[:2] |
| | feat_var = feat.view(N, C, -1).var(dim=2) + eps |
| | if len(size) == 3: |
| | feat_std = feat_var.sqrt().view(N, C, 1) |
| | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1) |
| | else: |
| | feat_std = feat_var.sqrt().view(N, C, 1, 1) |
| | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) |
| | return feat_mean, feat_std |
| |
|
| |
|
| | def get_img(img, resolution=512): |
| | norm_mean = [0.5, 0.5, 0.5] |
| | norm_std = [0.5, 0.5, 0.5] |
| | transform = transforms.Compose([ |
| | transforms.Resize((resolution, resolution)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(norm_mean, norm_std) |
| | ]) |
| | img = transform(img) |
| | return img.unsqueeze(0) |
| |
|
| | @torch.no_grad() |
| | def slerp(p0, p1, fract_mixing: float, adain=True): |
| | r""" Copied from lunarring/latentblending |
| | Helper function to correctly mix two random variables using spherical interpolation. |
| | The function will always cast up to float64 for sake of extra 4. |
| | Args: |
| | p0: |
| | First tensor for interpolation |
| | p1: |
| | Second tensor for interpolation |
| | fract_mixing: float |
| | Mixing coefficient of interval [0, 1]. |
| | 0 will return in p0 |
| | 1 will return in p1 |
| | 0.x will return a mix between both preserving angular velocity. |
| | """ |
| | if p0.dtype == torch.float16: |
| | recast_to = 'fp16' |
| | else: |
| | recast_to = 'fp32' |
| |
|
| | p0 = p0.double() |
| | p1 = p1.double() |
| |
|
| | if adain: |
| | mean1, std1 = calc_mean_std(p0) |
| | mean2, std2 = calc_mean_std(p1) |
| | mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing |
| | std = std1 * (1 - fract_mixing) + std2 * fract_mixing |
| | |
| | norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) |
| | epsilon = 1e-7 |
| | dot = torch.sum(p0 * p1) / norm |
| | dot = dot.clamp(-1+epsilon, 1-epsilon) |
| |
|
| | theta_0 = torch.arccos(dot) |
| | sin_theta_0 = torch.sin(theta_0) |
| | theta_t = theta_0 * fract_mixing |
| | s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 |
| | s1 = torch.sin(theta_t) / sin_theta_0 |
| | interp = p0*s0 + p1*s1 |
| |
|
| | if adain: |
| | interp = F.instance_norm(interp) * std + mean |
| |
|
| | if recast_to == 'fp16': |
| | interp = interp.half() |
| | elif recast_to == 'fp32': |
| | interp = interp.float() |
| |
|
| | return interp |
| |
|
| |
|
| | def do_replace_attn(key: str): |
| | |
| | return key.startswith('up') |
| |
|