| | import torch |
| | import torch.nn.functional as F |
| | from .sam2_implementation.modeling.sam2_base import SAM2Base as _SAM2Base |
| |
|
| | class SAM2Base(_SAM2Base): |
| |
|
| | def track_step( |
| | self, |
| | frame_idx, |
| | is_init_cond_frame, |
| | current_vision_feats, |
| | current_vision_pos_embeds, |
| | feat_sizes, |
| | point_inputs, |
| | mask_inputs, |
| | output_dict, |
| | num_frames, |
| | track_in_reverse=False, |
| | |
| | |
| | |
| | |
| | |
| | run_mem_encoder=True, |
| | |
| | prev_sam_mask_logits=None, |
| | |
| | language_embd=None, |
| | ): |
| | current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} |
| | |
| | if len(current_vision_feats) > 1: |
| | high_res_features = [ |
| | x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) |
| | for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) |
| | ] |
| | else: |
| | high_res_features = None |
| | if mask_inputs is not None and self.use_mask_input_as_output_without_sam: |
| | |
| | |
| | pix_feat = current_vision_feats[-1].permute(1, 2, 0) |
| | pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) |
| | sam_outputs = self._use_mask_as_output( |
| | pix_feat, high_res_features, mask_inputs |
| | ) |
| | else: |
| | |
| | pix_feat_with_mem = self._prepare_memory_conditioned_features( |
| | frame_idx=frame_idx, |
| | is_init_cond_frame=is_init_cond_frame, |
| | current_vision_feats=current_vision_feats[-1:], |
| | current_vision_pos_embeds=current_vision_pos_embeds[-1:], |
| | feat_sizes=feat_sizes[-1:], |
| | output_dict=output_dict, |
| | num_frames=num_frames, |
| | track_in_reverse=track_in_reverse, |
| | ) |
| | |
| | |
| | |
| | |
| | if prev_sam_mask_logits is not None: |
| | assert point_inputs is not None and mask_inputs is None |
| | mask_inputs = prev_sam_mask_logits |
| | multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) |
| | sam_outputs = self._forward_sam_heads( |
| | backbone_features=pix_feat_with_mem, |
| | point_inputs=point_inputs, |
| | mask_inputs=mask_inputs, |
| | high_res_features=high_res_features, |
| | multimask_output=multimask_output, |
| | |
| | language_embd=language_embd, |
| | ) |
| | ( |
| | _, |
| | _, |
| | _, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | _, |
| | ) = sam_outputs |
| |
|
| | current_out["pred_masks"] = low_res_masks |
| | current_out["pred_masks_high_res"] = high_res_masks |
| | current_out["obj_ptr"] = obj_ptr |
| |
|
| | |
| | |
| | if run_mem_encoder and self.num_maskmem > 0: |
| | high_res_masks_for_mem_enc = high_res_masks |
| | maskmem_features, maskmem_pos_enc = self._encode_new_memory( |
| | current_vision_feats=current_vision_feats, |
| | feat_sizes=feat_sizes, |
| | pred_masks_high_res=high_res_masks_for_mem_enc, |
| | is_mask_from_pts=(point_inputs is not None), |
| | ) |
| | current_out["maskmem_features"] = maskmem_features |
| | current_out["maskmem_pos_enc"] = maskmem_pos_enc |
| | else: |
| | current_out["maskmem_features"] = None |
| | current_out["maskmem_pos_enc"] = None |
| |
|
| | return current_out |
| |
|
| |
|
| | def _forward_sam_heads( |
| | self, |
| | backbone_features, |
| | point_inputs=None, |
| | mask_inputs=None, |
| | high_res_features=None, |
| | multimask_output=False, |
| | |
| | language_embd=None, |
| | ): |
| | """ |
| | Forward SAM prompt encoders and mask heads. |
| | |
| | Inputs: |
| | - backbone_features: image features of [B, C, H, W] shape |
| | - point_inputs: a dictionary with "point_coords" and "point_labels", where |
| | 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the |
| | absolute pixel-unit coordinate in (x, y) format of the P input points |
| | 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means |
| | positive clicks, 0 means negative clicks, and -1 means padding |
| | - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the |
| | same spatial size as the image. |
| | - high_res_features: either 1) None or 2) or a list of length 2 containing |
| | two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, |
| | which will be used as high-resolution feature maps for SAM decoder. |
| | - multimask_output: if it's True, we output 3 candidate masks and their 3 |
| | corresponding IoU estimates, and if it's False, we output only 1 mask and |
| | its corresponding IoU estimate. |
| | |
| | Outputs: |
| | - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if |
| | `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM |
| | output mask logits (before sigmoid) for the low-resolution masks, with 4x |
| | the resolution (1/4 stride) of the input backbone_features. |
| | - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 |
| | if `multimask_output=True` and M = 1 if `multimask_output=False`), |
| | upsampled from the low-resolution masks, with shape size as the image |
| | (stride is 1 pixel). |
| | - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 |
| | if `multimask_output=False`), the estimated IoU of each output mask. |
| | - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `low_res_multimasks`. |
| | - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. |
| | If `multimask_output=True`, it's the mask with the highest IoU estimate. |
| | If `multimask_output=False`, it's the same as `high_res_multimasks`. |
| | - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted |
| | based on the output token from the SAM mask decoder. |
| | """ |
| | B = backbone_features.size(0) |
| | device = backbone_features.device |
| | assert backbone_features.size(1) == self.sam_prompt_embed_dim |
| | assert backbone_features.size(2) == self.sam_image_embedding_size |
| | assert backbone_features.size(3) == self.sam_image_embedding_size |
| |
|
| | |
| | if point_inputs is not None: |
| | sam_point_coords = point_inputs["point_coords"] |
| | sam_point_labels = point_inputs["point_labels"] |
| | assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B |
| | else: |
| | |
| | sam_point_coords = torch.zeros(B, 1, 2, device=device) |
| | sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) |
| |
|
| | |
| | if mask_inputs is not None: |
| | |
| | |
| | assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) |
| | if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: |
| | sam_mask_prompt = F.interpolate( |
| | mask_inputs.float(), |
| | size=self.sam_prompt_encoder.mask_input_size, |
| | align_corners=False, |
| | mode="bilinear", |
| | antialias=True, |
| | ) |
| | else: |
| | sam_mask_prompt = mask_inputs |
| | else: |
| | |
| | |
| | sam_mask_prompt = None |
| |
|
| | sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( |
| | points=(sam_point_coords, sam_point_labels), |
| | boxes=None, |
| | masks=sam_mask_prompt, |
| | ) |
| |
|
| | |
| | if language_embd is not None: |
| | |
| | assert sparse_embeddings.size(0) == language_embd.size(0) |
| | assert sparse_embeddings.size(2) == language_embd.size(2) |
| | sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1) |
| |
|
| | ( |
| | low_res_multimasks, |
| | ious, |
| | sam_output_tokens, |
| | object_score_logits, |
| | ) = self.sam_mask_decoder( |
| | image_embeddings=backbone_features, |
| | image_pe=self.sam_prompt_encoder.get_dense_pe(), |
| | sparse_prompt_embeddings=sparse_embeddings, |
| | dense_prompt_embeddings=dense_embeddings, |
| | multimask_output=multimask_output, |
| | repeat_image=False, |
| | high_res_features=high_res_features, |
| | ) |
| | if self.pred_obj_scores: |
| | is_obj_appearing = object_score_logits > 0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | low_res_multimasks = low_res_multimasks.float() |
| | high_res_multimasks = F.interpolate( |
| | low_res_multimasks, |
| | size=(self.image_size, self.image_size), |
| | mode="bilinear", |
| | align_corners=False, |
| | ) |
| |
|
| | sam_output_token = sam_output_tokens[:, 0] |
| | if multimask_output: |
| | |
| | best_iou_inds = torch.argmax(ious, dim=-1) |
| | batch_inds = torch.arange(B, device=device) |
| | low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) |
| | if sam_output_tokens.size(1) > 1: |
| | sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] |
| | else: |
| | low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks |
| |
|
| | |
| | obj_ptr = self.obj_ptr_proj(sam_output_token) |
| | if self.pred_obj_scores: |
| | |
| | if self.soft_no_obj_ptr: |
| | |
| | assert not self.teacher_force_obj_scores_for_mem |
| | lambda_is_obj_appearing = object_score_logits.sigmoid() |
| | else: |
| | lambda_is_obj_appearing = is_obj_appearing.float() |
| |
|
| | if self.fixed_no_obj_ptr: |
| | obj_ptr = lambda_is_obj_appearing * obj_ptr |
| | obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr |
| |
|
| | return ( |
| | low_res_multimasks, |
| | high_res_multimasks, |
| | ious, |
| | low_res_masks, |
| | high_res_masks, |
| | obj_ptr, |
| | object_score_logits, |
| | ) |
| |
|