| | import torch |
| | from torch import nn |
| | from torch.utils.data import Dataset, DataLoader |
| | from torchvision.models import resnet50 |
| | from torchvision import transforms |
| | from PIL import Image |
| | import matplotlib.pyplot as plt |
| | from transformers import BertTokenizer, BertModel |
| | import os |
| | import json |
| | import numpy as np |
| | from collections import defaultdict |
| | import random |
| | from tqdm.notebook import tqdm |
| | from torchvision import models |
| | from torch.nn.utils.rnn import pad_sequence |
| | import matplotlib.patches as patches |
| |
|
| | import math |
| | import time |
| | import os |
| | from PIL import Image |
| | import requests |
| | import nltk |
| |
|
| | import os |
| | import cv2 |
| | import colorsys |
| | from numpy import asarray |
| | import math |
| |
|
| |
|
| | from transformers import GPT2LMHeadModel, GPT2Config |
| |
|
| | from scipy.optimize import linear_sum_assignment |
| |
|
| | import sys |
| | sys.path.append("../src") |
| |
|
| | from utils import * |
| |
|
| | NUM_QUERIES = 40 |
| | feature_size = 256 |
| | token_size = 256 |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | |
| |
|
| | """ |
| | Various positional encodings for the transformer. |
| | """ |
| |
|
| | class PositionEmbeddingSine(nn.Module): |
| | """ |
| | This is a more standard version of the position embedding, very similar to the one |
| | used by the Attention is all you need paper, generalized to work on images. |
| | """ |
| | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
| | super().__init__() |
| | self.num_pos_feats = num_pos_feats |
| | self.temperature = temperature |
| | self.normalize = normalize |
| | if scale is not None and normalize is False: |
| | raise ValueError("normalize should be True if scale is passed") |
| | if scale is None: |
| | scale = 2 * math.pi |
| | self.scale = scale |
| |
|
| | def forward(self, tensor_list: NestedTensor): |
| | x = tensor_list.tensors |
| | mask = tensor_list.mask |
| | assert mask is not None |
| | not_mask = ~mask |
| | y_embed = not_mask.cumsum(1, dtype=torch.float32) |
| | x_embed = not_mask.cumsum(2, dtype=torch.float32) |
| | if self.normalize: |
| | eps = 1e-6 |
| | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
| | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
| |
|
| | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
| | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
| |
|
| | pos_x = x_embed[:, :, :, None] / dim_t |
| | pos_y = y_embed[:, :, :, None] / dim_t |
| | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
| | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
| | return pos |
| |
|
| |
|
| | class PositionEmbeddingLearned(nn.Module): |
| | """ |
| | Absolute pos embedding, learned. |
| | """ |
| | def __init__(self, num_pos_feats=256): |
| | super().__init__() |
| | self.row_embed = nn.Embedding(50, num_pos_feats) |
| | self.col_embed = nn.Embedding(50, num_pos_feats) |
| | self.reset_parameters() |
| |
|
| | def reset_parameters(self): |
| | nn.init.uniform_(self.row_embed.weight) |
| | nn.init.uniform_(self.col_embed.weight) |
| |
|
| | def forward(self, tensor_list: NestedTensor): |
| | x = tensor_list.tensors |
| | h, w = x.shape[-2:] |
| | i = torch.arange(w, device=x.device) |
| | j = torch.arange(h, device=x.device) |
| | x_emb = self.col_embed(i) |
| | y_emb = self.row_embed(j) |
| | pos = torch.cat([ |
| | x_emb.unsqueeze(0).repeat(h, 1, 1), |
| | y_emb.unsqueeze(1).repeat(1, w, 1), |
| | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) |
| | return pos |
| |
|
| |
|
| | def build_position_encoding(args): |
| | N_steps = args.hidden_dim // 2 |
| | if args.position_embedding in ('v2', 'sine'): |
| | |
| | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) |
| | elif args.position_embedding in ('v3', 'learned'): |
| | position_embedding = PositionEmbeddingLearned(N_steps) |
| | else: |
| | raise ValueError(f"not supported {args.position_embedding}") |
| |
|
| | return position_embedding |
| |
|
| | from collections import OrderedDict |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import torchvision |
| | from torch import nn |
| | from torchvision.models._utils import IntermediateLayerGetter |
| | from typing import Dict, List |
| |
|
| |
|
| | class FrozenBatchNorm2d(torch.nn.Module): |
| | """ |
| | BatchNorm2d where the batch statistics and the affine parameters are fixed. |
| | |
| | Copy-paste from torchvision.misc.ops with added eps before rqsrt, |
| | without which any other models than torchvision.models.resnet[18,34,50,101] |
| | produce nans. |
| | """ |
| |
|
| | def __init__(self, n): |
| | super(FrozenBatchNorm2d, self).__init__() |
| | self.register_buffer("weight", torch.ones(n)) |
| | self.register_buffer("bias", torch.zeros(n)) |
| | self.register_buffer("running_mean", torch.zeros(n)) |
| | self.register_buffer("running_var", torch.ones(n)) |
| |
|
| | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, |
| | missing_keys, unexpected_keys, error_msgs): |
| | num_batches_tracked_key = prefix + 'num_batches_tracked' |
| | if num_batches_tracked_key in state_dict: |
| | del state_dict[num_batches_tracked_key] |
| |
|
| | super(FrozenBatchNorm2d, self)._load_from_state_dict( |
| | state_dict, prefix, local_metadata, strict, |
| | missing_keys, unexpected_keys, error_msgs) |
| |
|
| | def forward(self, x): |
| | |
| | |
| | w = self.weight.reshape(1, -1, 1, 1) |
| | b = self.bias.reshape(1, -1, 1, 1) |
| | rv = self.running_var.reshape(1, -1, 1, 1) |
| | rm = self.running_mean.reshape(1, -1, 1, 1) |
| | eps = 1e-5 |
| | scale = w * (rv + eps).rsqrt() |
| | bias = b - rm * scale |
| | return x * scale + bias |
| |
|
| |
|
| | class BackboneBase(nn.Module): |
| |
|
| | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): |
| | super().__init__() |
| | for name, parameter in backbone.named_parameters(): |
| | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: |
| | parameter.requires_grad_(False) |
| | if return_interm_layers: |
| | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} |
| | else: |
| | return_layers = {'layer4': "0"} |
| | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) |
| | self.num_channels = num_channels |
| |
|
| | def forward(self, tensor_list: NestedTensor): |
| | xs = self.body(tensor_list.tensors) |
| | out: Dict[str, NestedTensor] = {} |
| | for name, x in xs.items(): |
| | m = tensor_list.mask |
| | assert m is not None |
| | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] |
| | out[name] = NestedTensor(x, mask) |
| | return out |
| |
|
| | ''' |
| | The line mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] applies a mask to the output |
| | features from the backbone. The mask is used to indicate which pixels in the image are valid. |
| | |
| | |
| | The mask is a tensor of the same size as the output features. The mask is initialized to all zeros. The m[None].float() |
| | operation expands the mask to be a 1-D tensor of size 1 x H x W. The F.interpolate() |
| | operation then resizes the mask to the same size as the output features. The to(torch.bool) operation converts the |
| | mask to a binary tensor. The [0] operation takes the first element of the tensor, which is the mask for the first output |
| | feature map. |
| | |
| | The mask of a feature extracted from ResNet50 as a backbone is a binary tensor that indicates which pixels in the image |
| | are valid. The pixels that are valid are those that are not padded. The mask is used by the backbone to ignore the padded |
| | pixels when it is extracting features from the image. |
| | |
| | ''' |
| |
|
| | class Backbone(BackboneBase): |
| | """ResNet backbone with frozen BatchNorm.""" |
| | def __init__(self, name: str, |
| | train_backbone: bool, |
| | return_interm_layers: bool, |
| | dilation: bool): |
| | backbone = getattr(torchvision.models, name)( |
| | replace_stride_with_dilation=[False, False, dilation], |
| | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) |
| | |
| | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 |
| | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) |
| |
|
| |
|
| | class Joiner(nn.Sequential): |
| | def __init__(self, backbone, position_embedding): |
| | super().__init__(backbone, position_embedding) |
| |
|
| | def forward(self, tensor_list: NestedTensor): |
| | xs = self[0](tensor_list) |
| | out: List[NestedTensor] = [] |
| | pos = [] |
| | for name, x in xs.items(): |
| | out.append(x) |
| | |
| | pos.append(self[1](x).to(x.tensors.dtype)) |
| |
|
| | return out, pos |
| |
|
| |
|
| | def build_backbone(args): |
| | position_embedding = build_position_encoding(args) |
| | train_backbone = args.lr_backbone > 0 |
| | return_interm_layers = args.masks |
| | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) |
| | model = Joiner(backbone, position_embedding) |
| | model.num_channels = backbone.num_channels |
| | return model |
| |
|
| | def get_sinusoid_encoding_table(n_position, d_hid): |
| | def cal_angle(position, hid_idx): |
| | return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) |
| |
|
| | def get_posi_angle_vec(position): |
| | return [cal_angle(position, hid_j) for hid_j in range(d_hid)] |
| |
|
| | sinusoid_table = np.array([get_posi_angle_vec(pos_i) for pos_i in range(n_position)]) |
| | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
| | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
| | return torch.FloatTensor(sinusoid_table) |
| |
|
| | class PostProcess(nn.Module): |
| | """ This module converts the model's output into the format expected by the coco api""" |
| | @torch.no_grad() |
| | def forward(self, outputs, target_sizes): |
| | """ Perform the computation |
| | Parameters: |
| | outputs: raw outputs of the model |
| | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch |
| | For evaluation, this must be the original image size (before any data augmentation) |
| | For visualization, this should be the image size after data augment, but before padding |
| | """ |
| | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] |
| |
|
| | assert len(out_logits) == len(target_sizes) |
| | assert target_sizes.shape[1] == 2 |
| |
|
| | prob = F.softmax(out_logits, -1) |
| | scores, labels = prob[..., :-1].max(-1) |
| |
|
| | |
| | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) |
| | |
| | img_h, img_w = target_sizes.unbind(1) |
| | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) |
| | boxes = boxes * scale_fct[:, None, :] |
| |
|
| | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] |
| |
|
| | return results |
| |
|
| |
|
| | class MLP(nn.Module): |
| | """ Very simple multi-layer perceptron (also called FFN)""" |
| |
|
| | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
| | super().__init__() |
| | self.num_layers = num_layers |
| | h = [hidden_dim] * (num_layers - 1) |
| | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
| |
|
| | def forward(self, x): |
| | for i, layer in enumerate(self.layers): |
| | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
| | return x |
| |
|
| |
|
| | def build(args): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | num_classes = 20 if args.dataset_file != 'coco' else 91 |
| | if args.dataset_file == "coco_panoptic": |
| | |
| | |
| | num_classes = 250 |
| | device = torch.device(args.device) |
| |
|
| | backbone = build_backbone(args) |
| |
|
| | transformer = build_transformer(args) |
| |
|
| | model = DETR( |
| | backbone, |
| | transformer, |
| | num_classes=num_classes, |
| | num_queries=args.num_queries, |
| | aux_loss=args.aux_loss, |
| | ) |
| | if args.masks: |
| | model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) |
| | matcher = build_matcher(args) |
| | weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} |
| | weight_dict['loss_giou'] = args.giou_loss_coef |
| | if args.masks: |
| | weight_dict["loss_mask"] = args.mask_loss_coef |
| | weight_dict["loss_dice"] = args.dice_loss_coef |
| | |
| | if args.aux_loss: |
| | aux_weight_dict = {} |
| | for i in range(args.dec_layers - 1): |
| | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) |
| | weight_dict.update(aux_weight_dict) |
| |
|
| | losses = ['labels', 'boxes', 'cardinality'] |
| | if args.masks: |
| | losses += ["masks"] |
| | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, |
| | eos_coef=args.eos_coef, losses=losses) |
| | criterion.to(device) |
| | postprocessors = {'bbox': PostProcess()} |
| | if args.masks: |
| | postprocessors['segm'] = PostProcessSegm() |
| | if args.dataset_file == "coco_panoptic": |
| | is_thing_map = {i: i <= 90 for i in range(201)} |
| | postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) |
| |
|
| | return model, criterion, postprocessors |
| |
|
| | class Parameters: |
| | def __init__(self): |
| | self.lr = 1e-4 |
| | self.lr_backbone = 1e-5 |
| | self.batch_size = 2 |
| | self.weight_decay = 1e-4 |
| | self.epochs = 300 |
| | self.lr_drop = 200 |
| | self.clip_max_norm = 0.1 |
| |
|
| | args = Parameters() |
| |
|
| | args.lr=1e-4 |
| | args.lr_backbone=1e-5 |
| | args.batch_size=32 |
| | args.weight_decay=1e-4 |
| | args.epochs=300 |
| | args.lr_drop=200 |
| | args.clip_max_norm=0.1 |
| |
|
| | |
| | args.frozen_weights=False |
| |
|
| | |
| | args.backbone='resnet50' |
| | args.dilation=False |
| | args.position_embedding='sine' |
| |
|
| | |
| | args.enc_layers=6 |
| | args.dec_layers=6 |
| | args.dim_feedforward=2048 |
| | args.hidden_dim=256 |
| | args.dropout=0.1 |
| | args.nheads=8 |
| | args.num_queries=40 |
| | args.pre_norm=True |
| |
|
| | |
| | args.masks=False |
| |
|
| |
|
| | """ |
| | LLMEyeCap Transformer class. |
| | |
| | A DETR (FaceBook) Copy-paste from torch.nn.Transformer with modifications: |
| | * positional encodings are passed in MHattention |
| | * extra LN at the end of encoder is removed |
| | * decoder returns a stack of activations from all decoding layers |
| | |
| | """ |
| | import copy |
| | from typing import Optional, List |
| |
|
| |
|
| | class Transformer(nn.Module): |
| |
|
| | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
| | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False, |
| | return_intermediate_dec=False): |
| | super().__init__() |
| |
|
| | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| | dropout, activation, normalize_before) |
| | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
| |
|
| | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, |
| | dropout, activation, normalize_before) |
| | decoder_norm = nn.LayerNorm(d_model) |
| | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
| | return_intermediate=return_intermediate_dec) |
| |
|
| | self._reset_parameters() |
| |
|
| | self.d_model = d_model |
| | self.nhead = nhead |
| |
|
| | def _reset_parameters(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| |
|
| | def forward(self, src, mask, query_embed, pos_embed): |
| | |
| | bs, c, h, w = src.shape |
| | src = src.flatten(2).permute(2, 0, 1) |
| | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) |
| | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) |
| | mask = mask.flatten(1) |
| |
|
| | tgt = torch.zeros_like(query_embed) |
| | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) |
| | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, |
| | pos=pos_embed, query_pos=query_embed) |
| | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| |
|
| | def __init__(self, encoder_layer, num_layers, norm=None): |
| | super().__init__() |
| | self.layers = _get_clones(encoder_layer, num_layers) |
| | self.num_layers = num_layers |
| | self.norm = norm |
| |
|
| | def forward(self, src, |
| | mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None): |
| | output = src |
| |
|
| | for layer in self.layers: |
| | output = layer(output, src_mask=mask, |
| | src_key_padding_mask=src_key_padding_mask, pos=pos) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| |
|
| | return output |
| |
|
| |
|
| | class TransformerDecoder(nn.Module): |
| |
|
| | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): |
| | super().__init__() |
| | self.layers = _get_clones(decoder_layer, num_layers) |
| | self.num_layers = num_layers |
| | self.norm = norm |
| | self.return_intermediate = return_intermediate |
| |
|
| | def forward(self, tgt, memory, |
| | tgt_mask: Optional[Tensor] = None, |
| | memory_mask: Optional[Tensor] = None, |
| | tgt_key_padding_mask: Optional[Tensor] = None, |
| | memory_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None, |
| | query_pos: Optional[Tensor] = None): |
| | output = tgt |
| |
|
| | intermediate = [] |
| |
|
| | for layer in self.layers: |
| | output = layer(output, memory, tgt_mask=tgt_mask, |
| | memory_mask=memory_mask, |
| | tgt_key_padding_mask=tgt_key_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | pos=pos, query_pos=query_pos) |
| | if self.return_intermediate: |
| | intermediate.append(self.norm(output)) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| | if self.return_intermediate: |
| | intermediate.pop() |
| | intermediate.append(output) |
| |
|
| | if self.return_intermediate: |
| | return torch.stack(intermediate) |
| |
|
| | return output.unsqueeze(0) |
| |
|
| |
|
| | class TransformerEncoderLayer(nn.Module): |
| |
|
| | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False): |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| |
|
| | self.activation = _get_activation_fn(activation) |
| | self.normalize_before = normalize_before |
| |
|
| | def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| | return tensor if pos is None else tensor + pos |
| |
|
| | def forward_post(self, |
| | src, |
| | src_mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None): |
| | q = k = self.with_pos_embed(src, pos) |
| | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, |
| | key_padding_mask=src_key_padding_mask)[0] |
| | src = src + self.dropout1(src2) |
| | src = self.norm1(src) |
| | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| | src = src + self.dropout2(src2) |
| | src = self.norm2(src) |
| | return src |
| |
|
| | def forward_pre(self, src, |
| | src_mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None): |
| | src2 = self.norm1(src) |
| | q = k = self.with_pos_embed(src2, pos) |
| | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, |
| | key_padding_mask=src_key_padding_mask)[0] |
| | src = src + self.dropout1(src2) |
| | src2 = self.norm2(src) |
| | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
| | src = src + self.dropout2(src2) |
| | return src |
| |
|
| | def forward(self, src, |
| | src_mask: Optional[Tensor] = None, |
| | src_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None): |
| | if self.normalize_before: |
| | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
| | return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
| |
|
| |
|
| | class TransformerDecoderLayer(nn.Module): |
| |
|
| | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False): |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.norm3 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| | self.dropout3 = nn.Dropout(dropout) |
| |
|
| | self.activation = _get_activation_fn(activation) |
| | self.normalize_before = normalize_before |
| |
|
| | def with_pos_embed(self, tensor, pos: Optional[Tensor]): |
| | return tensor if pos is None else tensor + pos |
| |
|
| | def forward_post(self, tgt, memory, |
| | tgt_mask: Optional[Tensor] = None, |
| | memory_mask: Optional[Tensor] = None, |
| | tgt_key_padding_mask: Optional[Tensor] = None, |
| | memory_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None, |
| | query_pos: Optional[Tensor] = None): |
| | q = k = self.with_pos_embed(tgt, query_pos) |
| | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, |
| | key_padding_mask=tgt_key_padding_mask)[0] |
| | tgt = tgt + self.dropout1(tgt2) |
| | tgt = self.norm1(tgt) |
| | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), |
| | key=self.with_pos_embed(memory, pos), |
| | value=memory, attn_mask=memory_mask, |
| | key_padding_mask=memory_key_padding_mask)[0] |
| | tgt = tgt + self.dropout2(tgt2) |
| | tgt = self.norm2(tgt) |
| | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
| | tgt = tgt + self.dropout3(tgt2) |
| | tgt = self.norm3(tgt) |
| | return tgt |
| |
|
| | def forward_pre(self, tgt, memory, |
| | tgt_mask: Optional[Tensor] = None, |
| | memory_mask: Optional[Tensor] = None, |
| | tgt_key_padding_mask: Optional[Tensor] = None, |
| | memory_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None, |
| | query_pos: Optional[Tensor] = None): |
| | tgt2 = self.norm1(tgt) |
| | q = k = self.with_pos_embed(tgt2, query_pos) |
| | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, |
| | key_padding_mask=tgt_key_padding_mask)[0] |
| | tgt = tgt + self.dropout1(tgt2) |
| | tgt2 = self.norm2(tgt) |
| | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), |
| | key=self.with_pos_embed(memory, pos), |
| | value=memory, attn_mask=memory_mask, |
| | key_padding_mask=memory_key_padding_mask)[0] |
| | tgt = tgt + self.dropout2(tgt2) |
| | tgt2 = self.norm3(tgt) |
| | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
| | tgt = tgt + self.dropout3(tgt2) |
| | return tgt |
| |
|
| | def forward(self, tgt, memory, |
| | tgt_mask: Optional[Tensor] = None, |
| | memory_mask: Optional[Tensor] = None, |
| | tgt_key_padding_mask: Optional[Tensor] = None, |
| | memory_key_padding_mask: Optional[Tensor] = None, |
| | pos: Optional[Tensor] = None, |
| | query_pos: Optional[Tensor] = None): |
| | if self.normalize_before: |
| | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, |
| | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
| | return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
| | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
| |
|
| |
|
| | def _get_clones(module, N): |
| | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
| |
|
| |
|
| | def build_transformer(args): |
| | return Transformer( |
| | d_model=args.hidden_dim, |
| | dropout=args.dropout, |
| | nhead=args.nheads, |
| | dim_feedforward=args.dim_feedforward, |
| | num_encoder_layers=args.enc_layers, |
| | num_decoder_layers=args.dec_layers, |
| | normalize_before=args.pre_norm, |
| | return_intermediate_dec=True, |
| | ) |
| |
|
| |
|
| | def _get_activation_fn(activation): |
| | """Return an activation function given a string""" |
| | if activation == "relu": |
| | return F.relu |
| | if activation == "gelu": |
| | return F.gelu |
| | if activation == "glu": |
| | return F.glu |
| | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
| |
|
| |
|
| | class LLMEyeCap(nn.Module): |
| | |
| | def __init__(self, backbone, transformer, num_queries, vocab_size,pad_token): |
| | |
| | super().__init__() |
| | self.num_queries = num_queries |
| | self.transformer = transformer |
| | self.hidden_dim = transformer.d_model |
| | |
| | self.caption_embed = nn.Linear(self.hidden_dim, vocab_size) |
| | self.bbox_embed = MLP(self.hidden_dim, self.hidden_dim, 4, 3) |
| | |
| | self.query_embed = nn.Embedding(num_queries, self.hidden_dim) |
| | self.input_proj = nn.Conv2d(backbone.num_channels, self.hidden_dim, kernel_size=1) |
| | self.backbone = backbone |
| | ''' |
| | self.capdecoder = CaptioningDecoder(detr_decoder_dim=transformer.d_model, token_embedding_dim=transformer.d_model, |
| | vocab_size=vocab_size, num_queries=num_queries, num_layers=6) |
| | ''' |
| | self.capdecoder = CaptionDecoder(feature_size, token_size, vocab_size,num_queries,pad_token ).to(device) |
| | |
| |
|
| | def forward(self, samples: NestedTensor, captions): |
| | |
| | if isinstance(samples, (list, torch.Tensor)): |
| | samples = nested_tensor_from_tensor_list(samples) |
| | |
| | features, pos = self.backbone(samples) |
| | src, mask = features[-1].decompose() |
| | assert mask is not None |
| | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] |
| | outputs_coord = self.bbox_embed(hs).sigmoid() |
| | |
| | outputs_captions=self.capdecoder(hs,captions) |
| | |
| |
|
| | out = {'pred_logits': outputs_captions , 'pred_boxes': outputs_coord[-1]} |
| | return out |
| | |
| | def generate_caption(self, image_path, tokenizer, max_length, pad_sos): |
| | |
| | image = Image.open(image_path).convert('RGB') |
| | transform = transforms.Compose([ |
| | transforms.Resize((256, 256)), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | image = transform(image).unsqueeze(0).to(device) |
| | |
| | if isinstance(image, (list, torch.Tensor)): |
| | image = nested_tensor_from_tensor_list(image) |
| | |
| | with torch.no_grad(): |
| | features, pos = self.backbone(image) |
| | src, mask = features[-1].decompose() |
| | assert mask is not None |
| | |
| | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] |
| | outputs_coord = self.bbox_embed(hs).sigmoid() |
| | |
| | input_ids = torch.ones((1, 40, 1), dtype=torch.long, device=device) |
| | input_ids.fill_(pad_sos) |
| |
|
| | |
| | for i in range(max_length): |
| | outputs_captions = self.capdecoder(hs, input_ids) |
| | predicted_sequences = torch.argmax(outputs_captions, dim=-1) |
| | next_token = predicted_sequences[:, :, -1:] |
| | input_ids = torch.cat((input_ids, next_token), dim=-1) |
| |
|
| | |
| |
|
| | return outputs_coord[-1], input_ids |
| |
|
| | class LLMEyeCapModel(nn.Module): |
| | def __init__(self, num_queries,vocab_size,pad_token): |
| | super(LLMEyeCapModel,self).__init__() |
| | self.num_queries = num_queries |
| | self.vocab_size=vocab_size |
| | self.backbone = build_backbone(args) |
| | self.transformer = build_transformer(args) |
| |
|
| | self.model = LLMEyeCap( |
| | self.backbone, |
| | self.transformer, |
| | num_queries=self.num_queries, |
| | vocab_size=self.vocab_size, |
| | pad_token=pad_token |
| | ) |
| | |
| | |
| | |
| | |
| | |
| | self.model.num_queries = self.num_queries |
| | |
| | def forward(self,images,captions): |
| | return self.model(images,captions) |
| | |
| | def generate_caption(self, image_path, tokenizer, max_length=20,pad_sos=0): |
| | return self.model.generate_caption(image_path, tokenizer, max_length,pad_sos) |
| |
|
| | class CaptionDecoder(nn.Module): |
| | def __init__(self, detr_decoder_dim, token_embedding_dim, vocab_size, num_queries, pad_token, num_layers=6): |
| | super(CaptionDecoder, self).__init__() |
| | |
| | self.detr_decoder_dim = detr_decoder_dim |
| | self.token_embedding_dim = token_embedding_dim |
| | self.vocab_size = vocab_size |
| | self.num_queries = num_queries |
| | self.pad_token = pad_token |
| |
|
| | |
| | self.token_embedding = nn.Embedding(vocab_size, token_embedding_dim) |
| | |
| | |
| | config = GPT2Config(vocab_size=vocab_size, n_embd=detr_decoder_dim + token_embedding_dim, n_head=8 ) |
| | self.gpt2 = GPT2LMHeadModel(config) |
| | |
| | self.target_projection = nn.Linear(token_embedding_dim, detr_decoder_dim + token_embedding_dim) |
| | |
| | def forward(self, detr_output, captions): |
| | |
| | |
| | |
| | attention_mask = (captions != self.pad_token).float().to(captions.device) |
| |
|
| |
|
| | seq_length = captions.size(2) |
| | pos_encoding = get_sinusoid_encoding_table(seq_length, self.token_embedding_dim).to(captions.device) |
| | pos_encoding = pos_encoding.unsqueeze(0).repeat(captions.size(0) * self.num_queries, 1, 1) |
| | |
| | |
| | spatial_embedding = detr_output[-1] |
| | |
| | |
| | token_embeddings = self.token_embedding(captions) |
| | |
| | |
| | spatial_embedding = spatial_embedding.unsqueeze(2) |
| | combined_embedding = torch.cat([spatial_embedding.repeat(1, 1, token_embeddings.size(2), 1), token_embeddings], dim=-1) |
| | |
| | |
| | |
| | memory = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim) |
| | |
| | |
| | |
| | target = token_embeddings.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.token_embedding_dim) |
| | |
| | |
| | |
| | pos_encoding = pos_encoding.permute(1, 0, 2) |
| | target += pos_encoding |
| |
|
| |
|
| | |
| | |
| | target = self.target_projection(target) |
| | |
| | attention_mask = attention_mask.permute(2, 0, 1).reshape(captions.size(2), -1) |
| | tgt_key_padding_mask = (attention_mask == 0).permute(1,0) |
| | |
| | |
| | inputs_embeds = combined_embedding.permute(2, 0, 1, 3).reshape(captions.size(2), -1, self.detr_decoder_dim + self.token_embedding_dim) |
| | |
| | |
| | attention_mask = attention_mask.reshape(-1, captions.size(2)) |
| | |
| | |
| | outputs = self.gpt2(inputs_embeds=inputs_embeds, attention_mask=attention_mask) |
| | logits = outputs.logits |
| | |
| | |
| | logits = logits.view(captions.size(2), captions.size(0), self.num_queries, self.vocab_size).permute(1, 2, 0, 3) |
| | |
| | return logits |
| |
|
| |
|