|
|
""" |
|
|
Base Model Interfaces for Document Intelligence |
|
|
|
|
|
Abstract base classes defining the contract for all model components. |
|
|
All models are pluggable and can be swapped without changing the pipeline. |
|
|
""" |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
class ModelCapability(str, Enum): |
|
|
"""Capabilities that a model may support.""" |
|
|
|
|
|
OCR = "ocr" |
|
|
LAYOUT_DETECTION = "layout_detection" |
|
|
TABLE_EXTRACTION = "table_extraction" |
|
|
CHART_EXTRACTION = "chart_extraction" |
|
|
READING_ORDER = "reading_order" |
|
|
VISION_LANGUAGE = "vision_language" |
|
|
EMBEDDING = "embedding" |
|
|
CLASSIFICATION = "classification" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
"""Base configuration for all models.""" |
|
|
|
|
|
name: str |
|
|
version: str = "1.0.0" |
|
|
device: str = "auto" |
|
|
batch_size: int = 1 |
|
|
max_workers: int = 4 |
|
|
cache_enabled: bool = True |
|
|
cache_dir: Optional[Path] = None |
|
|
timeout_seconds: float = 300.0 |
|
|
extra_params: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.cache_dir is not None: |
|
|
self.cache_dir = Path(self.cache_dir) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelMetadata: |
|
|
"""Metadata about a loaded model.""" |
|
|
|
|
|
name: str |
|
|
version: str |
|
|
capabilities: List[ModelCapability] |
|
|
device: str |
|
|
memory_usage_mb: float = 0.0 |
|
|
is_loaded: bool = False |
|
|
supports_batching: bool = False |
|
|
max_batch_size: int = 1 |
|
|
input_requirements: Dict[str, Any] = field(default_factory=dict) |
|
|
output_format: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
class BaseModel(ABC): |
|
|
""" |
|
|
Abstract base class for all document intelligence models. |
|
|
|
|
|
All model implementations must inherit from this class and implement |
|
|
the required abstract methods. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[ModelConfig] = None): |
|
|
self.config = config or ModelConfig(name=self.__class__.__name__) |
|
|
self._is_loaded = False |
|
|
self._metadata: Optional[ModelMetadata] = None |
|
|
|
|
|
@property |
|
|
def is_loaded(self) -> bool: |
|
|
"""Check if the model is loaded and ready for inference.""" |
|
|
return self._is_loaded |
|
|
|
|
|
@property |
|
|
def metadata(self) -> Optional[ModelMetadata]: |
|
|
"""Get model metadata.""" |
|
|
return self._metadata |
|
|
|
|
|
@abstractmethod |
|
|
def load(self) -> None: |
|
|
""" |
|
|
Load the model into memory. |
|
|
|
|
|
Should set self._is_loaded = True upon successful loading. |
|
|
Should populate self._metadata with model information. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def unload(self) -> None: |
|
|
""" |
|
|
Unload the model from memory. |
|
|
|
|
|
Should set self._is_loaded = False. |
|
|
Should free GPU/CPU memory. |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
"""Return list of capabilities this model provides.""" |
|
|
pass |
|
|
|
|
|
def validate_input(self, input_data: Any) -> bool: |
|
|
""" |
|
|
Validate input data before processing. |
|
|
|
|
|
Override in subclasses for specific validation. |
|
|
""" |
|
|
return True |
|
|
|
|
|
def preprocess(self, input_data: Any) -> Any: |
|
|
""" |
|
|
Preprocess input data before model inference. |
|
|
|
|
|
Override in subclasses for specific preprocessing. |
|
|
""" |
|
|
return input_data |
|
|
|
|
|
def postprocess(self, output_data: Any) -> Any: |
|
|
""" |
|
|
Postprocess model output. |
|
|
|
|
|
Override in subclasses for specific postprocessing. |
|
|
""" |
|
|
return output_data |
|
|
|
|
|
def __enter__(self): |
|
|
"""Context manager entry.""" |
|
|
if not self.is_loaded: |
|
|
self.load() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Context manager exit.""" |
|
|
self.unload() |
|
|
return False |
|
|
|
|
|
|
|
|
class BatchableModel(BaseModel): |
|
|
""" |
|
|
Base class for models that support batch processing. |
|
|
|
|
|
Provides infrastructure for processing multiple inputs efficiently. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[Any], |
|
|
**kwargs |
|
|
) -> List[Any]: |
|
|
""" |
|
|
Process a batch of inputs. |
|
|
|
|
|
Args: |
|
|
inputs: List of input items to process |
|
|
**kwargs: Additional processing parameters |
|
|
|
|
|
Returns: |
|
|
List of outputs, one per input |
|
|
""" |
|
|
pass |
|
|
|
|
|
def process_single(self, input_data: Any, **kwargs) -> Any: |
|
|
"""Process a single input by wrapping in a batch.""" |
|
|
results = self.process_batch([input_data], **kwargs) |
|
|
return results[0] if results else None |
|
|
|
|
|
|
|
|
ImageInput = Union[np.ndarray, Image.Image, Path, str] |
|
|
|
|
|
|
|
|
def normalize_image_input(image: ImageInput) -> np.ndarray: |
|
|
""" |
|
|
Normalize various image input formats to numpy array. |
|
|
|
|
|
Args: |
|
|
image: Image as numpy array, PIL Image, or path |
|
|
|
|
|
Returns: |
|
|
Image as numpy array (RGB, HWC format) |
|
|
""" |
|
|
if isinstance(image, np.ndarray): |
|
|
return image |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
return np.array(image.convert("RGB")) |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
|
img = Image.open(image).convert("RGB") |
|
|
return np.array(img) |
|
|
|
|
|
raise ValueError(f"Unsupported image input type: {type(image)}") |
|
|
|
|
|
|
|
|
def ensure_pil_image(image: ImageInput) -> Image.Image: |
|
|
""" |
|
|
Ensure input is a PIL Image. |
|
|
|
|
|
Args: |
|
|
image: Image as numpy array, PIL Image, or path |
|
|
|
|
|
Returns: |
|
|
PIL Image in RGB mode |
|
|
""" |
|
|
if isinstance(image, Image.Image): |
|
|
return image.convert("RGB") |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
return Image.fromarray(image).convert("RGB") |
|
|
|
|
|
if isinstance(image, (str, Path)): |
|
|
return Image.open(image).convert("RGB") |
|
|
|
|
|
raise ValueError(f"Unsupported image input type: {type(image)}") |
|
|
|