|
|
""" |
|
|
SageMaker Multi-Model Endpoint inference script for GLiNER2. |
|
|
|
|
|
This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint. |
|
|
Models are loaded dynamically based on the TargetModel header in the request. |
|
|
|
|
|
Key differences from single-model inference: |
|
|
- model_fn() receives the full path to the model directory (including model name) |
|
|
- Models are cached automatically by SageMaker MME |
|
|
- Multiple models can be loaded in memory simultaneously |
|
|
- LRU eviction when memory is full |
|
|
""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
|
|
|
|
|
|
def _ensure_gliner2_installed(): |
|
|
""" |
|
|
Ensure gliner2 is installed. Install it dynamically if missing. |
|
|
|
|
|
This is a workaround for SageMaker MME where requirements.txt |
|
|
might not be installed automatically. |
|
|
""" |
|
|
try: |
|
|
import gliner2 |
|
|
|
|
|
print(f"[MME] gliner2 version {gliner2.__version__} already installed") |
|
|
return True |
|
|
except ImportError: |
|
|
print("[MME] gliner2 not found, installing...") |
|
|
try: |
|
|
|
|
|
|
|
|
subprocess.check_call( |
|
|
[ |
|
|
sys.executable, |
|
|
"-m", |
|
|
"pip", |
|
|
"install", |
|
|
"--quiet", |
|
|
"--no-cache-dir", |
|
|
"gliner2==1.0.1", |
|
|
"transformers>=4.30.0,<4.46.0", |
|
|
] |
|
|
) |
|
|
print("[MME] ✓ gliner2 installed successfully") |
|
|
return True |
|
|
except subprocess.CalledProcessError as e: |
|
|
print(f"[MME] ERROR: Failed to install gliner2: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
_ensure_gliner2_installed() |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
|
|
|
class DummyModel: |
|
|
"""Placeholder model for MME container initialization""" |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
|
|
|
def extract_entities(self, *args, **kwargs): |
|
|
raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
|
|
|
def classify_text(self, *args, **kwargs): |
|
|
raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
|
|
|
def extract_json(self, *args, **kwargs): |
|
|
raise ValueError("Container model invoked directly. Use TargetModel header.") |
|
|
|
|
|
|
|
|
def model_fn(model_dir): |
|
|
""" |
|
|
Load the GLiNER2 model from the model directory. |
|
|
|
|
|
For Multi-Model Endpoints, SageMaker passes the full path to the specific |
|
|
model being loaded, e.g., /opt/ml/models/<model_name>/ |
|
|
|
|
|
Args: |
|
|
model_dir: The directory where model artifacts are extracted |
|
|
|
|
|
Returns: |
|
|
The loaded GLiNER2 model |
|
|
""" |
|
|
print(f"[MME] Loading model from: {model_dir}") |
|
|
try: |
|
|
print(f"[MME] Contents: {os.listdir(model_dir)}") |
|
|
except Exception as e: |
|
|
print(f"[MME] Could not list directory contents: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
from gliner2 import GLiNER2 |
|
|
except ImportError as e: |
|
|
print(f"[MME] ERROR: gliner2 import failed: {e}") |
|
|
print("[MME] Attempting to install gliner2...") |
|
|
if _ensure_gliner2_installed(): |
|
|
from gliner2 import GLiNER2 |
|
|
else: |
|
|
GLiNER2 = None |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"[MME] Using device: {device}") |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print(f"[MME] GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f"[MME] CUDA version: {torch.version.cuda}") |
|
|
mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
|
|
print(f"[MME] GPU memory: {mem_gb:.2f} GB") |
|
|
|
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(model_dir, "mme_container.txt")): |
|
|
print("[MME] Container model detected - returning dummy model") |
|
|
return DummyModel() |
|
|
|
|
|
if GLiNER2 is None: |
|
|
raise ImportError("gliner2 package required but not found") |
|
|
|
|
|
|
|
|
if os.path.exists(os.path.join(model_dir, "config.json")): |
|
|
print("[MME] Loading model from extracted artifacts...") |
|
|
model = GLiNER2.from_pretrained(model_dir, token=hf_token) |
|
|
elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")): |
|
|
|
|
|
print("[MME] Model not in archive, downloading from HuggingFace...") |
|
|
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
|
|
print(f"[MME] Downloading model: {model_name}") |
|
|
model = GLiNER2.from_pretrained(model_name, token=hf_token) |
|
|
else: |
|
|
|
|
|
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1") |
|
|
print(f"[MME] Model directory empty, downloading: {model_name}") |
|
|
model = GLiNER2.from_pretrained(model_name, token=hf_token) |
|
|
|
|
|
|
|
|
print(f"[MME] Moving model to {device}...") |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print("[MME] Converting to fp16...") |
|
|
model = model.half() |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cudnn.allow_tf32 = True |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
torch.cuda.set_per_process_memory_fraction(0.85) |
|
|
print("[MME] GPU memory optimizations enabled") |
|
|
|
|
|
print(f"[MME] ✓ Model loaded successfully on {device}") |
|
|
return model |
|
|
|
|
|
|
|
|
def input_fn(request_body, request_content_type): |
|
|
""" |
|
|
Deserialize and prepare the input data for prediction. |
|
|
|
|
|
Args: |
|
|
request_body: The request body |
|
|
request_content_type: The content type of the request |
|
|
|
|
|
Returns: |
|
|
Parsed input data as a dictionary |
|
|
""" |
|
|
if request_content_type == "application/json": |
|
|
input_data = json.loads(request_body) |
|
|
return input_data |
|
|
else: |
|
|
raise ValueError(f"Unsupported content type: {request_content_type}") |
|
|
|
|
|
|
|
|
def predict_fn(input_data, model): |
|
|
""" |
|
|
Run prediction on the input data using the loaded model. |
|
|
|
|
|
Args: |
|
|
input_data: Dictionary containing: |
|
|
- task: One of 'extract_entities', 'classify_text', or 'extract_json' |
|
|
- text: Text to process (string) or list of texts (for batch processing) |
|
|
- schema: Schema for extraction (format depends on task) |
|
|
- threshold: Optional confidence threshold (default: 0.5) |
|
|
model: The loaded GLiNER2 model |
|
|
|
|
|
Returns: |
|
|
Task-specific results (single result or list of results for batch) |
|
|
""" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
text = input_data.get("text") |
|
|
task = input_data.get("task", "extract_entities") |
|
|
schema = input_data.get("schema") |
|
|
threshold = input_data.get("threshold", 0.5) |
|
|
|
|
|
if not text: |
|
|
raise ValueError("'text' field is required") |
|
|
if not schema: |
|
|
raise ValueError("'schema' field is required") |
|
|
|
|
|
|
|
|
is_batch = isinstance(text, list) |
|
|
|
|
|
if is_batch and len(text) == 0: |
|
|
raise ValueError("'text' list cannot be empty") |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
if task == "extract_entities": |
|
|
if is_batch: |
|
|
if hasattr(model, "batch_extract_entities"): |
|
|
result = model.batch_extract_entities( |
|
|
text, schema, threshold=threshold |
|
|
) |
|
|
elif hasattr(model, "batch_predict_entities"): |
|
|
result = model.batch_predict_entities( |
|
|
text, schema, threshold=threshold |
|
|
) |
|
|
else: |
|
|
result = [ |
|
|
model.extract_entities(t, schema, threshold=threshold) |
|
|
for t in text |
|
|
] |
|
|
else: |
|
|
result = model.extract_entities(text, schema, threshold=threshold) |
|
|
return result |
|
|
|
|
|
elif task == "classify_text": |
|
|
if is_batch: |
|
|
if hasattr(model, "batch_classify_text"): |
|
|
result = model.batch_classify_text( |
|
|
text, schema, threshold=threshold |
|
|
) |
|
|
else: |
|
|
result = [ |
|
|
model.classify_text(t, schema, threshold=threshold) |
|
|
for t in text |
|
|
] |
|
|
else: |
|
|
result = model.classify_text(text, schema, threshold=threshold) |
|
|
return result |
|
|
|
|
|
elif task == "extract_json": |
|
|
if is_batch: |
|
|
if hasattr(model, "batch_extract_json"): |
|
|
result = model.batch_extract_json(text, schema, threshold=threshold) |
|
|
else: |
|
|
result = [ |
|
|
model.extract_json(t, schema, threshold=threshold) for t in text |
|
|
] |
|
|
else: |
|
|
result = model.extract_json(text, schema, threshold=threshold) |
|
|
return result |
|
|
|
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unsupported task: {task}. " |
|
|
"Must be one of: extract_entities, classify_text, extract_json" |
|
|
) |
|
|
|
|
|
|
|
|
def output_fn(prediction, response_content_type): |
|
|
""" |
|
|
Serialize the prediction output. |
|
|
|
|
|
Args: |
|
|
prediction: The prediction result |
|
|
response_content_type: The desired response content type |
|
|
|
|
|
Returns: |
|
|
Serialized prediction |
|
|
""" |
|
|
if response_content_type == "application/json": |
|
|
return json.dumps(prediction) |
|
|
else: |
|
|
raise ValueError(f"Unsupported response content type: {response_content_type}") |
|
|
|