VREyeSAM / model_server.py
Dev Nagaich
Improve: Add better error handling and logging for model loading
35351a6
"""
Secure Model Server - Protects model weights from extraction
Never expose:
- File paths to checkpoints
- Model architecture details
- Debug routes
"""
import os
import sys
import torch
import numpy as np
from pathlib import Path
from typing import Tuple, Optional
# Secure path resolution (not hardcoded)
def get_model_checkpoint_path():
"""Get checkpoint path secretly, never expose to client"""
base_dir = Path(__file__).parent
checkpoint = base_dir / "segment-anything-2" / "checkpoints" / "sam2.1_hiera_small.pt"
if not checkpoint.exists():
raise FileNotFoundError(f"Model checkpoint not found")
return str(checkpoint)
def get_finetuned_weights_path():
"""Get fine-tuned weights path secretly, never expose to client
Attempts to download from Hugging Face if local copy doesn't exist
and HF_TOKEN is available.
"""
base_dir = Path(__file__).parent
checkpoint_dir = base_dir / "segment-anything-2" / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
weights = checkpoint_dir / "VREyeSAM_uncertainity_best.torch"
# If weights already exist locally, return path
if weights.exists():
return str(weights)
# Try to download from Hugging Face using HF_TOKEN
hf_token = os.getenv('HF_TOKEN', '')
if hf_token:
try:
from huggingface_hub import hf_hub_download
print("Downloading VREyeSAM weights from Hugging Face...")
checkpoint_path = hf_hub_download(
repo_id='devnagaich/VREyeSAM',
filename='VREyeSAM_uncertainity_best.torch',
token=hf_token,
cache_dir=str(checkpoint_dir)
)
print(f"Successfully downloaded VREyeSAM weights")
return checkpoint_path
except Exception as e:
print(f"Warning: Could not download VREyeSAM weights: {e}")
# If download fails or no token, return path anyway (may exist from upload)
if weights.exists():
return str(weights)
# Last resort - raise error
raise FileNotFoundError(f"VREyeSAM weights not found and could not download")
def get_model_config_path():
"""Get model config path secretly, never expose to client"""
return "configs/sam2.1/sam2.1_hiera_s.yaml"
class ProtectedModelServer:
"""
Encapsulates model loading and inference
Only exposes inference API, never raw weights or paths
"""
_instance = None # Singleton pattern
_model = None
_predictor = None
def __new__(cls):
# Singleton: only one instance ever
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
"""Initialize model (only once)"""
if self._predictor is None:
self._load_model()
def _load_model(self):
"""Load model weights securely - never called from frontend"""
try:
# Add segment-anything-2 to path (internally only)
base_dir = Path(__file__).parent
sam2_path = base_dir / "segment-anything-2"
if not sam2_path.exists():
raise FileNotFoundError(f"SAM2 installation not found at {sam2_path}")
sys.path.insert(0, str(sam2_path))
try:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
except ImportError as e:
raise ImportError("SAM2 not properly installed. Check build logs.") from e
# Get paths internally - NEVER sent to client
model_cfg = get_model_config_path()
sam2_checkpoint = get_model_checkpoint_path()
# Load device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Load base SAM2 model
print(f"Loading SAM2 from {sam2_checkpoint}")
self._model = build_sam2(model_cfg, sam2_checkpoint, device=device)
self._predictor = SAM2ImagePredictor(self._model)
# Try to load fine-tuned weights if available
try:
fine_tuned_weights = get_finetuned_weights_path()
print(f"Loading fine-tuned weights from {fine_tuned_weights}")
state_dict = torch.load(fine_tuned_weights, map_location=device)
self._predictor.model.load_state_dict(state_dict)
print("Fine-tuned weights loaded successfully")
except FileNotFoundError:
print("Warning: Fine-tuned weights not found. Using base SAM2 model.")
print("To use fine-tuned model, upload VREyeSAM_uncertainity_best.torch to Space Files")
except Exception as e:
print(f"Warning: Could not load fine-tuned weights: {e}")
print("Continuing with base SAM2 model")
# Model is now loaded - weights are NOT accessible to clients
self._predictor.model.eval()
print("Model loaded successfully")
return True
except Exception as e:
print(f"Error loading model: {e}")
import traceback
traceback.print_exc()
raise RuntimeError(f"Model initialization failed: {str(e)}") from e
def predict(self, image: np.ndarray, num_samples: int = 30) -> Tuple[np.ndarray, np.ndarray]:
"""
Perform iris segmentation
Args:
image: Input image (numpy array)
num_samples: Number of random points for inference
Returns:
binary_mask: Binary segmentation mask
prob_mask: Probability map
"""
if self._predictor is None:
raise RuntimeError("Model not initialized")
try:
# Generate random points for inference
input_points = np.random.randint(0, min(image.shape[:2]), (num_samples, 1, 2))
# Inference
with torch.no_grad():
self._predictor.set_image(image)
masks, scores, _ = self._predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0], 1])
)
# Convert to numpy
np_masks = np.array(masks[:, 0]).astype(np.float32)
np_scores = scores[:, 0]
# Normalize scores
score_sum = np.sum(np_scores)
if score_sum > 0:
normalized_scores = np_scores / score_sum
else:
normalized_scores = np.ones_like(np_scores) / len(np_scores)
# Generate probabilistic mask
prob_mask = np.sum(np_masks * normalized_scores[:, None, None], axis=0)
prob_mask = np.clip(prob_mask, 0, 1)
# Threshold to get binary mask
binary_mask = (prob_mask > 0.2).astype(np.uint8)
return binary_mask, prob_mask
except Exception as e:
raise RuntimeError(f"Inference failed") from e
def get_predictor() -> ProtectedModelServer:
"""Get singleton model instance"""
return ProtectedModelServer()