File size: 5,930 Bytes
d520909 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
"""
Base Model Interfaces for Document Intelligence
Abstract base classes defining the contract for all model components.
All models are pluggable and can be swapped without changing the pipeline.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
from PIL import Image
class ModelCapability(str, Enum):
"""Capabilities that a model may support."""
OCR = "ocr"
LAYOUT_DETECTION = "layout_detection"
TABLE_EXTRACTION = "table_extraction"
CHART_EXTRACTION = "chart_extraction"
READING_ORDER = "reading_order"
VISION_LANGUAGE = "vision_language"
EMBEDDING = "embedding"
CLASSIFICATION = "classification"
@dataclass
class ModelConfig:
"""Base configuration for all models."""
name: str
version: str = "1.0.0"
device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc.
batch_size: int = 1
max_workers: int = 4
cache_enabled: bool = True
cache_dir: Optional[Path] = None
timeout_seconds: float = 300.0
extra_params: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if self.cache_dir is not None:
self.cache_dir = Path(self.cache_dir)
@dataclass
class ModelMetadata:
"""Metadata about a loaded model."""
name: str
version: str
capabilities: List[ModelCapability]
device: str
memory_usage_mb: float = 0.0
is_loaded: bool = False
supports_batching: bool = False
max_batch_size: int = 1
input_requirements: Dict[str, Any] = field(default_factory=dict)
output_format: Dict[str, Any] = field(default_factory=dict)
class BaseModel(ABC):
"""
Abstract base class for all document intelligence models.
All model implementations must inherit from this class and implement
the required abstract methods.
"""
def __init__(self, config: Optional[ModelConfig] = None):
self.config = config or ModelConfig(name=self.__class__.__name__)
self._is_loaded = False
self._metadata: Optional[ModelMetadata] = None
@property
def is_loaded(self) -> bool:
"""Check if the model is loaded and ready for inference."""
return self._is_loaded
@property
def metadata(self) -> Optional[ModelMetadata]:
"""Get model metadata."""
return self._metadata
@abstractmethod
def load(self) -> None:
"""
Load the model into memory.
Should set self._is_loaded = True upon successful loading.
Should populate self._metadata with model information.
"""
pass
@abstractmethod
def unload(self) -> None:
"""
Unload the model from memory.
Should set self._is_loaded = False.
Should free GPU/CPU memory.
"""
pass
@abstractmethod
def get_capabilities(self) -> List[ModelCapability]:
"""Return list of capabilities this model provides."""
pass
def validate_input(self, input_data: Any) -> bool:
"""
Validate input data before processing.
Override in subclasses for specific validation.
"""
return True
def preprocess(self, input_data: Any) -> Any:
"""
Preprocess input data before model inference.
Override in subclasses for specific preprocessing.
"""
return input_data
def postprocess(self, output_data: Any) -> Any:
"""
Postprocess model output.
Override in subclasses for specific postprocessing.
"""
return output_data
def __enter__(self):
"""Context manager entry."""
if not self.is_loaded:
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.unload()
return False
class BatchableModel(BaseModel):
"""
Base class for models that support batch processing.
Provides infrastructure for processing multiple inputs efficiently.
"""
@abstractmethod
def process_batch(
self,
inputs: List[Any],
**kwargs
) -> List[Any]:
"""
Process a batch of inputs.
Args:
inputs: List of input items to process
**kwargs: Additional processing parameters
Returns:
List of outputs, one per input
"""
pass
def process_single(self, input_data: Any, **kwargs) -> Any:
"""Process a single input by wrapping in a batch."""
results = self.process_batch([input_data], **kwargs)
return results[0] if results else None
ImageInput = Union[np.ndarray, Image.Image, Path, str]
def normalize_image_input(image: ImageInput) -> np.ndarray:
"""
Normalize various image input formats to numpy array.
Args:
image: Image as numpy array, PIL Image, or path
Returns:
Image as numpy array (RGB, HWC format)
"""
if isinstance(image, np.ndarray):
return image
if isinstance(image, Image.Image):
return np.array(image.convert("RGB"))
if isinstance(image, (str, Path)):
img = Image.open(image).convert("RGB")
return np.array(img)
raise ValueError(f"Unsupported image input type: {type(image)}")
def ensure_pil_image(image: ImageInput) -> Image.Image:
"""
Ensure input is a PIL Image.
Args:
image: Image as numpy array, PIL Image, or path
Returns:
PIL Image in RGB mode
"""
if isinstance(image, Image.Image):
return image.convert("RGB")
if isinstance(image, np.ndarray):
return Image.fromarray(image).convert("RGB")
if isinstance(image, (str, Path)):
return Image.open(image).convert("RGB")
raise ValueError(f"Unsupported image input type: {type(image)}")
|