|
|
""" |
|
|
Chart Extraction Model Interface |
|
|
|
|
|
Abstract interface for chart/graph understanding models. |
|
|
Extracts data points, axes, legends, and interprets visualizations. |
|
|
""" |
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint |
|
|
from .base import ( |
|
|
BaseModel, |
|
|
BatchableModel, |
|
|
ImageInput, |
|
|
ModelCapability, |
|
|
ModelConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class ChartType(str, Enum): |
|
|
"""Types of charts that can be detected.""" |
|
|
|
|
|
|
|
|
BAR = "bar" |
|
|
LINE = "line" |
|
|
PIE = "pie" |
|
|
SCATTER = "scatter" |
|
|
AREA = "area" |
|
|
|
|
|
|
|
|
HISTOGRAM = "histogram" |
|
|
BOX_PLOT = "box_plot" |
|
|
HEATMAP = "heatmap" |
|
|
TREEMAP = "treemap" |
|
|
RADAR = "radar" |
|
|
BUBBLE = "bubble" |
|
|
WATERFALL = "waterfall" |
|
|
FUNNEL = "funnel" |
|
|
GANTT = "gantt" |
|
|
|
|
|
|
|
|
STACKED_BAR = "stacked_bar" |
|
|
GROUPED_BAR = "grouped_bar" |
|
|
MULTI_LINE = "multi_line" |
|
|
COMBO = "combo" |
|
|
|
|
|
|
|
|
DIAGRAM = "diagram" |
|
|
UNKNOWN = "unknown" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChartConfig(ModelConfig): |
|
|
"""Configuration for chart extraction models.""" |
|
|
|
|
|
min_confidence: float = 0.5 |
|
|
extract_data_points: bool = True |
|
|
extract_trends: bool = True |
|
|
max_data_points: int = 1000 |
|
|
detect_chart_type: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if not self.name: |
|
|
self.name = "chart_extractor" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AxisInfo: |
|
|
"""Information about a chart axis.""" |
|
|
|
|
|
label: str = "" |
|
|
unit: str = "" |
|
|
min_value: Optional[float] = None |
|
|
max_value: Optional[float] = None |
|
|
scale: str = "linear" |
|
|
tick_labels: List[str] = field(default_factory=list) |
|
|
tick_values: List[float] = field(default_factory=list) |
|
|
is_datetime: bool = False |
|
|
orientation: str = "horizontal" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LegendItem: |
|
|
"""A single legend entry.""" |
|
|
|
|
|
label: str |
|
|
color: Optional[str] = None |
|
|
series_index: int = 0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataSeries: |
|
|
"""A data series in a chart.""" |
|
|
|
|
|
name: str |
|
|
data_points: List[ChartDataPoint] = field(default_factory=list) |
|
|
color: Optional[str] = None |
|
|
series_type: Optional[ChartType] = None |
|
|
|
|
|
@property |
|
|
def x_values(self) -> List[Any]: |
|
|
return [p.x for p in self.data_points] |
|
|
|
|
|
@property |
|
|
def y_values(self) -> List[Any]: |
|
|
return [p.y for p in self.data_points] |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to dictionary.""" |
|
|
return { |
|
|
"name": self.name, |
|
|
"color": self.color, |
|
|
"series_type": self.series_type.value if self.series_type else None, |
|
|
"data_points": [ |
|
|
{"x": p.x, "y": p.y, "label": p.label, "value": p.value} |
|
|
for p in self.data_points |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrendInfo: |
|
|
"""Detected trend in the data.""" |
|
|
|
|
|
description: str |
|
|
direction: str = "neutral" |
|
|
start_point: Optional[ChartDataPoint] = None |
|
|
end_point: Optional[ChartDataPoint] = None |
|
|
change_percent: Optional[float] = None |
|
|
confidence: float = 0.0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChartStructure: |
|
|
""" |
|
|
Complete extracted chart structure. |
|
|
|
|
|
Contains all detected elements of a chart including |
|
|
type, axes, data series, legends, and interpretations. |
|
|
""" |
|
|
|
|
|
bbox: BoundingBox |
|
|
chart_type: ChartType = ChartType.UNKNOWN |
|
|
confidence: float = 0.0 |
|
|
|
|
|
|
|
|
title: str = "" |
|
|
subtitle: str = "" |
|
|
|
|
|
|
|
|
x_axis: Optional[AxisInfo] = None |
|
|
y_axis: Optional[AxisInfo] = None |
|
|
secondary_y_axis: Optional[AxisInfo] = None |
|
|
|
|
|
|
|
|
series: List[DataSeries] = field(default_factory=list) |
|
|
legend_items: List[LegendItem] = field(default_factory=list) |
|
|
|
|
|
|
|
|
key_values: Dict[str, Any] = field(default_factory=dict) |
|
|
trends: List[TrendInfo] = field(default_factory=list) |
|
|
summary: str = "" |
|
|
|
|
|
|
|
|
chart_id: str = "" |
|
|
source_text: str = "" |
|
|
|
|
|
def __post_init__(self): |
|
|
if not self.chart_id: |
|
|
import hashlib |
|
|
content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}" |
|
|
self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12] |
|
|
|
|
|
@property |
|
|
def total_data_points(self) -> int: |
|
|
return sum(len(s.data_points) for s in self.series) |
|
|
|
|
|
@property |
|
|
def all_data_points(self) -> List[ChartDataPoint]: |
|
|
"""Get all data points from all series.""" |
|
|
points = [] |
|
|
for series in self.series: |
|
|
points.extend(series.data_points) |
|
|
return points |
|
|
|
|
|
def get_series_by_name(self, name: str) -> Optional[DataSeries]: |
|
|
"""Find a series by name.""" |
|
|
for series in self.series: |
|
|
if series.name.lower() == name.lower(): |
|
|
return series |
|
|
return None |
|
|
|
|
|
def to_text_description(self) -> str: |
|
|
"""Generate a text description of the chart.""" |
|
|
parts = [] |
|
|
|
|
|
if self.title: |
|
|
parts.append(f"Chart: {self.title}") |
|
|
else: |
|
|
parts.append(f"Chart Type: {self.chart_type.value}") |
|
|
|
|
|
if self.x_axis and self.x_axis.label: |
|
|
parts.append(f"X-Axis: {self.x_axis.label}") |
|
|
if self.y_axis and self.y_axis.label: |
|
|
parts.append(f"Y-Axis: {self.y_axis.label}") |
|
|
|
|
|
if self.series: |
|
|
parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}") |
|
|
|
|
|
if self.key_values: |
|
|
kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items()) |
|
|
parts.append(f"Key Values: {kv_str}") |
|
|
|
|
|
if self.trends: |
|
|
trend_strs = [t.description for t in self.trends if t.description] |
|
|
if trend_strs: |
|
|
parts.append(f"Trends: {'; '.join(trend_strs)}") |
|
|
|
|
|
return "\n".join(parts) |
|
|
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
|
"""Convert to structured dictionary.""" |
|
|
return { |
|
|
"chart_type": self.chart_type.value, |
|
|
"title": self.title, |
|
|
"x_axis": { |
|
|
"label": self.x_axis.label if self.x_axis else "", |
|
|
"unit": self.x_axis.unit if self.x_axis else "", |
|
|
}, |
|
|
"y_axis": { |
|
|
"label": self.y_axis.label if self.y_axis else "", |
|
|
"unit": self.y_axis.unit if self.y_axis else "", |
|
|
}, |
|
|
"series": [s.to_dict() for s in self.series], |
|
|
"key_values": self.key_values, |
|
|
"trends": [ |
|
|
{"description": t.description, "direction": t.direction} |
|
|
for t in self.trends |
|
|
], |
|
|
"summary": self.summary |
|
|
} |
|
|
|
|
|
def to_chart_chunk( |
|
|
self, |
|
|
doc_id: str, |
|
|
page: int, |
|
|
sequence_index: int |
|
|
) -> ChartChunk: |
|
|
"""Convert to ChartChunk for the chunks module.""" |
|
|
|
|
|
all_points = self.all_data_points |
|
|
|
|
|
return ChartChunk( |
|
|
chunk_id=ChartChunk.generate_chunk_id( |
|
|
doc_id=doc_id, |
|
|
page=page, |
|
|
bbox=self.bbox, |
|
|
chunk_type_str="chart" |
|
|
), |
|
|
doc_id=doc_id, |
|
|
text=self.to_text_description(), |
|
|
page=page, |
|
|
bbox=self.bbox, |
|
|
confidence=self.confidence, |
|
|
sequence_index=sequence_index, |
|
|
chart_type=self.chart_type.value, |
|
|
title=self.title, |
|
|
x_axis_label=self.x_axis.label if self.x_axis else None, |
|
|
y_axis_label=self.y_axis.label if self.y_axis else None, |
|
|
data_points=all_points, |
|
|
key_values=self.key_values, |
|
|
trends=[t.description for t in self.trends] |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ChartExtractionResult: |
|
|
"""Result of chart extraction from a page.""" |
|
|
|
|
|
charts: List[ChartStructure] = field(default_factory=list) |
|
|
processing_time_ms: float = 0.0 |
|
|
model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@property |
|
|
def chart_count(self) -> int: |
|
|
return len(self.charts) |
|
|
|
|
|
|
|
|
class ChartModel(BatchableModel): |
|
|
""" |
|
|
Abstract base class for chart extraction models. |
|
|
|
|
|
Implementations should handle: |
|
|
- Chart type classification |
|
|
- Axis detection and labeling |
|
|
- Data point extraction |
|
|
- Legend parsing |
|
|
- Trend detection |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[ChartConfig] = None): |
|
|
super().__init__(config or ChartConfig(name="chart")) |
|
|
self.config: ChartConfig = self.config |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
return [ModelCapability.CHART_EXTRACTION] |
|
|
|
|
|
@abstractmethod |
|
|
def extract_chart( |
|
|
self, |
|
|
image: ImageInput, |
|
|
chart_region: Optional[BoundingBox] = None, |
|
|
**kwargs |
|
|
) -> ChartStructure: |
|
|
""" |
|
|
Extract chart structure from an image. |
|
|
|
|
|
Args: |
|
|
image: Input image containing a chart |
|
|
chart_region: Optional bounding box of the chart |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
ChartStructure with extracted data |
|
|
""" |
|
|
pass |
|
|
|
|
|
def extract_all_charts( |
|
|
self, |
|
|
image: ImageInput, |
|
|
chart_regions: Optional[List[BoundingBox]] = None, |
|
|
**kwargs |
|
|
) -> ChartExtractionResult: |
|
|
""" |
|
|
Extract all charts from an image. |
|
|
|
|
|
Args: |
|
|
image: Input document image |
|
|
chart_regions: Optional list of chart bounding boxes |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
ChartExtractionResult with all detected charts |
|
|
""" |
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
charts = [] |
|
|
|
|
|
if chart_regions: |
|
|
for region in chart_regions: |
|
|
try: |
|
|
chart = self.extract_chart(image, region, **kwargs) |
|
|
if chart.chart_type != ChartType.UNKNOWN: |
|
|
charts.append(chart) |
|
|
except Exception: |
|
|
continue |
|
|
else: |
|
|
chart = self.extract_chart(image, **kwargs) |
|
|
if chart.chart_type != ChartType.UNKNOWN: |
|
|
charts.append(chart) |
|
|
|
|
|
processing_time = (time.time() - start_time) * 1000 |
|
|
|
|
|
return ChartExtractionResult( |
|
|
charts=charts, |
|
|
processing_time_ms=processing_time |
|
|
) |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[ImageInput], |
|
|
**kwargs |
|
|
) -> List[ChartExtractionResult]: |
|
|
"""Process multiple images.""" |
|
|
return [self.extract_all_charts(img, **kwargs) for img in inputs] |
|
|
|
|
|
@abstractmethod |
|
|
def classify_chart_type( |
|
|
self, |
|
|
image: ImageInput, |
|
|
chart_region: Optional[BoundingBox] = None, |
|
|
**kwargs |
|
|
) -> Tuple[ChartType, float]: |
|
|
""" |
|
|
Classify the type of chart in an image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
chart_region: Optional bounding box |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Tuple of (ChartType, confidence) |
|
|
""" |
|
|
pass |
|
|
|
|
|
def detect_trends( |
|
|
self, |
|
|
chart: ChartStructure, |
|
|
**kwargs |
|
|
) -> List[TrendInfo]: |
|
|
""" |
|
|
Analyze chart data for trends. |
|
|
|
|
|
Default implementation provides basic trend detection. |
|
|
Override for more sophisticated analysis. |
|
|
""" |
|
|
trends = [] |
|
|
|
|
|
for series in chart.series: |
|
|
if len(series.data_points) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
y_values = [] |
|
|
for dp in series.data_points: |
|
|
if dp.y is not None: |
|
|
try: |
|
|
y_values.append(float(dp.y)) |
|
|
except (ValueError, TypeError): |
|
|
continue |
|
|
|
|
|
if len(y_values) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2) |
|
|
second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2) |
|
|
|
|
|
if second_half_avg > first_half_avg * 1.1: |
|
|
direction = "increasing" |
|
|
elif second_half_avg < first_half_avg * 0.9: |
|
|
direction = "decreasing" |
|
|
else: |
|
|
direction = "stable" |
|
|
|
|
|
change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100 |
|
|
if first_half_avg != 0 else 0) |
|
|
|
|
|
trend = TrendInfo( |
|
|
description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)", |
|
|
direction=direction, |
|
|
start_point=series.data_points[0], |
|
|
end_point=series.data_points[-1], |
|
|
change_percent=change_pct, |
|
|
confidence=0.7 |
|
|
) |
|
|
trends.append(trend) |
|
|
|
|
|
return trends |
|
|
|