""" Secure Model Server - Protects model weights from extraction Never expose: - File paths to checkpoints - Model architecture details - Debug routes """ import os import sys import torch import numpy as np from pathlib import Path from typing import Tuple, Optional # Secure path resolution (not hardcoded) def get_model_checkpoint_path(): """Get checkpoint path secretly, never expose to client""" base_dir = Path(__file__).parent checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2.1_hiera_small.pt" if not checkpoint.exists(): raise FileNotFoundError(f"Model checkpoint not found") return str(checkpoint) def get_finetuned_weights_path(): """Get fine-tuned weights path secretly, never expose to client Attempts to download from Hugging Face if local copy doesn't exist and HF_TOKEN is available. """ base_dir = Path(__file__).parent checkpoint_dir = base_dir / "segment-anything-2" / "checkpoints" checkpoint_dir.mkdir(parents=True, exist_ok=True) weights = checkpoint_dir / "VREyeSAM_uncertainity_best.torch" # If weights already exist locally, return path if weights.exists(): return str(weights) # Try to download from Hugging Face using HF_TOKEN hf_token = os.getenv('HF_TOKEN', '') if hf_token: try: from huggingface_hub import hf_hub_download print("Downloading VREyeSAM weights from Hugging Face...") checkpoint_path = hf_hub_download( repo_id='devnagaich/VREyeSAM', filename='VREyeSAM_uncertainity_best.torch', token=hf_token, cache_dir=str(checkpoint_dir) ) print(f"Successfully downloaded VREyeSAM weights") return checkpoint_path except Exception as e: print(f"Warning: Could not download VREyeSAM weights: {e}") # If download fails or no token, return path anyway (may exist from upload) if weights.exists(): return str(weights) # Last resort - raise error raise FileNotFoundError(f"VREyeSAM weights not found and could not download") def get_model_config_path(): """Get model config path secretly, never expose to client""" return "configs/sam2.1/sam2.1_hiera_s.yaml" class ProtectedModelServer: """ Encapsulates model loading and inference Only exposes inference API, never raw weights or paths """ _instance = None # Singleton pattern _model = None _predictor = None def __new__(cls): # Singleton: only one instance ever if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): """Initialize model (only once)""" if self._predictor is None: self._load_model() def _load_model(self): """Load model weights securely - never called from frontend""" try: # Add segment-anything-2 to path (internally only) base_dir = Path(__file__).parent sam2_path = base_dir / "segment-anything-2" if not sam2_path.exists(): raise FileNotFoundError(f"SAM2 installation not found at {sam2_path}") sys.path.insert(0, str(sam2_path)) try: from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor except ImportError as e: raise ImportError("SAM2 not properly installed. Check build logs.") from e # Get paths internally - NEVER sent to client model_cfg = get_model_config_path() sam2_checkpoint = get_model_checkpoint_path() # Load device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Load base SAM2 model print(f"Loading SAM2 from {sam2_checkpoint}") self._model = build_sam2(model_cfg, sam2_checkpoint, device=device) self._predictor = SAM2ImagePredictor(self._model) # Try to load fine-tuned weights if available try: fine_tuned_weights = get_finetuned_weights_path() print(f"Loading fine-tuned weights from {fine_tuned_weights}") state_dict = torch.load(fine_tuned_weights, map_location=device) self._predictor.model.load_state_dict(state_dict) print("Fine-tuned weights loaded successfully") except FileNotFoundError: print("Warning: Fine-tuned weights not found. Using base SAM2 model.") print("To use fine-tuned model, upload VREyeSAM_uncertainity_best.torch to Space Files") except Exception as e: print(f"Warning: Could not load fine-tuned weights: {e}") print("Continuing with base SAM2 model") # Model is now loaded - weights are NOT accessible to clients self._predictor.model.eval() print("Model loaded successfully") return True except Exception as e: print(f"Error loading model: {e}") import traceback traceback.print_exc() raise RuntimeError(f"Model initialization failed: {str(e)}") from e def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]: """ Perform iris segmentation Args: image: Input image (numpy array) num_samples: Number of random points for inference Returns: binary_mask: Binary segmentation mask prob_mask: Probability map """ if self._predictor is None: raise RuntimeError("Model not initialized") try: # Generate random points for inference input_points = np.random.randint(0, min(image.shape[:2]), (num_samples, 1, 2)) # Inference with torch.no_grad(): self._predictor.set_image(image) masks, scores, _ = self._predictor.predict( point_coords=input_points, point_labels=np.ones([input_points.shape[0], 1]) ) # Convert to numpy np_masks = np.array(masks[:, 0]).astype(np.float32) np_scores = scores[:, 0] # Normalize scores score_sum = np.sum(np_scores) if score_sum > 0: normalized_scores = np_scores / score_sum else: normalized_scores = np.ones_like(np_scores) / len(np_scores) # Generate probabilistic mask prob_mask = np.sum(np_masks * normalized_scores[:, None, None], axis=0) prob_mask = np.clip(prob_mask, 0, 1) # Threshold to get binary mask binary_mask = (prob_mask > 0.2).astype(np.uint8) return binary_mask, prob_mask except Exception as e: raise RuntimeError(f"Inference failed") from e def get_predictor() -> ProtectedModelServer: """Get singleton model instance""" return ProtectedModelServer()