| | from typing import List, Optional, Tuple, Union
|
| | from addict import Dict
|
| | from dataclasses import dataclass
|
| | import torch.nn.functional as F
|
| | import numpy as np
|
| | import pickle
|
| | import torch
|
| | import math
|
| | import torch.nn as nn
|
| | from torch.nn import CrossEntropyLoss
|
| | from psalm.model.visual_prompt_module.context_cluster import region_pooling
|
| | from transformers import AutoConfig, AutoModelForCausalLM
|
| |
|
| | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
| |
|
| | from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, \
|
| | DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX
|
| | from detectron2.structures import Boxes, ImageList, Instances, BitMasks
|
| | from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast
|
| | from detectron2.modeling.postprocessing import sem_seg_postprocess
|
| | from detectron2.utils.memory import retry_if_cuda_oom
|
| | from ..mask_decoder.Mask2Former_Simplify.modeling.transformer_decoder.mask2former_transformer_decoder import \
|
| | MultiScaleMaskedTransformerDecoderForOPTPreTrain
|
| | from ..mask_decoder.Mask2Former_Simplify.modeling.pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder
|
| | from ..multimodal_projector.builder import build_vision_projector
|
| | from ..multimodal_encoder.swin_trans import build_swin_b, build_swin_l
|
| |
|
| |
|
| | from ..datasets_mapper.coco_instance_mapper_nullmask import COCOInstanceNewBaselineDatasetMapper as COCOnullmask
|
| | from ..datasets_mapper.coco_instance_mapper_exosize import COCOInstanceNewBaselineDatasetMapper as COCOresize
|
| | from ..datasets_mapper.coco_instance_mapper import COCOInstanceNewBaselineDatasetMapper
|
| | from ..datasets_mapper.coco_panoptic_mapper import COCOPanopticNewBaselineDatasetMapper
|
| | from ..datasets_mapper.coco_semantic_mapper import COCOSemanticNewBaselineDatasetMapper
|
| | from psalm.model.mask_decoder.mask_criterion.pretrain_criterion import PSALM_criterion, hungarian_matcher_PSALM
|
| | from transformers import PhiModel, PhiForCausalLM, PhiConfig
|
| |
|
| | import torch
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| | def calculate_region_embedding_dis(embedding_list1, embedding_list2):
|
| |
|
| | assert len(embedding_list1) == len(embedding_list2), "Embedding lists must have the same length."
|
| | for i in range(len(embedding_list1)):
|
| | print(embedding_list1[i].shape)
|
| | assert embedding_list1[i].shape == embedding_list2[i].shape
|
| |
|
| |
|
| | similarity_scores = []
|
| |
|
| | for emb1, emb2 in zip(embedding_list1, embedding_list2):
|
| |
|
| | sim = F.cosine_similarity(emb1, emb2, dim=-1)
|
| | similarity_scores.append(sim.mean())
|
| |
|
| |
|
| | avg_similarity = torch.mean(torch.stack(similarity_scores))
|
| |
|
| |
|
| | loss = 1 - avg_similarity
|
| |
|
| | return loss
|
| |
|
| |
|
| | class LlavaConfig(PhiConfig):
|
| | model_type = "llava_phi"
|
| |
|
| | @dataclass
|
| | class CausalOutputWithMask(CausalLMOutputWithPast):
|
| | loss: Optional[torch.FloatTensor] = None
|
| | logits: torch.FloatTensor = None
|
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| | attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| | loss_mask: Optional[torch.FloatTensor] = None
|
| | loss_dice: Optional[torch.FloatTensor] = None
|
| | loss_SEG_class: Optional[torch.FloatTensor] = None
|
| | loss_class_name_class: Optional[torch.FloatTensor] = None
|
| | loss_region_class: Optional[torch.FloatTensor] = None
|
| | loss_llm: Optional[torch.FloatTensor] = None
|
| |
|
| | @dataclass
|
| | class CausalOutputWithMaskSSL(CausalLMOutputWithPast):
|
| | loss: Optional[torch.FloatTensor] = None
|
| | logits: torch.FloatTensor = None
|
| | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| | attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| | loss_mask: Optional[torch.FloatTensor] = None
|
| | loss_dice: Optional[torch.FloatTensor] = None
|
| | loss_SEG_class: Optional[torch.FloatTensor] = None
|
| | loss_class_name_class: Optional[torch.FloatTensor] = None
|
| | loss_region_class: Optional[torch.FloatTensor] = None
|
| | loss_llm: Optional[torch.FloatTensor] = None
|
| | loss_region_emb_SSL: Optional[torch.FloatTensor] = None
|
| |
|
| |
|
| | class PSALMModel(LlavaMetaModel, PhiModel):
|
| | config_class = LlavaConfig
|
| |
|
| | def __init__(self, config: PhiConfig, mask_decoder_cfg=None):
|
| | super(PSALMModel, self).__init__(config)
|
| | self.cfg = mask_decoder_cfg
|
| | self.projector_outdim = config.hidden_size
|
| | if hasattr(config, "mm_vision_tower"):
|
| | swin_type = getattr(config,'swin_type','base')
|
| | if swin_type == 'base':
|
| | self.vision_tower = build_swin_b(None)
|
| | else:
|
| | self.vision_tower = build_swin_l(None)
|
| | self.mm_projector = build_vision_projector(config)
|
| | self.vision_tower.image_processor = {}
|
| | self.vision_tower.image_processor['null_mask'] = COCOnullmask(self.cfg)
|
| | self.vision_tower.image_processor['panoptic'] = COCOPanopticNewBaselineDatasetMapper(self.cfg)
|
| | self.vision_tower.image_processor['instance'] = COCOInstanceNewBaselineDatasetMapper(self.cfg)
|
| | self.vision_tower.image_processor['semantic'] = COCOSemanticNewBaselineDatasetMapper(self.cfg)
|
| | self.vision_tower.image_processor['instance_resize'] = COCOresize(self.cfg)
|
| |
|
| |
|
| | def get_vision_tower(self):
|
| | vision_tower = getattr(self, 'vision_tower', None)
|
| | if type(vision_tower) is list:
|
| | vision_tower = vision_tower[0]
|
| | return vision_tower
|
| |
|
| | def initialize_vision_modules(self, model_args, fsdp=None):
|
| | vision_tower = model_args.vision_tower if hasattr(model_args, 'vision_tower') else model_args.mm_vision_tower
|
| | with_norm = model_args.with_norm
|
| | with_layernorm = model_args.with_layernorm
|
| | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter if hasattr(model_args,
|
| | 'pretrain_mm_mlp_adapter') else None
|
| | projector_outdim = self.projector_outdim
|
| |
|
| | self.config.mm_vision_tower = vision_tower
|
| | swin_type = getattr(model_args,'swin_type','base')
|
| | self.config.swin_type = swin_type
|
| | if swin_type == 'base':
|
| | vision_tower = build_swin_b(vision_tower)
|
| | else:
|
| | print('current visual encoder is swin large')
|
| | vision_tower = build_swin_l(vision_tower)
|
| | self.config.mm_input_embeds = 1536
|
| |
|
| | if fsdp is not None and len(fsdp) > 0:
|
| | self.vision_tower = [vision_tower]
|
| | else:
|
| | self.vision_tower = vision_tower
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.config.use_mm_proj = True
|
| | vision_tower.hidden_size = 256
|
| | vision_tower.image_processor = {}
|
| | vision_tower.image_processor['panoptic'] = COCOPanopticNewBaselineDatasetMapper(self.cfg)
|
| | vision_tower.image_processor['instance'] = COCOInstanceNewBaselineDatasetMapper(self.cfg)
|
| | vision_tower.image_processor['semantic'] = COCOSemanticNewBaselineDatasetMapper(self.cfg)
|
| |
|
| |
|
| |
|
| |
|
| | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'conv')
|
| | print(f'current mm_project_type is {self.config.mm_projector_type}, the output dim is {projector_outdim}')
|
| | self.config.mm_hidden_size = vision_tower.hidden_size
|
| | self.config.with_norm = with_norm
|
| | self.config.with_layernorm = with_layernorm
|
| | self.config.projector_outdim = projector_outdim
|
| |
|
| | if not hasattr(self, "mm_projector"):
|
| | self.mm_projector = build_vision_projector(self.config)
|
| | else:
|
| | print('exist mm_projector, skip init')
|
| |
|
| | if pretrain_mm_mlp_adapter is not None:
|
| | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
|
| |
|
| | def get_w(weights, keyword):
|
| | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| |
|
| |
|
| | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False)
|
| | print('load mm_projector pth successfully')
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class PSALM_SSL(PhiForCausalLM, LlavaMetaForCausalLM):
|
| | config_class = LlavaConfig
|
| |
|
| | def __init__(self, config, mask_decoder_cfg=None, add_cross_attn=True, cross_attn_index=None):
|
| | super(PSALM_SSL, self).__init__(config)
|
| |
|
| | self.model = PSALMModel(config, mask_decoder_cfg)
|
| | self.init_config = config
|
| | self.mask_decoder_cfg = mask_decoder_cfg
|
| | self.cross_attn_index = cross_attn_index
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| |
|
| | is_train_mask_decode = getattr(config, 'mask_decode_train', False)
|
| | self.is_train_mask_decode = is_train_mask_decode
|
| | self.refer_pooling = nn.AdaptiveAvgPool1d(output_size=1)
|
| | self.class_name_pooling = nn.AdaptiveAvgPool1d(output_size=1)
|
| | self.region_sampler = region_pooling(num_sample_point=256)
|
| | self.region_projector = nn.Linear(config.hidden_size, mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| |
|
| | if is_train_mask_decode:
|
| | print('Mask Decoder has been trained, init directly')
|
| | self.initial_mask_module()
|
| | self.post_init()
|
| |
|
| |
|
| | def initial_mask_module(self, pretrained_path=None, model_args=None):
|
| | if not self.is_train_mask_decode:
|
| | print('Initialize mask modules...')
|
| | self.config.mask_decode_train = True
|
| | self.seg_query = nn.Parameter(
|
| | torch.zeros([self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, self.config.hidden_size]))
|
| | self.num_queries = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | self.num_classes = self.mask_decoder_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
| | self.test_topk_per_image = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | input_shape = self.output_shape()
|
| | self.pixel_decoder = self.pixel_decoder_init(cfg=self.mask_decoder_cfg, input_shape=input_shape)
|
| | self.predictor = self.predictor_init(cfg=self.mask_decoder_cfg)
|
| |
|
| | self.seg_query_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| | self.SEG_token_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| | self.class_name_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| |
|
| | self.mask_decoder_training_init(self.mask_decoder_cfg)
|
| | if pretrained_path is not None:
|
| | def get_w(weights, keyword):
|
| | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| | def change_w(weights, old_name, new_name):
|
| | weights[new_name] = weights[old_name]
|
| | weights.pop(old_name)
|
| |
|
| | if pretrained_path.endswith('.pkl'):
|
| | with open(pretrained_path, 'rb') as f:
|
| | ckpt = pickle.load(f)
|
| | else:
|
| | ckpt = torch.load(pretrained_path)
|
| | pixel_decoder_weights = get_w(ckpt['model'],'sem_seg_head.pixel_decoder')
|
| | predictor_weights = get_w(ckpt['model'],'sem_seg_head.predictor')
|
| | pixel_decoder_weights = {k: torch.tensor(v) for k, v in pixel_decoder_weights.items()}
|
| | predictor_weights = {k: torch.tensor(v) for k, v in predictor_weights.items()}
|
| |
|
| |
|
| | change_w(pixel_decoder_weights,'adapter_1.weight','adapter_1.0.weight')
|
| | change_w(pixel_decoder_weights,'adapter_1.norm.weight','adapter_1.1.weight')
|
| | change_w(pixel_decoder_weights,'adapter_1.norm.bias','adapter_1.1.bias')
|
| | change_w(pixel_decoder_weights,'layer_1.weight','layer_1.0.weight')
|
| | change_w(pixel_decoder_weights,'layer_1.norm.weight','layer_1.1.weight')
|
| | change_w(pixel_decoder_weights,'layer_1.norm.bias','layer_1.1.bias')
|
| | if 'static_query.weight' in predictor_weights:
|
| | change_w(predictor_weights,'static_query.weight','query_feat.weight')
|
| | if predictor_weights['query_embed.weight'].shape[0] == 200:
|
| | predictor_weights['query_embed.weight'] = predictor_weights['query_embed.weight'][:100,:]
|
| | diff_pixel_msg = self.pixel_decoder.load_state_dict(pixel_decoder_weights,strict=False)
|
| | diff_predictor_msg = self.predictor.load_state_dict(predictor_weights,strict=False)
|
| | print(diff_predictor_msg)
|
| | print(diff_pixel_msg)
|
| |
|
| |
|
| | def get_vision_tower_feature(self, images):
|
| | features = self.get_model().get_vision_tower()(images)
|
| | features_dict = {
|
| | 'res2': features[0],
|
| | 'res3': features[1],
|
| | 'res4': features[2],
|
| | 'res5': features[3],
|
| | }
|
| | return features_dict
|
| |
|
| | def mask_decoder_training_init(self, cfg):
|
| |
|
| | deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
| | no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
| |
|
| |
|
| | class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
|
| | dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
| | mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
| |
|
| |
|
| | matcher = hungarian_matcher_PSALM(
|
| | cost_class=class_weight,
|
| | cost_mask=mask_weight,
|
| | cost_dice=dice_weight,
|
| | num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
| | )
|
| |
|
| | weight_dict = {"loss_SEG_class": class_weight, "loss_class_name_class": class_weight, "loss_mask": mask_weight,
|
| | "loss_dice": dice_weight, "loss_region_class": class_weight}
|
| | self.weight_dict = weight_dict
|
| | if deep_supervision:
|
| | dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
| | aux_weight_dict = {}
|
| | for i in range(dec_layers - 1):
|
| | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| | weight_dict.update(aux_weight_dict)
|
| | losses = ["SEG_labels", "class_name_labels", "masks", "region_labels"]
|
| | self.criterion = PSALM_criterion(
|
| | matcher=matcher,
|
| | losses=losses,
|
| | num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
| | oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
| | importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
| | device=self.device
|
| | )
|
| | self.size_divisibility = 32
|
| | if cfg.MODEL.MASK_FORMER.SEG_TASK == 'semantic':
|
| | self.semantic_on = True
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = False
|
| |
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'instance':
|
| | self.semantic_on = False
|
| | self.instance_on = True
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'panoptic':
|
| | self.semantic_on = True
|
| | self.instance_on = True
|
| | self.panoptic_on = True
|
| | self.referring_on = False
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'referring':
|
| | self.semantic_on = False
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = True
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'region':
|
| | self.semantic_on = False
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = True
|
| | else:
|
| | raise NotImplementedError
|
| | self.sem_seg_postprocess_before_inference = self.instance_on or self.panoptic_on or self.referring_on or self.region_on
|
| | def get_region_embedding(self, hidden_states, region_embedding_masks):
|
| | region_embedding_list = []
|
| | for sample_hidden_satates, sample_region_embedding_masks in zip(hidden_states, region_embedding_masks):
|
| | sample_region_embedding = sample_hidden_satates[sample_region_embedding_masks.bool()]
|
| | region_embedding_list.append(sample_region_embedding)
|
| | return region_embedding_list
|
| | def SEG_instance_inference(self, SEG_cls, mask_pred):
|
| |
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | scores = F.sigmoid(SEG_cls)
|
| | scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
| |
|
| | mask_pred = mask_pred[topk_indices]
|
| |
|
| | result = Instances(image_size)
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = scores_per_image * mask_scores_per_image
|
| | return result
|
| | def class_name_panoptic_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| |
|
| | scores, labels = F.softmax(class_name_cls, dim=-1).max(-1)
|
| | num_classes = class_name_cls.shape[-1] - 1
|
| | mask_pred = mask_pred.sigmoid()
|
| |
|
| | object_mask_threshold = 0.8
|
| | overlap_threshold = 0.8
|
| |
|
| | keep = labels.ne(num_classes) & (scores > object_mask_threshold)
|
| | cur_scores = scores[keep]
|
| | cur_classes = labels[keep]
|
| | cur_masks = mask_pred[keep]
|
| | cur_mask_cls = class_name_cls[keep]
|
| | cur_mask_cls = cur_mask_cls[:, :-1]
|
| |
|
| | cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
| |
|
| | h, w = cur_masks.shape[-2:]
|
| | panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
| | segments_info = []
|
| |
|
| | current_segment_id = 0
|
| |
|
| | if cur_masks.shape[0] == 0:
|
| |
|
| | return panoptic_seg, segments_info
|
| | else:
|
| |
|
| | cur_mask_ids = cur_prob_masks.argmax(0)
|
| | stuff_memory_list = {}
|
| | for k in range(cur_classes.shape[0]):
|
| | pred_class = cur_classes[k].item()
|
| | isthing = self.is_thing_list[pred_class]
|
| | mask_area = (cur_mask_ids == k).sum().item()
|
| | original_area = (cur_masks[k] >= 0.5).sum().item()
|
| | mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
| |
|
| | if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
| | if mask_area / original_area < overlap_threshold:
|
| | continue
|
| |
|
| |
|
| | if not isthing:
|
| | if int(pred_class) in stuff_memory_list.keys():
|
| | panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
| | continue
|
| | else:
|
| | stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
| |
|
| | current_segment_id += 1
|
| | panoptic_seg[mask] = current_segment_id
|
| |
|
| | segments_info.append(
|
| | {
|
| | "id": current_segment_id,
|
| | "isthing": bool(isthing),
|
| | "category_id": int(pred_class),
|
| | }
|
| | )
|
| |
|
| | return panoptic_seg, segments_info
|
| | def region_inference(self, region_cls, mask_pred):
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | scores = F.sigmoid(region_cls)
|
| |
|
| |
|
| | result = Instances(image_size)
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = (scores * mask_scores_per_image[None,...].repeat(scores.shape[0],1)).transpose(1,0)
|
| | return result
|
| |
|
| | def class_name_semantic_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| | mask_cls = F.softmax(class_name_cls, dim=-1)[:, :-1]
|
| | mask_pred = mask_pred.sigmoid()
|
| | semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
| | return semseg
|
| | def class_name_instance_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | cls_scores = F.softmax(class_name_cls, dim=-1)[:, :-1]
|
| | scores = cls_scores
|
| |
|
| | num_classes = scores.shape[-1]
|
| |
|
| | labels = torch.arange(num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
| |
|
| | scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
| |
|
| | labels_per_image = labels[topk_indices]
|
| |
|
| | topk_indices = topk_indices // num_classes
|
| | mask_pred = mask_pred[topk_indices]
|
| |
|
| |
|
| |
|
| | if self.panoptic_on:
|
| | keep = torch.zeros_like(scores_per_image).bool()
|
| | for i, lab in enumerate(labels_per_image):
|
| | keep[i] = self.is_thing_list[lab]
|
| |
|
| | scores_per_image = scores_per_image[keep]
|
| | labels_per_image = labels_per_image[keep]
|
| | mask_pred = mask_pred[keep]
|
| |
|
| | result = Instances(image_size)
|
| |
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| |
|
| |
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = scores_per_image * mask_scores_per_image
|
| | result.pred_classes = labels_per_image
|
| | return result
|
| | def encode_images(self, images):
|
| | image_features = self.get_model().get_vision_tower()(images)
|
| | image_features = self.get_model().mm_projector(image_features[-1])
|
| | return image_features
|
| |
|
| | def predictor_init(self, cfg):
|
| | in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
| | hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
| | num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | nheads = cfg.MODEL.MASK_FORMER.NHEADS
|
| | dim_feedforward = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
|
| | dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
|
| | pre_norm = cfg.MODEL.MASK_FORMER.PRE_NORM
|
| | mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
| | enforce_input_project = False
|
| | seg_norm = cfg.MODEL.MASK_FORMER.SEG_NORM
|
| | seg_proj = cfg.MODEL.MASK_FORMER.SEG_PROJ
|
| | seg_fuse_score = cfg.MODEL.MASK_FORMER.FUSE_SCORE
|
| | seg_concat = False
|
| | print(f'current seg concat mode: {seg_concat}, seg_norm: {seg_norm}, seg_proj: {seg_proj}, seg_fuse_score: {seg_fuse_score}')
|
| | predictor = MultiScaleMaskedTransformerDecoderForOPTPreTrain(in_channels,
|
| | hidden_dim,
|
| | num_queries,
|
| | nheads,
|
| | dim_feedforward,
|
| | dec_layers,
|
| | pre_norm,
|
| | mask_dim,
|
| | enforce_input_project,
|
| | seg_norm,
|
| | seg_concat,
|
| | seg_proj,
|
| | seg_fuse_score)
|
| | return predictor
|
| |
|
| |
|
| | def get_model(self):
|
| | return self.model
|
| | def output_shape(self):
|
| | out_features = self.mask_decoder_cfg.MODEL.SWIN.OUT_FEATURES
|
| | out_feature_strides = {
|
| | "res2": 4,
|
| | "res3": 8,
|
| | "res4": 16,
|
| | "res5": 32,
|
| | }
|
| | num_features = [int(self.mask_decoder_cfg.MODEL.SWIN.EMBED_DIM * 2 ** i) for i in
|
| | range(len(self.mask_decoder_cfg.MODEL.SWIN.DEPTHS))]
|
| | out_feature_channels = {
|
| | "res2": num_features[0],
|
| | "res3": num_features[1],
|
| | "res4": num_features[2],
|
| | "res5": num_features[3],
|
| | }
|
| | backbone_feature_shape = dict()
|
| | for name in out_features:
|
| | backbone_feature_shape[name] = Dict(
|
| | {'channel': out_feature_channels[name], 'stride': out_feature_strides[name]})
|
| | return backbone_feature_shape
|
| |
|
| | def get_encoder_image(self, images):
|
| | encode_image_features = self.get_model().get_vision_tower()(images)
|
| | return encode_image_features
|
| |
|
| | def pixel_decoder_init(self, cfg, input_shape):
|
| | common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
|
| | transformer_dropout = cfg.MODEL.MASK_FORMER.DROPOUT
|
| | transformer_nheads = cfg.MODEL.MASK_FORMER.NHEADS
|
| | transformer_dim_feedforward = 1024
|
| | transformer_enc_layers = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS
|
| | conv_dim = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
| | mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
| | transformer_in_features = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
|
| |
|
| | pixel_decoder = MSDeformAttnPixelDecoder(input_shape,
|
| | transformer_dropout,
|
| | transformer_nheads,
|
| | transformer_dim_feedforward,
|
| | transformer_enc_layers,
|
| | conv_dim,
|
| | mask_dim,
|
| | transformer_in_features,
|
| | common_stride)
|
| | return pixel_decoder
|
| | def prepare_targets(self, targets, images):
|
| | h_pad, w_pad = images.shape[-2:]
|
| | new_targets = []
|
| | for targets_per_image in targets:
|
| |
|
| | gt_masks = targets_per_image.gt_masks
|
| | padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
| | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
| | new_targets.append(
|
| | {
|
| | "labels": targets_per_image.gt_classes,
|
| | "masks": padded_masks,
|
| | }
|
| | )
|
| | return new_targets
|
| |
|
| | def get_special_token(self, SEG, EOS):
|
| | self.SEG_id = SEG
|
| | self.EOS_id = EOS
|
| |
|
| | def get_class_name_embedding(self, hidden_states, cls_token_indices):
|
| | class_name_embedding_list = []
|
| | for current_hidden_state, current_token_indice in zip(hidden_states, cls_token_indices):
|
| | class_id = torch.unique(current_token_indice)
|
| | class_id = class_id[class_id != 0]
|
| | current_class_name_embedding_list = []
|
| | for id in class_id:
|
| | current_class_mask = (current_token_indice == id)
|
| | current_class_state = current_hidden_state[current_class_mask]
|
| | current_class_name_embedding_list.append(current_class_state)
|
| | current_pool_class_name_embedding = [self.class_name_pooling(class_name.transpose(-2, -1)).transpose(-2, -1)
|
| | for class_name in current_class_name_embedding_list]
|
| | class_name_embedding_list.append(torch.cat(current_pool_class_name_embedding, dim=0))
|
| | return torch.stack(class_name_embedding_list, dim=0)
|
| | def embed_class_ids(self, class_name_ids, cls_indices):
|
| | if class_name_ids is None:
|
| | return None
|
| | num_class = cls_indices.unique_consecutive()
|
| | num_class = num_class[num_class >= 0]
|
| | class_name_ids = [class_name_ids[cls_indices == idx] for idx in num_class]
|
| | embedded_class_name = [self.get_model().embed_tokens(id) for id in class_name_ids]
|
| |
|
| | return embedded_class_name
|
| |
|
| | def embed_refer_ids(self, refer_ids):
|
| | if refer_ids is None:
|
| | return None
|
| | embedded_refer = self.get_model().embed_tokens(refer_ids)
|
| | return embedded_refer
|
| |
|
| |
|
| |
|
| | def concat_image_seg_cls_embeds(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed,
|
| | class_name_embedding_indices,region_embedding_mask=None, region_feature_list=None, refer_embedding_indices=None,
|
| | refer_embedding=None):
|
| | image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0]
|
| | seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0]
|
| | cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0]
|
| | region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0]
|
| | assert len(image_token_indices) == 1, 'not supporting multi image index'
|
| | assert len(seg_query_indices) == 1, 'not supporting multi seg index'
|
| | if class_name_embedding_indices is not None:
|
| | assert len(cls_token_indices) == len(class_embed), 'the number of <cls> tokens and class_embed needs to be same'
|
| | if region_feature_list is not None:
|
| | assert len(region_feature_list) == len(
|
| | region_token_indices), 'the munber of <region> tokens and regions needs to be same'
|
| | cur_new_input_embeds = []
|
| | cur_new_seg_query_mask = []
|
| | if label is not None:
|
| | cur_new_label = []
|
| | assert label.shape == input_id.shape
|
| | else:
|
| | cur_new_label = None
|
| | cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| |
|
| | if region_embedding_mask is not None:
|
| | enable_region_mask = True
|
| | cur_new_region_embedding_mask = []
|
| | else:
|
| | enable_region_mask = False
|
| | cur_new_region_embedding_mask = None
|
| | chunks = []
|
| | current_chunk = []
|
| |
|
| | for id in input_id:
|
| | if id >= 0:
|
| | current_chunk.append(id.item())
|
| | else:
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| | current_chunk = []
|
| | chunks.append([id])
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| |
|
| | cls_idx = 0
|
| | region_idx = 0
|
| | for chunk in chunks:
|
| | chunk_len = len(chunk)
|
| | if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(img_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(seg_query)
|
| | cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX:
|
| | cls_embed = class_embed[cls_idx]
|
| | if len(cls_embed.shape) == 1:
|
| | cls_embed = cls_embed.unsqueeze(0)
|
| | cls_idx += 1
|
| | cur_new_input_embeds.append(cls_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX:
|
| | region_feature = region_feature_list[region_idx]
|
| | region_idx += 1
|
| | cur_new_input_embeds.append(region_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX:
|
| | refer_embed = refer_embedding
|
| | if len(refer_embed.shape) == 1:
|
| | refer_embed = refer_embed.unsqueeze(0)
|
| | cur_new_input_embeds.append(refer_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 1, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len]))
|
| | cur_new_seg_query_mask.append(seg_query_mask[:chunk_len])
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len])
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len])
|
| | if label is not None:
|
| | cur_new_label.append(label[:chunk_len])
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len])
|
| |
|
| | input_id = input_id[chunk_len:]
|
| | seg_query_mask = seg_query_mask[chunk_len:]
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding_indices = class_name_embedding_indices[chunk_len:]
|
| | if refer_embedding_indices is not None:
|
| | refer_embedding_indices = refer_embedding_indices[chunk_len:]
|
| | if label is not None:
|
| | label = label[chunk_len:]
|
| | if enable_region_mask:
|
| | region_embedding_mask = region_embedding_mask[chunk_len:]
|
| |
|
| | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| | if label is not None:
|
| | cur_new_label = [x.to(device=self.device) for x in cur_new_label]
|
| | cur_new_label = torch.cat(cur_new_label, dim=0)
|
| | cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask]
|
| | cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices]
|
| | cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices]
|
| | cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0)
|
| |
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask]
|
| | cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0)
|
| |
|
| | return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_refer_embedding_indices
|
| |
|
| |
|
| |
|
| |
|
| | def concat_image_seg_cls_embeds_SSL(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed,
|
| | class_name_embedding_indices,region_embedding_mask=None, region_embedding_mask_exo=None, region_feature_list=None, region_feature_list_exo=None, refer_embedding_indices=None,
|
| | refer_embedding=None):
|
| | image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0]
|
| | seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0]
|
| | cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0]
|
| | region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0]
|
| | assert len(image_token_indices) == 1, 'not supporting multi image index'
|
| | assert len(seg_query_indices) == 1, 'not supporting multi seg index'
|
| | if class_name_embedding_indices is not None:
|
| | assert len(cls_token_indices) == len(class_embed), 'the number of <cls> tokens and class_embed needs to be same'
|
| | if region_feature_list is not None:
|
| | assert len(region_feature_list) == len(
|
| | region_token_indices), 'the munber of <region> tokens and regions needs to be same'
|
| | cur_new_input_embeds = []
|
| | cur_new_seg_query_mask = []
|
| | if label is not None:
|
| | cur_new_label = []
|
| | assert label.shape == input_id.shape
|
| | else:
|
| | cur_new_label = None
|
| | cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| |
|
| | if region_embedding_mask is not None:
|
| | enable_region_mask = True
|
| | cur_new_region_embedding_mask = []
|
| | cur_new_region_embedding_mask_exo = []
|
| | else:
|
| | enable_region_mask = False
|
| | cur_new_region_embedding_mask = None
|
| | cur_new_region_embedding_mask_exo = None
|
| | chunks = []
|
| | current_chunk = []
|
| |
|
| | for id in input_id:
|
| | if id >= 0:
|
| | current_chunk.append(id.item())
|
| | else:
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| | current_chunk = []
|
| | chunks.append([id])
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| |
|
| | cls_idx = 0
|
| | region_idx = 0
|
| | for chunk in chunks:
|
| | chunk_len = len(chunk)
|
| | if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(img_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0]))
|
| | cur_new_region_embedding_mask_exo.append(torch.zeros(img_feature.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(seg_query)
|
| | cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0]))
|
| | cur_new_region_embedding_mask_exo.append(torch.zeros(seg_query.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX:
|
| | cls_embed = class_embed[cls_idx]
|
| | if len(cls_embed.shape) == 1:
|
| | cls_embed = cls_embed.unsqueeze(0)
|
| | cls_idx += 1
|
| | cur_new_input_embeds.append(cls_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | cur_new_region_embedding_mask_exo.append(torch.zeros(cls_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX:
|
| | region_feature = region_feature_list[region_idx]
|
| | region_feature_exo = region_feature_list_exo[region_idx]
|
| | region_idx += 1
|
| | cur_new_input_embeds.append(region_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0]))
|
| |
|
| | cur_new_region_embedding_mask_exo.append(torch.ones(region_feature_exo.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX:
|
| | refer_embed = refer_embedding
|
| | if len(refer_embed.shape) == 1:
|
| | refer_embed = refer_embed.unsqueeze(0)
|
| | cur_new_input_embeds.append(refer_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | cur_new_region_embedding_mask_exo.append(torch.zeros(refer_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 1, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len]))
|
| | cur_new_seg_query_mask.append(seg_query_mask[:chunk_len])
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len])
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len])
|
| | if label is not None:
|
| | cur_new_label.append(label[:chunk_len])
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len])
|
| | cur_new_region_embedding_mask_exo.append(region_embedding_mask_exo[:chunk_len])
|
| |
|
| | input_id = input_id[chunk_len:]
|
| | seg_query_mask = seg_query_mask[chunk_len:]
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding_indices = class_name_embedding_indices[chunk_len:]
|
| | if refer_embedding_indices is not None:
|
| | refer_embedding_indices = refer_embedding_indices[chunk_len:]
|
| | if label is not None:
|
| | label = label[chunk_len:]
|
| | if enable_region_mask:
|
| | region_embedding_mask = region_embedding_mask[chunk_len:]
|
| | region_embedding_mask_exo = region_embedding_mask_exo[chunk_len:]
|
| |
|
| | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| | if label is not None:
|
| | cur_new_label = [x.to(device=self.device) for x in cur_new_label]
|
| | cur_new_label = torch.cat(cur_new_label, dim=0)
|
| | cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask]
|
| | cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices]
|
| | cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices]
|
| | cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0)
|
| |
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask]
|
| | cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0)
|
| |
|
| | cur_new_region_embedding_mask_exo = [x.to(device=self.device) for x in cur_new_region_embedding_mask_exo]
|
| | cur_new_region_embedding_mask_exo = torch.cat(cur_new_region_embedding_mask_exo, dim=0)
|
| |
|
| | return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_new_region_embedding_mask_exo, cur_refer_embedding_indices
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def prepare_inputs_labels_for_multimodal_SSL(
|
| | self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None,
|
| | class_name_embedding_indices=None,
|
| | class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None
|
| | ):
|
| | vision_tower = self.get_vision_tower()
|
| |
|
| | seg_query_mask = torch.zeros_like(input_ids)
|
| |
|
| | if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
| | 1] == 1:
|
| | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask
|
| |
|
| | if type(images) is list or images.ndim == 5:
|
| | concat_images = torch.cat([image for image in images], dim=0)
|
| | image_features = self.encode_images(concat_images)
|
| | split_sizes = [image.shape[0] for image in images]
|
| | image_features = torch.split(image_features, split_sizes, dim=0)
|
| | image_features = [x.flatten(0, 1) for x in image_features]
|
| | else:
|
| | image_features = self.encode_images(images)
|
| |
|
| |
|
| | expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1)
|
| |
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None:
|
| | region_masks_list = [instance.vp_region_masks.tensor for instance in instances]
|
| |
|
| | vp_image_features = self.encode_images(vp_images)
|
| |
|
| |
|
| | region_features = self.region_sampler(vp_image_features, region_masks_list,
|
| | original_dtype=vp_image_features.dtype,
|
| | return_dtype=vp_image_features.dtype)
|
| | region_embedding_masks = torch.zeros_like(input_ids)
|
| | print('llava_phi_SSL: PSALM')
|
| |
|
| |
|
| | for instance in instances:
|
| | print('llava_phi_SSL: instance.vp_region_masks:', instance.vp_region_masks.tensor.shape)
|
| | print('llava_phi_SSL: instance.gt_masks:', instance.gt_masks.shape)
|
| |
|
| |
|
| | print('llava_phi_SSL: vp_image_fea:', vp_image_features.shape)
|
| | print('llava_phi_SSL: image_fea:', image_features.shape)
|
| |
|
| |
|
| | region_masks_list_exo = [instance.gt_masks for instance in instances]
|
| |
|
| | image_features_exo = image_features.detach().clone().requires_grad_(True)
|
| | region_features_exo = self.region_sampler(image_features_exo, region_masks_list_exo,
|
| | original_dtype=image_features.dtype,
|
| | return_dtype=image_features.dtype)
|
| |
|
| | print('region_features:', len(region_features))
|
| | for rf in region_features:
|
| | print(rf.shape)
|
| | print('region_features exo:', len(region_features_exo))
|
| | for rf in region_features_exo:
|
| | print(rf.shape)
|
| | print(region_features[0]==region_features_exo[0], region_features[1] == region_features_exo[1], region_features[2]==region_features_exo[2], region_features[3]==region_features_exo[3])
|
| |
|
| | region_embedding_masks_exo = torch.zeros_like(input_ids)
|
| | print('input_ids:', input_ids)
|
| | print('region_embedding_masks', region_embedding_masks.shape, region_embedding_masks)
|
| | print('region_embedding_masks_exo', region_embedding_masks_exo.shape, region_embedding_masks_exo)
|
| |
|
| |
|
| |
|
| |
|
| | else:
|
| | region_features = None
|
| | region_embedding_masks = None
|
| |
|
| | region_features_exo = None
|
| | region_embedding_masks_exo = None
|
| |
|
| | new_input_embeds = []
|
| | new_input_embeds_exo = []
|
| | new_labels = [] if labels is not None else None
|
| | new_seg_query_masks = []
|
| | new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | new_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| | new_region_embedding_masks = [] if region_features is not None else None
|
| |
|
| |
|
| | new_region_embedding_masks_exo = [] if region_features_exo is not None else None
|
| |
|
| | for batch_idx, cur_input_ids in enumerate(input_ids):
|
| | cur_seg_query_mask = seg_query_mask[batch_idx]
|
| | cur_seg_query = expanded_seg_query[batch_idx]
|
| | cur_image_feature = image_features[batch_idx]
|
| | cur_class_name_embedding_indices = class_name_embedding_indices[
|
| | batch_idx] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = refer_embedding_indices[
|
| | batch_idx] if refer_embedding_indices is not None else None
|
| | cur_region_feature_list = region_features[batch_idx] if region_features is not None else None
|
| | cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None
|
| |
|
| | cur_region_feature_list_exo = region_features_exo[batch_idx] if region_features_exo is not None else None
|
| | cur_region_embedding_mask_exo = region_embedding_masks_exo[batch_idx] if region_features_exo is not None else None
|
| |
|
| | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| |
|
| | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| |
|
| | cur_input_embeds = cur_input_embeds + (
|
| | 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| | new_input_embeds.append(cur_input_embeds)
|
| | new_input_embeds_exo.append(cur_input_embeds.detach().clone())
|
| | if labels is not None:
|
| | new_labels.append(labels[batch_idx])
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| |
|
| | continue
|
| |
|
| | if labels is not None:
|
| | cur_label = labels[batch_idx]
|
| | else:
|
| | cur_label = None
|
| |
|
| | if class_name_ids is not None:
|
| | cur_class_name_ids = class_name_ids[batch_idx]
|
| | cur_cls_indices = cls_indices[batch_idx]
|
| | else:
|
| | cur_class_name_ids = None
|
| | cur_cls_indices = None
|
| | if token_refer_id is not None:
|
| | cur_token_refer_id = token_refer_id[batch_idx]
|
| | else:
|
| | cur_token_refer_id = None
|
| |
|
| | cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices)
|
| | cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id)
|
| |
|
| |
|
| |
|
| | '''
|
| | cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds(
|
| | input_id=cur_input_ids,
|
| | img_feature=cur_image_feature,
|
| | label=cur_label,
|
| | seg_query=cur_seg_query,
|
| | seg_query_mask=cur_seg_query_mask,
|
| | class_embed=cur_class_name_embedding,
|
| | class_name_embedding_indices=cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask,
|
| | region_feature_list=cur_region_feature_list,
|
| | refer_embedding_indices=cur_refer_embedding_indices,
|
| | refer_embedding=cur_refer_embedding
|
| | )
|
| | '''
|
| | print('llava_phi_SSL: PSALM_SSL, prepare_inputs_for_multimodaal_SSL:', 'cur_region_embedding_mask', cur_region_embedding_mask.shape, cur_region_embedding_mask)
|
| | print('llava_phi_SSL: PSALM_SSL, prepare_inputs_for_multimodaal_SSL:', 'cur_region_feature_list', len(cur_region_feature_list), cur_region_feature_list[0].shape, cur_region_feature_list[0])
|
| | print('llava_phi_SSL: PSALM_SSL, prepare_inputs_for_multimodaal_SSL:', 'cur_region_embedding_mask_exo', cur_region_embedding_mask_exo.shape, cur_region_embedding_mask_exo)
|
| | print('llava_phi_SSL: PSALM_SSL, prepare_inputs_for_multimodaal_SSL:', 'cur_region_feature_list_exo', len(cur_region_feature_list_exo), cur_region_feature_list_exo[0].shape, cur_region_feature_list_exo[0])
|
| |
|
| |
|
| |
|
| | '''
|
| | cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_region_embedding_mask_exo, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds_SSL(
|
| | input_id=cur_input_ids,
|
| | img_feature=cur_image_feature,
|
| | label=cur_label,
|
| | seg_query=cur_seg_query,
|
| | seg_query_mask=cur_seg_query_mask,
|
| | class_embed=cur_class_name_embedding,
|
| | class_name_embedding_indices=cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask,
|
| | region_embedding_mask_exo=cur_region_embedding_mask_exo,
|
| | region_feature_list=cur_region_feature_list,
|
| | region_feature_list_exo=cur_region_feature_list_exo,
|
| | refer_embedding_indices=cur_refer_embedding_indices,
|
| | refer_embedding=cur_refer_embedding
|
| | )
|
| | '''
|
| |
|
| |
|
| | Init_cur_input_ids = cur_input_ids.clone()
|
| | Init_cur_image_feature = cur_image_feature.clone()
|
| | Init_cur_label = cur_label.clone()
|
| | Init_cur_seg_query = cur_seg_query.clone()
|
| | Init_cur_seg_query_mask = cur_seg_query_mask.clone()
|
| | if cur_class_name_embedding is not None:
|
| | Init_cur_class_name_embedding = cur_class_name_embedding.clone()
|
| | else:
|
| | Init_cur_class_name_embedding = cur_class_name_embedding
|
| | if cur_class_name_embedding_indices is not None:
|
| | Init_cur_class_name_embedding_indices = cur_class_name_embedding_indices.clone()
|
| | else:
|
| | Init_cur_class_name_embedding_indices = cur_class_name_embedding_indices
|
| | if cur_refer_embedding_indices is not None:
|
| | Init_cur_refer_embedding_indices = cur_refer_embedding_indices.clone()
|
| | else:
|
| | Init_cur_refer_embedding_indices = cur_refer_embedding_indices
|
| | if cur_refer_embedding is not None:
|
| | Init_cur_refer_embedding = cur_refer_embedding.clone()
|
| | else:
|
| | Init_cur_refer_embedding = cur_refer_embedding
|
| |
|
| | cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds(
|
| | input_id=cur_input_ids,
|
| | img_feature=cur_image_feature,
|
| | label=cur_label,
|
| | seg_query=cur_seg_query,
|
| | seg_query_mask=cur_seg_query_mask,
|
| | class_embed=cur_class_name_embedding,
|
| | class_name_embedding_indices=cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask,
|
| | region_feature_list=cur_region_feature_list,
|
| | refer_embedding_indices=cur_refer_embedding_indices,
|
| | refer_embedding=cur_refer_embedding
|
| | )
|
| |
|
| | cur_input_embeds_exo, cur_label_exo, cur_seg_query_mask_exo, cur_class_name_embedding_indices_exo, cur_region_embedding_mask_exo, cur_refer_embedding_indices_exo = self.concat_image_seg_cls_embeds(
|
| | input_id=Init_cur_input_ids,
|
| | img_feature=Init_cur_image_feature,
|
| | label=Init_cur_label,
|
| | seg_query=Init_cur_seg_query,
|
| | seg_query_mask=Init_cur_seg_query_mask,
|
| | class_embed=Init_cur_class_name_embedding,
|
| | class_name_embedding_indices=Init_cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask_exo,
|
| | region_feature_list=cur_region_feature_list_exo,
|
| | refer_embedding_indices=Init_cur_refer_embedding_indices,
|
| | refer_embedding=Init_cur_refer_embedding
|
| | )
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0]
|
| |
|
| | print('compare data after teh concat: cur_input_embeds', torch.mean(cur_input_embeds_exo-cur_input_embeds))
|
| | print('compare data after teh concat: cur_label', torch.mean((cur_label-cur_label_exo).float()))
|
| | print('compare data after teh concat: cur_seg_query_mask', torch.mean((cur_seg_query_mask - cur_seg_query_mask_exo).float()))
|
| | print('compare data after teh concat: cur_class_name_embedding_indices', cur_class_name_embedding_indices == cur_class_name_embedding_indices_exo)
|
| | print('compare data after teh concat: cur_region_embedding_mask', torch.mean((cur_region_embedding_mask - cur_region_embedding_mask_exo).float()))
|
| | print('compare data after teh concat: cur_refer_embedding_indices', cur_refer_embedding_indices == cur_refer_embedding_indices_exo)
|
| |
|
| |
|
| | '''
|
| | # for exo
|
| | new_input_embeds_exo = new_input_embeds #BUG: the new_input_embeds will be changed beacuse of new_input_embeds_exo
|
| | '''
|
| |
|
| |
|
| |
|
| |
|
| | new_input_embeds.append(cur_input_embeds)
|
| | print('HERE:', len(new_input_embeds), len(new_input_embeds_exo))
|
| | new_input_embeds_exo.append(cur_input_embeds_exo)
|
| | print('HERE 2:', len(new_input_embeds), len(new_input_embeds_exo))
|
| |
|
| | if labels is not None:
|
| | new_labels.append(cur_label)
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices.append(cur_class_name_embedding_indices)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices.append(cur_refer_embedding_indices)
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks.append(cur_region_embedding_mask)
|
| | new_region_embedding_masks_exo.append(cur_region_embedding_mask_exo)
|
| |
|
| |
|
| | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| | max_len = max(x.shape[0] for x in new_input_embeds)
|
| |
|
| | new_input_embeds_align = []
|
| | for cur_new_embed in new_input_embeds:
|
| | cur_new_embed = torch.cat((cur_new_embed,
|
| | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
| | dtype=cur_new_embed.dtype, device=cur_new_embed.device)),
|
| | dim=0)
|
| | new_input_embeds_align.append(cur_new_embed)
|
| | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| |
|
| |
|
| | new_input_embeds_align_exo = []
|
| | for cur_new_embed in new_input_embeds_exo:
|
| | cur_new_embed = torch.cat((cur_new_embed,
|
| | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
| | dtype=cur_new_embed.dtype, device=cur_new_embed.device)),
|
| | dim=0)
|
| | new_input_embeds_align_exo.append(cur_new_embed)
|
| | new_input_embeds_exo = torch.stack(new_input_embeds_align_exo, dim=0)
|
| |
|
| |
|
| | if labels is not None:
|
| | new_labels_align = []
|
| | _new_labels = new_labels
|
| | for cur_new_label in new_labels:
|
| | cur_new_label = torch.cat((cur_new_label,
|
| | torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX,
|
| | dtype=cur_new_label.dtype, device=cur_new_label.device)),
|
| | dim=0)
|
| | new_labels_align.append(cur_new_label)
|
| | new_labels = torch.stack(new_labels_align, dim=0)
|
| |
|
| | new_seg_query_masks_align = []
|
| | for new_seg_query_mask in new_seg_query_masks:
|
| | new_seg_query_mask = torch.cat(
|
| | (new_seg_query_mask,
|
| | torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype,
|
| | device=new_seg_query_mask.device)),
|
| | dim=0)
|
| | new_seg_query_masks_align.append(new_seg_query_mask)
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0)
|
| |
|
| | new_class_name_embedding_indices_align = []
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | for new_class_name_embedding_indice in new_class_name_embedding_indices:
|
| | new_class_name_embedding_indice = torch.cat(
|
| | (new_class_name_embedding_indice,
|
| | torch.zeros((max_len - new_class_name_embedding_indice.shape[0]),
|
| | dtype=new_class_name_embedding_indice.dtype,
|
| | device=new_class_name_embedding_indice.device)),
|
| | dim=0)
|
| | new_class_name_embedding_indices_align.append(new_class_name_embedding_indice)
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices_align = []
|
| | for new_refer_embedding_indice in new_refer_embedding_indices:
|
| | new_refer_embedding_indice = torch.cat(
|
| | (new_refer_embedding_indice,
|
| | torch.zeros((max_len - new_refer_embedding_indice.shape[0]),
|
| | dtype=new_refer_embedding_indice.dtype,
|
| | device=new_refer_embedding_indice.device)),
|
| | dim=0)
|
| | new_refer_embedding_indices_align.append(new_refer_embedding_indice)
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks_align = []
|
| | for new_region_embedding_mask in new_region_embedding_masks:
|
| | new_region_embedding_mask = torch.cat(
|
| | (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]),
|
| | dtype=new_region_embedding_mask.dtype,
|
| | device=new_region_embedding_mask.device)),
|
| | dim=0)
|
| | new_region_embedding_masks_align.append(new_region_embedding_mask)
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0)
|
| |
|
| |
|
| | new_region_embedding_masks_align_exo = []
|
| | for new_region_embedding_mask in new_region_embedding_masks_exo:
|
| | new_region_embedding_mask = torch.cat(
|
| | (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]),
|
| | dtype=new_region_embedding_mask.dtype,
|
| | device=new_region_embedding_mask.device)),
|
| | dim=0)
|
| | new_region_embedding_masks_align_exo.append(new_region_embedding_mask)
|
| | new_region_embedding_masks_exo = torch.stack(new_region_embedding_masks_align_exo, dim=0)
|
| |
|
| |
|
| | if attention_mask is not None:
|
| | new_attention_mask = []
|
| | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels,
|
| | new_labels):
|
| | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
|
| | False, dtype=attention_mask.dtype,
|
| | device=attention_mask.device)
|
| | cur_new_attention_mask = torch.cat(
|
| | (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| | new_attention_mask.append(cur_new_attention_mask)
|
| | attention_mask = torch.stack(new_attention_mask, dim=0)
|
| | assert attention_mask.shape == new_labels.shape
|
| |
|
| | else:
|
| | new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| | new_input_embeds_exo = torch.stack(new_input_embeds_exo, dim=0)
|
| |
|
| | if labels is not None:
|
| | new_labels = torch.stack(new_labels, dim=0)
|
| |
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0)
|
| | new_region_embedding_masks_exo = torch.stack(new_region_embedding_masks_exo, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attn_mask_pad_left = torch.full(
|
| | (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| | print('attention_mask:', attention_mask.shape)
|
| | print('new_input_embeds:', new_input_embeds.shape)
|
| | assert attention_mask.shape == new_input_embeds.shape[:2]
|
| |
|
| |
|
| |
|
| | return None, attention_mask, past_key_values, new_input_embeds, new_input_embeds_exo, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_region_embedding_masks_exo, new_refer_embedding_indices
|
| |
|
| |
|
| | def get_SEG_embedding(self,hidden_states, refer_embedding_indices):
|
| | refer_embedding_list = []
|
| | for current_hidden_state, current_token_indice in zip(hidden_states, refer_embedding_indices):
|
| | current_refer_state = current_hidden_state[current_token_indice.bool()]
|
| | current_pool_refer_state = self.refer_pooling(current_refer_state.transpose(-2, -1)).transpose(-2, -1)
|
| | refer_embedding_list.append(current_pool_refer_state)
|
| | return torch.stack(refer_embedding_list, dim=0)
|
| |
|
| |
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | vp_images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | random_idx=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | dataset_type=None,
|
| | ) -> Union[Tuple, CausalLMOutputWithPast]:
|
| | if dataset_type is not None:
|
| | assert all(item == dataset_type[0] for item in dataset_type), f'this batch contain different dataset_type: {dataset_type}'
|
| | batch_dataset_type = dataset_type[0]
|
| |
|
| | else:
|
| | batch_dataset_type = []
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| |
|
| | if (input_ids == SEG_TOKEN_INDEX).sum() != 0:
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| |
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, inputs_embeds_exo, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, region_embedding_masks_exo, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal_SSL(input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| | print('After the prepare_inputs_labels_for_multimodal_SSL:', inputs_embeds.shape, inputs_embeds_exo.shape, seg_query_mask.shape)
|
| | else:
|
| | seg_query_mask = None
|
| | class_name_embedding_indices = None
|
| | region_embedding_masks = None
|
| | region_embedding_masks_exo = None
|
| | SEG_token_indices = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.mm_conv_prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images)
|
| |
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| |
|
| |
|
| | '''
|
| | outputs_exo = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds_exo,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| | '''
|
| |
|
| |
|
| | if input_ids is not None:
|
| | input_ids_exo = input_ids.clone()
|
| | else:
|
| | input_ids_exo = input_ids
|
| | if attention_mask is not None:
|
| | attention_mask_exo = attention_mask.clone()
|
| | else:
|
| | attention_mask_exo = attention_mask
|
| | if past_key_values is not None:
|
| | past_key_values_exo = past_key_values.clone()
|
| | else:
|
| | past_key_values_exo = past_key_values
|
| | if use_cache is not None:
|
| | use_cache_exo = use_cache.clone()
|
| | else:
|
| | use_cache_exo = use_cache
|
| | output_attentions_exo = output_attentions
|
| | output_hidden_states_exo = output_hidden_states
|
| | return_dict_exo = return_dict
|
| |
|
| |
|
| | outputs_exo = self.model(
|
| | input_ids=input_ids_exo,
|
| | attention_mask=attention_mask_exo,
|
| | past_key_values=past_key_values_exo,
|
| | inputs_embeds=inputs_embeds_exo,
|
| | use_cache=use_cache_exo,
|
| | output_attentions=output_attentions_exo,
|
| | output_hidden_states=output_hidden_states_exo,
|
| | return_dict=return_dict_exo
|
| | )
|
| |
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | logits = self.lm_head(hidden_states)
|
| |
|
| | hidden_states_exo = outputs_exo.last_hidden_state
|
| |
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if class_name_embedding is not None:
|
| | class_name_embedding = torch.gather(class_name_embedding,dim=1,index=random_idx.unsqueeze(-1).repeat(1, 1, class_name_embedding.shape[-1]))
|
| |
|
| |
|
| |
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| |
|
| |
|
| | region_embedding_list_exo = self.get_region_embedding(hidden_states_exo, region_embedding_masks_exo)
|
| | region_embedding_list_exo = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list_exo]
|
| |
|
| | else:
|
| | region_embedding_list = None
|
| | if 'referring' in batch_dataset_type or 'region' in batch_dataset_type:
|
| | class_name_embedding = None
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | loss_region_emb_SSL = calculate_region_embedding_dis(region_embedding_list, region_embedding_list_exo)
|
| |
|
| |
|
| | loss = None
|
| | if labels is not None and seg_query_mask is None:
|
| |
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = labels[..., 1:].contiguous()
|
| |
|
| | loss_fct = CrossEntropyLoss()
|
| | shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| | shift_labels = shift_labels.view(-1)
|
| |
|
| | shift_labels = shift_labels.to(shift_logits.device)
|
| | llm_loss = loss_fct(shift_logits, shift_labels)
|
| |
|
| | if seg_query_mask is not None:
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| | if 'panoptic' in batch_dataset_type or 'region' in batch_dataset_type:
|
| | SEG_embedding = None
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| | if seg_info is not None:
|
| | if "instances" in seg_info[0]:
|
| | gt_instances = [x["instances"].to(self.device) for x in seg_info]
|
| |
|
| | targets = self.prepare_targets(gt_instances, images)
|
| | else:
|
| | targets = None
|
| |
|
| | mask_losses = self.criterion(mask_outputs, targets)
|
| | weight_dict = self.weight_dict
|
| |
|
| | loss_mask = 0.0
|
| | loss_dice = 0.0
|
| | loss_SEG_class = 0.0
|
| | loss_class_name_class = 0.0
|
| | loss_region_class = 0.0
|
| | for k in list(mask_losses.keys()):
|
| | if k in weight_dict:
|
| | if mask_losses[k] is not None:
|
| | mask_losses[k] *= weight_dict[k]
|
| | if '_SEG' in k and mask_losses[k] is not None:
|
| | loss_SEG_class += mask_losses[k]
|
| | elif '_name' in k and mask_losses[k] is not None:
|
| | loss_class_name_class += mask_losses[k]
|
| | elif '_mask' in k:
|
| | loss_mask += mask_losses[k]
|
| | elif '_dice' in k:
|
| | loss_dice += mask_losses[k]
|
| | elif '_region' in k and mask_losses[k] is not None:
|
| | loss_region_class += mask_losses[k]
|
| | else:
|
| | mask_losses.pop(k)
|
| |
|
| |
|
| |
|
| |
|
| | mask_loss = loss_mask + loss_dice + loss_SEG_class + loss_class_name_class + loss_region_class + loss_region_emb_SSL
|
| | print('loss_region_emb_SSL:', loss_region_emb_SSL)
|
| |
|
| | if isinstance(loss_class_name_class, float):
|
| | loss_class_name_class = torch.tensor(loss_class_name_class, device=mask_loss.device)
|
| | if isinstance(loss_SEG_class, float):
|
| | loss_SEG_class = torch.tensor(loss_SEG_class, device=mask_loss.device)
|
| | if isinstance(loss_region_class, float):
|
| | loss_region_class = torch.tensor(loss_region_class, device=mask_loss.device)
|
| | llm = torch.tensor(0.0, device=mask_loss.device)
|
| | if labels is not None:
|
| |
|
| | loss = mask_loss
|
| | print('llava_phi_SSL: PSALM_SSL: loss', loss)
|
| |
|
| | '''
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | loss_mask=loss_mask.detach(),
|
| | loss_dice=loss_dice.detach(),
|
| | loss_SEG_class=loss_SEG_class.detach(),
|
| | loss_class_name_class=loss_class_name_class.detach(),
|
| | loss_region_class=loss_region_class.detach(),
|
| | loss_llm=llm.detach(),
|
| | )
|
| | '''
|
| | return CausalOutputWithMaskSSL(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | loss_mask=loss_mask.detach(),
|
| | loss_dice=loss_dice.detach(),
|
| | loss_SEG_class=loss_SEG_class.detach(),
|
| | loss_class_name_class=loss_class_name_class.detach(),
|
| | loss_region_class=loss_region_class.detach(),
|
| | loss_llm=llm.detach(),
|
| | loss_region_emb_SSL = loss_region_emb_SSL.detach(),
|
| | )
|
| |
|
| | if labels is not None and seg_query_mask is None:
|
| | loss_mask = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_dice = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_SEG_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_class_name_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_region_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss = llm_loss
|
| | else:
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | )
|
| |
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | loss_mask=loss_mask.detach(),
|
| | loss_dice=loss_dice.detach(),
|
| | loss_SEG_class=loss_SEG_class.detach(),
|
| | loss_class_name_class=loss_class_name_class.detach(),
|
| | loss_region_class=loss_region_class.detach(),
|
| | loss_llm=llm_loss.detach(),
|
| | )
|
| |
|
| |
|
| | def mm_conv_prepare_inputs_labels_for_multimodal(
|
| | self, input_ids, attention_mask, past_key_values, labels, images
|
| | ):
|
| | vision_tower = self.get_vision_tower()
|
| | if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
| | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
| | return input_ids, attention_mask, past_key_values, None, labels
|
| |
|
| | if type(images) is list or images.ndim == 5:
|
| | concat_images = torch.cat([image for image in images], dim=0)
|
| | image_features = self.encode_images(concat_images)
|
| | split_sizes = [image.shape[0] for image in images]
|
| | image_features = torch.split(image_features, split_sizes, dim=0)
|
| | image_features = [x.flatten(0, 1) for x in image_features]
|
| | else:
|
| | image_features = self.encode_images(images)
|
| |
|
| | new_input_embeds = []
|
| | new_labels = [] if labels is not None else None
|
| | cur_image_idx = 0
|
| | for batch_idx, cur_input_ids in enumerate(input_ids):
|
| | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| |
|
| | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| |
|
| | cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(labels[batch_idx])
|
| | cur_image_idx += 1
|
| | continue
|
| | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
| | cur_new_input_embeds = []
|
| | if labels is not None:
|
| | cur_labels = labels[batch_idx]
|
| | cur_new_labels = []
|
| | assert cur_labels.shape == cur_input_ids.shape
|
| |
|
| | while image_token_indices.numel() > 0:
|
| | cur_image_features = image_features[cur_image_idx]
|
| | image_token_start = image_token_indices[0]
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
|
| | cur_new_input_embeds.append(cur_image_features)
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels[:image_token_start])
|
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| | cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
|
| | cur_labels = cur_labels[image_token_start+2:]
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
| | cur_new_input_embeds.append(cur_image_features)
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels[:image_token_start])
|
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| | cur_labels = cur_labels[image_token_start+1:]
|
| | cur_image_idx += 1
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_input_ids = cur_input_ids[image_token_start+2:]
|
| | else:
|
| | cur_input_ids = cur_input_ids[image_token_start+1:]
|
| | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
| | if cur_input_ids.numel() > 0:
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels)
|
| | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| | new_input_embeds.append(cur_new_input_embeds)
|
| | if labels is not None:
|
| | cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
| | new_labels.append(cur_new_labels)
|
| |
|
| |
|
| | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| | max_len = max(x.shape[0] for x in new_input_embeds)
|
| |
|
| | new_input_embeds_align = []
|
| | for cur_new_embed in new_input_embeds:
|
| | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
| | new_input_embeds_align.append(cur_new_embed)
|
| | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| |
|
| | if labels is not None:
|
| | new_labels_align = []
|
| | _new_labels = new_labels
|
| | for cur_new_label in new_labels:
|
| | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
| | new_labels_align.append(cur_new_label)
|
| | new_labels = torch.stack(new_labels_align, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attention_mask = []
|
| | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
| | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| | new_attention_mask.append(cur_new_attention_mask)
|
| | attention_mask = torch.stack(new_attention_mask, dim=0)
|
| | assert attention_mask.shape == new_labels.shape
|
| | else:
|
| | new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| | if labels is not None:
|
| | new_labels = torch.stack(new_labels, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| | assert attention_mask.shape == new_input_embeds.shape[:2]
|
| |
|
| | return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
| |
|
| | def get_seg_query(self, hidden_states, seg_query_masks):
|
| | seg_query_list = []
|
| | for sample_hidden_state, sample_query_mask in zip(hidden_states, seg_query_masks):
|
| | if torch.sum(sample_query_mask) == 0:
|
| | continue
|
| |
|
| | unique_query_value = torch.unique(sample_query_mask)
|
| | unique_query_value = unique_query_value[unique_query_value != 0]
|
| |
|
| | for value in unique_query_value:
|
| | current_query_mask = (sample_query_mask == value)
|
| | current_query = sample_hidden_state[current_query_mask]
|
| |
|
| | seg_query_list.append(current_query)
|
| |
|
| | seg_query = torch.stack(seg_query_list, dim=0)
|
| |
|
| | return seg_query
|
| |
|
| | def eval_seg(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | vp_images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | is_thing_list=None
|
| | ):
|
| | if self.panoptic_on:
|
| | assert is_thing_list is not None, 'is_thing_list need to be given'
|
| | self.is_thing_list = is_thing_list
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,
|
| | class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| | else:
|
| | region_embedding_list = None
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| |
|
| | SEG_cls_results = mask_outputs['pred_SEG_logits']
|
| | class_name_cls_results = mask_outputs['pred_class_name_logits']
|
| | mask_pred_results = mask_outputs["pred_masks"]
|
| | region_cls_results = mask_outputs['pred_region_logits']
|
| | images = [x for x in images]
|
| | images = ImageList.from_tensors(images, self.size_divisibility)
|
| | mask_pred_results = F.interpolate(
|
| | mask_pred_results,
|
| | size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
| | mode="bilinear",
|
| | align_corners=False,
|
| | )
|
| | del mask_outputs
|
| | processed_results = []
|
| | if SEG_cls_results is None:
|
| | SEG_cls_results = [None]
|
| | if class_name_cls_results is None:
|
| | class_name_cls_results = [None]
|
| | for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
| | seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes
|
| | ):
|
| | height = input_per_image.get("height", image_size[0])
|
| | width = input_per_image.get("width", image_size[1])
|
| | padding_mask = input_per_image.get("padding_mask")
|
| | non_padding_indices = np.where(~ np.array(padding_mask))
|
| | min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0])
|
| | min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1])
|
| | original_height = max_y - min_y + 1
|
| | original_width = max_x - min_x + 1
|
| | processed_results.append({})
|
| |
|
| | if self.sem_seg_postprocess_before_inference:
|
| | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | mask_pred_result, [original_height, original_width], height, width
|
| | )
|
| | if SEG_cls_result is not None:
|
| | SEG_cls_result = SEG_cls_result.to(mask_pred_result)
|
| |
|
| | if self.semantic_on:
|
| | semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | if not self.sem_seg_postprocess_before_inference:
|
| | semantic_r = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | semantic_r, [original_height, original_width], height, width
|
| | )
|
| | processed_results[-1]["sem_seg"] = semantic_r
|
| |
|
| | if self.instance_on:
|
| | instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.panoptic_on:
|
| | panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["panoptic_seg"] = panoptic_r
|
| | if self.referring_on:
|
| | instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.region_on:
|
| | gt = _seg_info['instances'].gt_masks
|
| | gt_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | gt, [original_height, original_width], height, width
|
| | )
|
| | region_cls_results = region_cls_results[0].to(mask_pred_result)
|
| | instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | processed_results[-1]["gt"] = gt_result
|
| | return processed_results
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class PSALM(PhiForCausalLM, LlavaMetaForCausalLM):
|
| | config_class = LlavaConfig
|
| |
|
| | def __init__(self, config, mask_decoder_cfg=None, add_cross_attn=True, cross_attn_index=None):
|
| | super(PSALM, self).__init__(config)
|
| |
|
| | self.model = PSALMModel(config, mask_decoder_cfg)
|
| | self.init_config = config
|
| | self.mask_decoder_cfg = mask_decoder_cfg
|
| | self.cross_attn_index = cross_attn_index
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| |
|
| | is_train_mask_decode = getattr(config, 'mask_decode_train', False)
|
| | self.is_train_mask_decode = is_train_mask_decode
|
| | self.refer_pooling = nn.AdaptiveAvgPool1d(output_size=1)
|
| | self.class_name_pooling = nn.AdaptiveAvgPool1d(output_size=1)
|
| | self.region_sampler = region_pooling(num_sample_point=256)
|
| | self.region_projector = nn.Linear(config.hidden_size, mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| |
|
| | if is_train_mask_decode:
|
| | print('Mask Decoder has been trained, init directly')
|
| | self.initial_mask_module()
|
| | self.post_init()
|
| |
|
| | def initial_mask_module(self, pretrained_path=None, model_args=None):
|
| | if not self.is_train_mask_decode:
|
| | print('Initialize mask modules...')
|
| | self.config.mask_decode_train = True
|
| | self.seg_query = nn.Parameter(
|
| | torch.zeros([self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, self.config.hidden_size]))
|
| | self.num_queries = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | self.num_classes = self.mask_decoder_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES
|
| | self.test_topk_per_image = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | input_shape = self.output_shape()
|
| | self.pixel_decoder = self.pixel_decoder_init(cfg=self.mask_decoder_cfg, input_shape=input_shape)
|
| | self.predictor = self.predictor_init(cfg=self.mask_decoder_cfg)
|
| |
|
| | self.seg_query_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| | self.SEG_token_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| | self.class_name_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM)
|
| |
|
| | self.mask_decoder_training_init(self.mask_decoder_cfg)
|
| | if pretrained_path is not None:
|
| | def get_w(weights, keyword):
|
| | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
|
| | def change_w(weights, old_name, new_name):
|
| | weights[new_name] = weights[old_name]
|
| | weights.pop(old_name)
|
| |
|
| | if pretrained_path.endswith('.pkl'):
|
| | with open(pretrained_path, 'rb') as f:
|
| | ckpt = pickle.load(f)
|
| | else:
|
| | ckpt = torch.load(pretrained_path)
|
| | pixel_decoder_weights = get_w(ckpt['model'],'sem_seg_head.pixel_decoder')
|
| | predictor_weights = get_w(ckpt['model'],'sem_seg_head.predictor')
|
| | pixel_decoder_weights = {k: torch.tensor(v) for k, v in pixel_decoder_weights.items()}
|
| | predictor_weights = {k: torch.tensor(v) for k, v in predictor_weights.items()}
|
| |
|
| |
|
| | change_w(pixel_decoder_weights,'adapter_1.weight','adapter_1.0.weight')
|
| | change_w(pixel_decoder_weights,'adapter_1.norm.weight','adapter_1.1.weight')
|
| | change_w(pixel_decoder_weights,'adapter_1.norm.bias','adapter_1.1.bias')
|
| | change_w(pixel_decoder_weights,'layer_1.weight','layer_1.0.weight')
|
| | change_w(pixel_decoder_weights,'layer_1.norm.weight','layer_1.1.weight')
|
| | change_w(pixel_decoder_weights,'layer_1.norm.bias','layer_1.1.bias')
|
| | if 'static_query.weight' in predictor_weights:
|
| | change_w(predictor_weights,'static_query.weight','query_feat.weight')
|
| | if predictor_weights['query_embed.weight'].shape[0] == 200:
|
| | predictor_weights['query_embed.weight'] = predictor_weights['query_embed.weight'][:100,:]
|
| | diff_pixel_msg = self.pixel_decoder.load_state_dict(pixel_decoder_weights,strict=False)
|
| | diff_predictor_msg = self.predictor.load_state_dict(predictor_weights,strict=False)
|
| | print(diff_predictor_msg)
|
| | print(diff_pixel_msg)
|
| |
|
| |
|
| | def get_vision_tower_feature(self, images):
|
| | features = self.get_model().get_vision_tower()(images)
|
| | features_dict = {
|
| | 'res2': features[0],
|
| | 'res3': features[1],
|
| | 'res4': features[2],
|
| | 'res5': features[3],
|
| | }
|
| | return features_dict
|
| | def mask_decoder_training_init(self, cfg):
|
| |
|
| | deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION
|
| | no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT
|
| |
|
| |
|
| | class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT
|
| | dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT
|
| | mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT
|
| |
|
| |
|
| | matcher = hungarian_matcher_PSALM(
|
| | cost_class=class_weight,
|
| | cost_mask=mask_weight,
|
| | cost_dice=dice_weight,
|
| | num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
| | )
|
| |
|
| | weight_dict = {"loss_SEG_class": class_weight, "loss_class_name_class": class_weight, "loss_mask": mask_weight,
|
| | "loss_dice": dice_weight, "loss_region_class": class_weight}
|
| | self.weight_dict = weight_dict
|
| | if deep_supervision:
|
| | dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS
|
| | aux_weight_dict = {}
|
| | for i in range(dec_layers - 1):
|
| | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()})
|
| | weight_dict.update(aux_weight_dict)
|
| | losses = ["SEG_labels", "class_name_labels", "masks", "region_labels"]
|
| | self.criterion = PSALM_criterion(
|
| | matcher=matcher,
|
| | losses=losses,
|
| | num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS,
|
| | oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO,
|
| | importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO,
|
| | device=self.device
|
| | )
|
| | self.size_divisibility = 32
|
| | if cfg.MODEL.MASK_FORMER.SEG_TASK == 'semantic':
|
| | self.semantic_on = True
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = False
|
| |
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'instance':
|
| | self.semantic_on = False
|
| | self.instance_on = True
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'panoptic':
|
| | self.semantic_on = True
|
| | self.instance_on = True
|
| | self.panoptic_on = True
|
| | self.referring_on = False
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'referring':
|
| | self.semantic_on = False
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = True
|
| | self.region_on = False
|
| | elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'region':
|
| | self.semantic_on = False
|
| | self.instance_on = False
|
| | self.panoptic_on = False
|
| | self.referring_on = False
|
| | self.region_on = True
|
| | else:
|
| | raise NotImplementedError
|
| | self.sem_seg_postprocess_before_inference = self.instance_on or self.panoptic_on or self.referring_on or self.region_on
|
| | def get_region_embedding(self, hidden_states, region_embedding_masks):
|
| | region_embedding_list = []
|
| | for sample_hidden_satates, sample_region_embedding_masks in zip(hidden_states, region_embedding_masks):
|
| | sample_region_embedding = sample_hidden_satates[sample_region_embedding_masks.bool()]
|
| | region_embedding_list.append(sample_region_embedding)
|
| | return region_embedding_list
|
| | def SEG_instance_inference(self, SEG_cls, mask_pred):
|
| |
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | scores = F.sigmoid(SEG_cls)
|
| | scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
| |
|
| | mask_pred = mask_pred[topk_indices]
|
| |
|
| | result = Instances(image_size)
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = scores_per_image * mask_scores_per_image
|
| | return result
|
| | def class_name_panoptic_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| |
|
| | scores, labels = F.softmax(class_name_cls, dim=-1).max(-1)
|
| | num_classes = class_name_cls.shape[-1] - 1
|
| | mask_pred = mask_pred.sigmoid()
|
| |
|
| | object_mask_threshold = 0.8
|
| | overlap_threshold = 0.8
|
| |
|
| | keep = labels.ne(num_classes) & (scores > object_mask_threshold)
|
| | cur_scores = scores[keep]
|
| | cur_classes = labels[keep]
|
| | cur_masks = mask_pred[keep]
|
| | cur_mask_cls = class_name_cls[keep]
|
| | cur_mask_cls = cur_mask_cls[:, :-1]
|
| |
|
| | cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks
|
| |
|
| | h, w = cur_masks.shape[-2:]
|
| | panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device)
|
| | segments_info = []
|
| |
|
| | current_segment_id = 0
|
| |
|
| | if cur_masks.shape[0] == 0:
|
| |
|
| | return panoptic_seg, segments_info
|
| | else:
|
| |
|
| | cur_mask_ids = cur_prob_masks.argmax(0)
|
| | stuff_memory_list = {}
|
| | for k in range(cur_classes.shape[0]):
|
| | pred_class = cur_classes[k].item()
|
| | isthing = self.is_thing_list[pred_class]
|
| | mask_area = (cur_mask_ids == k).sum().item()
|
| | original_area = (cur_masks[k] >= 0.5).sum().item()
|
| | mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)
|
| |
|
| | if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
|
| | if mask_area / original_area < overlap_threshold:
|
| | continue
|
| |
|
| |
|
| | if not isthing:
|
| | if int(pred_class) in stuff_memory_list.keys():
|
| | panoptic_seg[mask] = stuff_memory_list[int(pred_class)]
|
| | continue
|
| | else:
|
| | stuff_memory_list[int(pred_class)] = current_segment_id + 1
|
| |
|
| | current_segment_id += 1
|
| | panoptic_seg[mask] = current_segment_id
|
| |
|
| | segments_info.append(
|
| | {
|
| | "id": current_segment_id,
|
| | "isthing": bool(isthing),
|
| | "category_id": int(pred_class),
|
| | }
|
| | )
|
| |
|
| | return panoptic_seg, segments_info
|
| | def region_inference(self, region_cls, mask_pred):
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | scores = F.sigmoid(region_cls)
|
| |
|
| |
|
| | result = Instances(image_size)
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = (scores * mask_scores_per_image[None,...].repeat(scores.shape[0],1)).transpose(1,0)
|
| | return result
|
| |
|
| | def class_name_semantic_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| | mask_cls = F.softmax(class_name_cls, dim=-1)[:, :-1]
|
| | mask_pred = mask_pred.sigmoid()
|
| | semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred)
|
| | return semseg
|
| | def class_name_instance_inference(self, SEG_cls, class_name_cls, mask_pred):
|
| | image_size = mask_pred.shape[-2:]
|
| |
|
| | cls_scores = F.softmax(class_name_cls, dim=-1)[:, :-1]
|
| | scores = cls_scores
|
| |
|
| | num_classes = scores.shape[-1]
|
| |
|
| | labels = torch.arange(num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1)
|
| |
|
| | scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False)
|
| |
|
| | labels_per_image = labels[topk_indices]
|
| |
|
| | topk_indices = topk_indices // num_classes
|
| | mask_pred = mask_pred[topk_indices]
|
| |
|
| |
|
| |
|
| | if self.panoptic_on:
|
| | keep = torch.zeros_like(scores_per_image).bool()
|
| | for i, lab in enumerate(labels_per_image):
|
| | keep[i] = self.is_thing_list[lab]
|
| |
|
| | scores_per_image = scores_per_image[keep]
|
| | labels_per_image = labels_per_image[keep]
|
| | mask_pred = mask_pred[keep]
|
| |
|
| | result = Instances(image_size)
|
| |
|
| | result.pred_masks = (mask_pred > 0).float()
|
| | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4))
|
| |
|
| |
|
| |
|
| |
|
| | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (
|
| | result.pred_masks.flatten(1).sum(1) + 1e-6)
|
| | result.scores = scores_per_image * mask_scores_per_image
|
| | result.pred_classes = labels_per_image
|
| | return result
|
| | def encode_images(self, images):
|
| | image_features = self.get_model().get_vision_tower()(images)
|
| | image_features = self.get_model().mm_projector(image_features[-1])
|
| | return image_features
|
| |
|
| | def predictor_init(self, cfg):
|
| | in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
| | hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM
|
| | num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
|
| | nheads = cfg.MODEL.MASK_FORMER.NHEADS
|
| | dim_feedforward = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD
|
| | dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1
|
| | pre_norm = cfg.MODEL.MASK_FORMER.PRE_NORM
|
| | mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
| | enforce_input_project = False
|
| | seg_norm = cfg.MODEL.MASK_FORMER.SEG_NORM
|
| | seg_proj = cfg.MODEL.MASK_FORMER.SEG_PROJ
|
| | seg_fuse_score = cfg.MODEL.MASK_FORMER.FUSE_SCORE
|
| | seg_concat = False
|
| | print(f'current seg concat mode: {seg_concat}, seg_norm: {seg_norm}, seg_proj: {seg_proj}, seg_fuse_score: {seg_fuse_score}')
|
| | predictor = MultiScaleMaskedTransformerDecoderForOPTPreTrain(in_channels,
|
| | hidden_dim,
|
| | num_queries,
|
| | nheads,
|
| | dim_feedforward,
|
| | dec_layers,
|
| | pre_norm,
|
| | mask_dim,
|
| | enforce_input_project,
|
| | seg_norm,
|
| | seg_concat,
|
| | seg_proj,
|
| | seg_fuse_score)
|
| | return predictor
|
| |
|
| |
|
| | def get_model(self):
|
| | return self.model
|
| | def output_shape(self):
|
| | out_features = self.mask_decoder_cfg.MODEL.SWIN.OUT_FEATURES
|
| | out_feature_strides = {
|
| | "res2": 4,
|
| | "res3": 8,
|
| | "res4": 16,
|
| | "res5": 32,
|
| | }
|
| | num_features = [int(self.mask_decoder_cfg.MODEL.SWIN.EMBED_DIM * 2 ** i) for i in
|
| | range(len(self.mask_decoder_cfg.MODEL.SWIN.DEPTHS))]
|
| | out_feature_channels = {
|
| | "res2": num_features[0],
|
| | "res3": num_features[1],
|
| | "res4": num_features[2],
|
| | "res5": num_features[3],
|
| | }
|
| | backbone_feature_shape = dict()
|
| | for name in out_features:
|
| | backbone_feature_shape[name] = Dict(
|
| | {'channel': out_feature_channels[name], 'stride': out_feature_strides[name]})
|
| | return backbone_feature_shape
|
| |
|
| | def get_encoder_image(self, images):
|
| | encode_image_features = self.get_model().get_vision_tower()(images)
|
| | return encode_image_features
|
| |
|
| | def pixel_decoder_init(self, cfg, input_shape):
|
| | common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE
|
| | transformer_dropout = cfg.MODEL.MASK_FORMER.DROPOUT
|
| | transformer_nheads = cfg.MODEL.MASK_FORMER.NHEADS
|
| | transformer_dim_feedforward = 1024
|
| | transformer_enc_layers = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS
|
| | conv_dim = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM
|
| | mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM
|
| | transformer_in_features = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES
|
| |
|
| | pixel_decoder = MSDeformAttnPixelDecoder(input_shape,
|
| | transformer_dropout,
|
| | transformer_nheads,
|
| | transformer_dim_feedforward,
|
| | transformer_enc_layers,
|
| | conv_dim,
|
| | mask_dim,
|
| | transformer_in_features,
|
| | common_stride)
|
| | return pixel_decoder
|
| | def prepare_targets(self, targets, images):
|
| | h_pad, w_pad = images.shape[-2:]
|
| | new_targets = []
|
| | for targets_per_image in targets:
|
| |
|
| | gt_masks = targets_per_image.gt_masks
|
| | padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device)
|
| | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks
|
| | new_targets.append(
|
| | {
|
| | "labels": targets_per_image.gt_classes,
|
| | "masks": padded_masks,
|
| | }
|
| | )
|
| | return new_targets
|
| |
|
| | def get_special_token(self, SEG, EOS):
|
| | self.SEG_id = SEG
|
| | self.EOS_id = EOS
|
| |
|
| | def get_class_name_embedding(self, hidden_states, cls_token_indices):
|
| | class_name_embedding_list = []
|
| | for current_hidden_state, current_token_indice in zip(hidden_states, cls_token_indices):
|
| | class_id = torch.unique(current_token_indice)
|
| | class_id = class_id[class_id != 0]
|
| | current_class_name_embedding_list = []
|
| | for id in class_id:
|
| | current_class_mask = (current_token_indice == id)
|
| | current_class_state = current_hidden_state[current_class_mask]
|
| | current_class_name_embedding_list.append(current_class_state)
|
| | current_pool_class_name_embedding = [self.class_name_pooling(class_name.transpose(-2, -1)).transpose(-2, -1)
|
| | for class_name in current_class_name_embedding_list]
|
| | class_name_embedding_list.append(torch.cat(current_pool_class_name_embedding, dim=0))
|
| | return torch.stack(class_name_embedding_list, dim=0)
|
| | def embed_class_ids(self, class_name_ids, cls_indices):
|
| | if class_name_ids is None:
|
| | return None
|
| | num_class = cls_indices.unique_consecutive()
|
| | num_class = num_class[num_class >= 0]
|
| | class_name_ids = [class_name_ids[cls_indices == idx] for idx in num_class]
|
| | embedded_class_name = [self.get_model().embed_tokens(id) for id in class_name_ids]
|
| |
|
| | return embedded_class_name
|
| |
|
| | def embed_refer_ids(self, refer_ids):
|
| | if refer_ids is None:
|
| | return None
|
| | embedded_refer = self.get_model().embed_tokens(refer_ids)
|
| | return embedded_refer
|
| |
|
| | def concat_image_seg_cls_embeds(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed,
|
| | class_name_embedding_indices,region_embedding_mask=None, region_feature_list=None, refer_embedding_indices=None,
|
| | refer_embedding=None):
|
| | image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0]
|
| | seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0]
|
| | cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0]
|
| | region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0]
|
| | assert len(image_token_indices) == 1, 'not supporting multi image index'
|
| | assert len(seg_query_indices) == 1, 'not supporting multi seg index'
|
| | if class_name_embedding_indices is not None:
|
| | assert len(cls_token_indices) == len(class_embed), 'the number of <cls> tokens and class_embed needs to be same'
|
| | if region_feature_list is not None:
|
| | assert len(region_feature_list) == len(
|
| | region_token_indices), 'the munber of <region> tokens and regions needs to be same'
|
| | cur_new_input_embeds = []
|
| | cur_new_seg_query_mask = []
|
| | if label is not None:
|
| | cur_new_label = []
|
| | assert label.shape == input_id.shape
|
| | else:
|
| | cur_new_label = None
|
| | cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| |
|
| | if region_embedding_mask is not None:
|
| | enable_region_mask = True
|
| | cur_new_region_embedding_mask = []
|
| | else:
|
| | enable_region_mask = False
|
| | cur_new_region_embedding_mask = None
|
| | chunks = []
|
| | current_chunk = []
|
| |
|
| | for id in input_id:
|
| | if id >= 0:
|
| | current_chunk.append(id.item())
|
| | else:
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| | current_chunk = []
|
| | chunks.append([id])
|
| | if current_chunk:
|
| | chunks.append(torch.tensor(current_chunk, device=input_id.device))
|
| |
|
| | cls_idx = 0
|
| | region_idx = 0
|
| | for chunk in chunks:
|
| | chunk_len = len(chunk)
|
| | if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(img_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((img_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX:
|
| | cur_new_input_embeds.append(seg_query)
|
| | cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device,
|
| | dtype=label.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX:
|
| | cls_embed = class_embed[cls_idx]
|
| | if len(cls_embed.shape) == 1:
|
| | cls_embed = cls_embed.unsqueeze(0)
|
| | cls_idx += 1
|
| | cur_new_input_embeds.append(cls_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((cls_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX:
|
| | region_feature = region_feature_list[region_idx]
|
| | region_idx += 1
|
| | cur_new_input_embeds.append(region_feature)
|
| | cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((region_feature.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0]))
|
| | elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX:
|
| | refer_embed = refer_embedding
|
| | if len(refer_embed.shape) == 1:
|
| | refer_embed = refer_embed.unsqueeze(0)
|
| | cur_new_input_embeds.append(refer_embed)
|
| | cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0]))
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 0, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(
|
| | torch.full((refer_embed.shape[0],), 1, device=input_id.device,
|
| | dtype=input_id.dtype))
|
| | if label is not None:
|
| | cur_new_label.append(
|
| | torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device,
|
| | dtype=label.dtype)
|
| | )
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len]))
|
| | cur_new_seg_query_mask.append(seg_query_mask[:chunk_len])
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len])
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len])
|
| | if label is not None:
|
| | cur_new_label.append(label[:chunk_len])
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len])
|
| |
|
| | input_id = input_id[chunk_len:]
|
| | seg_query_mask = seg_query_mask[chunk_len:]
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding_indices = class_name_embedding_indices[chunk_len:]
|
| | if refer_embedding_indices is not None:
|
| | refer_embedding_indices = refer_embedding_indices[chunk_len:]
|
| | if label is not None:
|
| | label = label[chunk_len:]
|
| | if enable_region_mask:
|
| | region_embedding_mask = region_embedding_mask[chunk_len:]
|
| |
|
| | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| | if label is not None:
|
| | cur_new_label = [x.to(device=self.device) for x in cur_new_label]
|
| | cur_new_label = torch.cat(cur_new_label, dim=0)
|
| | cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask]
|
| | cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices]
|
| | cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices]
|
| | cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0)
|
| |
|
| | if enable_region_mask:
|
| | cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask]
|
| | cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0)
|
| |
|
| | return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_refer_embedding_indices
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def prepare_inputs_labels_for_multimodal(
|
| | self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None,
|
| | class_name_embedding_indices=None,
|
| | class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None
|
| | ):
|
| | vision_tower = self.get_vision_tower()
|
| |
|
| | seg_query_mask = torch.zeros_like(input_ids)
|
| |
|
| | if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
| | 1] == 1:
|
| | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask
|
| |
|
| | if type(images) is list or images.ndim == 5:
|
| | concat_images = torch.cat([image for image in images], dim=0)
|
| | image_features = self.encode_images(concat_images)
|
| | split_sizes = [image.shape[0] for image in images]
|
| | image_features = torch.split(image_features, split_sizes, dim=0)
|
| | image_features = [x.flatten(0, 1) for x in image_features]
|
| | else:
|
| | image_features = self.encode_images(images)
|
| |
|
| |
|
| | expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1)
|
| |
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None:
|
| | region_masks_list = [instance.vp_region_masks.tensor for instance in instances]
|
| |
|
| | vp_image_features = self.encode_images(vp_images)
|
| |
|
| |
|
| | region_features = self.region_sampler(vp_image_features, region_masks_list,
|
| | original_dtype=vp_image_features.dtype,
|
| | return_dtype=vp_image_features.dtype)
|
| | region_embedding_masks = torch.zeros_like(input_ids)
|
| | else:
|
| | region_features = None
|
| | region_embedding_masks = None
|
| | new_input_embeds = []
|
| | new_labels = [] if labels is not None else None
|
| | new_seg_query_masks = []
|
| | new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | new_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| | new_region_embedding_masks = [] if region_features is not None else None
|
| | for batch_idx, cur_input_ids in enumerate(input_ids):
|
| | cur_seg_query_mask = seg_query_mask[batch_idx]
|
| | cur_seg_query = expanded_seg_query[batch_idx]
|
| | cur_image_feature = image_features[batch_idx]
|
| | cur_class_name_embedding_indices = class_name_embedding_indices[
|
| | batch_idx] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = refer_embedding_indices[
|
| | batch_idx] if refer_embedding_indices is not None else None
|
| | cur_region_feature_list = region_features[batch_idx] if region_features is not None else None
|
| | cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None
|
| | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| |
|
| | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| |
|
| | cur_input_embeds = cur_input_embeds + (
|
| | 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(labels[batch_idx])
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| |
|
| | continue
|
| |
|
| | if labels is not None:
|
| | cur_label = labels[batch_idx]
|
| | else:
|
| | cur_label = None
|
| |
|
| | if class_name_ids is not None:
|
| | cur_class_name_ids = class_name_ids[batch_idx]
|
| | cur_cls_indices = cls_indices[batch_idx]
|
| | else:
|
| | cur_class_name_ids = None
|
| | cur_cls_indices = None
|
| | if token_refer_id is not None:
|
| | cur_token_refer_id = token_refer_id[batch_idx]
|
| | else:
|
| | cur_token_refer_id = None
|
| |
|
| | cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices)
|
| | cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id)
|
| |
|
| |
|
| |
|
| | cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds(
|
| | input_id=cur_input_ids,
|
| | img_feature=cur_image_feature,
|
| | label=cur_label,
|
| | seg_query=cur_seg_query,
|
| | seg_query_mask=cur_seg_query_mask,
|
| | class_embed=cur_class_name_embedding,
|
| | class_name_embedding_indices=cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask,
|
| | region_feature_list=cur_region_feature_list,
|
| | refer_embedding_indices=cur_refer_embedding_indices,
|
| | refer_embedding=cur_refer_embedding
|
| | )
|
| | assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0]
|
| |
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(cur_label)
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices.append(cur_class_name_embedding_indices)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices.append(cur_refer_embedding_indices)
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks.append(cur_region_embedding_mask)
|
| |
|
| |
|
| | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| | max_len = max(x.shape[0] for x in new_input_embeds)
|
| |
|
| | new_input_embeds_align = []
|
| | for cur_new_embed in new_input_embeds:
|
| | cur_new_embed = torch.cat((cur_new_embed,
|
| | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
| | dtype=cur_new_embed.dtype, device=cur_new_embed.device)),
|
| | dim=0)
|
| | new_input_embeds_align.append(cur_new_embed)
|
| | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| |
|
| | if labels is not None:
|
| | new_labels_align = []
|
| | _new_labels = new_labels
|
| | for cur_new_label in new_labels:
|
| | cur_new_label = torch.cat((cur_new_label,
|
| | torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX,
|
| | dtype=cur_new_label.dtype, device=cur_new_label.device)),
|
| | dim=0)
|
| | new_labels_align.append(cur_new_label)
|
| | new_labels = torch.stack(new_labels_align, dim=0)
|
| |
|
| | new_seg_query_masks_align = []
|
| | for new_seg_query_mask in new_seg_query_masks:
|
| | new_seg_query_mask = torch.cat(
|
| | (new_seg_query_mask,
|
| | torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype,
|
| | device=new_seg_query_mask.device)),
|
| | dim=0)
|
| | new_seg_query_masks_align.append(new_seg_query_mask)
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0)
|
| |
|
| | new_class_name_embedding_indices_align = []
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | for new_class_name_embedding_indice in new_class_name_embedding_indices:
|
| | new_class_name_embedding_indice = torch.cat(
|
| | (new_class_name_embedding_indice,
|
| | torch.zeros((max_len - new_class_name_embedding_indice.shape[0]),
|
| | dtype=new_class_name_embedding_indice.dtype,
|
| | device=new_class_name_embedding_indice.device)),
|
| | dim=0)
|
| | new_class_name_embedding_indices_align.append(new_class_name_embedding_indice)
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices_align = []
|
| | for new_refer_embedding_indice in new_refer_embedding_indices:
|
| | new_refer_embedding_indice = torch.cat(
|
| | (new_refer_embedding_indice,
|
| | torch.zeros((max_len - new_refer_embedding_indice.shape[0]),
|
| | dtype=new_refer_embedding_indice.dtype,
|
| | device=new_refer_embedding_indice.device)),
|
| | dim=0)
|
| | new_refer_embedding_indices_align.append(new_refer_embedding_indice)
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks_align = []
|
| | for new_region_embedding_mask in new_region_embedding_masks:
|
| | new_region_embedding_mask = torch.cat(
|
| | (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]),
|
| | dtype=new_region_embedding_mask.dtype,
|
| | device=new_region_embedding_mask.device)),
|
| | dim=0)
|
| | new_region_embedding_masks_align.append(new_region_embedding_mask)
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attention_mask = []
|
| | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels,
|
| | new_labels):
|
| | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
|
| | False, dtype=attention_mask.dtype,
|
| | device=attention_mask.device)
|
| | cur_new_attention_mask = torch.cat(
|
| | (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| | new_attention_mask.append(cur_new_attention_mask)
|
| | attention_mask = torch.stack(new_attention_mask, dim=0)
|
| | assert attention_mask.shape == new_labels.shape
|
| |
|
| | else:
|
| | new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| | if labels is not None:
|
| | new_labels = torch.stack(new_labels, dim=0)
|
| |
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attn_mask_pad_left = torch.full(
|
| | (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| | assert attention_mask.shape == new_input_embeds.shape[:2]
|
| |
|
| | return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices
|
| |
|
| |
|
| | def get_SEG_embedding(self,hidden_states, refer_embedding_indices):
|
| | refer_embedding_list = []
|
| | for current_hidden_state, current_token_indice in zip(hidden_states, refer_embedding_indices):
|
| | current_refer_state = current_hidden_state[current_token_indice.bool()]
|
| | current_pool_refer_state = self.refer_pooling(current_refer_state.transpose(-2, -1)).transpose(-2, -1)
|
| | refer_embedding_list.append(current_pool_refer_state)
|
| | return torch.stack(refer_embedding_list, dim=0)
|
| | def forward(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | vp_images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | random_idx=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | dataset_type=None,
|
| | ) -> Union[Tuple, CausalLMOutputWithPast]:
|
| | if dataset_type is not None:
|
| | assert all(item == dataset_type[0] for item in dataset_type), f'this batch contain different dataset_type: {dataset_type}'
|
| | batch_dataset_type = dataset_type[0]
|
| |
|
| | else:
|
| | batch_dataset_type = []
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| |
|
| | if (input_ids == SEG_TOKEN_INDEX).sum() != 0:
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,
|
| | class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| | else:
|
| | seg_query_mask = None
|
| | class_name_embedding_indices = None
|
| | region_embedding_masks = None
|
| | SEG_token_indices = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.mm_conv_prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images)
|
| |
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | logits = self.lm_head(hidden_states)
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if class_name_embedding is not None:
|
| | class_name_embedding = torch.gather(class_name_embedding,dim=1,index=random_idx.unsqueeze(-1).repeat(1, 1, class_name_embedding.shape[-1]))
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| | else:
|
| | region_embedding_list = None
|
| | if 'referring' in batch_dataset_type or 'region' in batch_dataset_type:
|
| | class_name_embedding = None
|
| |
|
| |
|
| | loss = None
|
| | if labels is not None and seg_query_mask is None:
|
| |
|
| | shift_logits = logits[..., :-1, :].contiguous()
|
| | shift_labels = labels[..., 1:].contiguous()
|
| |
|
| | loss_fct = CrossEntropyLoss()
|
| | shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| | shift_labels = shift_labels.view(-1)
|
| |
|
| | shift_labels = shift_labels.to(shift_logits.device)
|
| | llm_loss = loss_fct(shift_logits, shift_labels)
|
| |
|
| | if seg_query_mask is not None:
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| | if 'panoptic' in batch_dataset_type or 'region' in batch_dataset_type:
|
| | SEG_embedding = None
|
| |
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| | if seg_info is not None:
|
| | if "instances" in seg_info[0]:
|
| | gt_instances = [x["instances"].to(self.device) for x in seg_info]
|
| |
|
| | targets = self.prepare_targets(gt_instances, images)
|
| | else:
|
| | targets = None
|
| |
|
| | mask_losses = self.criterion(mask_outputs, targets)
|
| | weight_dict = self.weight_dict
|
| |
|
| | loss_mask = 0.0
|
| | loss_dice = 0.0
|
| | loss_SEG_class = 0.0
|
| | loss_class_name_class = 0.0
|
| | loss_region_class = 0.0
|
| | for k in list(mask_losses.keys()):
|
| | if k in weight_dict:
|
| | if mask_losses[k] is not None:
|
| | mask_losses[k] *= weight_dict[k]
|
| | if '_SEG' in k and mask_losses[k] is not None:
|
| | loss_SEG_class += mask_losses[k]
|
| | elif '_name' in k and mask_losses[k] is not None:
|
| | loss_class_name_class += mask_losses[k]
|
| | elif '_mask' in k:
|
| | loss_mask += mask_losses[k]
|
| | elif '_dice' in k:
|
| | loss_dice += mask_losses[k]
|
| | elif '_region' in k and mask_losses[k] is not None:
|
| | loss_region_class += mask_losses[k]
|
| | else:
|
| | mask_losses.pop(k)
|
| | mask_loss = loss_mask + loss_dice + loss_SEG_class + loss_class_name_class + loss_region_class
|
| | if isinstance(loss_class_name_class, float):
|
| | loss_class_name_class = torch.tensor(loss_class_name_class, device=mask_loss.device)
|
| | if isinstance(loss_SEG_class, float):
|
| | loss_SEG_class = torch.tensor(loss_SEG_class, device=mask_loss.device)
|
| | if isinstance(loss_region_class, float):
|
| | loss_region_class = torch.tensor(loss_region_class, device=mask_loss.device)
|
| | llm = torch.tensor(0.0, device=mask_loss.device)
|
| | if labels is not None:
|
| |
|
| | loss = mask_loss
|
| |
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | loss_mask=loss_mask.detach(),
|
| | loss_dice=loss_dice.detach(),
|
| | loss_SEG_class=loss_SEG_class.detach(),
|
| | loss_class_name_class=loss_class_name_class.detach(),
|
| | loss_region_class=loss_region_class.detach(),
|
| | loss_llm=llm.detach(),
|
| | )
|
| |
|
| | if labels is not None and seg_query_mask is None:
|
| | loss_mask = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_dice = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_SEG_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_class_name_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss_region_class = torch.tensor(0.0, device=llm_loss.device)
|
| | loss = llm_loss
|
| | else:
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | )
|
| | return CausalOutputWithMask(
|
| | loss=loss,
|
| | logits=logits,
|
| | past_key_values=outputs.past_key_values,
|
| | hidden_states=outputs.hidden_states,
|
| | attentions=outputs.attentions,
|
| | loss_mask=loss_mask.detach(),
|
| | loss_dice=loss_dice.detach(),
|
| | loss_SEG_class=loss_SEG_class.detach(),
|
| | loss_class_name_class=loss_class_name_class.detach(),
|
| | loss_region_class=loss_region_class.detach(),
|
| | loss_llm=llm_loss.detach(),
|
| | )
|
| |
|
| | def mm_conv_prepare_inputs_labels_for_multimodal(
|
| | self, input_ids, attention_mask, past_key_values, labels, images
|
| | ):
|
| | vision_tower = self.get_vision_tower()
|
| | if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
|
| | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
|
| | return input_ids, attention_mask, past_key_values, None, labels
|
| |
|
| | if type(images) is list or images.ndim == 5:
|
| | concat_images = torch.cat([image for image in images], dim=0)
|
| | image_features = self.encode_images(concat_images)
|
| | split_sizes = [image.shape[0] for image in images]
|
| | image_features = torch.split(image_features, split_sizes, dim=0)
|
| | image_features = [x.flatten(0, 1) for x in image_features]
|
| | else:
|
| | image_features = self.encode_images(images)
|
| |
|
| | new_input_embeds = []
|
| | new_labels = [] if labels is not None else None
|
| | cur_image_idx = 0
|
| | for batch_idx, cur_input_ids in enumerate(input_ids):
|
| | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| |
|
| | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| |
|
| | cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(labels[batch_idx])
|
| | cur_image_idx += 1
|
| | continue
|
| | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
| | cur_new_input_embeds = []
|
| | if labels is not None:
|
| | cur_labels = labels[batch_idx]
|
| | cur_new_labels = []
|
| | assert cur_labels.shape == cur_input_ids.shape
|
| |
|
| | while image_token_indices.numel() > 0:
|
| | cur_image_features = image_features[cur_image_idx]
|
| | image_token_start = image_token_indices[0]
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
|
| | cur_new_input_embeds.append(cur_image_features)
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels[:image_token_start])
|
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| | cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
|
| | cur_labels = cur_labels[image_token_start+2:]
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
|
| | cur_new_input_embeds.append(cur_image_features)
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels[:image_token_start])
|
| | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
|
| | cur_labels = cur_labels[image_token_start+1:]
|
| | cur_image_idx += 1
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_input_ids = cur_input_ids[image_token_start+2:]
|
| | else:
|
| | cur_input_ids = cur_input_ids[image_token_start+1:]
|
| | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
|
| | if cur_input_ids.numel() > 0:
|
| | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
|
| | else:
|
| | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
|
| | if labels is not None:
|
| | cur_new_labels.append(cur_labels)
|
| | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
|
| | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
|
| | new_input_embeds.append(cur_new_input_embeds)
|
| | if labels is not None:
|
| | cur_new_labels = torch.cat(cur_new_labels, dim=0)
|
| | new_labels.append(cur_new_labels)
|
| |
|
| |
|
| | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| | max_len = max(x.shape[0] for x in new_input_embeds)
|
| |
|
| | new_input_embeds_align = []
|
| | for cur_new_embed in new_input_embeds:
|
| | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
|
| | new_input_embeds_align.append(cur_new_embed)
|
| | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| |
|
| | if labels is not None:
|
| | new_labels_align = []
|
| | _new_labels = new_labels
|
| | for cur_new_label in new_labels:
|
| | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
|
| | new_labels_align.append(cur_new_label)
|
| | new_labels = torch.stack(new_labels_align, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attention_mask = []
|
| | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
|
| | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| | new_attention_mask.append(cur_new_attention_mask)
|
| | attention_mask = torch.stack(new_attention_mask, dim=0)
|
| | assert attention_mask.shape == new_labels.shape
|
| | else:
|
| | new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| | if labels is not None:
|
| | new_labels = torch.stack(new_labels, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
|
| | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| | assert attention_mask.shape == new_input_embeds.shape[:2]
|
| |
|
| | return None, attention_mask, past_key_values, new_input_embeds, new_labels
|
| |
|
| | def get_seg_query(self, hidden_states, seg_query_masks):
|
| | seg_query_list = []
|
| | for sample_hidden_state, sample_query_mask in zip(hidden_states, seg_query_masks):
|
| | if torch.sum(sample_query_mask) == 0:
|
| | continue
|
| |
|
| | unique_query_value = torch.unique(sample_query_mask)
|
| | unique_query_value = unique_query_value[unique_query_value != 0]
|
| |
|
| | for value in unique_query_value:
|
| | current_query_mask = (sample_query_mask == value)
|
| | current_query = sample_hidden_state[current_query_mask]
|
| |
|
| | seg_query_list.append(current_query)
|
| |
|
| | seg_query = torch.stack(seg_query_list, dim=0)
|
| |
|
| | return seg_query
|
| | def eval_seg(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | vp_images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | is_thing_list=None
|
| | ):
|
| | if self.panoptic_on:
|
| | assert is_thing_list is not None, 'is_thing_list need to be given'
|
| | self.is_thing_list = is_thing_list
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,
|
| | class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| | else:
|
| | region_embedding_list = None
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| |
|
| | SEG_cls_results = mask_outputs['pred_SEG_logits']
|
| | class_name_cls_results = mask_outputs['pred_class_name_logits']
|
| | mask_pred_results = mask_outputs["pred_masks"]
|
| | region_cls_results = mask_outputs['pred_region_logits']
|
| | images = [x for x in images]
|
| | images = ImageList.from_tensors(images, self.size_divisibility)
|
| | mask_pred_results = F.interpolate(
|
| | mask_pred_results,
|
| | size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
| | mode="bilinear",
|
| | align_corners=False,
|
| | )
|
| | del mask_outputs
|
| | processed_results = []
|
| | if SEG_cls_results is None:
|
| | SEG_cls_results = [None]
|
| | if class_name_cls_results is None:
|
| | class_name_cls_results = [None]
|
| | for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
| | seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes
|
| | ):
|
| | height = input_per_image.get("height", image_size[0])
|
| | width = input_per_image.get("width", image_size[1])
|
| | padding_mask = input_per_image.get("padding_mask")
|
| | non_padding_indices = np.where(~ np.array(padding_mask))
|
| | min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0])
|
| | min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1])
|
| | original_height = max_y - min_y + 1
|
| | original_width = max_x - min_x + 1
|
| | processed_results.append({})
|
| |
|
| | if self.sem_seg_postprocess_before_inference:
|
| | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | mask_pred_result, [original_height, original_width], height, width
|
| | )
|
| | if SEG_cls_result is not None:
|
| | SEG_cls_result = SEG_cls_result.to(mask_pred_result)
|
| |
|
| | if self.semantic_on:
|
| | semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | if not self.sem_seg_postprocess_before_inference:
|
| | semantic_r = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | semantic_r, [original_height, original_width], height, width
|
| | )
|
| | processed_results[-1]["sem_seg"] = semantic_r
|
| |
|
| | if self.instance_on:
|
| | instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.panoptic_on:
|
| | panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["panoptic_seg"] = panoptic_r
|
| | if self.referring_on:
|
| | instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.region_on:
|
| | gt = _seg_info['instances'].gt_masks
|
| | gt_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | gt, [original_height, original_width], height, width
|
| | )
|
| | region_cls_results = region_cls_results[0].to(mask_pred_result)
|
| | instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | processed_results[-1]["gt"] = gt_result
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return processed_results
|
| |
|
| |
|
| |
|
| |
|
| | class PSALMForDAVISEval(PSALM):
|
| | def eval_seg(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | is_thing_list=None,
|
| | vp_images=None
|
| | ):
|
| |
|
| | if self.panoptic_on:
|
| | assert is_thing_list is not None, 'is_thing_list need to be given'
|
| | self.is_thing_list = is_thing_list
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images,vp_images, class_name_embedding_indices,
|
| | class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| |
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| | else:
|
| | region_embedding_list = None
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| |
|
| | SEG_cls_results = mask_outputs['pred_SEG_logits']
|
| | class_name_cls_results = mask_outputs['pred_class_name_logits']
|
| | mask_pred_results = mask_outputs["pred_masks"]
|
| | region_cls_results = mask_outputs['pred_region_logits']
|
| | images = [x for x in images]
|
| | images = ImageList.from_tensors(images, self.size_divisibility)
|
| | mask_pred_results = F.interpolate(
|
| | mask_pred_results,
|
| | size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
| | mode="bilinear",
|
| | align_corners=False,
|
| | )
|
| | del mask_outputs
|
| | processed_results = []
|
| | if SEG_cls_results is None:
|
| | SEG_cls_results = [None]
|
| | if class_name_cls_results is None:
|
| | class_name_cls_results = [None]
|
| | for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
| | seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes
|
| | ):
|
| | height = input_per_image.get("height", image_size[0])
|
| | width = input_per_image.get("width", image_size[1])
|
| | padding_mask = input_per_image.get("padding_mask")
|
| | non_padding_indices = np.where(~ np.array(padding_mask))
|
| | min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0])
|
| | min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1])
|
| | original_height = max_y - min_y + 1
|
| | original_width = max_x - min_x + 1
|
| | processed_results.append({})
|
| |
|
| | if self.sem_seg_postprocess_before_inference:
|
| | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | mask_pred_result, [original_height, original_width], height, width
|
| | )
|
| |
|
| |
|
| |
|
| | if SEG_cls_result is not None:
|
| | SEG_cls_result = SEG_cls_result.to(mask_pred_result)
|
| |
|
| | if self.semantic_on:
|
| | semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | if not self.sem_seg_postprocess_before_inference:
|
| | semantic_r = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | semantic_r, [original_height, original_width], height, width
|
| | )
|
| | processed_results[-1]["sem_seg"] = semantic_r
|
| |
|
| | if self.instance_on:
|
| | instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.panoptic_on:
|
| | panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["panoptic_seg"] = panoptic_r
|
| | if self.referring_on:
|
| | instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.region_on:
|
| | gt = _seg_info['instances'].gt_masks
|
| | gt_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | gt, [original_height, original_width], height, width
|
| | )
|
| | region_cls_results = region_cls_results[0].to(mask_pred_result)
|
| | instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | processed_results[-1]["gt"] = gt_result
|
| |
|
| |
|
| |
|
| | return processed_results
|
| |
|
| |
|
| |
|
| | def prepare_inputs_labels_for_multimodal(
|
| | self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None,
|
| | class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None
|
| | ):
|
| | vision_tower = self.get_vision_tower()
|
| |
|
| | seg_query_mask = torch.zeros_like(input_ids)
|
| |
|
| | if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
| | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
|
| | 1] == 1:
|
| | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1),
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask
|
| |
|
| | if type(images) is list or images.ndim == 5:
|
| | concat_images = torch.cat([image for image in images], dim=0)
|
| | image_features = self.encode_images(concat_images)
|
| | split_sizes = [image.shape[0] for image in images]
|
| | image_features = torch.split(image_features, split_sizes, dim=0)
|
| | image_features = [x.flatten(0, 1) for x in image_features]
|
| | else:
|
| | image_features = self.encode_images(images)
|
| |
|
| |
|
| | expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1)
|
| |
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None:
|
| | region_masks_list = [instance.vp_region_masks.tensor for instance in instances]
|
| | vp_image_features = self.encode_images(vp_images)
|
| |
|
| |
|
| | region_features = self.region_sampler(vp_image_features, region_masks_list,
|
| | original_dtype=vp_image_features.dtype,
|
| | return_dtype=vp_image_features.dtype)
|
| | region_embedding_masks = torch.zeros_like(input_ids)
|
| | else:
|
| | region_features = None
|
| | region_embedding_masks = None
|
| | new_input_embeds = []
|
| | new_labels = [] if labels is not None else None
|
| | new_seg_query_masks = []
|
| | new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None
|
| | new_refer_embedding_indices = [] if refer_embedding_indices is not None else None
|
| | new_region_embedding_masks = [] if region_features is not None else None
|
| | for batch_idx, cur_input_ids in enumerate(input_ids):
|
| | cur_seg_query_mask = seg_query_mask[batch_idx]
|
| | cur_seg_query = expanded_seg_query[batch_idx]
|
| | cur_image_feature = image_features[batch_idx]
|
| | cur_class_name_embedding_indices = class_name_embedding_indices[batch_idx] if class_name_embedding_indices is not None else None
|
| | cur_refer_embedding_indices = refer_embedding_indices[batch_idx] if refer_embedding_indices is not None else None
|
| | cur_region_feature_list = region_features[batch_idx] if region_features is not None else None
|
| | cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None
|
| | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
|
| |
|
| | cur_input_embeds = self.get_model().embed_tokens(cur_input_ids)
|
| |
|
| | cur_input_embeds = cur_input_embeds + (
|
| | 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum()
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(labels[batch_idx])
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| |
|
| | continue
|
| |
|
| | if labels is not None:
|
| | cur_label = labels[batch_idx]
|
| | else:
|
| | cur_label = None
|
| |
|
| | if class_name_ids is not None:
|
| | cur_class_name_ids = class_name_ids[batch_idx]
|
| | cur_cls_indices = cls_indices[batch_idx]
|
| | else:
|
| | cur_class_name_ids = None
|
| | cur_cls_indices = None
|
| | if token_refer_id is not None:
|
| | cur_token_refer_id = token_refer_id[batch_idx]
|
| | else:
|
| | cur_token_refer_id = None
|
| |
|
| |
|
| | cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices)
|
| | cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id)
|
| |
|
| |
|
| |
|
| | cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds(
|
| | input_id=cur_input_ids,
|
| | img_feature=cur_image_feature,
|
| | label=cur_label,
|
| | seg_query=cur_seg_query,
|
| | seg_query_mask=cur_seg_query_mask,
|
| | class_embed=cur_class_name_embedding,
|
| | class_name_embedding_indices=cur_class_name_embedding_indices,
|
| | region_embedding_mask=cur_region_embedding_mask,
|
| | region_feature_list=cur_region_feature_list,
|
| | refer_embedding_indices=cur_refer_embedding_indices,
|
| | refer_embedding=cur_refer_embedding
|
| | )
|
| | assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0]
|
| |
|
| | new_input_embeds.append(cur_input_embeds)
|
| | if labels is not None:
|
| | new_labels.append(cur_label)
|
| | new_seg_query_masks.append(cur_seg_query_mask)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices.append(cur_class_name_embedding_indices)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices.append(cur_refer_embedding_indices)
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks.append(cur_region_embedding_mask)
|
| |
|
| |
|
| | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
|
| | max_len = max(x.shape[0] for x in new_input_embeds)
|
| |
|
| | new_input_embeds_align = []
|
| | for cur_new_embed in new_input_embeds:
|
| | cur_new_embed = torch.cat((cur_new_embed,
|
| | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]),
|
| | dtype=cur_new_embed.dtype, device=cur_new_embed.device)),
|
| | dim=0)
|
| | new_input_embeds_align.append(cur_new_embed)
|
| | new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
|
| |
|
| | if labels is not None:
|
| | new_labels_align = []
|
| | _new_labels = new_labels
|
| | for cur_new_label in new_labels:
|
| | cur_new_label = torch.cat((cur_new_label,
|
| | torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX,
|
| | dtype=cur_new_label.dtype, device=cur_new_label.device)),
|
| | dim=0)
|
| | new_labels_align.append(cur_new_label)
|
| | new_labels = torch.stack(new_labels_align, dim=0)
|
| |
|
| | new_seg_query_masks_align = []
|
| | for new_seg_query_mask in new_seg_query_masks:
|
| | new_seg_query_mask = torch.cat(
|
| | (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]),dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)),
|
| | dim=0)
|
| | new_seg_query_masks_align.append(new_seg_query_mask)
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0)
|
| |
|
| | new_class_name_embedding_indices_align = []
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | for new_class_name_embedding_indice in new_class_name_embedding_indices:
|
| | new_class_name_embedding_indice = torch.cat(
|
| | (new_class_name_embedding_indice,
|
| | torch.zeros((max_len - new_class_name_embedding_indice.shape[0]),dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)),
|
| | dim=0)
|
| | new_class_name_embedding_indices_align.append(new_class_name_embedding_indice)
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices_align = []
|
| | for new_refer_embedding_indice in new_refer_embedding_indices:
|
| | new_refer_embedding_indice = torch.cat(
|
| | (new_refer_embedding_indice,
|
| | torch.zeros((max_len - new_refer_embedding_indice.shape[0]),dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)),
|
| | dim=0)
|
| | new_refer_embedding_indices_align.append(new_refer_embedding_indice)
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks_align = []
|
| | for new_region_embedding_mask in new_region_embedding_masks:
|
| | new_region_embedding_mask = torch.cat(
|
| | (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]),dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)),
|
| | dim=0)
|
| | new_region_embedding_masks_align.append(new_region_embedding_mask)
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attention_mask = []
|
| | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels,
|
| | new_labels):
|
| | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],),
|
| | False, dtype=attention_mask.dtype,
|
| | device=attention_mask.device)
|
| | cur_new_attention_mask = torch.cat(
|
| | (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
|
| | new_attention_mask.append(cur_new_attention_mask)
|
| | attention_mask = torch.stack(new_attention_mask, dim=0)
|
| | assert attention_mask.shape == new_labels.shape
|
| |
|
| | else:
|
| | new_input_embeds = torch.stack(new_input_embeds, dim=0)
|
| | if labels is not None:
|
| | new_labels = torch.stack(new_labels, dim=0)
|
| |
|
| | new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0)
|
| | if class_name_embedding_indices is not None:
|
| | new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0)
|
| | if refer_embedding_indices is not None:
|
| | new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0)
|
| |
|
| | if new_region_embedding_masks is not None:
|
| | new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0)
|
| |
|
| | if attention_mask is not None:
|
| | new_attn_mask_pad_left = torch.full(
|
| | (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True,
|
| | dtype=attention_mask.dtype, device=attention_mask.device)
|
| | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
|
| | assert attention_mask.shape == new_input_embeds.shape[:2]
|
| |
|
| | return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices
|
| |
|
| | def eval_video(
|
| | self,
|
| | input_ids: torch.LongTensor = None,
|
| | attention_mask: Optional[torch.Tensor] = None,
|
| | past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| | inputs_embeds: Optional[torch.FloatTensor] = None,
|
| | labels: Optional[torch.LongTensor] = None,
|
| | use_cache: Optional[bool] = None,
|
| | output_attentions: Optional[bool] = None,
|
| | output_hidden_states: Optional[bool] = None,
|
| | images: Optional[torch.FloatTensor] = None,
|
| | vp_images: Optional[torch.FloatTensor] = None,
|
| | return_dict: Optional[bool] = None,
|
| | seg_info=None,
|
| | class_name_ids=None,
|
| | class_name_embedding_indices=None,
|
| | cls_indices=None,
|
| | token_refer_id=None,
|
| | refer_embedding_indices=None,
|
| | is_thing_list=None
|
| | ):
|
| | if self.panoptic_on:
|
| | assert is_thing_list is not None, 'is_thing_list need to be given'
|
| | self.is_thing_list = is_thing_list
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| | output_hidden_states = (
|
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| | )
|
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| |
|
| | if (input_ids == REGION_TOKEN_INDEX).sum() != 0:
|
| | instances = [i['instances'] for i in seg_info]
|
| | else:
|
| | instances = None
|
| |
|
| | input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(
|
| | input_ids, attention_mask, past_key_values, labels, images,vp_images, class_name_embedding_indices,
|
| | class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices)
|
| |
|
| |
|
| |
|
| | outputs = self.model(
|
| | input_ids=input_ids,
|
| | attention_mask=attention_mask,
|
| | past_key_values=past_key_values,
|
| | inputs_embeds=inputs_embeds,
|
| | use_cache=use_cache,
|
| | output_attentions=output_attentions,
|
| | output_hidden_states=output_hidden_states,
|
| | return_dict=return_dict
|
| | )
|
| |
|
| | hidden_states = outputs.last_hidden_state
|
| | seg_query = self.get_seg_query(hidden_states, seg_query_mask)
|
| | seg_query = self.seg_query_projector(seg_query)
|
| |
|
| | image_features = self.get_vision_tower_feature(images)
|
| | mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features(
|
| | image_features)
|
| |
|
| | if refer_embedding_indices is not None:
|
| | SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices)
|
| | SEG_embedding = self.SEG_token_projector(SEG_embedding)
|
| | else:
|
| | SEG_embedding = None
|
| |
|
| | if class_name_embedding_indices is not None:
|
| | class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices)
|
| | class_name_embedding = self.class_name_projector(class_name_embedding)
|
| | else:
|
| | class_name_embedding = None
|
| |
|
| | if region_embedding_masks is not None:
|
| | region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks)
|
| | region_embedding_list = [self.region_projector(region_embedding) for region_embedding in
|
| | region_embedding_list]
|
| | else:
|
| | region_embedding_list = None
|
| |
|
| | mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding,
|
| | class_name_embedding, region_embedding_list)
|
| |
|
| | SEG_cls_results = mask_outputs['pred_SEG_logits']
|
| | class_name_cls_results = mask_outputs['pred_class_name_logits']
|
| | mask_pred_results = mask_outputs["pred_masks"]
|
| | region_cls_results = mask_outputs['pred_region_logits']
|
| | images = [x for x in images]
|
| | images = ImageList.from_tensors(images, self.size_divisibility)
|
| | mask_pred_results = F.interpolate(
|
| | mask_pred_results,
|
| | size=(images.tensor.shape[-2], images.tensor.shape[-1]),
|
| | mode="bilinear",
|
| | align_corners=False,
|
| | )
|
| | del mask_outputs
|
| | processed_results = []
|
| | if SEG_cls_results is None:
|
| | SEG_cls_results = [None]
|
| | if class_name_cls_results is None:
|
| | class_name_cls_results = [None]
|
| | for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip(
|
| | seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes
|
| | ):
|
| | height = input_per_image.get("height", image_size[0])
|
| | width = input_per_image.get("width", image_size[1])
|
| | padding_mask = input_per_image.get("padding_mask")
|
| | non_padding_indices = np.where(~ np.array(padding_mask))
|
| | min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0])
|
| | min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1])
|
| | original_height = max_y - min_y + 1
|
| | original_width = max_x - min_x + 1
|
| | processed_results.append({})
|
| | if self.sem_seg_postprocess_before_inference:
|
| | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | mask_pred_result, [original_height, original_width], height, width
|
| | )
|
| | if SEG_cls_result is not None:
|
| | SEG_cls_result = SEG_cls_result.to(mask_pred_result)
|
| |
|
| | if self.semantic_on:
|
| | semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | if not self.sem_seg_postprocess_before_inference:
|
| | semantic_r = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | semantic_r, [original_height, original_width], height, width
|
| | )
|
| | processed_results[-1]["sem_seg"] = semantic_r
|
| |
|
| | if self.instance_on:
|
| | instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.panoptic_on:
|
| | panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None,
|
| | class_name_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["panoptic_seg"] = panoptic_r
|
| | if self.referring_on:
|
| | instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | if self.region_on:
|
| | gt = _seg_info['instances'].gt_masks
|
| | gt_result = retry_if_cuda_oom(sem_seg_postprocess)(
|
| | gt, [original_height, original_width], height, width
|
| | )
|
| | region_cls_results = region_cls_results[0].to(mask_pred_result)
|
| | instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(),
|
| | mask_pred_result.float())
|
| | processed_results[-1]["instances"] = instance_r
|
| | processed_results[-1]["gt"] = gt_result
|
| |
|
| | return processed_results
|
| |
|
| |
|
| | AutoConfig.register("llava_phi", LlavaConfig)
|
| | AutoModelForCausalLM.register(LlavaConfig, PSALMModel)
|
| |
|