| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Tuple, Union |
| import torch |
| from einops import rearrange |
| from torch import nn |
| from torch.nn.modules.utils import _triple |
|
|
| from common.cache import Cache |
| from common.distributed.ops import gather_outputs, slice_inputs |
|
|
| from . import na |
|
|
|
|
| class PatchIn(nn.Module): |
| def __init__( |
| self, |
| in_channels: int, |
| patch_size: Union[int, Tuple[int, int, int]], |
| dim: int, |
| ): |
| super().__init__() |
| t, h, w = _triple(patch_size) |
| self.patch_size = t, h, w |
| self.proj = nn.Linear(in_channels * t * h * w, dim) |
|
|
| def forward( |
| self, |
| vid: torch.Tensor, |
| ) -> torch.Tensor: |
| t, h, w = self.patch_size |
| vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) |
| vid = self.proj(vid) |
| return vid |
|
|
|
|
| class PatchOut(nn.Module): |
| def __init__( |
| self, |
| out_channels: int, |
| patch_size: Union[int, Tuple[int, int, int]], |
| dim: int, |
| ): |
| super().__init__() |
| t, h, w = _triple(patch_size) |
| self.patch_size = t, h, w |
| self.proj = nn.Linear(dim, out_channels * t * h * w) |
|
|
| def forward( |
| self, |
| vid: torch.Tensor, |
| ) -> torch.Tensor: |
| t, h, w = self.patch_size |
| vid = self.proj(vid) |
| vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) |
| return vid |
|
|
|
|
| class NaPatchIn(PatchIn): |
| def forward( |
| self, |
| vid: torch.Tensor, |
| vid_shape: torch.LongTensor, |
| ) -> torch.Tensor: |
| t, h, w = self.patch_size |
| if not (t == h == w == 1): |
| vid, vid_shape = na.rearrange( |
| vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w |
| ) |
| |
| vid = slice_inputs(vid, dim=0) |
| vid = self.proj(vid) |
| return vid, vid_shape |
|
|
|
|
| class NaPatchOut(PatchOut): |
| def forward( |
| self, |
| vid: torch.FloatTensor, |
| vid_shape: torch.LongTensor, |
| cache: Cache = Cache(disable=True), |
| ) -> Tuple[ |
| torch.FloatTensor, |
| torch.LongTensor, |
| ]: |
| t, h, w = self.patch_size |
| vid = self.proj(vid) |
| |
| vid = gather_outputs( |
| vid, |
| gather_dim=0, |
| padding_dim=0, |
| unpad_shape=vid_shape, |
| cache=cache.namespace("vid"), |
| ) |
| if not (t == h == w == 1): |
| vid, vid_shape = na.rearrange( |
| vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w |
| ) |
| return vid, vid_shape |
|
|