| | from typing import Optional |
| |
|
| | import torch |
| | from torch import Tensor |
| | from torch.nn import Linear, Module |
| | from transformers import PreTrainedModel |
| |
|
| | from .encoder import MarlinEncoder |
| | from .decoder import MarlinDecoder |
| |
|
| | from .config import MarlinConfig |
| |
|
| |
|
| | class Marlin(Module): |
| | def __init__( |
| | self, |
| | img_size: int, |
| | patch_size: int, |
| | n_frames: int, |
| | encoder_embed_dim: int, |
| | encoder_depth: int, |
| | encoder_num_heads: int, |
| | decoder_embed_dim: int, |
| | decoder_depth: int, |
| | decoder_num_heads: int, |
| | mlp_ratio: float, |
| | qkv_bias: bool, |
| | qk_scale: Optional[float], |
| | drop_rate: float, |
| | attn_drop_rate: float, |
| | norm_layer: str, |
| | init_values: float, |
| | tubelet_size: int, |
| | as_feature_extractor: bool = True, |
| | ): |
| | super().__init__() |
| | self.encoder = MarlinEncoder( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | n_frames=n_frames, |
| | embed_dim=encoder_embed_dim, |
| | depth=encoder_depth, |
| | num_heads=encoder_num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | drop_rate=drop_rate, |
| | attn_drop_rate=attn_drop_rate, |
| | norm_layer=norm_layer, |
| | init_values=init_values, |
| | tubelet_size=tubelet_size, |
| | ) |
| | self.as_feature_extractor = as_feature_extractor |
| | self.clip_frames = n_frames |
| | if as_feature_extractor: |
| | self.enc_dec_proj = None |
| | self.decoder = None |
| | else: |
| | self.decoder = MarlinDecoder( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | embed_dim=decoder_embed_dim, |
| | depth=decoder_depth, |
| | num_heads=decoder_num_heads, |
| | mlp_ratio=mlp_ratio, |
| | qkv_bias=qkv_bias, |
| | qk_scale=qk_scale, |
| | drop_rate=drop_rate, |
| | attn_drop_rate=attn_drop_rate, |
| | norm_layer=norm_layer, |
| | init_values=init_values, |
| | tubelet_size=tubelet_size, |
| | ) |
| |
|
| | self.enc_dec_proj = Linear(encoder_embed_dim, decoder_embed_dim, bias=False) |
| |
|
| | def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: |
| | if self.as_feature_extractor: |
| | raise RuntimeError( |
| | "For feature extraction, please use `extract_features` or `extract_video`." |
| | ) |
| | else: |
| | assert mask is not None |
| | x = self.encoder(x, mask) |
| | x = self.enc_dec_proj(x) |
| | x = self.decoder(x, mask) |
| | return x |
| |
|
| | @property |
| | def device(self): |
| | return self.encoder.norm.weight.device |
| |
|
| | def extract_features(self, x: Tensor, keep_seq: bool = True): |
| | """Extract features for one video clip (v)""" |
| | if self.training: |
| | return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) |
| | else: |
| | with torch.no_grad(): |
| | return self.encoder.extract_features(x, seq_mean_pool=not keep_seq) |
| |
|
| |
|
| | class MarlinModel(PreTrainedModel): |
| | config_class = MarlinConfig |
| |
|
| | def __init__(self, config: MarlinConfig): |
| | super().__init__(config) |
| | self.config = config |
| | self.marlin = Marlin( |
| | img_size=config.img_size, |
| | patch_size=config.patch_size, |
| | n_frames=config.n_frames, |
| | encoder_embed_dim=config.encoder_embed_dim, |
| | encoder_depth=config.encoder_depth, |
| | encoder_num_heads=config.encoder_num_heads, |
| | decoder_embed_dim=config.decoder_embed_dim, |
| | decoder_depth=config.decoder_depth, |
| | decoder_num_heads=config.decoder_num_heads, |
| | mlp_ratio=config.mlp_ratio, |
| | qkv_bias=config.qkv_bias, |
| | qk_scale=config.qk_scale, |
| | drop_rate=config.drop_rate, |
| | attn_drop_rate=config.attn_drop_rate, |
| | norm_layer=config.norm_layer, |
| | init_values=config.init_values, |
| | tubelet_size=config.tubelet_size, |
| | ) |
| |
|
| | def forward(self, x: Tensor, keep_seq: bool = True): |
| | return self.marlin.extract_features(x, keep_seq=keep_seq) |
| |
|