Spaces:
Sleeping
Sleeping
| """ | |
| Secure Model Server - Protects model weights from extraction | |
| Never expose: | |
| - File paths to checkpoints | |
| - Model architecture details | |
| - Debug routes | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import numpy as np | |
| from pathlib import Path | |
| from typing import Tuple, Optional | |
| # Secure path resolution (not hardcoded) | |
| def get_model_checkpoint_path(): | |
| """Get checkpoint path secretly, never expose to client""" | |
| base_dir = Path(__file__).parent | |
| checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2.1_hiera_small.pt" | |
| if not checkpoint.exists(): | |
| raise FileNotFoundError(f"Model checkpoint not found") | |
| return str(checkpoint) | |
| def get_finetuned_weights_path(): | |
| """Get fine-tuned weights path secretly, never expose to client | |
| Attempts to download from Hugging Face if local copy doesn't exist | |
| and HF_TOKEN is available. | |
| """ | |
| base_dir = Path(__file__).parent | |
| checkpoint_dir = base_dir / "segment-anything-2" / "checkpoints" | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| weights = checkpoint_dir / "VREyeSAM_uncertainity_best.torch" | |
| # If weights already exist locally, return path | |
| if weights.exists(): | |
| return str(weights) | |
| # Try to download from Hugging Face using HF_TOKEN | |
| hf_token = os.getenv('HF_TOKEN', '') | |
| if hf_token: | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print("Downloading VREyeSAM weights from Hugging Face...") | |
| checkpoint_path = hf_hub_download( | |
| repo_id='devnagaich/VREyeSAM', | |
| filename='VREyeSAM_uncertainity_best.torch', | |
| token=hf_token, | |
| cache_dir=str(checkpoint_dir) | |
| ) | |
| print(f"Successfully downloaded VREyeSAM weights") | |
| return checkpoint_path | |
| except Exception as e: | |
| print(f"Warning: Could not download VREyeSAM weights: {e}") | |
| # If download fails or no token, return path anyway (may exist from upload) | |
| if weights.exists(): | |
| return str(weights) | |
| # Last resort - raise error | |
| raise FileNotFoundError(f"VREyeSAM weights not found and could not download") | |
| def get_model_config_path(): | |
| """Get model config path secretly, never expose to client""" | |
| return "configs/sam2.1/sam2.1_hiera_s.yaml" | |
| class ProtectedModelServer: | |
| """ | |
| Encapsulates model loading and inference | |
| Only exposes inference API, never raw weights or paths | |
| """ | |
| _instance = None # Singleton pattern | |
| _model = None | |
| _predictor = None | |
| def __new__(cls): | |
| # Singleton: only one instance ever | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| """Initialize model (only once)""" | |
| if self._predictor is None: | |
| self._load_model() | |
| def _load_model(self): | |
| """Load model weights securely - never called from frontend""" | |
| try: | |
| # Add segment-anything-2 to path (internally only) | |
| base_dir = Path(__file__).parent | |
| sam2_path = base_dir / "segment-anything-2" | |
| if not sam2_path.exists(): | |
| raise FileNotFoundError(f"SAM2 installation not found at {sam2_path}") | |
| sys.path.insert(0, str(sam2_path)) | |
| try: | |
| from sam2.build_sam import build_sam2 | |
| from sam2.sam2_image_predictor import SAM2ImagePredictor | |
| except ImportError as e: | |
| raise ImportError("SAM2 not properly installed. Check build logs.") from e | |
| # Get paths internally - NEVER sent to client | |
| model_cfg = get_model_config_path() | |
| sam2_checkpoint = get_model_checkpoint_path() | |
| # Load device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model on device: {device}") | |
| # Load base SAM2 model | |
| print(f"Loading SAM2 from {sam2_checkpoint}") | |
| self._model = build_sam2(model_cfg, sam2_checkpoint, device=device) | |
| self._predictor = SAM2ImagePredictor(self._model) | |
| # Try to load fine-tuned weights if available | |
| try: | |
| fine_tuned_weights = get_finetuned_weights_path() | |
| print(f"Loading fine-tuned weights from {fine_tuned_weights}") | |
| state_dict = torch.load(fine_tuned_weights, map_location=device) | |
| self._predictor.model.load_state_dict(state_dict) | |
| print("Fine-tuned weights loaded successfully") | |
| except FileNotFoundError: | |
| print("Warning: Fine-tuned weights not found. Using base SAM2 model.") | |
| print("To use fine-tuned model, upload VREyeSAM_uncertainity_best.torch to Space Files") | |
| except Exception as e: | |
| print(f"Warning: Could not load fine-tuned weights: {e}") | |
| print("Continuing with base SAM2 model") | |
| # Model is now loaded - weights are NOT accessible to clients | |
| self._predictor.model.eval() | |
| print("Model loaded successfully") | |
| return True | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise RuntimeError(f"Model initialization failed: {str(e)}") from e | |
| def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]: | |
| """ | |
| Perform iris segmentation | |
| Args: | |
| image: Input image (numpy array) | |
| num_samples: Number of random points for inference | |
| Returns: | |
| binary_mask: Binary segmentation mask | |
| prob_mask: Probability map | |
| """ | |
| if self._predictor is None: | |
| raise RuntimeError("Model not initialized") | |
| try: | |
| # Generate random points for inference | |
| input_points = np.random.randint(0, min(image.shape[:2]), (num_samples, 1, 2)) | |
| # Inference | |
| with torch.no_grad(): | |
| self._predictor.set_image(image) | |
| masks, scores, _ = self._predictor.predict( | |
| point_coords=input_points, | |
| point_labels=np.ones([input_points.shape[0], 1]) | |
| ) | |
| # Convert to numpy | |
| np_masks = np.array(masks[:, 0]).astype(np.float32) | |
| np_scores = scores[:, 0] | |
| # Normalize scores | |
| score_sum = np.sum(np_scores) | |
| if score_sum > 0: | |
| normalized_scores = np_scores / score_sum | |
| else: | |
| normalized_scores = np.ones_like(np_scores) / len(np_scores) | |
| # Generate probabilistic mask | |
| prob_mask = np.sum(np_masks * normalized_scores[:, None, None], axis=0) | |
| prob_mask = np.clip(prob_mask, 0, 1) | |
| # Threshold to get binary mask | |
| binary_mask = (prob_mask > 0.2).astype(np.uint8) | |
| return binary_mask, prob_mask | |
| except Exception as e: | |
| raise RuntimeError(f"Inference failed") from e | |
| def get_predictor() -> ProtectedModelServer: | |
| """Get singleton model instance""" | |
| return ProtectedModelServer() | |