File size: 10,422 Bytes
b8469cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
"""
SageMaker Multi-Model Endpoint inference script for GLiNER2.
This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint.
Models are loaded dynamically based on the TargetModel header in the request.
Key differences from single-model inference:
- model_fn() receives the full path to the model directory (including model name)
- Models are cached automatically by SageMaker MME
- Multiple models can be loaded in memory simultaneously
- LRU eviction when memory is full
"""
import json
import os
import sys
import subprocess
def _ensure_gliner2_installed():
"""
Ensure gliner2 is installed. Install it dynamically if missing.
This is a workaround for SageMaker MME where requirements.txt
might not be installed automatically.
"""
try:
import gliner2 # noqa: PLC0415
print(f"[MME] gliner2 version {gliner2.__version__} already installed")
return True
except ImportError:
print("[MME] gliner2 not found, installing...")
try:
# IMPORTANT: Use transformers<4.46 for compatibility with PyTorch 2.1.0
# (transformers 4.46+ requires PyTorch 2.3+ for torch.utils._pytree.register_pytree_node)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"--quiet",
"--no-cache-dir",
"gliner2==1.0.1",
"transformers>=4.30.0,<4.46.0",
]
)
print("[MME] ✓ gliner2 installed successfully")
return True
except subprocess.CalledProcessError as e:
print(f"[MME] ERROR: Failed to install gliner2: {e}")
return False
# Ensure gliner2 is installed before importing torch (to avoid conflicts)
_ensure_gliner2_installed()
import torch # noqa: E402
# Add parent directory to path to potentially import from gliner_2_inference
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
class DummyModel:
"""Placeholder model for MME container initialization"""
def __call__(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def extract_entities(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def classify_text(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def extract_json(self, *args, **kwargs):
raise ValueError("Container model invoked directly. Use TargetModel header.")
def model_fn(model_dir):
"""
Load the GLiNER2 model from the model directory.
For Multi-Model Endpoints, SageMaker passes the full path to the specific
model being loaded, e.g., /opt/ml/models/<model_name>/
Args:
model_dir: The directory where model artifacts are extracted
Returns:
The loaded GLiNER2 model
"""
print(f"[MME] Loading model from: {model_dir}")
try:
print(f"[MME] Contents: {os.listdir(model_dir)}")
except Exception as e:
print(f"[MME] Could not list directory contents: {e}")
# Import GLiNER2 here (should be installed by _ensure_gliner2_installed)
try:
from gliner2 import GLiNER2 # noqa: PLC0415
except ImportError as e:
print(f"[MME] ERROR: gliner2 import failed: {e}")
print("[MME] Attempting to install gliner2...")
if _ensure_gliner2_installed():
from gliner2 import GLiNER2 # noqa: PLC0415
else:
GLiNER2 = None
# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[MME] Using device: {device}")
if torch.cuda.is_available():
print(f"[MME] GPU: {torch.cuda.get_device_name(0)}")
print(f"[MME] CUDA version: {torch.version.cuda}")
mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"[MME] GPU memory: {mem_gb:.2f} GB")
# Get HuggingFace token if available
hf_token = os.environ.get("HF_TOKEN")
# Check if this is the container model (placeholder)
if os.path.exists(os.path.join(model_dir, "mme_container.txt")):
print("[MME] Container model detected - returning dummy model")
return DummyModel()
if GLiNER2 is None:
raise ImportError("gliner2 package required but not found")
# Check if model is already extracted in model_dir
if os.path.exists(os.path.join(model_dir, "config.json")):
print("[MME] Loading model from extracted artifacts...")
model = GLiNER2.from_pretrained(model_dir, token=hf_token)
elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")):
# Fallback: download from HuggingFace
print("[MME] Model not in archive, downloading from HuggingFace...")
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
print(f"[MME] Downloading model: {model_name}")
model = GLiNER2.from_pretrained(model_name, token=hf_token)
else:
# Final fallback
model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
print(f"[MME] Model directory empty, downloading: {model_name}")
model = GLiNER2.from_pretrained(model_name, token=hf_token)
# Move model to GPU if available
print(f"[MME] Moving model to {device}...")
model = model.to(device)
# Enable half precision on GPU for memory efficiency
if torch.cuda.is_available():
print("[MME] Converting to fp16...")
model = model.half()
# Memory optimizations for GPU
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.cuda.empty_cache()
# Reserve memory for multiple models in MME
torch.cuda.set_per_process_memory_fraction(0.85)
print("[MME] GPU memory optimizations enabled")
print(f"[MME] ✓ Model loaded successfully on {device}")
return model
def input_fn(request_body, request_content_type):
"""
Deserialize and prepare the input data for prediction.
Args:
request_body: The request body
request_content_type: The content type of the request
Returns:
Parsed input data as a dictionary
"""
if request_content_type == "application/json":
input_data = json.loads(request_body)
return input_data
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model):
"""
Run prediction on the input data using the loaded model.
Args:
input_data: Dictionary containing:
- task: One of 'extract_entities', 'classify_text', or 'extract_json'
- text: Text to process (string) or list of texts (for batch processing)
- schema: Schema for extraction (format depends on task)
- threshold: Optional confidence threshold (default: 0.5)
model: The loaded GLiNER2 model
Returns:
Task-specific results (single result or list of results for batch)
"""
# Clear CUDA cache before processing
if torch.cuda.is_available():
torch.cuda.empty_cache()
text = input_data.get("text")
task = input_data.get("task", "extract_entities")
schema = input_data.get("schema")
threshold = input_data.get("threshold", 0.5)
if not text:
raise ValueError("'text' field is required")
if not schema:
raise ValueError("'schema' field is required")
# Detect batch mode
is_batch = isinstance(text, list)
if is_batch and len(text) == 0:
raise ValueError("'text' list cannot be empty")
# Use inference_mode for faster inference
with torch.inference_mode():
if task == "extract_entities":
if is_batch:
if hasattr(model, "batch_extract_entities"):
result = model.batch_extract_entities(
text, schema, threshold=threshold
)
elif hasattr(model, "batch_predict_entities"):
result = model.batch_predict_entities(
text, schema, threshold=threshold
)
else:
result = [
model.extract_entities(t, schema, threshold=threshold)
for t in text
]
else:
result = model.extract_entities(text, schema, threshold=threshold)
return result
elif task == "classify_text":
if is_batch:
if hasattr(model, "batch_classify_text"):
result = model.batch_classify_text(
text, schema, threshold=threshold
)
else:
result = [
model.classify_text(t, schema, threshold=threshold)
for t in text
]
else:
result = model.classify_text(text, schema, threshold=threshold)
return result
elif task == "extract_json":
if is_batch:
if hasattr(model, "batch_extract_json"):
result = model.batch_extract_json(text, schema, threshold=threshold)
else:
result = [
model.extract_json(t, schema, threshold=threshold) for t in text
]
else:
result = model.extract_json(text, schema, threshold=threshold)
return result
else:
raise ValueError(
f"Unsupported task: {task}. "
"Must be one of: extract_entities, classify_text, extract_json"
)
def output_fn(prediction, response_content_type):
"""
Serialize the prediction output.
Args:
prediction: The prediction result
response_content_type: The desired response content type
Returns:
Serialized prediction
"""
if response_content_type == "application/json":
return json.dumps(prediction)
else:
raise ValueError(f"Unsupported response content type: {response_content_type}")
|