new-alien-model / code /inference.py
hfastino's picture
Upload trained model to hfastino/small-model-_alien
b8469cc verified
"""
SageMaker Multi-Model Endpoint inference script for GLiNER2.
This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint.
Models are loaded dynamically based on the TargetModel header in the request.
Key differences from single-model inference:
- model_fn() receives the full path to the model directory (including model name)
- Models are cached automatically by SageMaker MME
- Multiple models can be loaded in memory simultaneously
- LRU eviction when memory is full
"""
import json
import os
import sys
import subprocess
def _ensure_gliner2_installed():
"""
Ensure gliner2 is installed. Install it dynamically if missing.
This is a workaround for SageMaker MME where requirements.txt
might not be installed automatically.
"""
try:
import gliner2 # noqa: PLC0415
print(f"[MME] gliner2 version {gliner2.__version__} already installed")
return True
except ImportError:
print("[MME] gliner2 not found, installing...")
try:
# IMPORTANT: Use transformers<4.46 for compatibility with PyTorch 2.1.0
# (transformers 4.46+ requires PyTorch 2.3+ for torch.utils._pytree.register_pytree_node)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--quiet",
"--no-cache-dir",
"gliner2==1.0.1",
"transformers>=4.30.0,<4.46.0",
]
)
print("[MME] ✓ gliner2 installed successfully")
return True
except subprocess.CalledProcessError as e:
print(f"[MME] ERROR: Failed to install gliner2: {e}")
return False
# Ensure gliner2 is installed before importing torch (to avoid conflicts)
_ensure_gliner2_installed()
import torch # noqa: E402
# Add parent directory to path to potentially import from gliner_2_inference
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class DummyModel:
"""Placeholder model for MME container initialization"""
def __call__(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def extract_entities(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def classify_text(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def extract_json(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def model_fn(model_dir):
"""
Load the GLiNER2 model from the model directory.
For Multi-Model Endpoints, SageMaker passes the full path to the specific
model being loaded, e.g., /opt/ml/models/<model_name>/
Args:
model_dir: The directory where model artifacts are extracted
Returns:
The loaded GLiNER2 model
"""
print(f"[MME] Loading model from: {model_dir}")
try:
print(f"[MME] Contents: {os.listdir(model_dir)}")
except Exception as e:
print(f"[MME] Could not list directory contents: {e}")
# Import GLiNER2 here (should be installed by _ensure_gliner2_installed)
try:
from gliner2 import GLiNER2 # noqa: PLC0415
except ImportError as e:
print(f"[MME] ERROR: gliner2 import failed: {e}")
print("[MME] Attempting to install gliner2...")
if _ensure_gliner2_installed():
from gliner2 import GLiNER2 # noqa: PLC0415
else:
GLiNER2 = None
# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[MME] Using device: {device}")
if torch.cuda.is_available():
print(f"[MME] GPU: {torch.cuda.get_device_name(0)}")
print(f"[MME] CUDA version: {torch.version.cuda}")
mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"[MME] GPU memory: {mem_gb:.2f} GB")
# Get HuggingFace token if available
hf_token = os.environ.get("HF_TOKEN")
# Check if this is the container model (placeholder)
if os.path.exists(os.path.join(model_dir, "mme_container.txt")):
print("[MME] Container model detected - returning dummy model")
return DummyModel()
if GLiNER2 is None:
raise ImportError("gliner2 package required but not found")
# Check if model is already extracted in model_dir
if os.path.exists(os.path.join(model_dir, "config.json")):
print("[MME] Loading model from extracted artifacts...")
model = GLiNER2.from_pretrained(model_dir, token=hf_token)
elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")):
# Fallback: download from HuggingFace
print("[MME] Model not in archive, downloading from HuggingFace...")
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
print(f"[MME] Downloading model: {model_name}")
model = GLiNER2.from_pretrained(model_name, token=hf_token)
else:
# Final fallback
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
print(f"[MME] Model directory empty, downloading: {model_name}")
model = GLiNER2.from_pretrained(model_name, token=hf_token)
# Move model to GPU if available
print(f"[MME] Moving model to {device}...")
model = model.to(device)
# Enable half precision on GPU for memory efficiency
if torch.cuda.is_available():
print("[MME] Converting to fp16...")
model = model.half()
# Memory optimizations for GPU
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
# Reserve memory for multiple models in MME
torch.cuda.set_per_process_memory_fraction(0.85)
print("[MME] GPU memory optimizations enabled")
print(f"[MME] ✓ Model loaded successfully on {device}")
return model
def input_fn(request_body, request_content_type):
"""
Deserialize and prepare the input data for prediction.
Args:
request_body: The request body
request_content_type: The content type of the request
Returns:
Parsed input data as a dictionary
"""
if request_content_type == "application/json":
input_data = json.loads(request_body)
return input_data
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model):
"""
Run prediction on the input data using the loaded model.
Args:
input_data: Dictionary containing:
- task: One of 'extract_entities', 'classify_text', or 'extract_json'
- text: Text to process (string) or list of texts (for batch processing)
- schema: Schema for extraction (format depends on task)
- threshold: Optional confidence threshold (default: 0.5)
model: The loaded GLiNER2 model
Returns:
Task-specific results (single result or list of results for batch)
"""
# Clear CUDA cache before processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
text = input_data.get("text")
task = input_data.get("task", "extract_entities")
schema = input_data.get("schema")
threshold = input_data.get("threshold", 0.5)
if not text:
raise ValueError("'text' field is required")
if not schema:
raise ValueError("'schema' field is required")
# Detect batch mode
is_batch = isinstance(text, list)
if is_batch and len(text) == 0:
raise ValueError("'text' list cannot be empty")
# Use inference_mode for faster inference
with torch.inference_mode():
if task == "extract_entities":
if is_batch:
if hasattr(model, "batch_extract_entities"):
result = model.batch_extract_entities(
text, schema, threshold=threshold
)
elif hasattr(model, "batch_predict_entities"):
result = model.batch_predict_entities(
text, schema, threshold=threshold
)
else:
result = [
model.extract_entities(t, schema, threshold=threshold)
for t in text
]
else:
result = model.extract_entities(text, schema, threshold=threshold)
return result
elif task == "classify_text":
if is_batch:
if hasattr(model, "batch_classify_text"):
result = model.batch_classify_text(
text, schema, threshold=threshold
)
else:
result = [
model.classify_text(t, schema, threshold=threshold)
for t in text
]
else:
result = model.classify_text(text, schema, threshold=threshold)
return result
elif task == "extract_json":
if is_batch:
if hasattr(model, "batch_extract_json"):
result = model.batch_extract_json(text, schema, threshold=threshold)
else:
result = [
model.extract_json(t, schema, threshold=threshold) for t in text
]
else:
result = model.extract_json(text, schema, threshold=threshold)
return result
else:
raise ValueError(
f"Unsupported task: {task}. "
"Must be one of: extract_entities, classify_text, extract_json"
)
def output_fn(prediction, response_content_type):
"""
Serialize the prediction output.
Args:
prediction: The prediction result
response_content_type: The desired response content type
Returns:
Serialized prediction
"""
if response_content_type == "application/json":
return json.dumps(prediction)
else:
raise ValueError(f"Unsupported response content type: {response_content_type}")