""" SageMaker Multi-Model Endpoint inference script for GLiNER2. This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint. Models are loaded dynamically based on the TargetModel header in the request. Key differences from single-model inference: - model_fn() receives the full path to the model directory (including model name) - Models are cached automatically by SageMaker MME - Multiple models can be loaded in memory simultaneously - LRU eviction when memory is full """ import json import os import sys import subprocess def _ensure_gliner2_installed(): """ Ensure gliner2 is installed. Install it dynamically if missing. This is a workaround for SageMaker MME where requirements.txt might not be installed automatically. """ try: import gliner2 # noqa: PLC0415 print(f"[MME] gliner2 version {gliner2.__version__} already installed") return True except ImportError: print("[MME] gliner2 not found, installing...") try: # IMPORTANT: Use transformers<4.46 for compatibility with PyTorch 2.1.0 # (transformers 4.46+ requires PyTorch 2.3+ for torch.utils._pytree.register_pytree_node) subprocess.check_call( [ sys.executable, "-m", "pip", "install", "--quiet", "--no-cache-dir", "gliner2==1.0.1", "transformers>=4.30.0,<4.46.0", ] ) print("[MME] ✓ gliner2 installed successfully") return True except subprocess.CalledProcessError as e: print(f"[MME] ERROR: Failed to install gliner2: {e}") return False # Ensure gliner2 is installed before importing torch (to avoid conflicts) _ensure_gliner2_installed() import torch # noqa: E402 # Add parent directory to path to potentially import from gliner_2_inference sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) class DummyModel: """Placeholder model for MME container initialization""" def __call__(self, *args, **kwargs): raise ValueError("Container model invoked directly. Use TargetModel header.") def extract_entities(self, *args, **kwargs): raise ValueError("Container model invoked directly. Use TargetModel header.") def classify_text(self, *args, **kwargs): raise ValueError("Container model invoked directly. Use TargetModel header.") def extract_json(self, *args, **kwargs): raise ValueError("Container model invoked directly. Use TargetModel header.") def model_fn(model_dir): """ Load the GLiNER2 model from the model directory. For Multi-Model Endpoints, SageMaker passes the full path to the specific model being loaded, e.g., /opt/ml/models// Args: model_dir: The directory where model artifacts are extracted Returns: The loaded GLiNER2 model """ print(f"[MME] Loading model from: {model_dir}") try: print(f"[MME] Contents: {os.listdir(model_dir)}") except Exception as e: print(f"[MME] Could not list directory contents: {e}") # Import GLiNER2 here (should be installed by _ensure_gliner2_installed) try: from gliner2 import GLiNER2 # noqa: PLC0415 except ImportError as e: print(f"[MME] ERROR: gliner2 import failed: {e}") print("[MME] Attempting to install gliner2...") if _ensure_gliner2_installed(): from gliner2 import GLiNER2 # noqa: PLC0415 else: GLiNER2 = None # Detect device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[MME] Using device: {device}") if torch.cuda.is_available(): print(f"[MME] GPU: {torch.cuda.get_device_name(0)}") print(f"[MME] CUDA version: {torch.version.cuda}") mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 print(f"[MME] GPU memory: {mem_gb:.2f} GB") # Get HuggingFace token if available hf_token = os.environ.get("HF_TOKEN") # Check if this is the container model (placeholder) if os.path.exists(os.path.join(model_dir, "mme_container.txt")): print("[MME] Container model detected - returning dummy model") return DummyModel() if GLiNER2 is None: raise ImportError("gliner2 package required but not found") # Check if model is already extracted in model_dir if os.path.exists(os.path.join(model_dir, "config.json")): print("[MME] Loading model from extracted artifacts...") model = GLiNER2.from_pretrained(model_dir, token=hf_token) elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")): # Fallback: download from HuggingFace print("[MME] Model not in archive, downloading from HuggingFace...") model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") print(f"[MME] Downloading model: {model_name}") model = GLiNER2.from_pretrained(model_name, token=hf_token) else: # Final fallback model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") print(f"[MME] Model directory empty, downloading: {model_name}") model = GLiNER2.from_pretrained(model_name, token=hf_token) # Move model to GPU if available print(f"[MME] Moving model to {device}...") model = model.to(device) # Enable half precision on GPU for memory efficiency if torch.cuda.is_available(): print("[MME] Converting to fp16...") model = model.half() # Memory optimizations for GPU if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.cuda.empty_cache() # Reserve memory for multiple models in MME torch.cuda.set_per_process_memory_fraction(0.85) print("[MME] GPU memory optimizations enabled") print(f"[MME] ✓ Model loaded successfully on {device}") return model def input_fn(request_body, request_content_type): """ Deserialize and prepare the input data for prediction. Args: request_body: The request body request_content_type: The content type of the request Returns: Parsed input data as a dictionary """ if request_content_type == "application/json": input_data = json.loads(request_body) return input_data else: raise ValueError(f"Unsupported content type: {request_content_type}") def predict_fn(input_data, model): """ Run prediction on the input data using the loaded model. Args: input_data: Dictionary containing: - task: One of 'extract_entities', 'classify_text', or 'extract_json' - text: Text to process (string) or list of texts (for batch processing) - schema: Schema for extraction (format depends on task) - threshold: Optional confidence threshold (default: 0.5) model: The loaded GLiNER2 model Returns: Task-specific results (single result or list of results for batch) """ # Clear CUDA cache before processing if torch.cuda.is_available(): torch.cuda.empty_cache() text = input_data.get("text") task = input_data.get("task", "extract_entities") schema = input_data.get("schema") threshold = input_data.get("threshold", 0.5) if not text: raise ValueError("'text' field is required") if not schema: raise ValueError("'schema' field is required") # Detect batch mode is_batch = isinstance(text, list) if is_batch and len(text) == 0: raise ValueError("'text' list cannot be empty") # Use inference_mode for faster inference with torch.inference_mode(): if task == "extract_entities": if is_batch: if hasattr(model, "batch_extract_entities"): result = model.batch_extract_entities( text, schema, threshold=threshold ) elif hasattr(model, "batch_predict_entities"): result = model.batch_predict_entities( text, schema, threshold=threshold ) else: result = [ model.extract_entities(t, schema, threshold=threshold) for t in text ] else: result = model.extract_entities(text, schema, threshold=threshold) return result elif task == "classify_text": if is_batch: if hasattr(model, "batch_classify_text"): result = model.batch_classify_text( text, schema, threshold=threshold ) else: result = [ model.classify_text(t, schema, threshold=threshold) for t in text ] else: result = model.classify_text(text, schema, threshold=threshold) return result elif task == "extract_json": if is_batch: if hasattr(model, "batch_extract_json"): result = model.batch_extract_json(text, schema, threshold=threshold) else: result = [ model.extract_json(t, schema, threshold=threshold) for t in text ] else: result = model.extract_json(text, schema, threshold=threshold) return result else: raise ValueError( f"Unsupported task: {task}. " "Must be one of: extract_entities, classify_text, extract_json" ) def output_fn(prediction, response_content_type): """ Serialize the prediction output. Args: prediction: The prediction result response_content_type: The desired response content type Returns: Serialized prediction """ if response_content_type == "application/json": return json.dumps(prediction) else: raise ValueError(f"Unsupported response content type: {response_content_type}")