| """ |
| SAM2 Base Segmenter |
| Adapted from MatAnyone demo |
| """ |
|
|
| import sys |
| sys.path.append("/home/cvlab19/project/samuel/CVPR/sam2") |
|
|
| import torch |
| import numpy as np |
| from sam2.build_sam import build_sam2_video_predictor |
|
|
|
|
| class BaseSegmenter: |
| def __init__(self, SAM_checkpoint, model_type, device): |
| """ |
| Initialize SAM2 segmenter |
| |
| Args: |
| SAM_checkpoint: Path to SAM2 checkpoint |
| model_type: SAM2 model config file |
| device: Device to run on |
| """ |
| self.device = device |
| self.model_type = model_type |
| |
| |
| self.sam_predictor = build_sam2_video_predictor( |
| config_file=model_type, |
| ckpt_path=SAM_checkpoint, |
| device=device |
| ) |
| |
| self.orignal_image = None |
| self.inference_state = None |
| |
| def set_image(self, image: np.ndarray): |
| """Set the current image for segmentation""" |
| self.orignal_image = image |
| |
| def reset_image(self): |
| """Reset the current image""" |
| self.orignal_image = None |
| self.inference_state = None |
| |
| def predict(self, prompts, prompt_type, multimask=True): |
| """ |
| Predict mask from prompts |
| |
| Args: |
| prompts: Dictionary with point_coords, point_labels, mask_input |
| prompt_type: 'point' or 'both' |
| multimask: Whether to return multiple masks |
| |
| Returns: |
| masks, scores, logits |
| """ |
| |
| |
| |
| |
| |
| h, w = self.orignal_image.shape[:2] |
| dummy_mask = np.zeros((h, w), dtype=bool) |
| dummy_score = np.array([1.0]) |
| dummy_logit = np.zeros((h, w), dtype=np.float32) |
| |
| return np.array([dummy_mask]), dummy_score, np.array([dummy_logit]) |
|
|