| | import torch |
| | import torch.nn as nn |
| | from torch.nn.functional import interpolate |
| |
|
| | import math |
| | from tqdm import tqdm |
| |
|
| | from modules.feature_extactor import Extractor |
| | from modules.half_warper import HalfWarper |
| | from modules.cupy_module.nedt import NEDT |
| | from modules.flow_models.flow_models import ( |
| | RAFTFineFlow, |
| | PWCFineFlow |
| | ) |
| | from modules.synthesizer import Synthesis |
| |
|
| | class FeatureWarper(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | channels: list[int] = [32, 64, 128, 256], |
| | ): |
| | super().__init__() |
| | channels = [in_channels + 1] + channels |
| | |
| | self.half_warper = HalfWarper() |
| | self.feature_extractor = Extractor(channels) |
| | self.nedt = NEDT() |
| |
|
| | def forward( |
| | self, |
| | I0: torch.Tensor, |
| | I1: torch.Tensor, |
| | flow0to1: torch.Tensor, |
| | flow1to0: torch.Tensor, |
| | tau: torch.Tensor = None |
| | ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: |
| | assert tau.shape == (I0.shape[0], 2), "tau shape must be (batch, 2)" |
| |
|
| | flow0tot = tau[:, 0][:, None, None, None] * flow0to1 |
| | flow1tot = tau[:, 1][:, None, None, None] * flow1to0 |
| |
|
| | I0 = torch.cat([I0, self.nedt(I0)], dim=1) |
| | I1 = torch.cat([I1, self.nedt(I1)], dim=1) |
| |
|
| | z0to1, z1to0 = HalfWarper.z_metric(I0, I1, flow0to1, flow1to0) |
| | base0, base1 = self.half_warper(I0, I1, flow0tot, flow1tot, z0to1, z1to0) |
| | warped0, warped1 = [base0], [base1] |
| |
|
| | features0 = self.feature_extractor(I0) |
| | features1 = self.feature_extractor(I1) |
| |
|
| | for feat0, feat1 in zip(features0, features1): |
| | f0 = interpolate(flow0tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
| | f1 = interpolate(flow1tot, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
| | z0 = interpolate(z0to1, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
| | z1 = interpolate(z1to0, size=feat0.shape[2:], mode='bilinear', align_corners=False) |
| | w0, w1 = self.half_warper(feat0, feat1, f0, f1, z0, z1) |
| | warped0.append(w0) |
| | warped1.append(w1) |
| | return warped0, warped1 |
| |
|
| | class MultiInputResShift(nn.Module): |
| | def __init__( |
| | self, |
| | kappa: float=2.0, |
| | p: float =0.3, |
| | min_noise_level: float=0.04, |
| | etas_end: float=0.99, |
| | timesteps: int=15, |
| | flow_model: str = 'raft', |
| | flow_kwargs: dict = {}, |
| | warping_kwargs: dict = {}, |
| | synthesis_kwargs: dict = {} |
| | ): |
| | super().__init__() |
| |
|
| | self.timesteps = timesteps |
| | self.kappa = kappa |
| | self.eta_partition = None |
| |
|
| | sqrt_eta_1 = min(min_noise_level / kappa, min_noise_level, math.sqrt(0.001)) |
| | b0 = math.exp(1/float(timesteps - 1) * math.log(etas_end/sqrt_eta_1)) |
| | base = torch.ones(timesteps)*b0 |
| | beta = ((torch.linspace(0,1,timesteps))**p)*(timesteps-1) |
| | sqrt_eta = torch.pow(base, beta) * sqrt_eta_1 |
| |
|
| | self.register_buffer("sqrt_sum_eta", sqrt_eta) |
| | self.register_buffer("sum_eta", sqrt_eta**2) |
| |
|
| | sum_prev_eta = torch.roll(self.sum_eta, 1) |
| | sum_prev_eta[0] = 0 |
| | self.register_buffer("sum_prev_eta", sum_prev_eta) |
| |
|
| | self.register_buffer("sum_alpha", self.sum_eta - self.sum_prev_eta) |
| |
|
| | self.register_buffer("backward_mean_c1", self.sum_prev_eta / self.sum_eta) |
| | self.register_buffer("backward_mean_c2", self.sum_alpha / self.sum_eta) |
| | self.register_buffer("backward_std", self.kappa*torch.sqrt(self.sum_prev_eta*self.sum_alpha/self.sum_eta)) |
| |
|
| | if flow_model == 'raft': |
| | self.flow_model = RAFTFineFlow(**flow_kwargs) |
| | elif flow_model == 'pwc': |
| | self.flow_model = PWCFineFlow(**flow_kwargs) |
| | else: |
| | raise ValueError(f"Flow model {flow_model} not supported") |
| |
|
| | self.feature_warper = FeatureWarper(**warping_kwargs) |
| | self.synthesis = Synthesis(**synthesis_kwargs) |
| |
|
| | def forward_process( |
| | self, |
| | x: torch.Tensor | None, |
| | Y: list[torch.Tensor], |
| | tau: torch.Tensor | float | None, |
| | t: torch.Tensor | int |
| | ) -> torch.Tensor: |
| | if tau is None: |
| | tau: torch.Tensor = torch.full((x.shape[0], len(Y)), 0.5, device=x.device, dtype=x.dtype) |
| | elif isinstance(tau, float): |
| | assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" |
| | tau: torch.Tensor = torch.cat([ |
| | torch.full((x.shape[0], 1), tau, device=x.device, dtype=x.dtype), |
| | torch.full((x.shape[0], 1), 1 - tau, device=x.device, dtype=x.dtype) |
| | ], dim=1) |
| | if not torch.is_tensor(t): |
| | t: torch.Tensor = torch.tensor([t], device=x.device, dtype=torch.long) |
| | if x is None: |
| | x: torch.Tensor = torch.zeros_like(Y[0]) |
| | |
| | eta = self.sum_eta[t][:, None] * tau |
| | eta = eta[:, :, None, None, None].transpose(0, 1) |
| |
|
| | e_i = torch.stack([y - x for y in Y]) |
| | mean = x + (eta*e_i).sum(dim=0) |
| |
|
| | sqrt_sum_eta = self.sqrt_sum_eta[t][:, None, None, None] |
| | std = self.kappa*sqrt_sum_eta |
| | epsilon = torch.randn_like(x) |
| |
|
| | return mean + std*epsilon |
| |
|
| | @torch.inference_mode() |
| | def reverse_process( |
| | self, |
| | Y: list[torch.Tensor], |
| | tau: torch.Tensor | float, |
| | flows: list[torch.Tensor] | None = None, |
| | ) -> torch.Tensor: |
| | y = Y[0] |
| | batch, device, dtype = y.shape[0], y.device, y.dtype |
| | |
| | if isinstance(tau, float): |
| | assert tau >= 0 and tau <= 1, "tau must be between 0 and 1" |
| | tau: torch.Tensor = torch.cat([ |
| | torch.full((batch, 1), tau, device=device, dtype=dtype), |
| | torch.full((batch, 1), 1 - tau, device=device, dtype=dtype) |
| | ], dim=1) |
| | if flows is None: |
| | flow0to1, flow1to0 = self.flow_model(Y[0], Y[1]) |
| | else: |
| | flow0to1, flow1to0 = flows |
| | warp0to1, warp1to0 = self.feature_warper(Y[0], Y[1], flow0to1, flow1to0, tau) |
| |
|
| | T = torch.tensor([self.timesteps-1,] * batch, device=device, dtype=torch.long) |
| | x = self.forward_process(torch.zeros_like(Y[0]), [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, T) |
| |
|
| | pbar = tqdm(total=self.timesteps, desc="Reversing Process") |
| | for i in reversed(range(self.timesteps)): |
| | t = torch.ones(batch, device = device, dtype=torch.long) * i |
| | |
| | predicted_x0 = self.synthesis(x, warp0to1, warp1to0, t) |
| | |
| | mean_c1 = self.backward_mean_c1[t][:, None, None, None] |
| | mean_c2 = self.backward_mean_c2[t][:, None, None, None] |
| | std = self.backward_std[t][:, None, None, None] |
| | |
| | eta = self.sum_eta[t][:, None] * tau |
| | prev_eta = self.sum_prev_eta[t][:, None] * tau |
| | eta = eta[:, :, None, None, None].transpose(0, 1) |
| | prev_eta = prev_eta[:, :, None, None, None].transpose(0, 1) |
| | e_i = torch.stack([y - predicted_x0 for y in Y]) |
| |
|
| | mean = ( |
| | mean_c1*(x + (eta*e_i).sum(dim=0)) |
| | + mean_c2*predicted_x0 |
| | - (prev_eta*e_i).sum(dim=0) |
| | ) |
| |
|
| | x = mean + std*torch.randn_like(x) |
| | pbar.update(1) |
| | pbar.close() |
| | return x |
| |
|
| | |
| | def forward( |
| | self, |
| | I0: torch.Tensor, |
| | It: torch.Tensor, |
| | I1: torch.Tensor, |
| | flow1to0: torch.Tensor | None = None, |
| | flow0to1: torch.Tensor | None = None, |
| | tau: torch.Tensor | None = None, |
| | t: torch.Tensor | None = None |
| | ) -> torch.Tensor: |
| |
|
| | if tau is None: |
| | tau = torch.full((It.shape[0], 2), 0.5, device=It.device, dtype=It.dtype) |
| |
|
| | if flow0to1 is None or flow1to0 is None: |
| | flow0to1, flow1to0 = self.flow_model(I0, I1) |
| | |
| | if t is None: |
| | t = torch.randint(low=1, high=self.timesteps, size=(It.shape[0],), device=It.device, dtype=torch.long) |
| | |
| | warp0to1, warp1to0 = self.feature_warper(I0, I1, flow0to1, flow1to0, tau) |
| | x_t = self.forward_process(It, [warp0to1[0][:, :3], warp1to0[0][:, :3]], tau, t) |
| |
|
| | predicted_It = self.synthesis(x_t, warp0to1, warp1to0, t) |
| | return predicted_It |
| |
|