| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | from timm.models.layers import trunc_normal_ as __call_trunc_normal_ |
| |
|
| | from vlmo.torchscale.model.BEiT3 import BEiT3 |
| | from vlmo.torchscale.architecture.config import EncoderConfig |
| |
|
| |
|
| | def trunc_normal_(tensor, mean=0.0, std=1.0): |
| | __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) |
| |
|
| |
|
| | def _get_base_config( |
| | img_size=224, |
| | patch_size=16, |
| | drop_path_rate=0, |
| | checkpoint_activations=None, |
| | mlp_ratio=4, |
| | vocab_size=64010, |
| | encoder_layers=12, |
| | encoder_embed_dim=768, |
| | encoder_attention_heads=12, |
| | share_layer=False, |
| | share_attn=False, |
| | deepnorm=False, |
| | mask_ratio=0, |
| | max_text_len=52, |
| | one_attn=False, |
| | **kwargs |
| | ): |
| | return EncoderConfig( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | vocab_size=vocab_size, |
| | multiway=True, |
| | layernorm_embedding=False, |
| | normalize_output=True, |
| | no_output_layer=True, |
| | drop_path_rate=drop_path_rate, |
| | encoder_embed_dim=encoder_embed_dim, |
| | encoder_attention_heads=encoder_attention_heads, |
| | encoder_layers=encoder_layers, |
| | encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), |
| | checkpoint_activations=checkpoint_activations, |
| | share_layer=share_layer, |
| | share_attn=share_attn, |
| | deepnorm=deepnorm, |
| | mask_ratio=mask_ratio, |
| | max_text_len=max_text_len, |
| | one_attn=one_attn, |
| | ) |
| |
|
| |
|
| | def _get_large_config( |
| | img_size=224, |
| | patch_size=16, |
| | drop_path_rate=0, |
| | checkpoint_activations=None, |
| | mlp_ratio=4, |
| | vocab_size=64010, |
| | encoder_layers=24, |
| | encoder_embed_dim=1024, |
| | encoder_attention_heads=16, |
| | share_layer=False, |
| | share_attn=False, |
| | deepnorm=False, |
| | mask_ratio=0, |
| | max_text_len=52, |
| | one_attn=False, |
| | **kwargs |
| | ): |
| | return EncoderConfig( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | vocab_size=vocab_size, |
| | multiway=True, |
| | layernorm_embedding=False, |
| | normalize_output=True, |
| | no_output_layer=True, |
| | drop_path_rate=drop_path_rate, |
| | encoder_embed_dim=encoder_embed_dim, |
| | encoder_attention_heads=encoder_attention_heads, |
| | encoder_layers=encoder_layers, |
| | encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), |
| | checkpoint_activations=checkpoint_activations, |
| | share_layer=share_layer, |
| | share_attn=share_attn, |
| | deepnorm=deepnorm, |
| | mask_ratio=mask_ratio, |
| | max_text_len=max_text_len, |
| | one_attn=one_attn, |
| | ) |
| |
|
| |
|
| | def _get_huge_config( |
| | img_size=224, |
| | patch_size=16, |
| | drop_path_rate=0, |
| | checkpoint_activations=None, |
| | mlp_ratio=4, |
| | vocab_size=30522, |
| | encoder_layers=32, |
| | encoder_embed_dim=4096, |
| | encoder_attention_heads=32, |
| | share_layer=False, |
| | share_attn=False, |
| | deepnorm=False, |
| | mask_ratio=0, |
| | max_text_len=52, |
| | one_attn=False, |
| | **kwargs |
| | ): |
| | return EncoderConfig( |
| | img_size=img_size, |
| | patch_size=patch_size, |
| | vocab_size=vocab_size, |
| | multiway=True, |
| | layernorm_embedding=False, |
| | normalize_output=True, |
| | no_output_layer=True, |
| | drop_path_rate=drop_path_rate, |
| | encoder_embed_dim=encoder_embed_dim, |
| | encoder_attention_heads=encoder_attention_heads, |
| | encoder_layers=encoder_layers, |
| | encoder_ffn_embed_dim=int(encoder_embed_dim * mlp_ratio), |
| | checkpoint_activations=checkpoint_activations, |
| | share_layer=share_layer, |
| | share_attn=share_attn, |
| | deepnorm=deepnorm, |
| | mask_ratio=mask_ratio, |
| | max_text_len=max_text_len, |
| | one_attn=one_attn, |
| | ) |
| |
|
| |
|
| | class BEiT3Wrapper(nn.Module): |
| | def __init__(self, args, **kwargs): |
| | super().__init__() |
| | self.args = args |
| | self.beit3 = BEiT3(args) |
| | self.apply(self._init_weights) |
| |
|
| | def fix_init_weight(self): |
| | def rescale(param, layer_id): |
| | param.div_(math.sqrt(2.0 * layer_id)) |
| |
|
| | for layer_id, layer in enumerate(self.blocks): |
| | rescale(layer.attn.proj.weight.data, layer_id + 1) |
| | rescale(layer.mlp.fc2.weight.data, layer_id + 1) |
| |
|
| | def get_num_layers(self): |
| | return self.beit3.encoder.num_layers |
| |
|
| | @torch.jit.ignore |
| | def no_weight_decay(self): |
| | return { |
| | "pos_embed", |
| | "cls_token", |
| | "beit3.encoder.embed_positions.A.weight", |
| | "beit3.vision_embed.cls_token", |
| | "logit_scale", |
| | } |
| |
|
| | def _init_weights(self, m): |
| | if isinstance(m, nn.Linear): |
| | trunc_normal_(m.weight, std=0.02) |
| | if isinstance(m, nn.Linear) and m.bias is not None: |
| | nn.init.constant_(m.bias, 0) |
| | elif isinstance(m, nn.LayerNorm): |
| | nn.init.constant_(m.bias, 0) |
| | nn.init.constant_(m.weight, 1.0) |
| |
|