|
|
""" |
|
|
Vision-Language Model Interface |
|
|
|
|
|
Abstract interface for multimodal models that understand both |
|
|
images and text. Used for document understanding, VQA, and |
|
|
complex reasoning over visual content. |
|
|
""" |
|
|
|
|
|
from abc import abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
from ..chunks.models import BoundingBox |
|
|
from .base import ( |
|
|
BaseModel, |
|
|
BatchableModel, |
|
|
ImageInput, |
|
|
ModelCapability, |
|
|
ModelConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class VLMTask(str, Enum): |
|
|
"""Tasks that VLM models can perform.""" |
|
|
|
|
|
|
|
|
DOCUMENT_QA = "document_qa" |
|
|
DOCUMENT_SUMMARY = "document_summary" |
|
|
DOCUMENT_CLASSIFICATION = "document_classification" |
|
|
|
|
|
|
|
|
IMAGE_CAPTION = "image_caption" |
|
|
IMAGE_QA = "image_qa" |
|
|
VISUAL_GROUNDING = "visual_grounding" |
|
|
|
|
|
|
|
|
FIELD_EXTRACTION = "field_extraction" |
|
|
TABLE_UNDERSTANDING = "table_understanding" |
|
|
CHART_UNDERSTANDING = "chart_understanding" |
|
|
|
|
|
|
|
|
OCR_CORRECTION = "ocr_correction" |
|
|
TEXT_GENERATION = "text_generation" |
|
|
|
|
|
|
|
|
GENERAL = "general" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VLMConfig(ModelConfig): |
|
|
"""Configuration for vision-language models.""" |
|
|
|
|
|
max_tokens: int = 2048 |
|
|
temperature: float = 0.1 |
|
|
top_p: float = 0.9 |
|
|
max_image_size: int = 1024 |
|
|
image_detail: str = "high" |
|
|
system_prompt: Optional[str] = None |
|
|
|
|
|
def __post_init__(self): |
|
|
super().__post_init__() |
|
|
if not self.name: |
|
|
self.name = "vlm" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VLMMessage: |
|
|
"""A message in a VLM conversation.""" |
|
|
|
|
|
role: str |
|
|
content: str |
|
|
images: List[ImageInput] = field(default_factory=list) |
|
|
image_regions: List[Optional[BoundingBox]] = field(default_factory=list) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class VLMResponse: |
|
|
"""Response from a VLM model.""" |
|
|
|
|
|
text: str |
|
|
confidence: float = 0.0 |
|
|
tokens_used: int = 0 |
|
|
finish_reason: str = "stop" |
|
|
|
|
|
|
|
|
grounded_regions: List[BoundingBox] = field(default_factory=list) |
|
|
region_labels: List[str] = field(default_factory=list) |
|
|
|
|
|
|
|
|
structured_data: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
processing_time_ms: float = 0.0 |
|
|
model_metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DocumentQAResult: |
|
|
"""Result of document question answering.""" |
|
|
|
|
|
question: str |
|
|
answer: str |
|
|
confidence: float = 0.0 |
|
|
|
|
|
|
|
|
evidence_regions: List[BoundingBox] = field(default_factory=list) |
|
|
evidence_text: List[str] = field(default_factory=list) |
|
|
page_references: List[int] = field(default_factory=list) |
|
|
|
|
|
|
|
|
abstained: bool = False |
|
|
abstention_reason: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FieldExtractionVLMResult: |
|
|
"""Result of field extraction using VLM.""" |
|
|
|
|
|
fields: Dict[str, Any] = field(default_factory=dict) |
|
|
confidence_scores: Dict[str, float] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
field_regions: Dict[str, BoundingBox] = field(default_factory=dict) |
|
|
field_evidence: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
abstained_fields: List[str] = field(default_factory=list) |
|
|
abstention_reasons: Dict[str, str] = field(default_factory=dict) |
|
|
|
|
|
overall_confidence: float = 0.0 |
|
|
|
|
|
|
|
|
class VisionLanguageModel(BatchableModel): |
|
|
""" |
|
|
Abstract base class for Vision-Language Models. |
|
|
|
|
|
These models combine visual understanding with language |
|
|
capabilities for tasks like document QA, field extraction, |
|
|
and visual reasoning. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: Optional[VLMConfig] = None): |
|
|
super().__init__(config or VLMConfig(name="vlm")) |
|
|
self.config: VLMConfig = self.config |
|
|
|
|
|
def get_capabilities(self) -> List[ModelCapability]: |
|
|
return [ModelCapability.VISION_LANGUAGE] |
|
|
|
|
|
@abstractmethod |
|
|
def generate( |
|
|
self, |
|
|
prompt: str, |
|
|
images: List[ImageInput], |
|
|
**kwargs |
|
|
) -> VLMResponse: |
|
|
""" |
|
|
Generate a response given text prompt and images. |
|
|
|
|
|
Args: |
|
|
prompt: Text prompt/question |
|
|
images: List of images for context |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
VLMResponse with generated text |
|
|
""" |
|
|
pass |
|
|
|
|
|
def process_batch( |
|
|
self, |
|
|
inputs: List[Tuple[str, List[ImageInput]]], |
|
|
**kwargs |
|
|
) -> List[VLMResponse]: |
|
|
""" |
|
|
Process multiple prompt-image pairs. |
|
|
|
|
|
Args: |
|
|
inputs: List of (prompt, images) tuples |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
List of VLMResponses |
|
|
""" |
|
|
return [ |
|
|
self.generate(prompt, images, **kwargs) |
|
|
for prompt, images in inputs |
|
|
] |
|
|
|
|
|
@abstractmethod |
|
|
def chat( |
|
|
self, |
|
|
messages: List[VLMMessage], |
|
|
**kwargs |
|
|
) -> VLMResponse: |
|
|
""" |
|
|
Multi-turn conversation with images. |
|
|
|
|
|
Args: |
|
|
messages: Conversation history |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
VLMResponse for the conversation |
|
|
""" |
|
|
pass |
|
|
|
|
|
def answer_question( |
|
|
self, |
|
|
question: str, |
|
|
document_images: List[ImageInput], |
|
|
context: Optional[str] = None, |
|
|
**kwargs |
|
|
) -> DocumentQAResult: |
|
|
""" |
|
|
Answer a question about document images. |
|
|
|
|
|
Args: |
|
|
question: Question to answer |
|
|
document_images: Document page images |
|
|
context: Optional additional context |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
DocumentQAResult with answer and evidence |
|
|
""" |
|
|
prompt = self._build_qa_prompt(question, context) |
|
|
response = self.generate(prompt, document_images, **kwargs) |
|
|
|
|
|
|
|
|
answer, confidence, abstained, reason = self._parse_qa_response(response.text) |
|
|
|
|
|
return DocumentQAResult( |
|
|
question=question, |
|
|
answer=answer, |
|
|
confidence=confidence, |
|
|
evidence_regions=response.grounded_regions, |
|
|
abstained=abstained, |
|
|
abstention_reason=reason |
|
|
) |
|
|
|
|
|
def extract_fields( |
|
|
self, |
|
|
images: List[ImageInput], |
|
|
schema: Dict[str, Any], |
|
|
**kwargs |
|
|
) -> FieldExtractionVLMResult: |
|
|
""" |
|
|
Extract fields from document images according to a schema. |
|
|
|
|
|
Args: |
|
|
images: Document page images |
|
|
schema: Field schema (JSON Schema or Pydantic-like) |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
FieldExtractionVLMResult with extracted values |
|
|
""" |
|
|
prompt = self._build_extraction_prompt(schema) |
|
|
response = self.generate(prompt, images, **kwargs) |
|
|
|
|
|
|
|
|
result = self._parse_extraction_response(response, schema) |
|
|
return result |
|
|
|
|
|
def summarize_document( |
|
|
self, |
|
|
images: List[ImageInput], |
|
|
max_length: int = 500, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate a summary of document images. |
|
|
|
|
|
Args: |
|
|
images: Document page images |
|
|
max_length: Maximum summary length |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Document summary text |
|
|
""" |
|
|
prompt = f"""Summarize this document in at most {max_length} characters. |
|
|
Focus on the main points and key information. |
|
|
Be concise and factual.""" |
|
|
|
|
|
response = self.generate(prompt, images, **kwargs) |
|
|
return response.text |
|
|
|
|
|
def classify_document( |
|
|
self, |
|
|
images: List[ImageInput], |
|
|
categories: List[str], |
|
|
**kwargs |
|
|
) -> Tuple[str, float]: |
|
|
""" |
|
|
Classify document into predefined categories. |
|
|
|
|
|
Args: |
|
|
images: Document page images |
|
|
categories: List of possible categories |
|
|
**kwargs: Additional parameters |
|
|
|
|
|
Returns: |
|
|
Tuple of (category, confidence) |
|
|
""" |
|
|
categories_str = ", ".join(categories) |
|
|
prompt = f"""Classify this document into one of these categories: {categories_str} |
|
|
|
|
|
Respond with just the category name and confidence (0-1). |
|
|
Format: CATEGORY: confidence |
|
|
|
|
|
If you cannot confidently classify, respond with: UNKNOWN: 0.0""" |
|
|
|
|
|
response = self.generate(prompt, images, **kwargs) |
|
|
|
|
|
|
|
|
try: |
|
|
parts = response.text.strip().split(":") |
|
|
category = parts[0].strip().upper() |
|
|
confidence = float(parts[1].strip()) if len(parts) > 1 else 0.5 |
|
|
|
|
|
|
|
|
category_upper = {c.upper(): c for c in categories} |
|
|
if category in category_upper: |
|
|
return category_upper[category], confidence |
|
|
return "UNKNOWN", 0.0 |
|
|
except Exception: |
|
|
return "UNKNOWN", 0.0 |
|
|
|
|
|
def _build_qa_prompt( |
|
|
self, |
|
|
question: str, |
|
|
context: Optional[str] = None |
|
|
) -> str: |
|
|
"""Build prompt for document QA.""" |
|
|
prompt_parts = [ |
|
|
"You are analyzing a document image. Answer the following question based only on what you can see in the document.", |
|
|
"", |
|
|
"IMPORTANT RULES:", |
|
|
"- Only use information visible in the document", |
|
|
"- If the answer is not found, say 'NOT FOUND' and explain why", |
|
|
"- Be precise and quote exact values when possible", |
|
|
"- Indicate your confidence level (HIGH, MEDIUM, LOW)", |
|
|
"" |
|
|
] |
|
|
|
|
|
if context: |
|
|
prompt_parts.extend([ |
|
|
"Additional context:", |
|
|
context, |
|
|
"" |
|
|
]) |
|
|
|
|
|
prompt_parts.extend([ |
|
|
f"Question: {question}", |
|
|
"", |
|
|
"Provide your answer in this format:", |
|
|
"ANSWER: [your answer]", |
|
|
"CONFIDENCE: [HIGH/MEDIUM/LOW]", |
|
|
"EVIDENCE: [quote or describe where you found this information]" |
|
|
]) |
|
|
|
|
|
return "\n".join(prompt_parts) |
|
|
|
|
|
def _parse_qa_response( |
|
|
self, |
|
|
response_text: str |
|
|
) -> Tuple[str, float, bool, Optional[str]]: |
|
|
"""Parse QA response for answer, confidence, and abstention.""" |
|
|
lines = response_text.strip().split("\n") |
|
|
|
|
|
answer = "" |
|
|
confidence = 0.5 |
|
|
abstained = False |
|
|
reason = None |
|
|
|
|
|
for line in lines: |
|
|
line_lower = line.lower() |
|
|
if line_lower.startswith("answer:"): |
|
|
answer = line.split(":", 1)[1].strip() |
|
|
elif line_lower.startswith("confidence:"): |
|
|
conf_str = line.split(":", 1)[1].strip().upper() |
|
|
confidence = {"HIGH": 0.9, "MEDIUM": 0.6, "LOW": 0.3}.get(conf_str, 0.5) |
|
|
|
|
|
|
|
|
if "not found" in answer.lower() or "cannot find" in answer.lower(): |
|
|
abstained = True |
|
|
reason = answer |
|
|
|
|
|
return answer, confidence, abstained, reason |
|
|
|
|
|
def _build_extraction_prompt(self, schema: Dict[str, Any]) -> str: |
|
|
"""Build prompt for field extraction.""" |
|
|
import json |
|
|
|
|
|
schema_str = json.dumps(schema, indent=2) |
|
|
|
|
|
prompt = f"""Extract the following fields from this document image. |
|
|
|
|
|
SCHEMA: |
|
|
{schema_str} |
|
|
|
|
|
RULES: |
|
|
- Only extract values that are clearly visible in the document |
|
|
- For each field, provide the exact value and its location |
|
|
- If a field is not found, mark it as null with confidence 0 |
|
|
- Be precise with numbers, dates, and proper nouns |
|
|
|
|
|
Respond in valid JSON format matching the schema. |
|
|
Include a "_confidence" object with confidence scores (0-1) for each field. |
|
|
Include a "_evidence" object with the text snippet where each value was found. |
|
|
""" |
|
|
return prompt |
|
|
|
|
|
def _parse_extraction_response( |
|
|
self, |
|
|
response: VLMResponse, |
|
|
schema: Dict[str, Any] |
|
|
) -> FieldExtractionVLMResult: |
|
|
"""Parse extraction response into structured result.""" |
|
|
import json |
|
|
|
|
|
result = FieldExtractionVLMResult() |
|
|
|
|
|
try: |
|
|
|
|
|
text = response.text.strip() |
|
|
|
|
|
|
|
|
if "```json" in text: |
|
|
start = text.find("```json") + 7 |
|
|
end = text.find("```", start) |
|
|
text = text[start:end].strip() |
|
|
elif "```" in text: |
|
|
start = text.find("```") + 3 |
|
|
end = text.find("```", start) |
|
|
text = text[start:end].strip() |
|
|
|
|
|
data = json.loads(text) |
|
|
|
|
|
|
|
|
for key, value in data.items(): |
|
|
if key.startswith("_"): |
|
|
continue |
|
|
result.fields[key] = value |
|
|
|
|
|
|
|
|
if "_confidence" in data: |
|
|
result.confidence_scores = data["_confidence"] |
|
|
|
|
|
|
|
|
if "_evidence" in data: |
|
|
result.field_evidence = data["_evidence"] |
|
|
|
|
|
|
|
|
for field_name in schema.get("properties", {}).keys(): |
|
|
if field_name not in result.fields or result.fields[field_name] is None: |
|
|
result.abstained_fields.append(field_name) |
|
|
result.abstention_reasons[field_name] = "Field not found in document" |
|
|
|
|
|
|
|
|
if result.confidence_scores: |
|
|
result.overall_confidence = sum(result.confidence_scores.values()) / len(result.confidence_scores) |
|
|
|
|
|
except json.JSONDecodeError: |
|
|
|
|
|
for field_name in schema.get("properties", {}).keys(): |
|
|
result.abstained_fields.append(field_name) |
|
|
result.abstention_reasons[field_name] = "Failed to parse extraction response" |
|
|
|
|
|
return result |
|
|
|