| | |
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | class VisionLanguageEmbedding(nn.Module): |
| | def __init__(self, text_embed, vision_embed): |
| | super().__init__() |
| | self.text_embed = text_embed |
| | self.vision_embed = vision_embed |
| |
|
| | def forward(self, textual_tokens, visual_tokens, **kwargs): |
| | if textual_tokens is None: |
| | return self.vision_embed(visual_tokens) |
| |
|
| | if visual_tokens is None: |
| | return self.text_embed(textual_tokens) |
| |
|
| | x1 = self.vision_embed(visual_tokens) |
| | x2 = self.text_embed(textual_tokens) |
| |
|
| | return torch.cat([x1, x2], dim=1) |
| |
|
| |
|
| | class VisionEmbedding(nn.Module): |
| | """Image to Patch Embedding""" |
| |
|
| | def __init__( |
| | self, |
| | img_size=224, |
| | patch_size=16, |
| | in_chans=3, |
| | embed_dim=768, |
| | contain_mask_token=False, |
| | prepend_cls_token=False, |
| | ): |
| | super().__init__() |
| | img_size = (img_size, img_size) |
| | patch_size = (patch_size, patch_size) |
| | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
| | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) |
| | self.img_size = img_size |
| | self.patch_size = patch_size |
| | self.num_patches = num_patches |
| |
|
| | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
| |
|
| | if contain_mask_token: |
| | self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | else: |
| | self.mask_token = None |
| |
|
| | if prepend_cls_token: |
| | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| | else: |
| | self.cls_token = None |
| |
|
| | def num_position_embeddings(self): |
| | if self.cls_token is None: |
| | return self.num_patches |
| | else: |
| | return self.num_patches + 1 |
| |
|
| | def forward(self, x, masked_position=None, **kwargs): |
| | B, C, H, W = x.shape |
| | x = self.proj(x).flatten(2).transpose(1, 2) |
| |
|
| | batch_size, seq_len, _ = x.size() |
| |
|
| | if masked_position is not None: |
| | assert self.mask_token is not None |
| | mask_token = self.mask_token.expand(batch_size, seq_len, -1) |
| | w = masked_position.unsqueeze(-1).type_as(mask_token) |
| | x = x * (1 - w) + mask_token * w |
| |
|
| | if self.cls_token is not None: |
| | cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| | x = torch.cat((cls_tokens, x), dim=1) |
| |
|
| | return x |
| |
|
| |
|
| | class TextEmbedding(nn.Embedding): |
| | def reset_parameters(self): |
| | nn.init.normal_(self.weight, mean=0, std=self.embedding_dim**-0.5) |
| | self._fill_padding_idx_with_zero() |
| |
|
| |
|
| | class PositionalEmbedding(nn.Embedding): |
| | def forward( |
| | self, |
| | x, |
| | positions=None, |
| | **kwargs, |
| | ): |
| | if positions is None: |
| | |
| | positions = torch.arange(2, x.size(1) + 2, device=x.device).long().unsqueeze(0) |
| | return F.embedding( |
| | positions, |
| | self.weight, |
| | self.padding_idx, |
| | self.max_norm, |
| | self.norm_type, |
| | self.scale_grad_by_freq, |
| | self.sparse, |
| | ) |
| |
|