MHamdan's picture
Initial commit: SPARKNET framework
d520909
"""
Base Model Interfaces for Document Intelligence
Abstract base classes defining the contract for all model components.
All models are pluggable and can be swapped without changing the pipeline.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from PIL import Image
class ModelCapability(str, Enum):
"""Capabilities that a model may support."""
OCR = "ocr"
LAYOUT_DETECTION = "layout_detection"
TABLE_EXTRACTION = "table_extraction"
CHART_EXTRACTION = "chart_extraction"
READING_ORDER = "reading_order"
VISION_LANGUAGE = "vision_language"
EMBEDDING = "embedding"
CLASSIFICATION = "classification"
@dataclass
class ModelConfig:
"""Base configuration for all models."""
name: str
version: str = "1.0.0"
device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc.
batch_size: int = 1
max_workers: int = 4
cache_enabled: bool = True
cache_dir: Optional[Path] = None
timeout_seconds: float = 300.0
extra_params: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.cache_dir is not None:
self.cache_dir = Path(self.cache_dir)
@dataclass
class ModelMetadata:
"""Metadata about a loaded model."""
name: str
version: str
capabilities: List[ModelCapability]
device: str
memory_usage_mb: float = 0.0
is_loaded: bool = False
supports_batching: bool = False
max_batch_size: int = 1
input_requirements: Dict[str, Any] = field(default_factory=dict)
output_format: Dict[str, Any] = field(default_factory=dict)
class BaseModel(ABC):
"""
Abstract base class for all document intelligence models.
All model implementations must inherit from this class and implement
the required abstract methods.
"""
def __init__(self, config: Optional[ModelConfig] = None):
self.config = config or ModelConfig(name=self.__class__.__name__)
self._is_loaded = False
self._metadata: Optional[ModelMetadata] = None
@property
def is_loaded(self) -> bool:
"""Check if the model is loaded and ready for inference."""
return self._is_loaded
@property
def metadata(self) -> Optional[ModelMetadata]:
"""Get model metadata."""
return self._metadata
@abstractmethod
def load(self) -> None:
"""
Load the model into memory.
Should set self._is_loaded = True upon successful loading.
Should populate self._metadata with model information.
"""
pass
@abstractmethod
def unload(self) -> None:
"""
Unload the model from memory.
Should set self._is_loaded = False.
Should free GPU/CPU memory.
"""
pass
@abstractmethod
def get_capabilities(self) -> List[ModelCapability]:
"""Return list of capabilities this model provides."""
pass
def validate_input(self, input_data: Any) -> bool:
"""
Validate input data before processing.
Override in subclasses for specific validation.
"""
return True
def preprocess(self, input_data: Any) -> Any:
"""
Preprocess input data before model inference.
Override in subclasses for specific preprocessing.
"""
return input_data
def postprocess(self, output_data: Any) -> Any:
"""
Postprocess model output.
Override in subclasses for specific postprocessing.
"""
return output_data
def __enter__(self):
"""Context manager entry."""
if not self.is_loaded:
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.unload()
return False
class BatchableModel(BaseModel):
"""
Base class for models that support batch processing.
Provides infrastructure for processing multiple inputs efficiently.
"""
@abstractmethod
def process_batch(
self,
inputs: List[Any],
**kwargs
) -> List[Any]:
"""
Process a batch of inputs.
Args:
inputs: List of input items to process
**kwargs: Additional processing parameters
Returns:
List of outputs, one per input
"""
pass
def process_single(self, input_data: Any, **kwargs) -> Any:
"""Process a single input by wrapping in a batch."""
results = self.process_batch([input_data], **kwargs)
return results[0] if results else None
ImageInput = Union[np.ndarray, Image.Image, Path, str]
def normalize_image_input(image: ImageInput) -> np.ndarray:
"""
Normalize various image input formats to numpy array.
Args:
image: Image as numpy array, PIL Image, or path
Returns:
Image as numpy array (RGB, HWC format)
"""
if isinstance(image, np.ndarray):
return image
if isinstance(image, Image.Image):
return np.array(image.convert("RGB"))
if isinstance(image, (str, Path)):
img = Image.open(image).convert("RGB")
return np.array(img)
raise ValueError(f"Unsupported image input type: {type(image)}")
def ensure_pil_image(image: ImageInput) -> Image.Image:
"""
Ensure input is a PIL Image.
Args:
image: Image as numpy array, PIL Image, or path
Returns:
PIL Image in RGB mode
"""
if isinstance(image, Image.Image):
return image.convert("RGB")
if isinstance(image, np.ndarray):
return Image.fromarray(image).convert("RGB")
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
raise ValueError(f"Unsupported image input type: {type(image)}")