| """ |
| Vision-Language Model Interface |
| |
| Abstract interface for multimodal models that understand both |
| images and text. Used for document understanding, VQA, and |
| complex reasoning over visual content. |
| """ |
|
|
| from abc import abstractmethod |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
| from ..chunks.models import BoundingBox |
| from .base import ( |
| BaseModel, |
| BatchableModel, |
| ImageInput, |
| ModelCapability, |
| ModelConfig, |
| ) |
|
|
|
|
| class VLMTask(str, Enum): |
| """Tasks that VLM models can perform.""" |
|
|
| |
| DOCUMENT_QA = "document_qa" |
| DOCUMENT_SUMMARY = "document_summary" |
| DOCUMENT_CLASSIFICATION = "document_classification" |
|
|
| |
| IMAGE_CAPTION = "image_caption" |
| IMAGE_QA = "image_qa" |
| VISUAL_GROUNDING = "visual_grounding" |
|
|
| |
| FIELD_EXTRACTION = "field_extraction" |
| TABLE_UNDERSTANDING = "table_understanding" |
| CHART_UNDERSTANDING = "chart_understanding" |
|
|
| |
| OCR_CORRECTION = "ocr_correction" |
| TEXT_GENERATION = "text_generation" |
|
|
| |
| GENERAL = "general" |
|
|
|
|
| @dataclass |
| class VLMConfig(ModelConfig): |
| """Configuration for vision-language models.""" |
|
|
| max_tokens: int = 2048 |
| temperature: float = 0.1 |
| top_p: float = 0.9 |
| max_image_size: int = 1024 |
| image_detail: str = "high" |
| system_prompt: Optional[str] = None |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| if not self.name: |
| self.name = "vlm" |
|
|
|
|
| @dataclass |
| class VLMMessage: |
| """A message in a VLM conversation.""" |
|
|
| role: str |
| content: str |
| images: List[ImageInput] = field(default_factory=list) |
| image_regions: List[Optional[BoundingBox]] = field(default_factory=list) |
|
|
|
|
| @dataclass |
| class VLMResponse: |
| """Response from a VLM model.""" |
|
|
| text: str |
| confidence: float = 0.0 |
| tokens_used: int = 0 |
| finish_reason: str = "stop" |
|
|
| |
| grounded_regions: List[BoundingBox] = field(default_factory=list) |
| region_labels: List[str] = field(default_factory=list) |
|
|
| |
| structured_data: Optional[Dict[str, Any]] = None |
|
|
| |
| processing_time_ms: float = 0.0 |
| model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| @dataclass |
| class DocumentQAResult: |
| """Result of document question answering.""" |
|
|
| question: str |
| answer: str |
| confidence: float = 0.0 |
|
|
| |
| evidence_regions: List[BoundingBox] = field(default_factory=list) |
| evidence_text: List[str] = field(default_factory=list) |
| page_references: List[int] = field(default_factory=list) |
|
|
| |
| abstained: bool = False |
| abstention_reason: Optional[str] = None |
|
|
|
|
| @dataclass |
| class FieldExtractionVLMResult: |
| """Result of field extraction using VLM.""" |
|
|
| fields: Dict[str, Any] = field(default_factory=dict) |
| confidence_scores: Dict[str, float] = field(default_factory=dict) |
|
|
| |
| field_regions: Dict[str, BoundingBox] = field(default_factory=dict) |
| field_evidence: Dict[str, str] = field(default_factory=dict) |
|
|
| |
| abstained_fields: List[str] = field(default_factory=list) |
| abstention_reasons: Dict[str, str] = field(default_factory=dict) |
|
|
| overall_confidence: float = 0.0 |
|
|
|
|
| class VisionLanguageModel(BatchableModel): |
| """ |
| Abstract base class for Vision-Language Models. |
| |
| These models combine visual understanding with language |
| capabilities for tasks like document QA, field extraction, |
| and visual reasoning. |
| """ |
|
|
| def __init__(self, config: Optional[VLMConfig] = None): |
| super().__init__(config or VLMConfig(name="vlm")) |
| self.config: VLMConfig = self.config |
|
|
| def get_capabilities(self) -> List[ModelCapability]: |
| return [ModelCapability.VISION_LANGUAGE] |
|
|
| @abstractmethod |
| def generate( |
| self, |
| prompt: str, |
| images: List[ImageInput], |
| **kwargs |
| ) -> VLMResponse: |
| """ |
| Generate a response given text prompt and images. |
| |
| Args: |
| prompt: Text prompt/question |
| images: List of images for context |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| VLMResponse with generated text |
| """ |
| pass |
|
|
| def process_batch( |
| self, |
| inputs: List[Tuple[str, List[ImageInput]]], |
| **kwargs |
| ) -> List[VLMResponse]: |
| """ |
| Process multiple prompt-image pairs. |
| |
| Args: |
| inputs: List of (prompt, images) tuples |
| **kwargs: Additional parameters |
| |
| Returns: |
| List of VLMResponses |
| """ |
| return [ |
| self.generate(prompt, images, **kwargs) |
| for prompt, images in inputs |
| ] |
|
|
| @abstractmethod |
| def chat( |
| self, |
| messages: List[VLMMessage], |
| **kwargs |
| ) -> VLMResponse: |
| """ |
| Multi-turn conversation with images. |
| |
| Args: |
| messages: Conversation history |
| **kwargs: Additional parameters |
| |
| Returns: |
| VLMResponse for the conversation |
| """ |
| pass |
|
|
| def answer_question( |
| self, |
| question: str, |
| document_images: List[ImageInput], |
| context: Optional[str] = None, |
| **kwargs |
| ) -> DocumentQAResult: |
| """ |
| Answer a question about document images. |
| |
| Args: |
| question: Question to answer |
| document_images: Document page images |
| context: Optional additional context |
| **kwargs: Additional parameters |
| |
| Returns: |
| DocumentQAResult with answer and evidence |
| """ |
| prompt = self._build_qa_prompt(question, context) |
| response = self.generate(prompt, document_images, **kwargs) |
|
|
| |
| answer, confidence, abstained, reason = self._parse_qa_response(response.text) |
|
|
| return DocumentQAResult( |
| question=question, |
| answer=answer, |
| confidence=confidence, |
| evidence_regions=response.grounded_regions, |
| abstained=abstained, |
| abstention_reason=reason |
| ) |
|
|
| def extract_fields( |
| self, |
| images: List[ImageInput], |
| schema: Dict[str, Any], |
| **kwargs |
| ) -> FieldExtractionVLMResult: |
| """ |
| Extract fields from document images according to a schema. |
| |
| Args: |
| images: Document page images |
| schema: Field schema (JSON Schema or Pydantic-like) |
| **kwargs: Additional parameters |
| |
| Returns: |
| FieldExtractionVLMResult with extracted values |
| """ |
| prompt = self._build_extraction_prompt(schema) |
| response = self.generate(prompt, images, **kwargs) |
|
|
| |
| result = self._parse_extraction_response(response, schema) |
| return result |
|
|
| def summarize_document( |
| self, |
| images: List[ImageInput], |
| max_length: int = 500, |
| **kwargs |
| ) -> str: |
| """ |
| Generate a summary of document images. |
| |
| Args: |
| images: Document page images |
| max_length: Maximum summary length |
| **kwargs: Additional parameters |
| |
| Returns: |
| Document summary text |
| """ |
| prompt = f"""Summarize this document in at most {max_length} characters. |
| Focus on the main points and key information. |
| Be concise and factual.""" |
|
|
| response = self.generate(prompt, images, **kwargs) |
| return response.text |
|
|
| def classify_document( |
| self, |
| images: List[ImageInput], |
| categories: List[str], |
| **kwargs |
| ) -> Tuple[str, float]: |
| """ |
| Classify document into predefined categories. |
| |
| Args: |
| images: Document page images |
| categories: List of possible categories |
| **kwargs: Additional parameters |
| |
| Returns: |
| Tuple of (category, confidence) |
| """ |
| categories_str = ", ".join(categories) |
| prompt = f"""Classify this document into one of these categories: {categories_str} |
| |
| Respond with just the category name and confidence (0-1). |
| Format: CATEGORY: confidence |
| |
| If you cannot confidently classify, respond with: UNKNOWN: 0.0""" |
|
|
| response = self.generate(prompt, images, **kwargs) |
|
|
| |
| try: |
| parts = response.text.strip().split(":") |
| category = parts[0].strip().upper() |
| confidence = float(parts[1].strip()) if len(parts) > 1 else 0.5 |
|
|
| |
| category_upper = {c.upper(): c for c in categories} |
| if category in category_upper: |
| return category_upper[category], confidence |
| return "UNKNOWN", 0.0 |
| except Exception: |
| return "UNKNOWN", 0.0 |
|
|
| def _build_qa_prompt( |
| self, |
| question: str, |
| context: Optional[str] = None |
| ) -> str: |
| """Build prompt for document QA.""" |
| prompt_parts = [ |
| "You are analyzing a document image. Answer the following question based only on what you can see in the document.", |
| "", |
| "IMPORTANT RULES:", |
| "- Only use information visible in the document", |
| "- If the answer is not found, say 'NOT FOUND' and explain why", |
| "- Be precise and quote exact values when possible", |
| "- Indicate your confidence level (HIGH, MEDIUM, LOW)", |
| "" |
| ] |
|
|
| if context: |
| prompt_parts.extend([ |
| "Additional context:", |
| context, |
| "" |
| ]) |
|
|
| prompt_parts.extend([ |
| f"Question: {question}", |
| "", |
| "Provide your answer in this format:", |
| "ANSWER: [your answer]", |
| "CONFIDENCE: [HIGH/MEDIUM/LOW]", |
| "EVIDENCE: [quote or describe where you found this information]" |
| ]) |
|
|
| return "\n".join(prompt_parts) |
|
|
| def _parse_qa_response( |
| self, |
| response_text: str |
| ) -> Tuple[str, float, bool, Optional[str]]: |
| """Parse QA response for answer, confidence, and abstention.""" |
| lines = response_text.strip().split("\n") |
|
|
| answer = "" |
| confidence = 0.5 |
| abstained = False |
| reason = None |
|
|
| for line in lines: |
| line_lower = line.lower() |
| if line_lower.startswith("answer:"): |
| answer = line.split(":", 1)[1].strip() |
| elif line_lower.startswith("confidence:"): |
| conf_str = line.split(":", 1)[1].strip().upper() |
| confidence = {"HIGH": 0.9, "MEDIUM": 0.6, "LOW": 0.3}.get(conf_str, 0.5) |
|
|
| |
| if "not found" in answer.lower() or "cannot find" in answer.lower(): |
| abstained = True |
| reason = answer |
|
|
| return answer, confidence, abstained, reason |
|
|
| def _build_extraction_prompt(self, schema: Dict[str, Any]) -> str: |
| """Build prompt for field extraction.""" |
| import json |
|
|
| schema_str = json.dumps(schema, indent=2) |
|
|
| prompt = f"""Extract the following fields from this document image. |
| |
| SCHEMA: |
| {schema_str} |
| |
| RULES: |
| - Only extract values that are clearly visible in the document |
| - For each field, provide the exact value and its location |
| - If a field is not found, mark it as null with confidence 0 |
| - Be precise with numbers, dates, and proper nouns |
| |
| Respond in valid JSON format matching the schema. |
| Include a "_confidence" object with confidence scores (0-1) for each field. |
| Include a "_evidence" object with the text snippet where each value was found. |
| """ |
| return prompt |
|
|
| def _parse_extraction_response( |
| self, |
| response: VLMResponse, |
| schema: Dict[str, Any] |
| ) -> FieldExtractionVLMResult: |
| """Parse extraction response into structured result.""" |
| import json |
|
|
| result = FieldExtractionVLMResult() |
|
|
| try: |
| |
| text = response.text.strip() |
|
|
| |
| if "```json" in text: |
| start = text.find("```json") + 7 |
| end = text.find("```", start) |
| text = text[start:end].strip() |
| elif "```" in text: |
| start = text.find("```") + 3 |
| end = text.find("```", start) |
| text = text[start:end].strip() |
|
|
| data = json.loads(text) |
|
|
| |
| for key, value in data.items(): |
| if key.startswith("_"): |
| continue |
| result.fields[key] = value |
|
|
| |
| if "_confidence" in data: |
| result.confidence_scores = data["_confidence"] |
|
|
| |
| if "_evidence" in data: |
| result.field_evidence = data["_evidence"] |
|
|
| |
| for field_name in schema.get("properties", {}).keys(): |
| if field_name not in result.fields or result.fields[field_name] is None: |
| result.abstained_fields.append(field_name) |
| result.abstention_reasons[field_name] = "Field not found in document" |
|
|
| |
| if result.confidence_scores: |
| result.overall_confidence = sum(result.confidence_scores.values()) / len(result.confidence_scores) |
|
|
| except json.JSONDecodeError: |
| |
| for field_name in schema.get("properties", {}).keys(): |
| result.abstained_fields.append(field_name) |
| result.abstention_reasons[field_name] = "Failed to parse extraction response" |
|
|
| return result |
|
|