""" Chart Extraction Model Interface Abstract interface for chart/graph understanding models. Extracts data points, axes, legends, and interprets visualizations. """ from abc import abstractmethod from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint from .base import ( BaseModel, BatchableModel, ImageInput, ModelCapability, ModelConfig, ) class ChartType(str, Enum): """Types of charts that can be detected.""" # Common charts BAR = "bar" LINE = "line" PIE = "pie" SCATTER = "scatter" AREA = "area" # Advanced charts HISTOGRAM = "histogram" BOX_PLOT = "box_plot" HEATMAP = "heatmap" TREEMAP = "treemap" RADAR = "radar" BUBBLE = "bubble" WATERFALL = "waterfall" FUNNEL = "funnel" GANTT = "gantt" # Composite STACKED_BAR = "stacked_bar" GROUPED_BAR = "grouped_bar" MULTI_LINE = "multi_line" COMBO = "combo" # Mixed chart types # Other DIAGRAM = "diagram" # Flowcharts, org charts, etc. UNKNOWN = "unknown" @dataclass class ChartConfig(ModelConfig): """Configuration for chart extraction models.""" min_confidence: float = 0.5 extract_data_points: bool = True extract_trends: bool = True max_data_points: int = 1000 detect_chart_type: bool = True def __post_init__(self): super().__post_init__() if not self.name: self.name = "chart_extractor" @dataclass class AxisInfo: """Information about a chart axis.""" label: str = "" unit: str = "" min_value: Optional[float] = None max_value: Optional[float] = None scale: str = "linear" # "linear", "log", "categorical" tick_labels: List[str] = field(default_factory=list) tick_values: List[float] = field(default_factory=list) is_datetime: bool = False orientation: str = "horizontal" # "horizontal" or "vertical" @dataclass class LegendItem: """A single legend entry.""" label: str color: Optional[str] = None # Hex color if detected series_index: int = 0 @dataclass class DataSeries: """A data series in a chart.""" name: str data_points: List[ChartDataPoint] = field(default_factory=list) color: Optional[str] = None series_type: Optional[ChartType] = None # For combo charts @property def x_values(self) -> List[Any]: return [p.x for p in self.data_points] @property def y_values(self) -> List[Any]: return [p.y for p in self.data_points] def to_dict(self) -> Dict[str, Any]: """Convert to dictionary.""" return { "name": self.name, "color": self.color, "series_type": self.series_type.value if self.series_type else None, "data_points": [ {"x": p.x, "y": p.y, "label": p.label, "value": p.value} for p in self.data_points ] } @dataclass class TrendInfo: """Detected trend in the data.""" description: str # e.g., "Increasing trend from Q1 to Q4" direction: str = "neutral" # "increasing", "decreasing", "stable", "fluctuating" start_point: Optional[ChartDataPoint] = None end_point: Optional[ChartDataPoint] = None change_percent: Optional[float] = None confidence: float = 0.0 @dataclass class ChartStructure: """ Complete extracted chart structure. Contains all detected elements of a chart including type, axes, data series, legends, and interpretations. """ bbox: BoundingBox chart_type: ChartType = ChartType.UNKNOWN confidence: float = 0.0 # Title and labels title: str = "" subtitle: str = "" # Axes x_axis: Optional[AxisInfo] = None y_axis: Optional[AxisInfo] = None secondary_y_axis: Optional[AxisInfo] = None # Data series: List[DataSeries] = field(default_factory=list) legend_items: List[LegendItem] = field(default_factory=list) # Interpretation key_values: Dict[str, Any] = field(default_factory=dict) # Notable values trends: List[TrendInfo] = field(default_factory=list) summary: str = "" # Text description of the chart # Metadata chart_id: str = "" source_text: str = "" # Any text extracted from the chart def __post_init__(self): if not self.chart_id: import hashlib content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}" self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12] @property def total_data_points(self) -> int: return sum(len(s.data_points) for s in self.series) @property def all_data_points(self) -> List[ChartDataPoint]: """Get all data points from all series.""" points = [] for series in self.series: points.extend(series.data_points) return points def get_series_by_name(self, name: str) -> Optional[DataSeries]: """Find a series by name.""" for series in self.series: if series.name.lower() == name.lower(): return series return None def to_text_description(self) -> str: """Generate a text description of the chart.""" parts = [] if self.title: parts.append(f"Chart: {self.title}") else: parts.append(f"Chart Type: {self.chart_type.value}") if self.x_axis and self.x_axis.label: parts.append(f"X-Axis: {self.x_axis.label}") if self.y_axis and self.y_axis.label: parts.append(f"Y-Axis: {self.y_axis.label}") if self.series: parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}") if self.key_values: kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items()) parts.append(f"Key Values: {kv_str}") if self.trends: trend_strs = [t.description for t in self.trends if t.description] if trend_strs: parts.append(f"Trends: {'; '.join(trend_strs)}") return "\n".join(parts) def to_dict(self) -> Dict[str, Any]: """Convert to structured dictionary.""" return { "chart_type": self.chart_type.value, "title": self.title, "x_axis": { "label": self.x_axis.label if self.x_axis else "", "unit": self.x_axis.unit if self.x_axis else "", }, "y_axis": { "label": self.y_axis.label if self.y_axis else "", "unit": self.y_axis.unit if self.y_axis else "", }, "series": [s.to_dict() for s in self.series], "key_values": self.key_values, "trends": [ {"description": t.description, "direction": t.direction} for t in self.trends ], "summary": self.summary } def to_chart_chunk( self, doc_id: str, page: int, sequence_index: int ) -> ChartChunk: """Convert to ChartChunk for the chunks module.""" # Flatten all data points all_points = self.all_data_points return ChartChunk( chunk_id=ChartChunk.generate_chunk_id( doc_id=doc_id, page=page, bbox=self.bbox, chunk_type_str="chart" ), doc_id=doc_id, text=self.to_text_description(), page=page, bbox=self.bbox, confidence=self.confidence, sequence_index=sequence_index, chart_type=self.chart_type.value, title=self.title, x_axis_label=self.x_axis.label if self.x_axis else None, y_axis_label=self.y_axis.label if self.y_axis else None, data_points=all_points, key_values=self.key_values, trends=[t.description for t in self.trends] ) @dataclass class ChartExtractionResult: """Result of chart extraction from a page.""" charts: List[ChartStructure] = field(default_factory=list) processing_time_ms: float = 0.0 model_metadata: Dict[str, Any] = field(default_factory=dict) @property def chart_count(self) -> int: return len(self.charts) class ChartModel(BatchableModel): """ Abstract base class for chart extraction models. Implementations should handle: - Chart type classification - Axis detection and labeling - Data point extraction - Legend parsing - Trend detection """ def __init__(self, config: Optional[ChartConfig] = None): super().__init__(config or ChartConfig(name="chart")) self.config: ChartConfig = self.config def get_capabilities(self) -> List[ModelCapability]: return [ModelCapability.CHART_EXTRACTION] @abstractmethod def extract_chart( self, image: ImageInput, chart_region: Optional[BoundingBox] = None, **kwargs ) -> ChartStructure: """ Extract chart structure from an image. Args: image: Input image containing a chart chart_region: Optional bounding box of the chart **kwargs: Additional parameters Returns: ChartStructure with extracted data """ pass def extract_all_charts( self, image: ImageInput, chart_regions: Optional[List[BoundingBox]] = None, **kwargs ) -> ChartExtractionResult: """ Extract all charts from an image. Args: image: Input document image chart_regions: Optional list of chart bounding boxes **kwargs: Additional parameters Returns: ChartExtractionResult with all detected charts """ import time start_time = time.time() charts = [] if chart_regions: for region in chart_regions: try: chart = self.extract_chart(image, region, **kwargs) if chart.chart_type != ChartType.UNKNOWN: charts.append(chart) except Exception: continue else: chart = self.extract_chart(image, **kwargs) if chart.chart_type != ChartType.UNKNOWN: charts.append(chart) processing_time = (time.time() - start_time) * 1000 return ChartExtractionResult( charts=charts, processing_time_ms=processing_time ) def process_batch( self, inputs: List[ImageInput], **kwargs ) -> List[ChartExtractionResult]: """Process multiple images.""" return [self.extract_all_charts(img, **kwargs) for img in inputs] @abstractmethod def classify_chart_type( self, image: ImageInput, chart_region: Optional[BoundingBox] = None, **kwargs ) -> Tuple[ChartType, float]: """ Classify the type of chart in an image. Args: image: Input image chart_region: Optional bounding box **kwargs: Additional parameters Returns: Tuple of (ChartType, confidence) """ pass def detect_trends( self, chart: ChartStructure, **kwargs ) -> List[TrendInfo]: """ Analyze chart data for trends. Default implementation provides basic trend detection. Override for more sophisticated analysis. """ trends = [] for series in chart.series: if len(series.data_points) < 2: continue # Get numeric y-values y_values = [] for dp in series.data_points: if dp.y is not None: try: y_values.append(float(dp.y)) except (ValueError, TypeError): continue if len(y_values) < 2: continue # Simple trend detection first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2) second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2) if second_half_avg > first_half_avg * 1.1: direction = "increasing" elif second_half_avg < first_half_avg * 0.9: direction = "decreasing" else: direction = "stable" change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100 if first_half_avg != 0 else 0) trend = TrendInfo( description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)", direction=direction, start_point=series.data_points[0], end_point=series.data_points[-1], change_percent=change_pct, confidence=0.7 ) trends.append(trend) return trends