File size: 13,220 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
"""
Chart Extraction Model Interface

Abstract interface for chart/graph understanding models.
Extracts data points, axes, legends, and interprets visualizations.
"""

from abc import abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Union

from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint
from .base import (
    BaseModel,
    BatchableModel,
    ImageInput,
    ModelCapability,
    ModelConfig,
)


class ChartType(str, Enum):
    """Types of charts that can be detected."""

    # Common charts
    BAR = "bar"
    LINE = "line"
    PIE = "pie"
    SCATTER = "scatter"
    AREA = "area"

    # Advanced charts
    HISTOGRAM = "histogram"
    BOX_PLOT = "box_plot"
    HEATMAP = "heatmap"
    TREEMAP = "treemap"
    RADAR = "radar"
    BUBBLE = "bubble"
    WATERFALL = "waterfall"
    FUNNEL = "funnel"
    GANTT = "gantt"

    # Composite
    STACKED_BAR = "stacked_bar"
    GROUPED_BAR = "grouped_bar"
    MULTI_LINE = "multi_line"
    COMBO = "combo"  # Mixed chart types

    # Other
    DIAGRAM = "diagram"  # Flowcharts, org charts, etc.
    UNKNOWN = "unknown"


@dataclass
class ChartConfig(ModelConfig):
    """Configuration for chart extraction models."""

    min_confidence: float = 0.5
    extract_data_points: bool = True
    extract_trends: bool = True
    max_data_points: int = 1000
    detect_chart_type: bool = True

    def __post_init__(self):
        super().__post_init__()
        if not self.name:
            self.name = "chart_extractor"


@dataclass
class AxisInfo:
    """Information about a chart axis."""

    label: str = ""
    unit: str = ""
    min_value: Optional[float] = None
    max_value: Optional[float] = None
    scale: str = "linear"  # "linear", "log", "categorical"
    tick_labels: List[str] = field(default_factory=list)
    tick_values: List[float] = field(default_factory=list)
    is_datetime: bool = False
    orientation: str = "horizontal"  # "horizontal" or "vertical"


@dataclass
class LegendItem:
    """A single legend entry."""

    label: str
    color: Optional[str] = None  # Hex color if detected
    series_index: int = 0


@dataclass
class DataSeries:
    """A data series in a chart."""

    name: str
    data_points: List[ChartDataPoint] = field(default_factory=list)
    color: Optional[str] = None
    series_type: Optional[ChartType] = None  # For combo charts

    @property
    def x_values(self) -> List[Any]:
        return [p.x for p in self.data_points]

    @property
    def y_values(self) -> List[Any]:
        return [p.y for p in self.data_points]

    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary."""
        return {
            "name": self.name,
            "color": self.color,
            "series_type": self.series_type.value if self.series_type else None,
            "data_points": [
                {"x": p.x, "y": p.y, "label": p.label, "value": p.value}
                for p in self.data_points
            ]
        }


@dataclass
class TrendInfo:
    """Detected trend in the data."""

    description: str  # e.g., "Increasing trend from Q1 to Q4"
    direction: str = "neutral"  # "increasing", "decreasing", "stable", "fluctuating"
    start_point: Optional[ChartDataPoint] = None
    end_point: Optional[ChartDataPoint] = None
    change_percent: Optional[float] = None
    confidence: float = 0.0


@dataclass
class ChartStructure:
    """
    Complete extracted chart structure.

    Contains all detected elements of a chart including
    type, axes, data series, legends, and interpretations.
    """

    bbox: BoundingBox
    chart_type: ChartType = ChartType.UNKNOWN
    confidence: float = 0.0

    # Title and labels
    title: str = ""
    subtitle: str = ""

    # Axes
    x_axis: Optional[AxisInfo] = None
    y_axis: Optional[AxisInfo] = None
    secondary_y_axis: Optional[AxisInfo] = None

    # Data
    series: List[DataSeries] = field(default_factory=list)
    legend_items: List[LegendItem] = field(default_factory=list)

    # Interpretation
    key_values: Dict[str, Any] = field(default_factory=dict)  # Notable values
    trends: List[TrendInfo] = field(default_factory=list)
    summary: str = ""  # Text description of the chart

    # Metadata
    chart_id: str = ""
    source_text: str = ""  # Any text extracted from the chart

    def __post_init__(self):
        if not self.chart_id:
            import hashlib
            content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}"
            self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12]

    @property
    def total_data_points(self) -> int:
        return sum(len(s.data_points) for s in self.series)

    @property
    def all_data_points(self) -> List[ChartDataPoint]:
        """Get all data points from all series."""
        points = []
        for series in self.series:
            points.extend(series.data_points)
        return points

    def get_series_by_name(self, name: str) -> Optional[DataSeries]:
        """Find a series by name."""
        for series in self.series:
            if series.name.lower() == name.lower():
                return series
        return None

    def to_text_description(self) -> str:
        """Generate a text description of the chart."""
        parts = []

        if self.title:
            parts.append(f"Chart: {self.title}")
        else:
            parts.append(f"Chart Type: {self.chart_type.value}")

        if self.x_axis and self.x_axis.label:
            parts.append(f"X-Axis: {self.x_axis.label}")
        if self.y_axis and self.y_axis.label:
            parts.append(f"Y-Axis: {self.y_axis.label}")

        if self.series:
            parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}")

        if self.key_values:
            kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items())
            parts.append(f"Key Values: {kv_str}")

        if self.trends:
            trend_strs = [t.description for t in self.trends if t.description]
            if trend_strs:
                parts.append(f"Trends: {'; '.join(trend_strs)}")

        return "\n".join(parts)

    def to_dict(self) -> Dict[str, Any]:
        """Convert to structured dictionary."""
        return {
            "chart_type": self.chart_type.value,
            "title": self.title,
            "x_axis": {
                "label": self.x_axis.label if self.x_axis else "",
                "unit": self.x_axis.unit if self.x_axis else "",
            },
            "y_axis": {
                "label": self.y_axis.label if self.y_axis else "",
                "unit": self.y_axis.unit if self.y_axis else "",
            },
            "series": [s.to_dict() for s in self.series],
            "key_values": self.key_values,
            "trends": [
                {"description": t.description, "direction": t.direction}
                for t in self.trends
            ],
            "summary": self.summary
        }

    def to_chart_chunk(
        self,
        doc_id: str,
        page: int,
        sequence_index: int
    ) -> ChartChunk:
        """Convert to ChartChunk for the chunks module."""
        # Flatten all data points
        all_points = self.all_data_points

        return ChartChunk(
            chunk_id=ChartChunk.generate_chunk_id(
                doc_id=doc_id,
                page=page,
                bbox=self.bbox,
                chunk_type_str="chart"
            ),
            doc_id=doc_id,
            text=self.to_text_description(),
            page=page,
            bbox=self.bbox,
            confidence=self.confidence,
            sequence_index=sequence_index,
            chart_type=self.chart_type.value,
            title=self.title,
            x_axis_label=self.x_axis.label if self.x_axis else None,
            y_axis_label=self.y_axis.label if self.y_axis else None,
            data_points=all_points,
            key_values=self.key_values,
            trends=[t.description for t in self.trends]
        )


@dataclass
class ChartExtractionResult:
    """Result of chart extraction from a page."""

    charts: List[ChartStructure] = field(default_factory=list)
    processing_time_ms: float = 0.0
    model_metadata: Dict[str, Any] = field(default_factory=dict)

    @property
    def chart_count(self) -> int:
        return len(self.charts)


class ChartModel(BatchableModel):
    """
    Abstract base class for chart extraction models.

    Implementations should handle:
    - Chart type classification
    - Axis detection and labeling
    - Data point extraction
    - Legend parsing
    - Trend detection
    """

    def __init__(self, config: Optional[ChartConfig] = None):
        super().__init__(config or ChartConfig(name="chart"))
        self.config: ChartConfig = self.config

    def get_capabilities(self) -> List[ModelCapability]:
        return [ModelCapability.CHART_EXTRACTION]

    @abstractmethod
    def extract_chart(
        self,
        image: ImageInput,
        chart_region: Optional[BoundingBox] = None,
        **kwargs
    ) -> ChartStructure:
        """
        Extract chart structure from an image.

        Args:
            image: Input image containing a chart
            chart_region: Optional bounding box of the chart
            **kwargs: Additional parameters

        Returns:
            ChartStructure with extracted data
        """
        pass

    def extract_all_charts(
        self,
        image: ImageInput,
        chart_regions: Optional[List[BoundingBox]] = None,
        **kwargs
    ) -> ChartExtractionResult:
        """
        Extract all charts from an image.

        Args:
            image: Input document image
            chart_regions: Optional list of chart bounding boxes
            **kwargs: Additional parameters

        Returns:
            ChartExtractionResult with all detected charts
        """
        import time
        start_time = time.time()

        charts = []

        if chart_regions:
            for region in chart_regions:
                try:
                    chart = self.extract_chart(image, region, **kwargs)
                    if chart.chart_type != ChartType.UNKNOWN:
                        charts.append(chart)
                except Exception:
                    continue
        else:
            chart = self.extract_chart(image, **kwargs)
            if chart.chart_type != ChartType.UNKNOWN:
                charts.append(chart)

        processing_time = (time.time() - start_time) * 1000

        return ChartExtractionResult(
            charts=charts,
            processing_time_ms=processing_time
        )

    def process_batch(
        self,
        inputs: List[ImageInput],
        **kwargs
    ) -> List[ChartExtractionResult]:
        """Process multiple images."""
        return [self.extract_all_charts(img, **kwargs) for img in inputs]

    @abstractmethod
    def classify_chart_type(
        self,
        image: ImageInput,
        chart_region: Optional[BoundingBox] = None,
        **kwargs
    ) -> Tuple[ChartType, float]:
        """
        Classify the type of chart in an image.

        Args:
            image: Input image
            chart_region: Optional bounding box
            **kwargs: Additional parameters

        Returns:
            Tuple of (ChartType, confidence)
        """
        pass

    def detect_trends(
        self,
        chart: ChartStructure,
        **kwargs
    ) -> List[TrendInfo]:
        """
        Analyze chart data for trends.

        Default implementation provides basic trend detection.
        Override for more sophisticated analysis.
        """
        trends = []

        for series in chart.series:
            if len(series.data_points) < 2:
                continue

            # Get numeric y-values
            y_values = []
            for dp in series.data_points:
                if dp.y is not None:
                    try:
                        y_values.append(float(dp.y))
                    except (ValueError, TypeError):
                        continue

            if len(y_values) < 2:
                continue

            # Simple trend detection
            first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2)
            second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2)

            if second_half_avg > first_half_avg * 1.1:
                direction = "increasing"
            elif second_half_avg < first_half_avg * 0.9:
                direction = "decreasing"
            else:
                direction = "stable"

            change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100
                         if first_half_avg != 0 else 0)

            trend = TrendInfo(
                description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)",
                direction=direction,
                start_point=series.data_points[0],
                end_point=series.data_points[-1],
                change_percent=change_pct,
                confidence=0.7
            )
            trends.append(trend)

        return trends