""" Base Model Interfaces for Document Intelligence Abstract base classes defining the contract for all model components. All models are pluggable and can be swapped without changing the pipeline. """ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union import numpy as np from PIL import Image class ModelCapability(str, Enum): """Capabilities that a model may support.""" OCR = "ocr" LAYOUT_DETECTION = "layout_detection" TABLE_EXTRACTION = "table_extraction" CHART_EXTRACTION = "chart_extraction" READING_ORDER = "reading_order" VISION_LANGUAGE = "vision_language" EMBEDDING = "embedding" CLASSIFICATION = "classification" @dataclass class ModelConfig: """Base configuration for all models.""" name: str version: str = "1.0.0" device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc. batch_size: int = 1 max_workers: int = 4 cache_enabled: bool = True cache_dir: Optional[Path] = None timeout_seconds: float = 300.0 extra_params: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if self.cache_dir is not None: self.cache_dir = Path(self.cache_dir) @dataclass class ModelMetadata: """Metadata about a loaded model.""" name: str version: str capabilities: List[ModelCapability] device: str memory_usage_mb: float = 0.0 is_loaded: bool = False supports_batching: bool = False max_batch_size: int = 1 input_requirements: Dict[str, Any] = field(default_factory=dict) output_format: Dict[str, Any] = field(default_factory=dict) class BaseModel(ABC): """ Abstract base class for all document intelligence models. All model implementations must inherit from this class and implement the required abstract methods. """ def __init__(self, config: Optional[ModelConfig] = None): self.config = config or ModelConfig(name=self.__class__.__name__) self._is_loaded = False self._metadata: Optional[ModelMetadata] = None @property def is_loaded(self) -> bool: """Check if the model is loaded and ready for inference.""" return self._is_loaded @property def metadata(self) -> Optional[ModelMetadata]: """Get model metadata.""" return self._metadata @abstractmethod def load(self) -> None: """ Load the model into memory. Should set self._is_loaded = True upon successful loading. Should populate self._metadata with model information. """ pass @abstractmethod def unload(self) -> None: """ Unload the model from memory. Should set self._is_loaded = False. Should free GPU/CPU memory. """ pass @abstractmethod def get_capabilities(self) -> List[ModelCapability]: """Return list of capabilities this model provides.""" pass def validate_input(self, input_data: Any) -> bool: """ Validate input data before processing. Override in subclasses for specific validation. """ return True def preprocess(self, input_data: Any) -> Any: """ Preprocess input data before model inference. Override in subclasses for specific preprocessing. """ return input_data def postprocess(self, output_data: Any) -> Any: """ Postprocess model output. Override in subclasses for specific postprocessing. """ return output_data def __enter__(self): """Context manager entry.""" if not self.is_loaded: self.load() return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit.""" self.unload() return False class BatchableModel(BaseModel): """ Base class for models that support batch processing. Provides infrastructure for processing multiple inputs efficiently. """ @abstractmethod def process_batch( self, inputs: List[Any], **kwargs ) -> List[Any]: """ Process a batch of inputs. Args: inputs: List of input items to process **kwargs: Additional processing parameters Returns: List of outputs, one per input """ pass def process_single(self, input_data: Any, **kwargs) -> Any: """Process a single input by wrapping in a batch.""" results = self.process_batch([input_data], **kwargs) return results[0] if results else None ImageInput = Union[np.ndarray, Image.Image, Path, str] def normalize_image_input(image: ImageInput) -> np.ndarray: """ Normalize various image input formats to numpy array. Args: image: Image as numpy array, PIL Image, or path Returns: Image as numpy array (RGB, HWC format) """ if isinstance(image, np.ndarray): return image if isinstance(image, Image.Image): return np.array(image.convert("RGB")) if isinstance(image, (str, Path)): img = Image.open(image).convert("RGB") return np.array(img) raise ValueError(f"Unsupported image input type: {type(image)}") def ensure_pil_image(image: ImageInput) -> Image.Image: """ Ensure input is a PIL Image. Args: image: Image as numpy array, PIL Image, or path Returns: PIL Image in RGB mode """ if isinstance(image, Image.Image): return image.convert("RGB") if isinstance(image, np.ndarray): return Image.fromarray(image).convert("RGB") if isinstance(image, (str, Path)): return Image.open(image).convert("RGB") raise ValueError(f"Unsupported image input type: {type(image)}")