| | |
| | |
| |
|
| | |
| | |
| |
|
| | import torch |
| | import torch.distributed |
| | import torch.nn.functional as F |
| |
|
| | from torch.nn.init import trunc_normal_ |
| |
|
| | from sam2.modeling.sam.mask_decoder import MaskDecoder |
| | from sam2.modeling.sam.prompt_encoder import PromptEncoder |
| | from sam2.modeling.sam.transformer import TwoWayTransformer |
| | from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames |
| |
|
| | |
| | NO_OBJ_SCORE = -1024.0 |
| |
|
| |
|
| | class SAM2Base(torch.nn.Module): |
| | def __init__( |
| | self, |
| | image_encoder, |
| | memory_attention, |
| | memory_encoder, |
| | num_maskmem=7, |
| | image_size=512, |
| | backbone_stride=16, |
| | sigmoid_scale_for_mem_enc=1.0, |
| | sigmoid_bias_for_mem_enc=0.0, |
| | |
| | binarize_mask_from_pts_for_mem_enc=False, |
| | use_mask_input_as_output_without_sam=False, |
| | |
| | |
| | |
| | max_cond_frames_in_attn=-1, |
| | |
| | |
| | directly_add_no_mem_embed=False, |
| | |
| | use_high_res_features_in_sam=False, |
| | |
| | multimask_output_in_sam=False, |
| | |
| | |
| | multimask_min_pt_num=1, |
| | multimask_max_pt_num=1, |
| | |
| | multimask_output_for_tracking=False, |
| | |
| | |
| | use_multimask_token_for_obj_ptr: bool = False, |
| | |
| | iou_prediction_use_sigmoid=False, |
| | |
| | |
| | |
| | memory_temporal_stride_for_eval=1, |
| | |
| | |
| | add_all_frames_to_correct_as_cond=False, |
| | |
| | non_overlap_masks_for_mem_enc=False, |
| | |
| | use_obj_ptrs_in_encoder=False, |
| | |
| | max_obj_ptrs_in_encoder=16, |
| | |
| | add_tpos_enc_to_obj_ptrs=True, |
| | |
| | |
| | proj_tpos_enc_in_obj_ptrs=False, |
| | |
| | |
| | only_obj_ptrs_in_the_past_for_eval=False, |
| | |
| | pred_obj_scores: bool = False, |
| | |
| | pred_obj_scores_mlp: bool = False, |
| | |
| | |
| | |
| | fixed_no_obj_ptr: bool = False, |
| | |
| | |
| | soft_no_obj_ptr: bool = False, |
| | use_mlp_for_obj_ptr_proj: bool = False, |
| | |
| | sam_mask_decoder_extra_args=None, |
| | compile_image_encoder: bool = False, |
| | ): |
| | super().__init__() |
| |
|
| | |
| | self.image_encoder = image_encoder |
| | |
| | self.use_high_res_features_in_sam = use_high_res_features_in_sam |
| | self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 |
| | self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder |
| | self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder |
| | if use_obj_ptrs_in_encoder: |
| | |
| | |
| | |
| | self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) |
| | self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs |
| | if proj_tpos_enc_in_obj_ptrs: |
| | assert add_tpos_enc_to_obj_ptrs |
| | self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs |
| | self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval |
| |
|
| | |
| | |
| | self.memory_attention = memory_attention |
| | self.hidden_dim = memory_attention.d_model |
| |
|
| | |
| | self.memory_encoder = memory_encoder |
| | self.mem_dim = self.hidden_dim |
| | if hasattr(self.memory_encoder, "out_proj") and hasattr( |
| | self.memory_encoder.out_proj, "weight" |
| | ): |
| | |
| | self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] |
| | self.num_maskmem = num_maskmem |
| | |
| | self.maskmem_tpos_enc = torch.nn.Parameter( |
| | torch.zeros(num_maskmem, 1, 1, self.mem_dim) |
| | ) |
| | trunc_normal_(self.maskmem_tpos_enc, std=0.02) |
| | |
| | self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
| | self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) |
| | trunc_normal_(self.no_mem_embed, std=0.02) |
| | trunc_normal_(self.no_mem_pos_enc, std=0.02) |
| | self.directly_add_no_mem_embed = directly_add_no_mem_embed |
| | |
| | |
| | self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc |
| | self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc |
| | self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc |
| | self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc |
| | self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval |
| | |
| | |
| | self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam |
| | self.multimask_output_in_sam = multimask_output_in_sam |
| | self.multimask_min_pt_num = multimask_min_pt_num |
| | self.multimask_max_pt_num = multimask_max_pt_num |
| | self.multimask_output_for_tracking = multimask_output_for_tracking |
| | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr |
| | self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid |
| |
|
| | |
| | |
| | self.image_size = image_size |
| | self.backbone_stride = backbone_stride |
| | self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args |
| | self.pred_obj_scores = pred_obj_scores |
| | self.pred_obj_scores_mlp = pred_obj_scores_mlp |
| | self.fixed_no_obj_ptr = fixed_no_obj_ptr |
| | self.soft_no_obj_ptr = soft_no_obj_ptr |
| | if self.fixed_no_obj_ptr: |
| | assert self.pred_obj_scores |
| | assert self.use_obj_ptrs_in_encoder |
| | if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: |
| | self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) |
| | trunc_normal_(self.no_obj_ptr, std=0.02) |
| | self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj |
| |
|
| | self._build_sam_heads() |
| | self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond |
| | self.max_cond_frames_in_attn = max_cond_frames_in_attn |
| |
|
| | |
| | if compile_image_encoder: |
| | |
| | print( |
| | "Image encoder compilation is enabled. First forward pass will be slow." |
| | ) |
| | self.image_encoder.forward = torch.compile( |
| | self.image_encoder.forward, |
| | mode="max-autotune", |
| | fullgraph=True, |
| | dynamic=False, |
| | ) |
| |
|
| | @property |
| | def device(self): |
| | return next(self.parameters()).device |
| |
|
| | def forward(self, *args, **kwargs): |
| | raise NotImplementedError( |
| | "Please use the corresponding methods in SAM2VideoPredictor for inference." |
| | "See notebooks/video_predictor_example.ipynb for an example." |
| | ) |
| |
|
| | def _build_sam_heads(self): |
| | """Build SAM-style prompt encoder and mask decoder.""" |
| | self.sam_prompt_embed_dim = self.hidden_dim |
| | self.sam_image_embedding_size = self.image_size // self.backbone_stride |
| |
|
| | |
| | |
| | self.sam_prompt_encoder = PromptEncoder( |
| | embed_dim=self.sam_prompt_embed_dim, |
| | image_embedding_size=( |
| | self.sam_image_embedding_size, |
| | self.sam_image_embedding_size, |
| | ), |
| | input_image_size=(self.image_size, self.image_size), |
| | mask_in_chans=16, |
| | ) |
| | self.sam_mask_decoder = MaskDecoder( |
| | num_multimask_outputs=3, |
| | transformer=TwoWayTransformer( |
| | depth=2, |
| | embedding_dim=self.sam_prompt_embed_dim, |
| | mlp_dim=2048, |
| | num_heads=8, |
| | ), |
| | transformer_dim=self.sam_prompt_embed_dim, |
| | iou_head_depth=3, |
| | iou_head_hidden_dim=256, |
| | use_high_res_features=self.use_high_res_features_in_sam, |
| | iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, |
| | pred_obj_scores=self.pred_obj_scores, |
| | pred_obj_scores_mlp=self.pred_obj_scores_mlp, |
| | use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, |
| | **(self.sam_mask_decoder_extra_args or {}), |
| | ) |
| | if self.use_obj_ptrs_in_encoder: |
| | |
| | self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) |
| | if self.use_mlp_for_obj_ptr_proj: |
| | self.obj_ptr_proj = MLP( |
| | self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 |
| | ) |
| | else: |
| | self.obj_ptr_proj = torch.nn.Identity() |
| | if self.proj_tpos_enc_in_obj_ptrs: |
| | |
| | |
| | self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) |
| | else: |
| | self.obj_ptr_tpos_proj = torch.nn.Identity() |
| |
|
| | def _forward_sam_heads( |
| | self, |
| | backbone_features, |
| | point_inputs=None, |
| | mask_inputs=None, |
| | high_res_features=None, |
| | multimask_output=False, |
| | ): |
| | """ |
| | Forward SAM prompt encoders and mask heads. |
| | |
| | Inputs: |
| | - backbone_features: image features of [B, C, H, W] shape |
| | - point_inputs: a dictionary with "point_coords" and "point_labels", where |
| | 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the |
| | absolute pixel-unit coordinate in (x, y) format of the P input points |
| | 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means |
| | positive clicks, 0 means negative clicks, and -1 means padding |
| | - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the |
| | same spatial size as the image. |
| | - high_res_features: either 1) None or 2) or a list of length 2 containing |
| | two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, |
| | which will be used as high-resolution feature maps for SAM decoder. |
| | - multimask_output: if it's True, we output 3 candidate masks and their 3 |
| | corresponding IoU estimates, and if it's False, we output only 1 mask and |
| | its corresponding IoU estimate. |
| | |
| | Outputs: |
| | - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if |
| | `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM |
| | output mask logits (before sigmoid) for the low-resolution masks, with 4x |
| | the resolution (1/4 stride) of the input backbone_features. |
| | - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 |
| | if `multimask_output=True` and M = 1 if `multimask_output=False`), |
| | upsampled from the low-resolution masks, with shape size as the image |
| | (stride is 1 pixel). |
| | - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 |
| | if `multimask_output=False`), the estimated IoU of each output mask. |
| | - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `low_res_multimasks`. |
| | - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `high_res_multimasks`. |
| | - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted |
| | based on the output token from the SAM mask decoder. |
| | """ |
| | B = backbone_features.size(0) |
| | device = backbone_features.device |
| | assert backbone_features.size(1) == self.sam_prompt_embed_dim |
| | assert backbone_features.size(2) == self.sam_image_embedding_size |
| | assert backbone_features.size(3) == self.sam_image_embedding_size |
| |
|
| | |
| | if point_inputs is not None: |
| | sam_point_coords = point_inputs["point_coords"] |
| | sam_point_labels = point_inputs["point_labels"] |
| | assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
| | else: |
| | |
| | sam_point_coords = torch.zeros(B, 1, 2, device=device) |
| | sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
| |
|
| | |
| | if mask_inputs is not None: |
| | |
| | |
| | assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
| | if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
| | sam_mask_prompt = F.interpolate( |
| | mask_inputs, |
| | size=self.sam_prompt_encoder.mask_input_size, |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | else: |
| | sam_mask_prompt = mask_inputs |
| | else: |
| | |
| | |
| | sam_mask_prompt = None |
| |
|
| | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
| | points=(sam_point_coords, sam_point_labels), |
| | boxes=None, |
| | masks=sam_mask_prompt, |
| | ) |
| | ( |
| | low_res_multimasks, |
| | ious, |
| | sam_output_tokens, |
| | object_score_logits, |
| | ) = self.sam_mask_decoder( |
| | image_embeddings=backbone_features, |
| | image_pe=self.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=multimask_output, |
| | repeat_image=False, |
| | high_res_features=high_res_features, |
| | ) |
| | |
| | |
| | |
| |
|
| | if self.pred_obj_scores: |
| | is_obj_appearing = object_score_logits > 0 |
| |
|
| | |
| | |
| | low_res_multimasks = torch.where( |
| | is_obj_appearing[:, None, None], |
| | low_res_multimasks, |
| | NO_OBJ_SCORE, |
| | ) |
| |
|
| | |
| | |
| | low_res_multimasks = low_res_multimasks |
| | high_res_multimasks = F.interpolate( |
| | low_res_multimasks, |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| |
|
| | sam_output_token = sam_output_tokens[:, 0] |
| | if multimask_output: |
| | |
| | best_iou_inds = torch.argmax(ious, dim=-1) |
| | batch_inds = torch.arange(B, device=device) |
| | low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | if sam_output_tokens.size(1) > 1: |
| | sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
| | else: |
| | low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
| |
|
| | |
| | obj_ptr = self.obj_ptr_proj(sam_output_token) |
| | if self.pred_obj_scores: |
| | |
| | if self.soft_no_obj_ptr: |
| | |
| | assert not self.teacher_force_obj_scores_for_mem |
| | lambda_is_obj_appearing = object_score_logits.sigmoid() |
| | else: |
| | lambda_is_obj_appearing = is_obj_appearing.float() |
| |
|
| | if self.fixed_no_obj_ptr: |
| | obj_ptr = lambda_is_obj_appearing * obj_ptr |
| | obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
| |
|
| | return ( |
| | low_res_multimasks, |
| | high_res_multimasks, |
| | ious, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | object_score_logits, |
| | ) |
| |
|
| | def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): |
| | """ |
| | Directly turn binary `mask_inputs` into a output mask logits without using SAM. |
| | (same input and output shapes as in _forward_sam_heads above). |
| | """ |
| | |
| | out_scale, out_bias = 20.0, -10.0 |
| | |
| | high_res_masks = mask_inputs * out_scale + out_bias |
| | low_res_masks = F.interpolate( |
| | high_res_masks, |
| | size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | |
| | ious = mask_inputs.new_ones(mask_inputs.size(0), 1) |
| | if not self.use_obj_ptrs_in_encoder: |
| | |
| | obj_ptr = torch.zeros( |
| | mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device |
| | ) |
| | else: |
| | |
| | _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( |
| | backbone_features=backbone_features, |
| | mask_inputs=self.mask_downsample(mask_inputs), |
| | high_res_features=high_res_features, |
| | ) |
| | |
| | |
| | |
| | is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1) |
| | is_obj_appearing = is_obj_appearing[..., None] |
| | lambda_is_obj_appearing = is_obj_appearing |
| | object_score_logits = out_scale * lambda_is_obj_appearing + out_bias |
| | if self.pred_obj_scores: |
| | if self.fixed_no_obj_ptr: |
| | obj_ptr = lambda_is_obj_appearing * obj_ptr |
| | obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
| |
|
| | return ( |
| | low_res_masks, |
| | high_res_masks, |
| | ious, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | object_score_logits, |
| | ) |
| |
|
| | def forward_image(self, img_batch: torch.Tensor): |
| | """Get the image feature on the input batch.""" |
| | backbone_out = self.image_encoder(img_batch) |
| | if self.use_high_res_features_in_sam: |
| | |
| | |
| | backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( |
| | backbone_out["backbone_fpn"][0] |
| | ) |
| | backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( |
| | backbone_out["backbone_fpn"][1] |
| | ) |
| | return backbone_out |
| |
|
| | def _prepare_backbone_features(self, backbone_out): |
| | """Prepare and flatten visual features.""" |
| | backbone_out = backbone_out.copy() |
| | assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) |
| | assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels |
| |
|
| | feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] |
| | vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] |
| |
|
| | feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] |
| | |
| | vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] |
| | vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] |
| |
|
| | return backbone_out, vision_feats, vision_pos_embeds, feat_sizes |
| |
|
| | def _prepare_memory_conditioned_features( |
| | self, |
| | frame_idx, |
| | is_init_cond_frame, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | output_dict, |
| | num_frames, |
| | track_in_reverse=False, |
| | ): |
| | """Fuse the current frame's visual feature map with previous memory.""" |
| | B = current_vision_feats[-1].size(1) |
| | C = self.hidden_dim |
| | H, W = feat_sizes[-1] |
| | device = current_vision_feats[-1].device |
| | |
| | |
| | if self.num_maskmem == 0: |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
| | return pix_feat |
| |
|
| | num_obj_ptr_tokens = 0 |
| | |
| | if not is_init_cond_frame: |
| | |
| | to_cat_memory, to_cat_memory_pos_embed = [], [] |
| | |
| | |
| | assert len(output_dict["cond_frame_outputs"]) > 0 |
| | |
| | cond_outputs = output_dict["cond_frame_outputs"] |
| | selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( |
| | frame_idx, cond_outputs, self.max_cond_frames_in_attn |
| | ) |
| | t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] |
| | |
| | |
| | |
| | |
| | r = self.memory_temporal_stride_for_eval |
| | for t_pos in range(1, self.num_maskmem): |
| | t_rel = self.num_maskmem - t_pos |
| | if t_rel == 1: |
| | |
| | if not track_in_reverse: |
| | |
| | prev_frame_idx = frame_idx - t_rel |
| | else: |
| | |
| | prev_frame_idx = frame_idx + t_rel |
| | else: |
| | |
| | if not track_in_reverse: |
| | |
| | |
| | prev_frame_idx = ((frame_idx - 2) // r) * r |
| | |
| | prev_frame_idx = prev_frame_idx - (t_rel - 2) * r |
| | else: |
| | |
| | |
| | prev_frame_idx = -(-(frame_idx + 2) // r) * r |
| | |
| | prev_frame_idx = prev_frame_idx + (t_rel - 2) * r |
| | out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) |
| | if out is None: |
| | |
| | |
| | out = unselected_cond_outputs.get(prev_frame_idx, None) |
| | t_pos_and_prevs.append((t_pos, out)) |
| |
|
| | for t_pos, prev in t_pos_and_prevs: |
| | if prev is None: |
| | continue |
| | |
| | |
| | feats = prev["maskmem_features"].to(device, non_blocking=True) |
| | to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) |
| | |
| | maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) |
| | maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) |
| | |
| | maskmem_enc = ( |
| | maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] |
| | ) |
| | to_cat_memory_pos_embed.append(maskmem_enc) |
| |
|
| | |
| | if self.use_obj_ptrs_in_encoder: |
| | max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) |
| | |
| | |
| | if not self.training and self.only_obj_ptrs_in_the_past_for_eval: |
| | ptr_cond_outputs = { |
| | t: out |
| | for t, out in selected_cond_outputs.items() |
| | if (t >= frame_idx if track_in_reverse else t <= frame_idx) |
| | } |
| | else: |
| | ptr_cond_outputs = selected_cond_outputs |
| | pos_and_ptrs = [ |
| | |
| | (abs(frame_idx - t), out["obj_ptr"]) |
| | for t, out in ptr_cond_outputs.items() |
| | ] |
| | |
| | for t_diff in range(1, max_obj_ptrs_in_encoder): |
| | t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff |
| | if t < 0 or (num_frames is not None and t >= num_frames): |
| | break |
| | out = output_dict["non_cond_frame_outputs"].get( |
| | t, unselected_cond_outputs.get(t, None) |
| | ) |
| | if out is not None: |
| | pos_and_ptrs.append((t_diff, out["obj_ptr"])) |
| | |
| | if len(pos_and_ptrs) > 0: |
| | pos_list, ptrs_list = zip(*pos_and_ptrs) |
| | |
| | obj_ptrs = torch.stack(ptrs_list, dim=0) |
| | |
| | |
| | if self.add_tpos_enc_to_obj_ptrs: |
| | t_diff_max = max_obj_ptrs_in_encoder - 1 |
| | tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim |
| | obj_pos = torch.tensor(pos_list, device=device) |
| | obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) |
| | obj_pos = self.obj_ptr_tpos_proj(obj_pos) |
| | obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) |
| | else: |
| | obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) |
| | if self.mem_dim < C: |
| | |
| | obj_ptrs = obj_ptrs.reshape( |
| | -1, B, C // self.mem_dim, self.mem_dim |
| | ) |
| | obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) |
| | obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) |
| | to_cat_memory.append(obj_ptrs) |
| | to_cat_memory_pos_embed.append(obj_pos) |
| | num_obj_ptr_tokens = obj_ptrs.shape[0] |
| | else: |
| | num_obj_ptr_tokens = 0 |
| | else: |
| | |
| | if self.directly_add_no_mem_embed: |
| | |
| | pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed |
| | pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
| | return pix_feat_with_mem |
| |
|
| | |
| | to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] |
| | to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] |
| |
|
| | |
| | memory = torch.cat(to_cat_memory, dim=0) |
| | memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) |
| |
|
| | pix_feat_with_mem = self.memory_attention( |
| | curr=current_vision_feats, |
| | curr_pos=current_vision_pos_embeds, |
| | memory=memory, |
| | memory_pos=memory_pos_embed, |
| | num_obj_ptr_tokens=num_obj_ptr_tokens, |
| | ) |
| | |
| | pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) |
| | return pix_feat_with_mem |
| |
|
| | def _encode_new_memory( |
| | self, |
| | current_vision_feats, |
| | feat_sizes, |
| | pred_masks_high_res, |
| | is_mask_from_pts, |
| | ): |
| | """Encode the current image and its prediction into a memory feature.""" |
| | B = current_vision_feats[-1].size(1) |
| | C = self.hidden_dim |
| | H, W = feat_sizes[-1] |
| | |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) |
| | if self.non_overlap_masks_for_mem_enc and not self.training: |
| | |
| | |
| | |
| | pred_masks_high_res = self._apply_non_overlapping_constraints( |
| | pred_masks_high_res |
| | ) |
| | |
| | binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts |
| | if binarize and not self.training: |
| | mask_for_mem = (pred_masks_high_res > 0) |
| | else: |
| | |
| | mask_for_mem = torch.sigmoid(pred_masks_high_res) |
| | |
| | if self.sigmoid_scale_for_mem_enc != 1.0: |
| | mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc |
| | if self.sigmoid_bias_for_mem_enc != 0.0: |
| | mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc |
| | maskmem_out = self.memory_encoder( |
| | pix_feat, mask_for_mem, skip_mask_sigmoid=True |
| | ) |
| | maskmem_features = maskmem_out["vision_features"] |
| | maskmem_pos_enc = maskmem_out["vision_pos_enc"] |
| |
|
| | return maskmem_features, maskmem_pos_enc |
| |
|
| | def track_step( |
| | self, |
| | frame_idx, |
| | is_init_cond_frame, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | point_inputs, |
| | mask_inputs, |
| | output_dict, |
| | num_frames, |
| | track_in_reverse=False, |
| | |
| | |
| | |
| | |
| | |
| | run_mem_encoder=True, |
| | |
| | prev_sam_mask_logits=None, |
| | ): |
| | current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
| | |
| | if len(current_vision_feats) > 1: |
| | high_res_features = [ |
| | x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| | for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
| | ] |
| | else: |
| | high_res_features = None |
| | if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
| | |
| | |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
| | pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
| | sam_outputs = self._use_mask_as_output( |
| | pix_feat, high_res_features, mask_inputs |
| | ) |
| | else: |
| | |
| | pix_feat_with_mem = self._prepare_memory_conditioned_features( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats[-1:], |
| | current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
| | feat_sizes=feat_sizes[-1:], |
| | output_dict=output_dict, |
| | num_frames=num_frames, |
| | track_in_reverse=track_in_reverse, |
| | ) |
| | |
| | |
| | |
| | |
| | if prev_sam_mask_logits is not None: |
| | assert point_inputs is not None and mask_inputs is None |
| | mask_inputs = prev_sam_mask_logits |
| | multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
| | sam_outputs = self._forward_sam_heads( |
| | backbone_features=pix_feat_with_mem, |
| | point_inputs=point_inputs, |
| | mask_inputs=mask_inputs, |
| | high_res_features=high_res_features, |
| | multimask_output=multimask_output, |
| | ) |
| | ( |
| | _, |
| | _, |
| | _, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | _, |
| | ) = sam_outputs |
| |
|
| | current_out["pred_masks"] = low_res_masks |
| | current_out["pred_masks_high_res"] = high_res_masks |
| | current_out["obj_ptr"] = obj_ptr |
| |
|
| | |
| | |
| | if run_mem_encoder and self.num_maskmem > 0: |
| | high_res_masks_for_mem_enc = high_res_masks |
| | maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| | current_vision_feats=current_vision_feats, |
| | feat_sizes=feat_sizes, |
| | pred_masks_high_res=high_res_masks_for_mem_enc, |
| | is_mask_from_pts=(point_inputs is not None), |
| | ) |
| | current_out["maskmem_features"] = maskmem_features |
| | current_out["maskmem_pos_enc"] = maskmem_pos_enc |
| | else: |
| | current_out["maskmem_features"] = None |
| | current_out["maskmem_pos_enc"] = None |
| |
|
| | return current_out |
| |
|
| | def track_step_embed( |
| | self, |
| | frame_idx, |
| | is_init_cond_frame, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | box_embed, |
| | point_inputs, |
| | mask_inputs, |
| | output_dict, |
| | num_frames, |
| | track_in_reverse=False, |
| | |
| | |
| | |
| | |
| | |
| | run_mem_encoder=True, |
| | |
| | prev_sam_mask_logits=None, |
| | ): |
| | current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
| | |
| | if len(current_vision_feats) > 1: |
| | high_res_features = [ |
| | x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| | for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
| | ] |
| | else: |
| | high_res_features = None |
| | if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
| | |
| | |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
| | pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
| | sam_outputs = self._use_mask_as_output( |
| | pix_feat, high_res_features, mask_inputs |
| | ) |
| | else: |
| | |
| | pix_feat_with_mem = self._prepare_memory_conditioned_features( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats[-1:], |
| | current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
| | feat_sizes=feat_sizes[-1:], |
| | output_dict=output_dict, |
| | num_frames=num_frames, |
| | track_in_reverse=track_in_reverse, |
| | ) |
| | |
| | |
| | |
| | |
| | if prev_sam_mask_logits is not None: |
| | assert point_inputs is not None and mask_inputs is None |
| | mask_inputs = prev_sam_mask_logits |
| | multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
| | sam_outputs = self._forward_sam_heads_embed( |
| | backbone_features=pix_feat_with_mem, |
| | point_inputs=point_inputs, |
| | box_embed=box_embed, |
| | mask_inputs=mask_inputs, |
| | high_res_features=high_res_features, |
| | multimask_output=multimask_output, |
| | ) |
| | ( |
| | _, |
| | _, |
| | _, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | _, |
| | ) = sam_outputs |
| |
|
| | current_out["pred_masks"] = low_res_masks |
| | current_out["pred_masks_high_res"] = high_res_masks |
| | current_out["obj_ptr"] = obj_ptr |
| |
|
| | |
| | |
| | if run_mem_encoder and self.num_maskmem > 0: |
| | high_res_masks_for_mem_enc = high_res_masks |
| | maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| | current_vision_feats=current_vision_feats, |
| | feat_sizes=feat_sizes, |
| | pred_masks_high_res=high_res_masks_for_mem_enc, |
| | is_mask_from_pts=(point_inputs is not None), |
| | ) |
| | current_out["maskmem_features"] = maskmem_features |
| | current_out["maskmem_pos_enc"] = maskmem_pos_enc |
| | else: |
| | current_out["maskmem_features"] = None |
| | current_out["maskmem_pos_enc"] = None |
| |
|
| | return current_out |
| |
|
| |
|
| | def _forward_sam_heads_embed( |
| | self, |
| | backbone_features, |
| | point_inputs=None, |
| | box_embed = None, |
| | mask_inputs=None, |
| | high_res_features=None, |
| | multimask_output=False, |
| | ): |
| | """ |
| | Forward SAM prompt encoders and mask heads. |
| | |
| | Inputs: |
| | - backbone_features: image features of [B, C, H, W] shape |
| | - point_inputs: a dictionary with "point_coords" and "point_labels", where |
| | 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the |
| | absolute pixel-unit coordinate in (x, y) format of the P input points |
| | 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means |
| | positive clicks, 0 means negative clicks, and -1 means padding |
| | - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the |
| | same spatial size as the image. |
| | - high_res_features: either 1) None or 2) or a list of length 2 containing |
| | two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, |
| | which will be used as high-resolution feature maps for SAM decoder. |
| | - multimask_output: if it's True, we output 3 candidate masks and their 3 |
| | corresponding IoU estimates, and if it's False, we output only 1 mask and |
| | its corresponding IoU estimate. |
| | |
| | Outputs: |
| | - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if |
| | `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM |
| | output mask logits (before sigmoid) for the low-resolution masks, with 4x |
| | the resolution (1/4 stride) of the input backbone_features. |
| | - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 |
| | if `multimask_output=True` and M = 1 if `multimask_output=False`), |
| | upsampled from the low-resolution masks, with shape size as the image |
| | (stride is 1 pixel). |
| | - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 |
| | if `multimask_output=False`), the estimated IoU of each output mask. |
| | - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `low_res_multimasks`. |
| | - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `high_res_multimasks`. |
| | - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted |
| | based on the output token from the SAM mask decoder. |
| | """ |
| | B = backbone_features.size(0) |
| | device = backbone_features.device |
| | |
| | |
| | assert backbone_features.size(1) == self.sam_prompt_embed_dim |
| | assert backbone_features.size(2) == self.sam_image_embedding_size |
| | assert backbone_features.size(3) == self.sam_image_embedding_size |
| |
|
| | |
| | if point_inputs is not None: |
| | sam_point_coords = point_inputs["point_coords"] |
| | sam_point_labels = point_inputs["point_labels"] |
| | assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
| | else: |
| | |
| | sam_point_coords = torch.zeros(B, 1, 2, dtype=torch.bfloat16 , device=device) |
| | sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
| |
|
| | |
| | if mask_inputs is not None: |
| | |
| | |
| | assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
| | if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
| | sam_mask_prompt = F.interpolate( |
| | |
| | mask_inputs, |
| | size=self.sam_prompt_encoder.mask_input_size, |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | else: |
| | sam_mask_prompt = mask_inputs |
| | else: |
| | |
| | |
| | sam_mask_prompt = None |
| | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
| | points=(sam_point_coords, sam_point_labels), |
| | boxes=None, |
| | masks=sam_mask_prompt, |
| | ) |
| | if sparse_embeddings is not None and point_inputs is not None: |
| | sparse_embeddings = box_embed |
| | |
| | ( |
| | low_res_multimasks, |
| | ious, |
| | sam_output_tokens, |
| | object_score_logits, |
| | ) = self.sam_mask_decoder( |
| | image_embeddings=backbone_features, |
| | image_pe=self.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=multimask_output, |
| | repeat_image=False, |
| | high_res_features=high_res_features, |
| | ) |
| | |
| | |
| | |
| |
|
| | if self.pred_obj_scores: |
| | is_obj_appearing = object_score_logits > 0 |
| |
|
| | |
| | |
| | low_res_multimasks = torch.where( |
| | is_obj_appearing[:, None, None], |
| | low_res_multimasks, |
| | NO_OBJ_SCORE, |
| | ) |
| |
|
| | |
| | |
| | low_res_multimasks = low_res_multimasks |
| | high_res_multimasks = F.interpolate( |
| | low_res_multimasks, |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| |
|
| | sam_output_token = sam_output_tokens[:, 0] |
| | if multimask_output: |
| | |
| | best_iou_inds = torch.argmax(ious, dim=-1) |
| | batch_inds = torch.arange(B, device=device) |
| | low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | if sam_output_tokens.size(1) > 1: |
| | sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
| | else: |
| | low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
| |
|
| | |
| | obj_ptr = self.obj_ptr_proj(sam_output_token) |
| | if self.pred_obj_scores: |
| | |
| | if self.soft_no_obj_ptr: |
| | |
| | assert not self.teacher_force_obj_scores_for_mem |
| | lambda_is_obj_appearing = object_score_logits.sigmoid() |
| | else: |
| | lambda_is_obj_appearing = is_obj_appearing.float() |
| |
|
| | if self.fixed_no_obj_ptr: |
| | obj_ptr = lambda_is_obj_appearing * obj_ptr |
| | |
| | obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
| |
|
| | return ( |
| | low_res_multimasks, |
| | high_res_multimasks, |
| | ious, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | object_score_logits, |
| | ) |
| |
|
| |
|
| | def _use_multimask(self, is_init_cond_frame, point_inputs): |
| | """Whether to use multimask output in the SAM head.""" |
| | num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) |
| | multimask_output = ( |
| | self.multimask_output_in_sam |
| | and (is_init_cond_frame or self.multimask_output_for_tracking) |
| | and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) |
| | ) |
| | return multimask_output |
| |
|
| | def _apply_non_overlapping_constraints(self, pred_masks): |
| | """ |
| | Apply non-overlapping constraints to the object scores in pred_masks. Here we |
| | keep only the highest scoring object at each spatial location in pred_masks. |
| | """ |
| | batch_size = pred_masks.size(0) |
| | if batch_size == 1: |
| | return pred_masks |
| |
|
| | device = pred_masks.device |
| | |
| | max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) |
| | |
| | batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] |
| | keep = max_obj_inds == batch_obj_inds |
| | |
| | |
| | pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) |
| | return pred_masks |
| |
|