| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Callable, List, Optional |
| import torch |
| from einops import rearrange |
| from torch import nn |
|
|
| from common.cache import Cache |
| from common.distributed.ops import slice_inputs |
|
|
| |
| ada_layer_type = Callable[[int, int], nn.Module] |
|
|
|
|
| def get_ada_layer(ada_layer: str) -> ada_layer_type: |
| if ada_layer == "single": |
| return AdaSingle |
| raise NotImplementedError(f"{ada_layer} is not supported") |
|
|
|
|
| def expand_dims(x: torch.Tensor, dim: int, ndim: int): |
| """ |
| Expand tensor "x" to "ndim" by adding empty dims at "dim". |
| Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). |
| """ |
| shape = x.shape |
| shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] |
| return x.reshape(shape) |
|
|
|
|
| class AdaSingle(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| emb_dim: int, |
| layers: List[str], |
| ): |
| assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" |
| super().__init__() |
| self.dim = dim |
| self.emb_dim = emb_dim |
| self.layers = layers |
| for l in layers: |
| self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) |
| self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) |
| self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) |
|
|
| def forward( |
| self, |
| hid: torch.FloatTensor, |
| emb: torch.FloatTensor, |
| layer: str, |
| mode: str, |
| cache: Cache = Cache(disable=True), |
| branch_tag: str = "", |
| hid_len: Optional[torch.LongTensor] = None, |
| ) -> torch.FloatTensor: |
| idx = self.layers.index(layer) |
| emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] |
| emb = expand_dims(emb, 1, hid.ndim + 1) |
|
|
| if hid_len is not None: |
| emb = cache( |
| f"emb_repeat_{idx}_{branch_tag}", |
| lambda: slice_inputs( |
| torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), |
| dim=0, |
| ), |
| ) |
|
|
| shiftA, scaleA, gateA = emb.unbind(-1) |
| shiftB, scaleB, gateB = ( |
| getattr(self, f"{layer}_shift"), |
| getattr(self, f"{layer}_scale"), |
| getattr(self, f"{layer}_gate"), |
| ) |
|
|
| if mode == "in": |
| return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) |
| if mode == "out": |
| return hid.mul_(gateA + gateB) |
| raise NotImplementedError |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" |