| | from .base import BaseModel |
| | from .schema import DINOConfiguration |
| | import logging |
| | import torch |
| | import torch.nn as nn |
| |
|
| | import sys |
| | import re |
| | import os |
| |
|
| | from .dinov2.eval.depth.ops.wrappers import resize |
| | from .dinov2.hub.backbones import dinov2_vitb14_reg |
| |
|
| | module_dir = os.path.dirname(os.path.abspath(__file__)) |
| | sys.path.append(module_dir) |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class FeatureExtractor(BaseModel): |
| | mean = [0.485, 0.456, 0.406] |
| | std = [0.229, 0.224, 0.225] |
| |
|
| | def build_encoder(self, conf: DINOConfiguration): |
| | BACKBONE_SIZE = "small" |
| | backbone_archs = { |
| | "small": "vits14", |
| | "base": "vitb14", |
| | "large": "vitl14", |
| | "giant": "vitg14", |
| | } |
| | backbone_arch = backbone_archs[BACKBONE_SIZE] |
| | self.crop_size = int(re.search(r"\d+", backbone_arch).group()) |
| | backbone_name = f"dinov2_{backbone_arch}" |
| |
|
| | self.backbone_model = dinov2_vitb14_reg( |
| | pretrained=conf.pretrained, drop_path_rate=0.1) |
| |
|
| | if conf.frozen: |
| | for param in self.backbone_model.patch_embed.parameters(): |
| | param.requires_grad = False |
| |
|
| | for i in range(0, 10): |
| | for param in self.backbone_model.blocks[i].parameters(): |
| | param.requires_grad = False |
| | self.backbone_model.blocks[i].drop_path1 = nn.Identity() |
| | self.backbone_model.blocks[i].drop_path2 = nn.Identity() |
| |
|
| | self.feat_projection = torch.nn.Conv2d( |
| | 768, conf.output_dim, kernel_size=1) |
| |
|
| | return self.backbone_model |
| |
|
| | def _init(self, conf: DINOConfiguration): |
| | |
| | self.register_buffer("mean_", torch.tensor( |
| | self.mean), persistent=False) |
| | self.register_buffer("std_", torch.tensor(self.std), persistent=False) |
| |
|
| | self.build_encoder(conf) |
| |
|
| | def _forward(self, data): |
| | _, _, h, w = data["image"].shape |
| |
|
| | h_num_patches = h // self.crop_size |
| | w_num_patches = w // self.crop_size |
| |
|
| | h_dino = h_num_patches * self.crop_size |
| | w_dino = w_num_patches * self.crop_size |
| |
|
| | image = resize(data["image"], (h_dino, w_dino)) |
| |
|
| | image = (image - self.mean_[:, None, None]) / self.std_[:, None, None] |
| |
|
| | output = self.backbone_model.forward_features( |
| | image)['x_norm_patchtokens'] |
| | output = output.reshape(-1, h_num_patches, |
| | w_num_patches, output.shape[-1]) |
| | output = output.permute(0, 3, 1, 2) |
| | output = self.feat_projection(output) |
| |
|
| | camera = data['camera'].to(data["image"].device, non_blocking=True) |
| | camera = camera.scale(output.shape[-1] / data["image"].shape[-1]) |
| |
|
| | return output, camera |
| |
|