| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Union, Tuple |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | norm_t = Union[Tuple[float, float, float], torch.Tensor] |
| |
|
| | class InputConditioner(nn.Module): |
| | def __init__(self, |
| | input_scale: float, |
| | norm_mean: norm_t, |
| | norm_std: norm_t, |
| | dtype: torch.dtype = None, |
| | ): |
| | super().__init__() |
| |
|
| | self.dtype = dtype |
| |
|
| | self.register_buffer("norm_mean", _to_tensor(norm_mean) / input_scale) |
| | self.register_buffer("norm_std", _to_tensor(norm_std) / input_scale) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | y = (x - self.norm_mean) / self.norm_std |
| | if self.dtype is not None: |
| | y = y.to(self.dtype) |
| | return y |
| |
|
| |
|
| | def get_default_conditioner(): |
| | from timm.data.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
| |
|
| | return InputConditioner( |
| | input_scale=1.0, |
| | norm_mean=OPENAI_CLIP_MEAN, |
| | norm_std=OPENAI_CLIP_STD, |
| | ) |
| |
|
| |
|
| | def _to_tensor(v: norm_t): |
| | return torch.as_tensor(v, dtype=torch.float32).view(-1, 1, 1) |
| |
|