| | |
| | """ |
| | SAM2 Loader with Hugging Face Hub integration |
| | Provides SAM2Predictor class with memory management and optimization features |
| | Updated to use Hugging Face Hub models instead of direct downloads |
| | (Enhanced logging and exception safety) |
| | """ |
| |
|
| | import os |
| | import gc |
| | import torch |
| | import logging |
| | import numpy as np |
| | from pathlib import Path |
| | from typing import Optional, Any, Dict, List, Tuple |
| |
|
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| | class SAM2Predictor: |
| | """ |
| | T4-optimized SAM2 video predictor wrapper with memory management |
| | """ |
| | |
| | def __init__(self, device: torch.device, model_size: str = "small"): |
| | logger.info(f"[SAM2Predictor.__init__] device={device}, model_size={model_size}") |
| | self.device = device |
| | self.model_size = model_size |
| | self.predictor = None |
| | self.model = None |
| | self._load_predictor() |
| | |
| | def _load_predictor(self): |
| | """Load SAM2 predictor with Hugging Face Hub integration""" |
| | try: |
| | logger.info("[SAM2Predictor._load_predictor] Loading SAM2 predictor...") |
| | from sam2.build_sam import build_sam2_video_predictor |
| | |
| | checkpoint_path = self._get_hf_checkpoint() |
| | if not checkpoint_path: |
| | logger.error(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
| | raise RuntimeError(f"Failed to get SAM2 {self.model_size} checkpoint from HF Hub") |
| | |
| | model_cfg = self._get_model_config() |
| | logger.info(f"[SAM2Predictor._load_predictor] Using model_cfg: {model_cfg}") |
| | |
| | self.predictor = build_sam2_video_predictor(model_cfg, checkpoint_path, device=self.device) |
| | self._optimize_for_t4() |
| | logger.info(f"SAM2 {self.model_size} predictor loaded successfully from HF Hub") |
| | except ImportError as e: |
| | logger.error(f"SAM2 import failed: {e}") |
| | raise RuntimeError("SAM2 not available - check sam2 installation") |
| | except Exception as e: |
| | logger.error(f"SAM2 loading failed: {e}", exc_info=True) |
| | raise |
| | |
| | def _get_hf_checkpoint(self) -> Optional[str]: |
| | """Download checkpoint from Hugging Face Hub""" |
| | try: |
| | logger.info(f"[SAM2Predictor._get_hf_checkpoint] Downloading checkpoint...") |
| | from huggingface_hub import hf_hub_download |
| | |
| | repo_mapping = { |
| | "small": "facebook/sam2-hiera-small", |
| | "base": "facebook/sam2-hiera-base-plus", |
| | "large": "facebook/sam2-hiera-large" |
| | } |
| | filename_mapping = { |
| | "small": "sam2_hiera_small.pt", |
| | "base": "sam2_hiera_base_plus.pt", |
| | "large": "sam2_hiera_large.pt" |
| | } |
| | if self.model_size not in repo_mapping: |
| | logger.error(f"Unknown model size: {self.model_size}") |
| | return None |
| | repo_id = repo_mapping[self.model_size] |
| | filename = filename_mapping[self.model_size] |
| | logger.info(f"Downloading SAM2 {self.model_size} from HF Hub: {repo_id}") |
| | checkpoint_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=filename, |
| | cache_dir=None, |
| | force_download=False, |
| | token=None |
| | ) |
| | logger.info(f"SAM2 checkpoint downloaded to: {checkpoint_path}") |
| | return checkpoint_path |
| | except Exception as e: |
| | logger.error(f"HF Hub download failed: {e}") |
| | return self._fallback_local_checkpoint() |
| | |
| | def _fallback_local_checkpoint(self) -> Optional[str]: |
| | """Fallback to local checkpoint files""" |
| | try: |
| | checkpoint_path = f"./checkpoints/sam2_hiera_{self.model_size}.pt" |
| | if Path(checkpoint_path).exists(): |
| | logger.info(f"Using local checkpoint: {checkpoint_path}") |
| | return checkpoint_path |
| | else: |
| | logger.error(f"Local checkpoint not found: {checkpoint_path}") |
| | return None |
| | except Exception as e: |
| | logger.error(f"Local checkpoint fallback failed: {e}") |
| | return None |
| | |
| | def _get_model_config(self) -> str: |
| | """Get the appropriate model config file""" |
| | config_mapping = { |
| | "small": "sam2_hiera_s.yaml", |
| | "base": "sam2_hiera_b+.yaml", |
| | "large": "sam2_hiera_l.yaml" |
| | } |
| | cfg = config_mapping.get(self.model_size, "sam2_hiera_s.yaml") |
| | logger.info(f"[SAM2Predictor._get_model_config] Returning config: {cfg}") |
| | return cfg |
| | |
| | def _optimize_for_t4(self): |
| | """Apply T4-specific optimizations""" |
| | try: |
| | logger.info("[SAM2Predictor._optimize_for_t4] Optimizing for T4...") |
| | if hasattr(self.predictor, "model") and self.predictor.model is not None: |
| | self.model = self.predictor.model |
| | self.model = self.model.half().to(self.device) |
| | self.model = self.model.to(memory_format=torch.channels_last) |
| | logger.info("SAM2: fp16 + channels_last applied for T4 optimization") |
| | except Exception as e: |
| | logger.warning(f"SAM2 T4 optimization warning: {e}", exc_info=True) |
| | |
| | def init_state(self, video_path: str): |
| | logger.info(f"[SAM2Predictor.init_state] Initializing video state for: {video_path}") |
| | if self.predictor is None: |
| | logger.error("Predictor not loaded in init_state") |
| | raise RuntimeError("Predictor not loaded") |
| | try: |
| | state = self.predictor.init_state(video_path=video_path) |
| | logger.info("[SAM2Predictor.init_state] Video state initialized OK") |
| | return state |
| | except Exception as e: |
| | logger.error(f"Failed to initialize video state: {e}", exc_info=True) |
| | raise |
| | |
| | def add_new_points(self, inference_state, frame_idx: int, obj_id: int, |
| | points: np.ndarray, labels: np.ndarray): |
| | logger.info(f"[SAM2Predictor.add_new_points] Adding points for frame {frame_idx}, obj {obj_id}") |
| | if self.predictor is None: |
| | logger.error("Predictor not loaded in add_new_points") |
| | raise RuntimeError("Predictor not loaded") |
| | try: |
| | out = self.predictor.add_new_points( |
| | inference_state=inference_state, |
| | frame_idx=frame_idx, |
| | obj_id=obj_id, |
| | points=points, |
| | labels=labels |
| | ) |
| | logger.info(f"[SAM2Predictor.add_new_points] Points added OK") |
| | return out |
| | except Exception as e: |
| | logger.error(f"Failed to add new points: {e}", exc_info=True) |
| | raise |
| | |
| | def add_new_points_or_box(self, inference_state, frame_idx: int, obj_id: int, |
| | points: np.ndarray, labels: np.ndarray, clear_old_points: bool = True): |
| | logger.info(f"[SAM2Predictor.add_new_points_or_box] Adding points/box for frame {frame_idx}, obj {obj_id}") |
| | if self.predictor is None: |
| | logger.error("Predictor not loaded in add_new_points_or_box") |
| | raise RuntimeError("Predictor not loaded") |
| | try: |
| | if hasattr(self.predictor, 'add_new_points_or_box'): |
| | out = self.predictor.add_new_points_or_box( |
| | inference_state=inference_state, |
| | frame_idx=frame_idx, |
| | obj_id=obj_id, |
| | points=points, |
| | labels=labels, |
| | clear_old_points=clear_old_points |
| | ) |
| | logger.info(f"[SAM2Predictor.add_new_points_or_box] Used new API, points/box added OK") |
| | return out |
| | else: |
| | out = self.predictor.add_new_points( |
| | inference_state=inference_state, |
| | frame_idx=frame_idx, |
| | obj_id=obj_id, |
| | points=points, |
| | labels=labels |
| | ) |
| | logger.info(f"[SAM2Predictor.add_new_points_or_box] Used fallback, points added OK") |
| | return out |
| | except Exception as e: |
| | logger.error(f"Failed to add new points or box: {e}", exc_info=True) |
| | raise |
| | |
| | def propagate_in_video(self, inference_state, scale: float = 1.0, **kwargs): |
| | logger.info(f"[SAM2Predictor.propagate_in_video] Propagating in video...") |
| | if self.predictor is None: |
| | logger.error("Predictor not loaded in propagate_in_video") |
| | raise RuntimeError("Predictor not loaded") |
| | try: |
| | out = self.predictor.propagate_in_video(inference_state, **kwargs) |
| | logger.info(f"[SAM2Predictor.propagate_in_video] Propagation OK") |
| | return out |
| | except Exception as e: |
| | logger.error(f"Failed to propagate in video: {e}", exc_info=True) |
| | raise |
| | |
| | def prune_state(self, inference_state, keep: int): |
| | logger.info(f"[SAM2Predictor.prune_state] Pruning state to keep {keep} frames...") |
| | try: |
| | if hasattr(inference_state, 'cached_features'): |
| | cached_keys = list(inference_state.cached_features.keys()) |
| | if len(cached_keys) > keep: |
| | keys_to_remove = cached_keys[:-keep] |
| | for key in keys_to_remove: |
| | if key in inference_state.cached_features: |
| | del inference_state.cached_features[key] |
| | logger.debug(f"Pruned {len(keys_to_remove)} old cached features") |
| | if hasattr(inference_state, 'point_inputs_per_obj'): |
| | for obj_id in list(inference_state.point_inputs_per_obj.keys()): |
| | obj_inputs = inference_state.point_inputs_per_obj[obj_id] |
| | if len(obj_inputs) > keep: |
| | recent_keys = sorted(obj_inputs.keys())[-keep:] |
| | new_inputs = {k: obj_inputs[k] for k in recent_keys} |
| | inference_state.point_inputs_per_obj[obj_id] = new_inputs |
| | if self.device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | except Exception as e: |
| | logger.debug(f"State pruning warning: {e}", exc_info=True) |
| | |
| | def clear_memory(self): |
| | logger.info("[SAM2Predictor.clear_memory] Clearing GPU memory") |
| | try: |
| | if self.device.type == 'cuda': |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | torch.cuda.ipc_collect() |
| | gc.collect() |
| | except Exception as e: |
| | logger.warning(f"Memory clearing warning: {e}", exc_info=True) |
| | |
| | def get_memory_usage(self) -> Dict[str, float]: |
| | logger.info("[SAM2Predictor.get_memory_usage] Checking memory usage") |
| | if self.device.type != 'cuda': |
| | return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
| | try: |
| | allocated = torch.cuda.memory_allocated(self.device) / (1024**3) |
| | reserved = torch.cuda.memory_reserved(self.device) / (1024**3) |
| | free, total = torch.cuda.mem_get_info(self.device) |
| | free_gb = free / (1024**3) |
| | return { |
| | "allocated_gb": allocated, |
| | "reserved_gb": reserved, |
| | "free_gb": free_gb, |
| | "total_gb": total / (1024**3) |
| | } |
| | except Exception as e: |
| | logger.warning(f"Error checking memory usage: {e}", exc_info=True) |
| | return {"allocated_gb": 0.0, "reserved_gb": 0.0, "free_gb": 0.0} |
| | |
| | def __del__(self): |
| | logger.info("[SAM2Predictor.__del__] Cleaning up...") |
| | try: |
| | if hasattr(self, 'predictor') and self.predictor is not None: |
| | del self.predictor |
| | if hasattr(self, 'model') and self.model is not None: |
| | del self.model |
| | self.clear_memory() |
| | except Exception as e: |
| | logger.warning(f"Error in __del__: {e}", exc_info=True) |
| |
|