File size: 10,422 Bytes
b674efb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
"""
SageMaker Multi-Model Endpoint inference script for GLiNER2.

This script handles model loading and inference for the GLiNER2 Multi-Model Endpoint.
Models are loaded dynamically based on the TargetModel header in the request.

Key differences from single-model inference:
- model_fn() receives the full path to the model directory (including model name)
- Models are cached automatically by SageMaker MME
- Multiple models can be loaded in memory simultaneously
- LRU eviction when memory is full
"""

import json
import os
import sys
import subprocess


def _ensure_gliner2_installed():
    """
    Ensure gliner2 is installed. Install it dynamically if missing.

    This is a workaround for SageMaker MME where requirements.txt
    might not be installed automatically.
    """
    try:
        import gliner2  # noqa: PLC0415

        print(f"[MME] gliner2 version {gliner2.__version__} already installed")
        return True
    except ImportError:
        print("[MME] gliner2 not found, installing...")
        try:
            # IMPORTANT: Use transformers<4.46 for compatibility with PyTorch 2.1.0
            # (transformers 4.46+ requires PyTorch 2.3+ for torch.utils._pytree.register_pytree_node)
            subprocess.check_call(
                [
                    sys.executable,
                    "-m",
                    "pip",
                    "install",
                    "--quiet",
                    "--no-cache-dir",
                    "gliner2==1.0.1",
                    "transformers>=4.30.0,<4.46.0",
                ]
            )
            print("[MME] ✓ gliner2 installed successfully")
            return True
        except subprocess.CalledProcessError as e:
            print(f"[MME] ERROR: Failed to install gliner2: {e}")
            return False


# Ensure gliner2 is installed before importing torch (to avoid conflicts)
_ensure_gliner2_installed()

import torch  # noqa: E402

# Add parent directory to path to potentially import from gliner_2_inference
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


class DummyModel:
    """Placeholder model for MME container initialization"""

    def __call__(self, *args, **kwargs):
        raise ValueError("Container model invoked directly. Use TargetModel header.")

    def extract_entities(self, *args, **kwargs):
        raise ValueError("Container model invoked directly. Use TargetModel header.")

    def classify_text(self, *args, **kwargs):
        raise ValueError("Container model invoked directly. Use TargetModel header.")

    def extract_json(self, *args, **kwargs):
        raise ValueError("Container model invoked directly. Use TargetModel header.")


def model_fn(model_dir):
    """
    Load the GLiNER2 model from the model directory.

    For Multi-Model Endpoints, SageMaker passes the full path to the specific
    model being loaded, e.g., /opt/ml/models/<model_name>/

    Args:
        model_dir: The directory where model artifacts are extracted

    Returns:
        The loaded GLiNER2 model
    """
    print(f"[MME] Loading model from: {model_dir}")
    try:
        print(f"[MME] Contents: {os.listdir(model_dir)}")
    except Exception as e:
        print(f"[MME] Could not list directory contents: {e}")

    # Import GLiNER2 here (should be installed by _ensure_gliner2_installed)
    try:
        from gliner2 import GLiNER2  # noqa: PLC0415
    except ImportError as e:
        print(f"[MME] ERROR: gliner2 import failed: {e}")
        print("[MME] Attempting to install gliner2...")
        if _ensure_gliner2_installed():
            from gliner2 import GLiNER2  # noqa: PLC0415
        else:
            GLiNER2 = None

    # Detect device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[MME] Using device: {device}")

    if torch.cuda.is_available():
        print(f"[MME] GPU: {torch.cuda.get_device_name(0)}")
        print(f"[MME] CUDA version: {torch.version.cuda}")
        mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"[MME] GPU memory: {mem_gb:.2f} GB")

    # Get HuggingFace token if available
    hf_token = os.environ.get("HF_TOKEN")

    # Check if this is the container model (placeholder)
    if os.path.exists(os.path.join(model_dir, "mme_container.txt")):
        print("[MME] Container model detected - returning dummy model")
        return DummyModel()

    if GLiNER2 is None:
        raise ImportError("gliner2 package required but not found")

    # Check if model is already extracted in model_dir
    if os.path.exists(os.path.join(model_dir, "config.json")):
        print("[MME] Loading model from extracted artifacts...")
        model = GLiNER2.from_pretrained(model_dir, token=hf_token)
    elif os.path.exists(os.path.join(model_dir, "download_at_runtime.txt")):
        # Fallback: download from HuggingFace
        print("[MME] Model not in archive, downloading from HuggingFace...")
        model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
        print(f"[MME] Downloading model: {model_name}")
        model = GLiNER2.from_pretrained(model_name, token=hf_token)
    else:
        # Final fallback
        model_name = os.environ.get("GLINER_MODEL", "fastino/gliner2-base-v1")
        print(f"[MME] Model directory empty, downloading: {model_name}")
        model = GLiNER2.from_pretrained(model_name, token=hf_token)

    # Move model to GPU if available
    print(f"[MME] Moving model to {device}...")
    model = model.to(device)

    # Enable half precision on GPU for memory efficiency
    if torch.cuda.is_available():
        print("[MME] Converting to fp16...")
        model = model.half()

    # Memory optimizations for GPU
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.cuda.empty_cache()
        # Reserve memory for multiple models in MME
        torch.cuda.set_per_process_memory_fraction(0.85)
        print("[MME] GPU memory optimizations enabled")

    print(f"[MME] ✓ Model loaded successfully on {device}")
    return model


def input_fn(request_body, request_content_type):
    """
    Deserialize and prepare the input data for prediction.

    Args:
        request_body: The request body
        request_content_type: The content type of the request

    Returns:
        Parsed input data as a dictionary
    """
    if request_content_type == "application/json":
        input_data = json.loads(request_body)
        return input_data
    else:
        raise ValueError(f"Unsupported content type: {request_content_type}")


def predict_fn(input_data, model):
    """
    Run prediction on the input data using the loaded model.

    Args:
        input_data: Dictionary containing:
            - task: One of 'extract_entities', 'classify_text', or 'extract_json'
            - text: Text to process (string) or list of texts (for batch processing)
            - schema: Schema for extraction (format depends on task)
            - threshold: Optional confidence threshold (default: 0.5)
        model: The loaded GLiNER2 model

    Returns:
        Task-specific results (single result or list of results for batch)
    """
    # Clear CUDA cache before processing
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    text = input_data.get("text")
    task = input_data.get("task", "extract_entities")
    schema = input_data.get("schema")
    threshold = input_data.get("threshold", 0.5)

    if not text:
        raise ValueError("'text' field is required")
    if not schema:
        raise ValueError("'schema' field is required")

    # Detect batch mode
    is_batch = isinstance(text, list)

    if is_batch and len(text) == 0:
        raise ValueError("'text' list cannot be empty")

    # Use inference_mode for faster inference
    with torch.inference_mode():
        if task == "extract_entities":
            if is_batch:
                if hasattr(model, "batch_extract_entities"):
                    result = model.batch_extract_entities(
                        text, schema, threshold=threshold
                    )
                elif hasattr(model, "batch_predict_entities"):
                    result = model.batch_predict_entities(
                        text, schema, threshold=threshold
                    )
                else:
                    result = [
                        model.extract_entities(t, schema, threshold=threshold)
                        for t in text
                    ]
            else:
                result = model.extract_entities(text, schema, threshold=threshold)
            return result

        elif task == "classify_text":
            if is_batch:
                if hasattr(model, "batch_classify_text"):
                    result = model.batch_classify_text(
                        text, schema, threshold=threshold
                    )
                else:
                    result = [
                        model.classify_text(t, schema, threshold=threshold)
                        for t in text
                    ]
            else:
                result = model.classify_text(text, schema, threshold=threshold)
            return result

        elif task == "extract_json":
            if is_batch:
                if hasattr(model, "batch_extract_json"):
                    result = model.batch_extract_json(text, schema, threshold=threshold)
                else:
                    result = [
                        model.extract_json(t, schema, threshold=threshold) for t in text
                    ]
            else:
                result = model.extract_json(text, schema, threshold=threshold)
            return result

        else:
            raise ValueError(
                f"Unsupported task: {task}. "
                "Must be one of: extract_entities, classify_text, extract_json"
            )


def output_fn(prediction, response_content_type):
    """
    Serialize the prediction output.

    Args:
        prediction: The prediction result
        response_content_type: The desired response content type

    Returns:
        Serialized prediction
    """
    if response_content_type == "application/json":
        return json.dumps(prediction)
    else:
        raise ValueError(f"Unsupported response content type: {response_content_type}")