| | import torch |
| | import torchvision |
| | import torch.nn as nn |
| |
|
| |
|
| | class resnet18(nn.Module): |
| | def __init__( |
| | self, |
| | pretrained: bool = True, |
| | output_dim: int = 512, |
| | unit_norm: bool = False, |
| | ): |
| | super().__init__() |
| | resnet = torchvision.models.resnet18(pretrained=pretrained) |
| | self.resnet = nn.Sequential(*list(resnet.children())[:-1]) |
| | self.flatten = nn.Flatten() |
| | self.pretrained = pretrained |
| | self.normalize = torchvision.transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
| | ) |
| | self.unit_norm = unit_norm |
| |
|
| | def forward(self, x): |
| | dims = len(x.shape) |
| | orig_shape = x.shape |
| | if dims == 3: |
| | x = x.unsqueeze(0) |
| | elif dims > 4: |
| | |
| | x = x.reshape(-1, *orig_shape[-3:]) |
| | x = self.normalize(x) |
| | out = self.resnet(x) |
| | out = self.flatten(out) |
| | if self.unit_norm: |
| | out = torch.nn.functional.normalize(out, p=2, dim=-1) |
| | if dims == 3: |
| | out = out.squeeze(0) |
| | elif dims > 4: |
| | out = out.reshape(*orig_shape[:-3], -1) |
| | return out |
| |
|