File size: 5,930 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""
Base Model Interfaces for Document Intelligence

Abstract base classes defining the contract for all model components.
All models are pluggable and can be swapped without changing the pipeline.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import numpy as np
from PIL import Image


class ModelCapability(str, Enum):
    """Capabilities that a model may support."""

    OCR = "ocr"
    LAYOUT_DETECTION = "layout_detection"
    TABLE_EXTRACTION = "table_extraction"
    CHART_EXTRACTION = "chart_extraction"
    READING_ORDER = "reading_order"
    VISION_LANGUAGE = "vision_language"
    EMBEDDING = "embedding"
    CLASSIFICATION = "classification"


@dataclass
class ModelConfig:
    """Base configuration for all models."""

    name: str
    version: str = "1.0.0"
    device: str = "auto"  # "auto", "cpu", "cuda", "cuda:0", etc.
    batch_size: int = 1
    max_workers: int = 4
    cache_enabled: bool = True
    cache_dir: Optional[Path] = None
    timeout_seconds: float = 300.0
    extra_params: Dict[str, Any] = field(default_factory=dict)

    def __post_init__(self):
        if self.cache_dir is not None:
            self.cache_dir = Path(self.cache_dir)


@dataclass
class ModelMetadata:
    """Metadata about a loaded model."""

    name: str
    version: str
    capabilities: List[ModelCapability]
    device: str
    memory_usage_mb: float = 0.0
    is_loaded: bool = False
    supports_batching: bool = False
    max_batch_size: int = 1
    input_requirements: Dict[str, Any] = field(default_factory=dict)
    output_format: Dict[str, Any] = field(default_factory=dict)


class BaseModel(ABC):
    """
    Abstract base class for all document intelligence models.

    All model implementations must inherit from this class and implement
    the required abstract methods.
    """

    def __init__(self, config: Optional[ModelConfig] = None):
        self.config = config or ModelConfig(name=self.__class__.__name__)
        self._is_loaded = False
        self._metadata: Optional[ModelMetadata] = None

    @property
    def is_loaded(self) -> bool:
        """Check if the model is loaded and ready for inference."""
        return self._is_loaded

    @property
    def metadata(self) -> Optional[ModelMetadata]:
        """Get model metadata."""
        return self._metadata

    @abstractmethod
    def load(self) -> None:
        """
        Load the model into memory.

        Should set self._is_loaded = True upon successful loading.
        Should populate self._metadata with model information.
        """
        pass

    @abstractmethod
    def unload(self) -> None:
        """
        Unload the model from memory.

        Should set self._is_loaded = False.
        Should free GPU/CPU memory.
        """
        pass

    @abstractmethod
    def get_capabilities(self) -> List[ModelCapability]:
        """Return list of capabilities this model provides."""
        pass

    def validate_input(self, input_data: Any) -> bool:
        """
        Validate input data before processing.

        Override in subclasses for specific validation.
        """
        return True

    def preprocess(self, input_data: Any) -> Any:
        """
        Preprocess input data before model inference.

        Override in subclasses for specific preprocessing.
        """
        return input_data

    def postprocess(self, output_data: Any) -> Any:
        """
        Postprocess model output.

        Override in subclasses for specific postprocessing.
        """
        return output_data

    def __enter__(self):
        """Context manager entry."""
        if not self.is_loaded:
            self.load()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Context manager exit."""
        self.unload()
        return False


class BatchableModel(BaseModel):
    """
    Base class for models that support batch processing.

    Provides infrastructure for processing multiple inputs efficiently.
    """

    @abstractmethod
    def process_batch(
        self,
        inputs: List[Any],
        **kwargs
    ) -> List[Any]:
        """
        Process a batch of inputs.

        Args:
            inputs: List of input items to process
            **kwargs: Additional processing parameters

        Returns:
            List of outputs, one per input
        """
        pass

    def process_single(self, input_data: Any, **kwargs) -> Any:
        """Process a single input by wrapping in a batch."""
        results = self.process_batch([input_data], **kwargs)
        return results[0] if results else None


ImageInput = Union[np.ndarray, Image.Image, Path, str]


def normalize_image_input(image: ImageInput) -> np.ndarray:
    """
    Normalize various image input formats to numpy array.

    Args:
        image: Image as numpy array, PIL Image, or path

    Returns:
        Image as numpy array (RGB, HWC format)
    """
    if isinstance(image, np.ndarray):
        return image

    if isinstance(image, Image.Image):
        return np.array(image.convert("RGB"))

    if isinstance(image, (str, Path)):
        img = Image.open(image).convert("RGB")
        return np.array(img)

    raise ValueError(f"Unsupported image input type: {type(image)}")


def ensure_pil_image(image: ImageInput) -> Image.Image:
    """
    Ensure input is a PIL Image.

    Args:
        image: Image as numpy array, PIL Image, or path

    Returns:
        PIL Image in RGB mode
    """
    if isinstance(image, Image.Image):
        return image.convert("RGB")

    if isinstance(image, np.ndarray):
        return Image.fromarray(image).convert("RGB")

    if isinstance(image, (str, Path)):
        return Image.open(image).convert("RGB")

    raise ValueError(f"Unsupported image input type: {type(image)}")