| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from torch import nn |
| | import torch.nn.functional as F |
| | import torch.distributed.nn |
| | import torch.distributed as dist |
| | from torch.nn.init import trunc_normal_ |
| | from torch.nn.utils import weight_norm |
| | import models_dinov2 |
| | from models_IB import IF_Module |
| | import math |
| |
|
| |
|
| | class MetaArch(nn.Module): |
| |
|
| | def __init__(self, cfg): |
| | super().__init__() |
| | self.cfg = cfg |
| |
|
| | student_model_dict = dict() |
| | teacher_model_dict = dict() |
| |
|
| | import_student = getattr(models_dinov2, cfg.target_model) |
| | student = import_student(img_size=224, |
| | patch_size=cfg.patch_size, |
| | init_values=1.0, |
| | ffn_layer='mlp', |
| | block_chunks=0, |
| | num_register_tokens=0, |
| | interpolate_antialias=False, |
| | interpolate_offset=0.1) |
| |
|
| | embed_dim = student.embed_dim |
| | |
| | if cfg.teacher_model == 'vit_base': |
| | teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc') |
| | elif cfg.teacher_model == 'vit_small': |
| | teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc') |
| | elif cfg.teacher_model == 'vit_large': |
| | teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc') |
| | elif cfg.teacher_model == 'vit_giant': |
| | teacher_backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc') |
| | teacher_backbone.eval() |
| |
|
| | student_model_dict['backbone'] = student |
| | teacher_model_dict['backbone'] = teacher_backbone.backbone |
| | |
| | self.embed_dim = embed_dim |
| |
|
| | |
| | self.total_n_global_crops = cfg.batch_size |
| |
|
| | self.student = nn.ModuleDict(student_model_dict) |
| | self.teacher = nn.ModuleDict(teacher_model_dict) |
| |
|
| | teacher_embed_dim = teacher_backbone.backbone.embed_dim |
| | self.ibot_head = nn.Sequential( |
| | nn.LayerNorm(embed_dim), |
| | nn.Linear(embed_dim, teacher_embed_dim)) |
| | |
| | self.token_head = nn.Sequential( |
| | nn.LayerNorm(embed_dim), |
| | nn.Linear(embed_dim, teacher_embed_dim)) |
| |
|
| | self.fea_head = nn.Sequential( |
| | nn.LayerNorm(embed_dim), |
| | nn.Linear(embed_dim, teacher_embed_dim)) |
| |
|
| | self.soft_criterion = torch.nn.MSELoss() |
| |
|
| | self.info_bottleneck = IF_Module(embed_dim=embed_dim, num_heads=12, mlp_ratio=4, depth=4) |
| |
|
| | for param in self.teacher.backbone.parameters(): |
| | param.requires_grad = False |
| | |
| | def cal_bpp(self, image, unmask_likelihood, mask_likelihood): |
| | b, _, h, w = image.size() |
| | num_pixels = b * h * w |
| | log_unmask_likelihoods = torch.log(unmask_likelihood) |
| | log_mask_likelihoods = torch.log(mask_likelihood) |
| | bpp = (log_unmask_likelihoods.sum() + log_mask_likelihoods.sum()) / (-math.log(2) * num_pixels * 1.5) |
| | return bpp |
| |
|
| | def forward(self, inputs): |
| | global_crops = inputs["collated_global_crops"] |
| | |
| | masks = inputs["collated_masks"] |
| | mask_indices_list = inputs["mask_indices_list"] |
| | n_masked_patches = mask_indices_list.shape[0] |
| | upperbound = inputs["upperbound"] |
| |
|
| | n_global_crops = 1 |
| |
|
| | |
| | |
| | def compute_teacher_output(): |
| | with torch.no_grad(): |
| | teacher_backbone_output_dict = self.teacher.backbone(global_crops, is_training=True) |
| | teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] |
| | teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] |
| | _dim = teacher_patch_tokens.shape[-1] |
| |
|
| | |
| | buffer_tensor_teacher = teacher_patch_tokens.new_zeros(upperbound, _dim) |
| | torch.index_select( |
| | teacher_patch_tokens.flatten(0, 1), |
| | dim=0, |
| | index=mask_indices_list, |
| | out=buffer_tensor_teacher[:n_masked_patches], |
| | ) |
| | teacher_patch_tokens_masked = buffer_tensor_teacher[:n_masked_patches] |
| |
|
| | return teacher_cls_tokens, teacher_patch_tokens, teacher_patch_tokens_masked |
| |
|
| | |
| | ( |
| | teacher_cls_tokens, |
| | teacher_patch_tokens, |
| | teacher_patch_tokens_masked |
| | ) = compute_teacher_output() |
| | |
| | cur_masks = masks if self.cfg.mask_probability > 0 else None |
| |
|
| | student_backbone_output_dict, student_backbone_output_dict_unmask = self.student.backbone( |
| | [global_crops, global_crops], masks=[cur_masks, None], is_training=True |
| | ) |
| |
|
| | student_cls_token_unmask = student_backbone_output_dict_unmask["x_norm_clstoken"] |
| | student_patch_tokens_unmask = student_backbone_output_dict_unmask["x_norm_patchtokens"] |
| | student_patch_tokens = student_backbone_output_dict["x_norm_patchtokens"] |
| |
|
| | |
| | student_patch_tokens_unmask, unmask_likelihood = self.info_bottleneck(student_patch_tokens_unmask, is_training=True) |
| | student_patch_tokens, mask_likelihood = self.info_bottleneck(student_patch_tokens, is_training=True) |
| | bpp = self.cal_bpp(global_crops, unmask_likelihood, mask_likelihood) |
| |
|
| | |
| | _dim = student_patch_tokens.shape[-1] |
| | |
| | buffer_tensor_student = student_patch_tokens.new_zeros(upperbound, _dim) |
| | buffer_tensor_student[:n_masked_patches].copy_( |
| | torch.index_select(student_patch_tokens.flatten(0, 1), |
| | dim=0, |
| | index=mask_indices_list) |
| | ) |
| |
|
| | |
| | student_patch_tokens_unmask = self.fea_head(student_patch_tokens_unmask) |
| | |
| | student_cls_token_unmask = self.token_head(student_cls_token_unmask) |
| | |
| | tokens_after_head = self.ibot_head(buffer_tensor_student) |
| | student_patch_tokens_masked = tokens_after_head[:n_masked_patches] |
| |
|
| | |
| | distillation_loss_token = self.soft_criterion(student_cls_token_unmask, teacher_cls_tokens) |
| |
|
| | |
| | student_whole_fea = torch.cat((student_cls_token_unmask.unsqueeze(1),student_patch_tokens_unmask),dim=1) |
| | teacher_whole_fea = torch.cat((teacher_cls_tokens.unsqueeze(1),teacher_patch_tokens),dim=1) |
| | distillation_loss_fea = self.soft_criterion(student_whole_fea, teacher_whole_fea) |
| |
|
| | |
| | patch_loss = self.soft_criterion(student_patch_tokens_masked, teacher_patch_tokens_masked) |
| | |
| | |
| | token_loss = self.cfg.lambda_token * distillation_loss_token |
| | fea_loss = self.cfg.lambda_fea * distillation_loss_fea |
| | patch_loss_weighted = self.cfg.lambda_patch * patch_loss |
| | |
| | |
| |
|
| | |
| | total_loss = patch_loss_weighted + fea_loss + token_loss + 0.48 * bpp |
| | |
| | task_loss = patch_loss + distillation_loss_fea + distillation_loss_token |
| |
|
| | |
| | loss_dict = {"bpp_loss": bpp, |
| | "patch_loss": patch_loss, |
| | "fea_loss": distillation_loss_fea, |
| | "token_loss": token_loss, |
| | "loss": total_loss, |
| | "task_loss": task_loss, |
| | } |
| | |
| | return loss_dict |