| | import logging |
| | import json |
| | import torch |
| | from torch import nn |
| | from .config import InternVideo2Config, EasyDict |
| | from .internvideo2 import pretrain_internvideo2_1b_patch14_224, pretrain_internvideo2_6b_patch14_224 |
| | from transformers.utils import logging |
| | import warnings |
| |
|
| | warnings.filterwarnings("ignore") |
| |
|
| | class InternVideo2_Stage2(nn.Module): |
| | """docstring for InternVideo2_Stage2""" |
| |
|
| | def __init__(self, config, is_pretrain=True): |
| | super(InternVideo2_Stage2, self).__init__() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.config = config |
| |
|
| | self.is_pretrain = is_pretrain |
| | self.vision_width = config.model.vision_encoder.clip_embed_dim |
| | |
| | self.embed_dim = config.model.embed_dim |
| |
|
| | |
| | self.vision_encoder = self.build_vision_encoder() |
| | if config.model.get("freeze_vision", False): |
| | self.freeze_vision() |
| |
|
| | self.vision_proj = nn.Linear(self.vision_width, self.embed_dim) |
| |
|
| | self.temp = nn.parameter.Parameter(torch.ones([]) * config.model.temp) |
| | self.uta_image_only = config.criterion.get('uta_image_only', False) |
| |
|
| | |
| |
|
| | def freeze_vision(self): |
| | """freeze vision encoder""" |
| | for p in self.vision_encoder.parameters(): |
| | p.requires_grad = False |
| | |
| | def no_weight_decay(self): |
| | ret = {"temp"} |
| | ret.update( |
| | {"vision_encoder." + k for k in self.vision_encoder.no_weight_decay()} |
| | ) |
| | |
| | |
| | |
| |
|
| | return ret |
| |
|
| | @property |
| | def dtype(self): |
| | return self.vision_encoder.patch_embed.proj.weight.dtype |
| |
|
| | def encode_vision(self, image): |
| | """encode image / videos as features. |
| | |
| | Args: |
| | image (torch.Tensor): The input images. Shape(B, N, C, H, W) |
| | test (bool): Whether testing. |
| | |
| | Returns: tuple. |
| | - vision_embeds (torch.Tensor): The output features. Shape: [B,N,C]. |
| | - pooled_vision_embeds (torch.Tensor): The pooled output features. Shape: [B,1,C]. |
| | - student_output (torch.Tensor): The features of alignment. Shape: [K,B,N,C]. |
| | - clip_output (torch.Tensor): The features of clip. Shape: [K,B,N,C]. |
| | |
| | """ |
| | T = image.shape[1] |
| | use_image = True if T == 1 else False |
| | image = image.permute(0, 2, 1, 3, 4) |
| | |
| | |
| | vision_embeds, pooled_vision_embeds, _, _ = self.vision_encoder( |
| | image, None, use_image) |
| | return vision_embeds, pooled_vision_embeds |
| |
|
| | def build_vision_encoder(self): |
| | """build vision encoder |
| | Returns: (vision_encoder, clip_teacher). Each is a `nn.Module`. |
| | |
| | """ |
| | encoder_name = self.config.model.vision_encoder.name |
| | |
| | if encoder_name == 'pretrain_internvideo2_1b_patch14_224': |
| | vision_encoder = pretrain_internvideo2_1b_patch14_224(self.config.model) |
| | elif encoder_name == 'pretrain_internvideo2_6b_patch14_224': |
| | vision_encoder = pretrain_internvideo2_6b_patch14_224(self.config.model) |
| | else: |
| | raise ValueError(f"Not implemented: {encoder_name}") |
| | return vision_encoder |
| |
|
| |
|