diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000000000000000000000000000000000..5b1c3ac0de8096376f74ef34bf08eaa86f651e3a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,59 @@ +# Ignore development artifacts +__pycache__/ +*.pyc +*.pyo +*.pyd +.Python +*.so +*.dylib +*.log +.venv/ +venv/ +ENV/ +env/ +.git/ +.gitignore +.gitlab-ci.yml +*.md +!README.md +.pytest_cache/ +*.swp +*.swo +*~ +.DS_Store + +# Ignore data directories (too large for Docker context) +data/ +!data/prompt_templates/ +!data/visual_element_prefabs/ + +# Ignore build artifacts +*.egg-info/ +dist/ +build/ +*.whl + +# Ignore handwriting service (separate deployment) +handwriting_service/ + +# Ignore WordStylist (not needed for API) +WordStylist/ + +# Ignore scripts (not needed for API runtime) +scripts/ + +# Ignore documentation and deployment files +ARCHITECTURE.md +DEPLOYMENT.md +*.sh +!start.sh +!start_worker.sh +docker-compose.yml +railway.json +railway_setup_vars.sh + +# Keep only essential code +!docgenie/ +!api/ +!setup.py +!pyproject.toml diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..aec80c52e8a5408669d9e12fe7376fc1488be89b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,7 @@ +*.svg filter=lfs diff=lfs merge=lfs -text +*.webp filter=lfs diff=lfs merge=lfs -text +*.ico filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100755 index 0000000000000000000000000000000000000000..92b93970145852b19d0dab127119aab54824a7ad --- /dev/null +++ b/.gitignore @@ -0,0 +1,172 @@ +# Project +data/clusters/ +data/embeddings/ +data/temp/ +wandb/ +data/models/ +data/webapp_cache/ +data/analyzation/ +data/cherrypicks/ +data/hw_imgs/ +/data/seed-images/* +/docgenie/playground/test.py +/docgenie/playground/handwritten_text/doc_vqa_handwriting_text_images +/docgenie/playground/handwritten_text/handwriting_raw_tokens +/docgenie/playground/handwritten_text/temp +data/datasets +data/models +data/cluster_plots +data/syn_dataset_statistics_plots +data/gt_embeddings +data/wandb_downloads +data/wandb_project_csvs +data/folders.txt +cache +runs +visualizations +.venv +**/**.__pycache__ +/docgenie/playground/handwritten_text/doc_vqa_handwriting_text_images +/docgenie/playground/handwritten_text/temp +data/datasets +data/models + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +*.log + +# Virtual environments +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Jupyter Notebook +.ipynb_checkpoints +*.ipynb_checkpoints/ + +# Model artifacts - download separately +inference/ +inference_new/ +inference_hf/ +model/experiments/hf_conditional_latent/cached_vae/ +*.zip + + +# Datasets - download separately +docvqa-handwritten-sizes4/ +syn_docvqa/ +iam_dataset/ +iam_dataset_processed/ +iam_dataset_processed_partial/ +docvqa-test/ +docvqa-viselems/ +docvqa-viselems2/ +temp/ +generations/ + +# Generated outputs +output/ + +# Backup files +*.bak +*.backup +*.tmp + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# OS +./data/clusters_old/ +Thumbs.db + + +# Training +training/ +vae_evaluation/ + + +# Logs and checkpoints +*.pt +# But allow the inference model for handwriting service +!handwriting_service/WordStylist/models/ema_ckpt.pt +*.ckpt +*.pth +*.safetensors + +.env + +# Playwright +node_modules/ +/test-results/ +/playwright-report/ +/blob-report/ +/playwright/.cache/ +/playwright/.auth/ + + +!data/models/ +!data/models/handwriting/ +!data/models/handwriting/char_vocab.json +!data/models/handwriting/config.yaml +!data/models/handwriting/writer_id_map.json +!data/models/handwriting/cached_vae/config.json +data/models/.locks* +data/models/baseline +data/models/legacy +data/models/models* +data/models/pretrained +test_run.py +test_vlm.ipynb +test.ipynb +test2.ipynb +test3.py +test4.py +test5.py +test6.py +data/results +data/results_old/ +data/tmp/ +docgenie/playground/extract_02_eval_metrics_from_wandb.py +docgenie/playground/extract_metrics_from_wandb.py +data/cached_subsets +data/mixed_datasets +data/results_backup_v1 +data/results_v1 +data/old-results/ +data/embeddings +data/mixed_datasets +data/results_backup_v1 +sync_datasets.sh +data/results_latest +data/results_latest copy diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml new file mode 100755 index 0000000000000000000000000000000000000000..6cd5d7481bede501a691eac5043403cd029d7eec --- /dev/null +++ b/.gitlab-ci.yml @@ -0,0 +1,16 @@ +# You can override the included template(s) by including variable overrides +# SAST customization: https://docs.gitlab.com/ee/user/application_security/sast/#customizing-the-sast-settings +# Secret Detection customization: https://docs.gitlab.com/user/application_security/secret_detection/pipeline/configure +# Dependency Scanning customization: https://docs.gitlab.com/ee/user/application_security/dependency_scanning/#customizing-the-dependency-scanning-settings +# Container Scanning customization: https://docs.gitlab.com/ee/user/application_security/container_scanning/#customizing-the-container-scanning-settings +# Note that environment variables can be set in several places +# See https://docs.gitlab.com/ee/ci/variables/#cicd-variable-precedence +stages: +- test +- secret-detection +variables: + SECRET_DETECTION_ENABLED: 'true' +secret_detection: + stage: secret-detection +include: +- template: Security/Secret-Detection.gitlab-ci.yml diff --git a/.python-version b/.python-version new file mode 100755 index 0000000000000000000000000000000000000000..efbce23a0e1b1eed58654641085f009d5233a0fb --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.12 diff --git a/API_FLOW_DOCUMENTATION.md b/API_FLOW_DOCUMENTATION.md new file mode 100644 index 0000000000000000000000000000000000000000..b8a7767b0b10d8485dfad38c82284ade137eb4d3 --- /dev/null +++ b/API_FLOW_DOCUMENTATION.md @@ -0,0 +1,1024 @@ +# Complete API Flow Documentation + +## Overview +The DocGenie API provides three endpoints for synthetic document generation, implementing a 19-stage pipeline that transforms seed images and prompts into complete datasets with OCR, ground truth, and optional handwriting/visual elements. + +**Base URL**: `http://localhost:8000` (development) or Railway deployment +**Documentation**: `/docs` (FastAPI auto-generated Swagger UI) + +--- + +## API Endpoints + +### 1. `/generate` - Legacy JSON Response (POST) +**Purpose**: Generate documents and return complete JSON metadata +**Response**: JSON with HTML, PDF (base64), bounding boxes, optional handwriting/visual elements +**Use Case**: Testing, development, full metadata inspection +**Pipeline Stages**: 1-19 (configurable via parameters) + +### 2. `/generate/pdf` - Sync PDF+Dataset ZIP (POST) +**Purpose**: Generate documents and return ZIP file with all artifacts +**Response**: ZIP file containing: +- `*.pdf` - Generated document PDFs +- `*_final.pdf` - PDFs with handwriting/visual elements (if enabled) +- `*.msgpack` - Dataset format (if export enabled) +- `metadata.json` - Complete generation metadata +- `handwriting/` - Individual handwriting images +- `visual_elements/` - Individual visual element images + +**Use Case**: Production dataset generation, batch processing +**Pipeline Stages**: 1-19 (all features available) + +### 3. `/generate/async` - Async Batch Processing (POST) +**Purpose**: Queue large batch jobs via background worker (Redis Queue) +**Response**: Task ID for status polling +**Status Check**: `GET /generate/async/status/{task_id}` +**Result Download**: `GET /generate/async/result/{task_id}` (returns ZIP) +**Use Case**: Large-scale dataset generation (100+ documents) +**Pipeline Stages**: 1-19 (via worker.py) + +--- + +## Request Parameters + +```python +class GenerateDocumentRequest: + seed_images: List[HttpUrl] # 1-8 seed images from web URLs + prompt_params: PromptParameters # Generation configuration + +class PromptParameters: + # Core Parameters + language: str = "english" # Document language + doc_type: str = "invoice" # Document type (invoice, receipt, form, etc.) + gt_type: str = "qa" # Ground truth format (qa, kie) + gt_format: str = "json" # GT encoding (json, annotation) + num_solutions: int = 1 # Documents per seed set + + # Feature Toggles (Stages 07-19) + enable_handwriting: bool = False # Stage 07-09, 12 + handwriting_ratio: float = 0.2 # Probabilistic filter (0.0-1.0) + enable_visual_elements: bool = False # Stage 08, 10, 13 + visual_element_types: List[str] = [] # Filter types: logo, photo, figure, barcode, etc. + enable_ocr: bool = True # Stage 15 + enable_bbox_normalization: bool = True # Stage 16 + enable_gt_verification: bool = False # Stage 17 + enable_analysis: bool = False # Stage 18 + enable_debug_visualization: bool = False # Stage 19 + enable_dataset_export: bool = False # Stage 19 (msgpack format) + dataset_export_format: str = "msgpack" # Currently only msgpack supported + + # Reproducibility + seed: Optional[int] = None # Random seed (null = random, int = reproducible) +``` + +--- + +## Pipeline Architecture: The 19 Stages + +The API implements all 19 stages of the original batch pipeline in `docgenie/generation/`. Each stage is mapped to corresponding functions in `api/utils.py`. + +### **Phase 1: Core Pipeline (Stages 01-06)** +Generate base documents from seed images and LLM prompts. + +#### **Stage 01: Seed Selection & Download** +- **Original**: `pipeline_01_select_seeds.py` +- **API**: `download_seed_images()` in `api/utils.py:117-161` +- **Process**: + 1. Accept user-provided seed image URLs (1-8 images) + 2. Download with retry logic (3 attempts, exponential backoff) + 3. Handle transient HTTP errors (502, 503, 504, 429) + 4. Convert to base64 for LLM input +- **Error Handling**: Retry with 2s, 4s, 8s delays; raise HTTPException on failure + +#### **Stage 02: Prompt LLM** +- **Original**: `pipeline_02_prompt_llm.py` +- **API**: `call_claude_api_direct()` in `api/utils.py:550-600` +- **Process**: + 1. Load prompt template: `data/prompt_templates/ClaudeRefined12/seed-based-json.txt` + 2. Build prompt with parameters: language, doc_type, gt_type, num_solutions + 3. Call Claude API (Anthropic Messages API v1) + - Model: `claude-3-5-sonnet-20241022` (configurable) + - Max tokens: 16,000 + - Temperature: 1.0 + - Vision: Send base64-encoded seed images + 4. Receive HTML documents with embedded ground truth +- **LLM Output Format**: Multiple `...` blocks with: + - CSS styling with page dimensions + - HTML elements with semantic classes + - Handwriting markers: `class="handwritten author1"` (author1, author2, etc.) + - Visual element placeholders: `data-placeholder="logo"`, `data-content="company-logo"` + - Ground truth: `` + +#### **Stage 03: Process Response & Extract HTML** +- **Original**: `pipeline_03_process_response.py` +- **API**: `extract_html_documents_from_response()` in `api/utils.py:605-635` +- **Process**: + 1. Parse LLM response for `...` blocks (regex) + 2. Prettify HTML with BeautifulSoup + 3. Validate HTML structure + 4. Extract ground truth JSON from ` in the following format: {gt example} +Notes: +• Pay close attention to cultural/regional differences seen +in the seed images (e.g., language, format, disclaimers). +• Feel free to creatively adapt or combine stylistic cues +from the seeds, as long as the end result looks authentic +for that cultural context. +• Do NOT directly copy-paste text or entire code blocks +from any single seed image or across these new solutions. +Now please generate the {num solutions} distinct +{doc type} documents. diff --git a/data/prompt_templates/Adaptation_GT/seed-free.txt b/data/prompt_templates/Adaptation_GT/seed-free.txt new file mode 100755 index 0000000000000000000000000000000000000000..634698f7e5ffd52c3e87fbf2a7844c3079998758 --- /dev/null +++ b/data/prompt_templates/Adaptation_GT/seed-free.txt @@ -0,0 +1,25 @@ +You are an AI specialized in generating multiple unique +HTML documents in one response. Please create +{num solutions} unique HTML documents representing +{doc type}. +Each solution must: +1. Include all mandatory fields: {sections}. +2. Be formatted so it could print on A4 (e.g., use @page +{{ size: A4; }} in your CSS). +3. Show a significantly different layout, styling, and textual content from every other solution. +4. Maintain a {background requirements}. +5. Avoid copy-pasting or reusing large chunks of HTML, +CSS, or disclaimers—each document must be at least +70% different in code and text than the others. +6. Wrap each complete document between +and tags, labeled as: +1. ...Solution #1... +2. ...Solution #2... +... +{num solutions}. ...Solution +#{num solutions}... +Include the {gt type} as JSON in the document via in the following format: {gt example} +Do not provide additional commentary or references to the +other solutions within each HTML. +Now generate the {num solutions} distinct {doc type} +documents. diff --git a/data/prompt_templates/ClaudeRefined1/seed-based.txt b/data/prompt_templates/ClaudeRefined1/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..83b816d620eed5f01c6328df4c1f5d02f3f4bd13 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined1/seed-based.txt @@ -0,0 +1,78 @@ +# HTML Document Generation Prompt (Refined) + +You are an AI specialized in creating culturally authentic HTML documents based on visual analysis of real-world examples. You have been provided with {num_seed_images} seed images of **{doc_type}** documents from different cultural and regional contexts. + +## Cultural Variations (If Present) +The seed images may demonstrate regional differences such as: +- Language variations and local terminology +- Date formatting conventions (DD/MM/YYYY, MM/DD/YYYY, etc.) +- Currency symbols and number formatting +- Layout preferences (field positioning, official elements, cultural design patterns) +- Regional legal disclaimers and regulatory requirements +- Typography and visual hierarchy standards + +## Task Requirements +Generate **{num_solutions}** unique HTML documents that meet these specifications: + +### Core Requirements +1. **Cultural Authenticity**: If cultural/regional variations are present in the seed images, reflect those stylistic elements without directly copying any text, disclaimers, or layouts verbatim +2. **Required Content**: Include all essential fields: {required_sections} +3. **Single Page Format**: Design as single-page documents with dimensions appropriate to the document type (receipts: narrow format, forms: standard width, etc.) +4. **Language**: Generate all content in {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: Each document must be at least 70% different in code structure, styling, and content from others + +## Ground Truth Generation +Generate appropriate ground truth data for each document: {gt_type}. +Include the ground truth as JSON inside each document in a `` tag. +The ground truth must follow the format: {gt_format} + +### Technical Specifications +- Wrap each solution in `...` tags numbered sequentially +- Include the ground truth JSON in `` as specified above +- Implement static CSS appropriate for the document type and single-page layout (no animations, transitions, or dynamic effects) + +## Additional Requirements +{user_descriptions} + +### Content Guidelines +- **DO**: Adapt any cultural/regional stylistic elements present in the seed images +- **DO**: Create authentic-feeling content appropriate to each cultural context +- **DO**: Vary layout structures, color schemes, and typographic choices +- **DO**: Use static styling only (no animations, hover effects, or transitions) +- **DON'T**: Copy-paste text, code blocks, or entire sections between solutions +- **DON'T**: Reuse identical disclaimers, headers, or formatting patterns +- **DON'T**: Include any dynamic effects, animations, or interactive elements + +## Additional Requirements +{user_descriptions} + +## Output Format +Structure your response as: + +``` +1. + + ...complete HTML document... + + +2. + + ...complete HTML document... + + +...continue for all {num_solutions} solutions +``` + +## Quality Checklist +Before generating, ensure each document: +- [ ] Reflects any authentic cultural/regional characteristics present in seed images +- [ ] Contains all required sections: {required_sections} +- [ ] Uses static styling only (no animations or dynamic effects) +- [ ] Uses appropriate single-page formatting for the document type +- [ ] All content is in English +- [ ] Includes the specified ground truth in proper JSON format +- [ ] Maintains 70%+ uniqueness from other solutions +- [ ] Follows semantic HTML best practices + +Now generate the **{num_solutions}** distinct **{doc_type}** documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined10/seed-based.txt b/data/prompt_templates/ClaudeRefined10/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..21fb12f4689eae2f5eb109d684594b5b67a21c42 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined10/seed-based.txt @@ -0,0 +1,57 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply generously increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- Never include signatures as handwriting + +## Visual Placeholders (if document type requires) +- Insert `
` for non-text elements at appropiate positions +- Valid types are: signature, stamp, logo, barcode, photo, chart +- Add data-content attribute with actual content description +- For signatures, add author class ('author1', 'author2', etc.) to distinguish different people and ensure the author is semantically coherent with the document content +- For stamps, use `position:absolute;z-index:10;` and specify 'top' and 'right' +- Dimensions in mm/cm, e.g. `width:30mm;height:20mm;` +- Example: `
` +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with minified HTML/CSS/JS +- [ ] Content in {language} +- [ ] GT JSON present and correctly formatted +- [ ] Visual elements are semantically coherent + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined11/seed-based.txt b/data/prompt_templates/ClaudeRefined11/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..ea526e5508632fca75840f26cc364944daa15015 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined11/seed-based.txt @@ -0,0 +1,55 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' and use regular text +- Apply no special styles to 'handwritten', except generously increased size, in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- If the handwriting represents a signature mark it additionally with class 'signature' + +## Visual Placeholders (if document type requires) +- Insert `
` for non-text elements at appropriate positions +- Valid types are: stamp, logo, barcode, photo, chart +- Add data-content attribute with actual content description +- For stamps, use `position:absolute;z-index:10;` and specify 'top' and 'right' +- Always provide dimensions in mm/cm, e.g. `width:30mm;height:20mm;` +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with minified HTML/CSS/JS +- [ ] Content in {language} +- [ ] GT JSON present and correctly formatted +- [ ] Visual elements are semantically coherent + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined12/seed-based-annotation.txt b/data/prompt_templates/ClaudeRefined12/seed-based-annotation.txt new file mode 100755 index 0000000000000000000000000000000000000000..166162e89a5670ee549ee9a0eb66df9642bdaa2d --- /dev/null +++ b/data/prompt_templates/ClaudeRefined12/seed-based-annotation.txt @@ -0,0 +1,55 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' and use regular text +- Apply no special styles to 'handwritten', except generously increased size, in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- If the handwriting represents a signature mark it additionally with class 'signature' + +## Visual Placeholders (if document type requires) +- Insert `
` for non-text elements at appropriate positions +- Valid types are: stamp, logo, figure, barcode, photo +- Add data-content attribute with actual content description +- For stamps, use `position:absolute;z-index:10;` and specify 'top' and 'right' +- Always provide appropiate dimensions +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +Generate ground truth by assigning each applicable element in HTML a class from the list below to uniquely identify its label: +{gt_type} +{gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with minified HTML/CSS +- [ ] Content in {language} +- [ ] GT labels via class annotations are present and assigned to correct elements +- [ ] Visual elements are semantically coherent + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined12/seed-based-json.txt b/data/prompt_templates/ClaudeRefined12/seed-based-json.txt new file mode 100755 index 0000000000000000000000000000000000000000..6dbac5efd21eb7a8365ac553b11817d6defbb395 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined12/seed-based-json.txt @@ -0,0 +1,55 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' and use regular text +- Apply no special styles to 'handwritten', except generously increased size, in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- If the handwriting represents a signature mark it additionally with class 'signature' + +## Visual Placeholders (if document type requires) +- Insert `
` for non-text elements at appropriate positions +- Valid types are: stamp, logo, figure, barcode, photo +- Add data-content attribute with actual content description +- For stamps, use `position:absolute;z-index:10;` and specify 'top' and 'right' +- Always provide appropiate dimensions +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with minified HTML/CSS +- [ ] Content in {language} +- [ ] GT JSON present, correctly formatted and semantically coherent +- [ ] Visual elements are semantically coherent + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined2/seed-based.txt b/data/prompt_templates/ClaudeRefined2/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..107eeb8425d485b835ec96c4f36573cdb54d1d70 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined2/seed-based.txt @@ -0,0 +1,70 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' (no special styling/fonts, treat as regular text) +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people's handwriting on the same document + +## Visual Placeholders (if document type requires) +- Include placeholders for non-text visual elements using HTML class 'visual-element' +- Add data attributes: data-type (signature/logo/stamp/barcode/photo/chart/etc.) and data-content (actual content) +- Give each placeholder appropriate dimensions via inline styles +- Examples: `
`, `
`, `
` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined3/seed-based.txt b/data/prompt_templates/ClaudeRefined3/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..1ef1aeafbc5d5865dc065b64c5ff0520bcfc5a03 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined3/seed-based.txt @@ -0,0 +1,70 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' (no special styling/fonts, treat as regular text) +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people's handwriting on the same document + +## Visual Placeholders (if document type requires) +- Include placeholders for non-text visual elements as JSON in `` tag. +- Describe type (signature/logo/stamp/barcode/photo/chart/etc.) and content (actual content) +- Describe placement of each visual element with appropriate dimensions and y-rotation +- Examples: `[{"type": "signature", "content": "John Doe", "x0": 105, "x1": 116, "y0": 82, "y1": 102, "rotation": -4}, ...]` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined3CloneDoc/seed-based.txt b/data/prompt_templates/ClaudeRefined3CloneDoc/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..b968eed4d2f93b252244cc9afb1b18d4b6556a5d --- /dev/null +++ b/data/prompt_templates/ClaudeRefined3CloneDoc/seed-based.txt @@ -0,0 +1,97 @@ +You are an AI creating HTML documents that **clone the style and structure** of {num_seed_images} seed images of **{doc_type}** documents. + +# Task: Generate {num_solutions} cloned HTML documents + +## Core Objective +**CLONE the visual design, layout, and structure** of the seed images while using **completely different data**. Think of this as creating blank template instances filled with new information. + +## Critical Requirements +1. **Visual Fidelity**: Replicate styling elements from seed images: + - Exact layout structure (positioning, spacing, alignment) + - Typography (fonts, sizes, weights, colors) + - Visual hierarchy and sectioning + - Color schemes and backgrounds + - Border styles, dividers, and decorative elements + - Logo/header/footer placement and styling + +2. **Data Uniqueness**: Generate completely new content: + - **NEVER copy**: names, addresses, phone numbers, emails, IBANs, account numbers, license numbers, ID numbers, dates, amounts, prices, or any other specific data points + - Generate realistic but fictional alternatives for all data fields + - Maintain data type appropriateness (valid formats for phones, IBANs, dates, etc.) + - Ensure cultural/regional authenticity for generated data + +3. **Required Fields**: {required_sections} + +4. **Format**: Single-page, dimensions matching seed documents + +5. **Language**: {language} + +6. **Background**: {background_requirements} + +7. **Static Only**: No animations, transitions, or dynamic effects + +## Cloning Strategy +- **DO**: Match layout grids, spacing, font choices, color palettes, sectioning patterns, table structures, visual element placement +- **DON'T**: Copy any actual text content, numerical data, personal information, or business-specific details +- **Think**: "Same template, different instance" + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Replicate CSS styling patterns from seed documents + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' (no special styling/fonts, treat as regular text) +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people's handwriting on the same document +- Generate different handwritten content than seed documents + +## Visual Placeholders (if document type requires) +- Include placeholders for non-text visual elements as JSON in `` tag. +- Describe type (signature/logo/stamp/barcode/photo/chart/etc.) and content (actual content - must be different from seed) +- Match placement patterns from seed documents with appropriate dimensions and y-rotation +- Examples: `[{"type": "signature", "content": "Jane Smith", "x0": 105, "x1": 116, "y0": 82, "y1": 102, "rotation": -4}, ...]` + +## Data Generation Guidelines +- Names: Generate culturally appropriate fictional names +- Addresses: Create realistic but non-existent addresses +- Phone/Fax: Use valid formats with fictional numbers +- IBANs/Account numbers: Generate format-compliant fictional numbers +- Dates: Use different dates maintaining logical consistency +- Amounts: Generate different values appropriate to context +- IDs/References: Create format-matching fictional identifiers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Layout/structure matches seed documents +- [ ] Typography and colors replicated +- [ ] ALL data is different from seed (no copied info) +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] Data formats are culturally appropriate + +Generate {num_solutions} cloned {doc_type} documents with new data. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined4/seed-based.txt b/data/prompt_templates/ClaudeRefined4/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..985e165a782f5848155ec49288544652368759cd --- /dev/null +++ b/data/prompt_templates/ClaudeRefined4/seed-based.txt @@ -0,0 +1,71 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' and apply no styles to this class +- Distinguish between different sizes of handwriting using classes 'hw-size1', 'hw-size2' which are in line with realistic handwriting and dependent on the context +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people's handwriting on the same document + +## Visual Placeholders (if document type requires) +- Include placeholders for non-text visual elements as JSON in `` tag. +- Describe type (signature/logo/stamp/barcode/photo/chart/etc.) and content (actual content) +- Describe placement of each visual element with appropriate dimensions and y-rotation +- Examples: `[{"type": "signature", "content": "John Doe", "x0": 105, "x1": 116, "y0": 82, "y1": 102, "rotation": -4}, ...]` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined5/seed-based.txt b/data/prompt_templates/ClaudeRefined5/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..0d87781ea0b394f199ee0b01c8d3a1705c29c166 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined5/seed-based.txt @@ -0,0 +1,77 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people's handwriting on the same document +- Never include signatures as handwriting + +## Visual Placeholders (if document type requires) +- Use invisible placeholder divs with class 'visual-placeholder' +- Specify type via data-type attribute ('signature', 'stamp', 'logo', 'barcode', 'photo', 'chart', etc.) +- Add data-content attribute with actual content description +- For signatures/handwriting, add author class ('author1', 'author2', etc.) to distinguish different people +- Position naturally in document flow or use CSS positioning (absolute/relative) as appropriate +- Specify dimensions in mm/cm and rotation via inline style transform +- For overlapping elements (stamps over text), use CSS z-index and absolute positioning +- Example: `
` +- Example: `
` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined6/seed-based.txt b/data/prompt_templates/ClaudeRefined6/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..34abb8cfa0c6676642a2a3dae57dc4ed613dab8d --- /dev/null +++ b/data/prompt_templates/ClaudeRefined6/seed-based.txt @@ -0,0 +1,77 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- Never include signatures as handwriting + +## Visual Placeholders (if document type requires) +- Use invisible placeholder divs with class 'visual-placeholder' +- Specify type via data-type attribute (signature, stamp, logo, barcode, photo, chart, etc.) +- Add data-content attribute with actual content description +- For signatures, add author class ('author1', 'author2', etc.) to distinguish different people +- Position naturally in document flow or use CSS positioning (absolute/relative) as appropriate +- Specify dimensions in mm/cm and rotation via **inline** style transform +- For overlapping elements (stamps over text), use CSS z-index and absolute positioning +- Example: `
` +- Example: `
` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined7/seed-based.txt b/data/prompt_templates/ClaudeRefined7/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..9016855e8c2b98941ccbee9dd1a1619d64daf01d --- /dev/null +++ b/data/prompt_templates/ClaudeRefined7/seed-based.txt @@ -0,0 +1,78 @@ +You are an AI creating culturally authentic HTML documents based on {num_seed_images} seed images of **{doc_type}** documents. + +# Cultural Variations +Seed images may show regional differences: language/terminology, date/number/currency formats, layout preferences, legal disclaimers, typography standards. + +# Task: Generate {num_solutions} unique HTML documents + +## Requirements +1. **Cultural Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Required Fields**: {required_sections} +3. **Format**: Single-page, dimensions appropriate to document type +4. **Language**: {language} +5. **Background**: {background_requirements} +6. **Uniqueness**: 70%+ different in code, styling, content +7. **Static Only**: No animations, transitions, or dynamic effects + +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Technical +- Wrap each in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Specify page size via `@media print { @page { size: ... } }` in CSS and use standard sizes when appropiate + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply generously increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- Never include signatures as handwriting + +## Visual Placeholders (if document type requires) +- Use invisible placeholder divs with class 'visual-placeholder' +- Specify type via data-type attribute (signature, stamp, logo, barcode, photo, chart, etc.) +- Add data-content attribute with actual content description +- For signatures, add author class ('author1', 'author2', etc.) to distinguish different people +- Position naturally in document flow or use CSS positioning (absolute/relative) as appropriate +- Specify dimensions in mm/cm +- For overlapping elements (e.g. stamps over text), use CSS z-index and absolute positioning +- Example: `
` +- Example: `
` + +## Structural Elements (analyze seed images for) +Headers/titles, content organization (tables/lists/paragraphs), data hierarchies, labels/captions, numerical data/dates/references, visual elements (charts/diagrams), footers + +## Additional Requirements +{user_descriptions} + +## Output Format +``` +1. + + ...complete document... + + +2. + + ...complete document... + + +... +``` + +## Quality Checklist +- [ ] Authentic cultural characteristics +- [ ] All required sections: {required_sections} +- [ ] Static styling only +- [ ] Single-page format +- [ ] {language} language +- [ ] Ground truth JSON included +- [ ] 70%+ unique + +Generate {num_solutions} distinct {doc_type} documents. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined8/seed-based.txt b/data/prompt_templates/ClaudeRefined8/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..2c92de2b2134d09db081d31545dc2511579008d8 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined8/seed-based.txt @@ -0,0 +1,60 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Specify page size via `@media print { @page { size: ... } }` and also `body` such that the content looks the same in browser and when printed +- In CSS use standard sizes when appropriate +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply generously increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- Never include signatures as handwriting + +## Visual Placeholders +- Use `
` for non-text elements (signature, stamp, logo, barcode, photo, chart) +- Add data-content attribute with actual content description +- For signatures, add author class ('author1', 'author2', etc.) to distinguish different people +- Dimensions in mm/cm: `width:30mm;height:20mm;` +- Positioning: `position:absolute;top:50mm;right:20mm;` with `z-index` for overlays +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +- Generate ground truth as JSON in `` tag. +- For each GT entry, insert the key of the entry as the `id` attribute with the corresponding HTML element. +- Individual values MUST BE visible and found in the DOM as elements because we want to get the geometries of the values before printing. +- Example: `
Name:
` +- Example: `
Corp XY LLC
` +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with correct dimensions and minified HTML/CSS +- [ ] Content in {language} +- [ ] GT ids present HTML and GT JSON present and correctly formatted + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images in {language}. \ No newline at end of file diff --git a/data/prompt_templates/ClaudeRefined9/seed-based.txt b/data/prompt_templates/ClaudeRefined9/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..4b77f2ba4e08dadcf342bce1510e7286ed80df38 --- /dev/null +++ b/data/prompt_templates/ClaudeRefined9/seed-based.txt @@ -0,0 +1,54 @@ +You are an AI creating authentic HTML representations of documents based on seed images. +Analyze the seed images for structural and semantic content and generate authentic variations. +The generated documents will be printed. + +## Requirements +1. **Authenticity**: Reflect stylistic elements from seed images without copying text/layouts verbatim +2. **Format**: Single-page documents with dimensions appropriate to the document type +3. **Language**: {language} +4. **Static Only**: No animations, transitions, or dynamic effects + +## Technical +- Wrap each document in `...` tags, numbered sequentially +- Static CSS only for single-page layout +- Generate only minified CSS, HTML, JS. + +## Content Guidelines +**DO**: Adapt cultural elements, vary layouts/colors/typography, use static styling +**DON'T**: Copy text/code blocks, reuse identical sections, include dynamic effects + +## Handwritten Fields (if document type requires) +- Mark with class 'handwritten' +- Apply generously increased size to 'handwritten', in line with realistic handwriting +- Assign author ID via class ('author1', 'author2', etc.) to distinguish different people +- Never include signatures as handwriting + +## Visual Placeholders +- Use `
` for non-text elements (signature, stamp, logo, barcode, photo, chart) +- Add data-content attribute with actual content description +- For signatures, add author class ('author1', 'author2', etc.) to distinguish different people +- Dimensions in mm/cm: `width:30mm;height:20mm;` +- Positioning: `position:absolute;top:50mm;right:20mm;` with `z-index` for overlays +- Example: `
` +- Example: `
` + +## Output Format +Generate minified HTML like this: +``` +1. +2. +... +``` +## Ground Truth +Generate ground truth as JSON in `` tag. +Ground truth specification: {gt_type} +Ground truth must follow the format: {gt_format} + +## Quality Checklist +- [ ] Authentic variations without verbatim copying from seed images +- [ ] Static styling only (no animations or dynamic effects) +- [ ] Single-page format with minified HTML/CSS/JS +- [ ] Content in {language} +- [ ] GT JSON present and correctly formatted + +Generate {num_solutions} distinct {doc_type} documents based on {num_seed_images} seed images. \ No newline at end of file diff --git a/data/prompt_templates/DocGenie/seed-based.txt b/data/prompt_templates/DocGenie/seed-based.txt new file mode 100755 index 0000000000000000000000000000000000000000..8c0911a092f031d8914d4831e9a7dfd33944a1d0 --- /dev/null +++ b/data/prompt_templates/DocGenie/seed-based.txt @@ -0,0 +1,39 @@ +You are an AI specialized in generating unique HTML +documents based on multiple scanned images of realworld examples. You have been provided with distinct +sample images, each from a different cultural or regional +background. You have been provided seed images of +{doc type}, each originating from different cultural or regional contexts. For example, some might feature: +• Local languages or regional disclaimers +• Different date formats (e.g., dd/mm/yyyy vs. mm/dd/yyyy) +• Unique currency or numbering formats +• Varying layout norms (positions of key fields, disclaimers, official stamps, etc.) +Now, please generate {num solutions} unique HTML +documents that: +1. Strictly reflect the overall style, layout, and cultural +cues found in these samples, but do NOT copy any text, +disclaimers, or layout verbatim from the samples. +2. Include any essential mandatory fields: {sections}. +3. Maintain an A4 size format for printing (using @page +{{ size: A4; }} or similar CSS). +4. Maintain a {background requirements}. +5. Avoid copy-pasting or reusing large chunks of HTML, +CSS, or disclaimers—each document must be at least +70% different in code and text than the others. +6. Strictly wrap each new document in +... tags, for example: +1. ...Solution #1... +2. ...Solution #2... +... +{num solutions}. ...Solution +#{num solutions}... +Additional Requirements: {user descriptions} +Notes: +• Pay close attention to cultural/regional differences seen +in the seed images (e.g., language, format, disclaimers). +• Feel free to creatively adapt or combine stylistic cues +from the seeds, as long as the end result looks authentic +for that cultural context. +• Do NOT directly copy-paste text or entire code blocks +from any single seed image or across these new solutions. +Now please generate the {num solutions} distinct +{doc type} documents. diff --git a/data/prompt_templates/DocGenie/seed-free.txt b/data/prompt_templates/DocGenie/seed-free.txt new file mode 100755 index 0000000000000000000000000000000000000000..fbdc60bdb2ca1ee2b1904774c481d90d0fe02966 --- /dev/null +++ b/data/prompt_templates/DocGenie/seed-free.txt @@ -0,0 +1,24 @@ +You are an AI specialized in generating multiple unique +HTML documents in one response. Please create +{num solutions} unique HTML documents representing +{doc type}. +Each solution must: +1. Include all mandatory fields: {sections}. +2. Be formatted so it could print on A4 (e.g., use @page +{{ size: A4; }} in your CSS). +3. Show a significantly different layout, styling, and textual content from every other solution. +4. Maintain a {background requirements}. +5. Avoid copy-pasting or reusing large chunks of HTML, +CSS, or disclaimers—each document must be at least +70% different in code and text than the others. +6. Wrap each complete document between +and tags, labeled as: +1. ...Solution #1... +2. ...Solution #2... +... +{num solutions}. ...Solution +#{num solutions}... +Do not provide additional commentary or references to the +other solutions within each HTML. +Now generate the {num solutions} distinct {doc type} +documents. diff --git a/data/syn_dataset_definitions/cord_alpha=0.5.yaml b/data/syn_dataset_definitions/cord_alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..56477fc1cb5e7283ad9e4c755a9d66e190947ade --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=0.5.yaml @@ -0,0 +1,235 @@ +name: "cord_alpha=0.5" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/cord_alpha=0.5_v1.yaml b/data/syn_dataset_definitions/cord_alpha=0.5_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..97c8b163be1e4465a3d47ade8cd74a2de3a812cb --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=0.5_v1.yaml @@ -0,0 +1,236 @@ +name: "cord_alpha=0.5_v1" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/cord_alpha=0.75.yaml b/data/syn_dataset_definitions/cord_alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..17a90c6b18a4d805e8a7e45dfdb226434d95ca46 --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=0.75.yaml @@ -0,0 +1,235 @@ +name: "cord_alpha=0.75" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/cord_alpha=0.75_v1.yaml b/data/syn_dataset_definitions/cord_alpha=0.75_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8c8349967f803443cb53f336fbb11aed6fdb7e01 --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=0.75_v1.yaml @@ -0,0 +1,235 @@ +name: "cord_alpha=0.75_v1" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/cord_alpha=1.0.yaml b/data/syn_dataset_definitions/cord_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..564c6ef6081edeebaf80a4a78c0ed17f8fa98f14 --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=1.0.yaml @@ -0,0 +1,235 @@ +name: "cord_alpha=1.0" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB.NM #MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB.UNITPRICE #MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB.CNT # MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB.PRICE #MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/cord_alpha=1.0_v1.yaml b/data/syn_dataset_definitions/cord_alpha=1.0_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..f4c0bc3fef2d8b5c777b2b280aeb2891a5a0899f --- /dev/null +++ b/data/syn_dataset_definitions/cord_alpha=1.0_v1.yaml @@ -0,0 +1,235 @@ +name: "cord_alpha=1.0_v1" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "cord" +documents_count: 1200 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT + +label_mapping: + MENU_NM: MENU.NM + MENU_NUM: MENU.NUM + MENU_UNITPRICE: MENU.UNITPRICE + MENU_CNT: MENU.CNT + MENU_DISCOUNTPRICE: MENU.DISCOUNTPRICE + MENU_PRICE: MENU.PRICE + MENU_ITEMSUBTOTAL: MENU.ITEMSUBTOTAL + MENU_VATYN: MENU.VATYN + MENU_ETC: MENU.ETC + MENU_SUB_NM: MENU.SUB_NM + MENU_SUB_UNITPRICE: MENU.SUB_UNITPRICE + MENU_SUB_CNT: MENU.SUB_CNT + MENU_SUB_PRICE: MENU.SUB_PRICE + MENU_SUB_ETC: MENU.SUB_ETC + VOID_MENU_NM: VOID_MENU.NM + VOID_MENU_PRICE: VOID_MENU.PRICE + SUB_TOTAL_SUBTOTAL_PRICE: SUB_TOTAL.SUBTOTAL_PRICE + SUB_TOTAL_DISCOUNT_PRICE: SUB_TOTAL.DISCOUNT_PRICE + SUB_TOTAL_SERVICE_PRICE: SUB_TOTAL.SERVICE_PRICE + SUB_TOTAL_OTHERSVC_PRICE: SUB_TOTAL.OTHERSVC_PRICE + SUB_TOTAL_TAX_PRICE: SUB_TOTAL.TAX_PRICE + SUB_TOTAL_ETC: SUB_TOTAL.ETC + TOTAL_TOTAL_PRICE: TOTAL.TOTAL_PRICE + TOTAL_TOTAL_ETC: TOTAL.TOTAL_ETC + TOTAL_CASHPRICE: TOTAL.CASHPRICE + TOTAL_CHANGEPRICE: TOTAL.CHANGEPRICE + TOTAL_CREDITCARDPRICE: TOTAL.CREDITCARDPRICE + TOTAL_EMONEYPRICE: TOTAL.EMONEYPRICE + TOTAL_MENUTYPE_CNT: TOTAL.MENUTYPE_CNT + TOTAL_MENUQTY_CNT: TOTAL.MENUQTY_CNT + +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_CLS.yaml b/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_CLS.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a584e090f06d98708333198de6d6102016ff0ff3 --- /dev/null +++ b/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_CLS.yaml @@ -0,0 +1,40 @@ +name: "doclaynet4k_alpha=1.0_CLS" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "doclaynet_4k_cls" +documents_count: 4500 +valid_labels: + - financial_reports + - scientific_articles + - laws_and_regulations + - government_tenders + - manuals + - patents +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "single A4 pages out of diverse business and technical" + language: "English" + gt_type: | + document class label + * financial_reports + * scientific_articles + * laws_and_regulations + * government_tenders + * manuals + * patents + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_DLA.yaml b/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_DLA.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d672f79881d2d4d10c146a48319a43ce4dc533b2 --- /dev/null +++ b/data/syn_dataset_definitions/doclaynet4k_alpha=1.0_DLA.yaml @@ -0,0 +1,60 @@ +name: "doclaynet4k_alpha=1.0_DLA" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "doclaynet_4k_dla" +documents_count: 4500 +valid_labels: + - LE-CAPTION + - LE-FOOTNOTE + - LE-FORMULA + - LE-LIST-ITEM + - LE-PAGE-FOOTER + - LE-PAGE-HEADER + - LE-PICTURE + - LE-SECTION-HEADER + - LE-TABLE + - LE-TEXT + - LE-TITLE +label_mapping: + LE-CAPTION: Caption + LE-FOOTNOTE: Footnote + LE-FORMULA: Formula + LE-LIST-ITEM: List-item + LE-PAGE-FOOTER: Page-footer + LE-PAGE-HEADER: Page-header + LE-PICTURE: Picture + LE-SECTION-HEADER: Section-header + LE-TABLE: Table + LE-TEXT: Text + LE-TITLE: "Title " +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of diverse business and technical" + language: "English" + gt_type: | + * "LE-CAPTION": Text that accompanies and explains figures, tables, or other visual elements, typically appearing above or below the referenced element. + * "LE-FOOTNOTE": Supplementary notes or citations placed at the bottom of a page, providing additional context or references to the main text, distinct from footers. + * "LE-FORMULA": Mathematical equations, chemical formulas, or symbolic expressions, whether displayed inline or as standalone elements. + * "LE-LIST-ITEM": Individual items within enumerated, bulleted, or definition lists, with each list item annotated separately rather than as a unified list structure. + * "LE-PAGE-FOOTER": Recurring content at the bottom of pages such as page numbers, copyright notices, document identifiers, or footer text. + * "LE-PAGE-HEADER": Recurring content at the top of pages including running headers, document titles, chapter names. + * "LE-PICTURE": Photographs, diagrams, charts, graphs, illustrations, and other visual content excluding tables. + * "LE-SECTION-HEADER": Section and subsection headings. + * "LE-TABLE": Complete table structure including grid content, inline captions, and column/row headers as a unified element. + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, definitions, descriptions, and other primary textual content. + * "LE-TITLE": The main document title appearing prominently at the beginning of the document, distinct from section headers. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/doclaynet_alpha=1.0_CLS.yaml b/data/syn_dataset_definitions/doclaynet_alpha=1.0_CLS.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e741acc279eca821ab3759ef3bc290468e91d052 --- /dev/null +++ b/data/syn_dataset_definitions/doclaynet_alpha=1.0_CLS.yaml @@ -0,0 +1,40 @@ +name: "doclaynet_alpha=1.0_CLS" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "doclaynet" +documents_count: 4500 +valid_labels: + - financial_reports + - scientific_articles + - laws_and_regulations + - government_tenders + - manuals + - patents +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "single A4 pages out of diverse business and technical" + language: "English" + gt_type: | + document class label + * financial_reports + * scientific_articles + * laws_and_regulations + * government_tenders + * manuals + * patents + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/doclaynet_alpha=1.0_DLA.yaml b/data/syn_dataset_definitions/doclaynet_alpha=1.0_DLA.yaml new file mode 100755 index 0000000000000000000000000000000000000000..704cc415d61b48f0ac721ea49b2249a46036ac95 --- /dev/null +++ b/data/syn_dataset_definitions/doclaynet_alpha=1.0_DLA.yaml @@ -0,0 +1,49 @@ +name: "doclaynet_alpha=1.0_DLA" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "doclaynet" +documents_count: 4500 +valid_labels: + - LE-CAPTION + - LE-FOOTNOTE + - LE-FORMULA + - LE-LIST-ITEM + - LE-PAGE-FOOTER + - LE-PAGE-HEADER + - LE-PICTURE + - LE-SECTION-HEADER + - LE-TABLE + - LE-TEXT + - LE-TITLE +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of diverse business and technical" + language: "English" + gt_type: | + * "LE-CAPTION": Text that accompanies and explains figures, tables, or other visual elements, typically appearing above or below the referenced element. + * "LE-FOOTNOTE": Supplementary notes or citations placed at the bottom of a page, providing additional context or references to the main text, distinct from footers. + * "LE-FORMULA": Mathematical equations, chemical formulas, or symbolic expressions, whether displayed inline or as standalone elements. + * "LE-LIST-ITEM": Individual items within enumerated, bulleted, or definition lists, with each list item annotated separately rather than as a unified list structure. + * "LE-PAGE-FOOTER": Recurring content at the bottom of pages such as page numbers, copyright notices, document identifiers, or footer text. + * "LE-PAGE-HEADER": Recurring content at the top of pages including running headers, document titles, chapter names. + * "LE-PICTURE": Photographs, diagrams, charts, graphs, illustrations, and other visual content excluding tables. + * "LE-SECTION-HEADER": Section and subsection headings. + * "LE-TABLE": Complete table structure including grid content, inline captions, and column/row headers as a unified element. + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, definitions, descriptions, and other primary textual content. + * "LE-TITLE": The main document title appearing prominently at the beginning of the document, distinct from section headers. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa.yaml b/data/syn_dataset_definitions/docvqa.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d01d9234a697f4925263966fe1a164452da20ca3 --- /dev/null +++ b/data/syn_dataset_definitions/docvqa.yaml @@ -0,0 +1,24 @@ +name: "docvqa" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=0.5.yaml b/data/syn_dataset_definitions/docvqa_alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e85ff874ab12aa07a541c08464fba387c9efce65 --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=0.5.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=0.5" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=0.5_v1.yaml b/data/syn_dataset_definitions/docvqa_alpha=0.5_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b4969951fe24990db65740d49e1e5aeb40bd79af --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=0.5_v1.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=0.5_v1" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=0.75.yaml b/data/syn_dataset_definitions/docvqa_alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..df32b8ec9bedd893272cf1c2dc07abc0a7efc45d --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=0.75.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=0.75" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=0.75_v1.yaml b/data/syn_dataset_definitions/docvqa_alpha=0.75_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0faec31c7ec5b5ddae500e8abfaff2dbbc284de1 --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=0.75_v1.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=0.75_v1" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=1.0.yaml b/data/syn_dataset_definitions/docvqa_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a300707ba8462a2b661629dd466a475c170a5011 --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=1.0.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=1.0" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/docvqa_alpha=1.0_v1.yaml b/data/syn_dataset_definitions/docvqa_alpha=1.0_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..c4600f0d5c0430203a3f34883c18c4d09d18350a --- /dev/null +++ b/data/syn_dataset_definitions/docvqa_alpha=1.0_v1.yaml @@ -0,0 +1,24 @@ +name: "docvqa_alpha=1.0_v1" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_docvqa" +documents_count: 10000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Multiple questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/funsd_alpha=1.0.yaml b/data/syn_dataset_definitions/funsd_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..08368cca7e58a7a63bdf5010f265e0dc704d213b --- /dev/null +++ b/data/syn_dataset_definitions/funsd_alpha=1.0.yaml @@ -0,0 +1,133 @@ +name: "funsd_alpha=1.0" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "funsd" +documents_count: 300 +valid_labels: + - HEADER + - QUESTION + - ANSWER +label_mapping: +valid_secondary_labels: + - PAIR_1 + - PAIR_2 + - PAIR_3 + - PAIR_4 + - PAIR_5 + - PAIR_6 + - PAIR_7 + - PAIR_8 + - PAIR_9 + - PAIR_10 + - PAIR_11 + - PAIR_12 + - PAIR_13 + - PAIR_14 + - PAIR_15 + - PAIR_16 + - PAIR_17 + - PAIR_18 + - PAIR_19 + - PAIR_20 + - PAIR_21 + - PAIR_22 + - PAIR_23 + - PAIR_24 + - PAIR_25 + - PAIR_26 + - PAIR_27 + - PAIR_28 + - PAIR_29 + - PAIR_30 + - PAIR_31 + - PAIR_32 + - PAIR_33 + - PAIR_34 + - PAIR_35 + - PAIR_36 + - PAIR_37 + - PAIR_38 + - PAIR_39 + - PAIR_40 + - PAIR_41 + - PAIR_42 + - PAIR_43 + - PAIR_44 + - PAIR_45 + - PAIR_46 + - PAIR_47 + - PAIR_48 + - PAIR_49 + - PAIR_50 + - PAIR_51 + - PAIR_52 + - PAIR_53 + - PAIR_54 + - PAIR_55 + - PAIR_56 + - PAIR_57 + - PAIR_58 + - PAIR_59 + - PAIR_60 + - PAIR_61 + - PAIR_62 + - PAIR_63 + - PAIR_64 + - PAIR_65 + - PAIR_66 + - PAIR_67 + - PAIR_68 + - PAIR_69 + - PAIR_70 + - PAIR_71 + - PAIR_72 + - PAIR_73 + - PAIR_74 + - PAIR_75 + - PAIR_76 + - PAIR_77 + - PAIR_78 + - PAIR_79 + - PAIR_80 + - PAIR_81 + - PAIR_82 + - PAIR_83 + - PAIR_84 + - PAIR_85 + - PAIR_86 + - PAIR_87 + - PAIR_88 + - PAIR_89 + - PAIR_90 + - PAIR_91 + - PAIR_92 + - PAIR_93 + - PAIR_94 + - PAIR_95 + - PAIR_96 + - PAIR_97 + - PAIR_98 + - PAIR_99 + - PAIR_100 + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "form" + language: "English" + gt_type: | + keys and their values structured as QA pairs + * "HEADER": The header of the question answer pair. + * "QUESTION": The question i.e. a key. + * "ANSWER": The answer i.e a value. + gt_format: | + Group individual annotations in groups using the enumerator class PAIR_ and a annotation class from the list above (e.g. "PAIR_1 QUESTION", "PAIR_1 ANSWER", "PAIR_2 HEADER", ...). + Ensure to annotate exact using spans, i.e. "QUESTION" element should not contain "ANSWER". + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/icdar2019_alpha=1.0.yaml b/data/syn_dataset_definitions/icdar2019_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..268048251ed1767b29dbdaa91e4f91356d05e9bc --- /dev/null +++ b/data/syn_dataset_definitions/icdar2019_alpha=1.0.yaml @@ -0,0 +1,27 @@ +name: "icdar2019_alpha=1.0" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "icdar2019" +documents_count: 1600 +valid_labels: + - LE-TABLE +label_mapping: + LE-TABLE: table +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of diverse modern digital-born and historical archival scanned" + language: "English" + gt_type: | + * "LE-TABLE": Any tabular structure containing data organized in rows and columns. Include the complete table region from border to border. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/kleister_alpha=1.0.yaml b/data/syn_dataset_definitions/kleister_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3c43337a33cae3f3d4477805b7b685140063027b --- /dev/null +++ b/data/syn_dataset_definitions/kleister_alpha=1.0.yaml @@ -0,0 +1,41 @@ +name: "kleister_alpha=1.0" +task: "KIE" +dataloader_model_task_as: "QA" +base_dataset_name: "ex_klc" +documents_count: 4000 +valid_labels: + - address__post_town + - address__postcode + - address__street_line + - charity_name + - charity_number + - income_annually_in_british_pounds + - report_date + - spending_annually_in_british_pounds +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "UK charity annual financial report" + language: "English" + gt_type: | + keys and their values (if applicable, provide as plaintext values from the document): + * "address__post_town": Post town of the address of the charitable organization. + * "address__postcode": Postcode of the address of the charitable organization. + * "address__street_line": Street line of the address of the charitable organization. + * "charity_name": The name of the charitable organization. + * "charity_number": The registered number of the charitable organization. + * "income_annually_in_british_pounds": The annual income in British Pounds of the charitable organization. + * "report_date": The reporting date of the annual document of the charitable organization. + * "spending_annually_in_british_pounds": The annual spending in British Pounds of the charitable organization. + gt_format: '{"address__post_town": "", "spending_annually_in_british_pounds": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/cord.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/cord.yaml new file mode 100755 index 0000000000000000000000000000000000000000..163946dbc2158f3fff6154388ffe18514685689e --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/cord.yaml @@ -0,0 +1,89 @@ +name: "cord" +task: "KIE" +base_dataset_name: "cord" +documents_count: 1000 +valid_labels: + - MENU.NM + - MENU.NUM + - MENU.UNITPRICE + - MENU.CNT + - MENU.DISCOUNTPRICE + - MENU.PRICE + - MENU.ITEMSUBTOTAL + - MENU.VATYN + - MENU.ETC + - MENU.SUB.NM + - MENU.SUB.UNITPRICE + - MENU.SUB.CNT + - MENU.SUB.PRICE + - MENU.SUB.ETC + - VOID_MENU.NM + - VOID_MENU.PRICE + - SUB_TOTAL.SUBTOTAL_PRICE + - SUB_TOTAL.DISCOUNT_PRICE + - SUB_TOTAL.SERVICE_PRICE + - SUB_TOTAL.OTHERSVC_PRICE + - SUB_TOTAL.TAX_PRICE + - SUB_TOTAL.ETC + - TOTAL.TOTAL_PRICE + - TOTAL.TOTAL_ETC + - TOTAL.CASHPRICE + - TOTAL.CHANGEPRICE + - TOTAL.CREDITCARDPRICE + - TOTAL.EMONEYPRICE + - TOTAL.MENUTYPE_CNT + - TOTAL.MENUQTY_CNT + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU.NM": The menu item name. + * "MENU.NUM": The menu item number or identifier. + * "MENU.UNITPRICE": The price per unit of the menu item. + * "MENU.CNT": The quantity or count of the menu item. + * "MENU.DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU.PRICE": The final price of the menu item. + * "MENU.ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU.VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU.ETC": Other miscellaneous menu item information. + * "MENU.SUB.NM": The name of a sub-item or modifier. + * "MENU.SUB.UNITPRICE": The price per unit of the sub-item. + * "MENU.SUB.CNT": The quantity of the sub-item. + * "MENU.SUB.PRICE": The price of the sub-item. + * "MENU.SUB.ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU.NM": The name of a cancelled or voided item. + * "VOID_MENU.PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL.SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL.DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL.SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL.OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL.TAX_PRICE": The tax amount. + * "SUB_TOTAL.ETC": Other subtotal information. + * "TOTAL.TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL.TOTAL_ETC": Other total-related information. + * "TOTAL.CASHPRICE": The amount paid in cash. + * "TOTAL.CHANGEPRICE": The change given back to the customer. + * "TOTAL.CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL.EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL.MENUTYPE_CNT": The count of different menu item types. + * "TOTAL.MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Up to 8 menu items and the receipt data as a JSON object { + "MENU_1": {"MENU.NM": "", "MENU.NUM": "", ...}, + "MENU_2": {"MENU.NM": "", "MENU.NUM": "", ...}, + ..., + "VOID_MENU": {"VOID_MENU.NM": "", "VOID_MENU.PRICE": ""}, + "GENERIC": {"SUB_TOTAL.SUBTOTAL_PRICE": "", ..., "TOTAL.TOTAL_PRICE": ...} + } +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/doclaynet.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/doclaynet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..13ea3e5f1ca88611e3e79b965ee9bc549a50050c --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/doclaynet.yaml @@ -0,0 +1,45 @@ +name: "doclaynet" +task: "DLA" +base_dataset_name: "doclaynet" +documents_count: 10 +valid_labels: + - LE-CAPTION + - LE-FOOTNOTE + - LE-FORMULA + - LE-LIST-ITEM + - LE-PAGE-FOOTER + - LE-PAGE-HEADER + - LE-PICTURE + - LE-SECTION-HEADER + - LE-TABLE + - LE-TEXT + - LE-TITLE + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of diverse business and technical" + language: "English" + gt_type: | + Give each applicable element in HTML a layout class from the list below to uniquely identify its label: + * "LE-CAPTION": Text that accompanies and explains figures, tables, or other visual elements, typically appearing above or below the referenced element. + * "LE-FOOTNOTE": Supplementary notes or citations placed at the bottom of a page, providing additional context or references to the main text, distinct from footers. + * "LE-FORMULA": Mathematical equations, chemical formulas, or symbolic expressions, whether displayed inline or as standalone elements. + * "LE-LIST-ITEM": Individual items within enumerated, bulleted, or definition lists, with each list item annotated separately rather than as a unified list structure. + * "LE-PAGE-FOOTER": Recurring content at the bottom of pages such as page numbers, copyright notices, document identifiers, or footer text. + * "LE-PAGE-HEADER": Recurring content at the top of pages including running headers, document titles, chapter names. + * "LE-PICTURE": Photographs, diagrams, charts, graphs, illustrations, and other visual content excluding tables. + * "LE-SECTION-HEADER": Section and subsection headings. + * "LE-TABLE": Complete table structure including grid content, inline captions, and column/row headers as a unified element. + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, definitions, descriptions, and other primary textual content. + * "LE-TITLE": The main document title appearing prominently at the beginning of the document, distinct from section headers. + gt_format: 'Empty JSON object: {}' + +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: image +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/docvqa.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/docvqa.yaml new file mode 100755 index 0000000000000000000000000000000000000000..87ba0d4e8a76c87555378271680dbf0d6b1ce488 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/docvqa.yaml @@ -0,0 +1,19 @@ +name: "docvqa" +task: "QA" +base_dataset_name: "ex_docvqa" +documents_count: 1000 # 10.194 Documents in DocVQA train, 39,461 QA pairs +valid_labels: + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/funsd.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/funsd.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7b08ace6c32232d726b3b2800dbe420ee373c7d8 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/funsd.yaml @@ -0,0 +1,28 @@ +name: "funsd" +task: "QA" +base_dataset_name: "funsd" +documents_count: 300 +valid_labels: +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "form" + language: "English" + gt_type: | + keys and their values structured as QA pairs + * "HEADER": The header of the question answer pair. + * "QUESTION": The question i.e. a key. + * "ANSWER": The answer i.e, a value. + gt_format: | + Up to 8 pairs as a JSON object { + "PAIR_1": {"header": "
", "question": "", "answer": ""}, + "PAIR_2": {"header": "
", "question": "", "answer": ""}, + ... + } + + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/icdar2019.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/icdar2019.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7e8b12404d1fb0551366803399633fe5784e726c --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/icdar2019.yaml @@ -0,0 +1,25 @@ +name: "icdar2019" +task: "DLA" +base_dataset_name: "icdar2019" +documents_count: 10 +valid_labels: + - LE-TABLE + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of diverse modern digital-born and historical archival scanned" + language: "English" + gt_type: | + Give each applicable element in HTML a layout class from the list below to uniquely identify its label: + * "LE-TABLE": Any tabular structure containing data organized in rows and columns. Include the complete table region from border to border. + gt_format: 'Empty JSON object: {}' + +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: image +alpha: 1 +max_seed_pool: -1 + +# Issues: +# TODO: \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/publaynet.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/publaynet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..6fdccad3871b4bdc9cc0c72c5be35882cb3109cc --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/publaynet.yaml @@ -0,0 +1,30 @@ +name: "publaynet" +task: "DLA" +base_dataset_name: "publaynet" +documents_count: 10 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + Give each applicable element in HTML a layout class from the list below to uniquely identify its label: + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises article titles and standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: 'Empty JSON object: {}' + +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/rvlcdip.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/rvlcdip.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e5756d2d6451afc201a9cb5a74bf4c7b3e03cd87 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/rvlcdip.yaml @@ -0,0 +1,52 @@ +name: "rvlcdip" +task: "CLASSIFICATION" +base_dataset_name: "rvlcdip" +documents_count: 10 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/sroie.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/sroie.yaml new file mode 100755 index 0000000000000000000000000000000000000000..18680f4cd72e6fd69399e291f9b77adc3443ce5c --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/sroie.yaml @@ -0,0 +1,32 @@ +name: "sroie" +task: "KIE" +base_dataset_name: "sroie" +documents_count: 1000 +valid_labels: + - COMPANY + - DATE + - ADDRESS + - TOTAL + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: 'JSON object {"COMPANY": "", "DATE": "", "ADDRESS": "", "TOTAL": ""}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/ClaudeRefined11/tobacco3482.yaml b/data/syn_dataset_definitions/legacy/ClaudeRefined11/tobacco3482.yaml new file mode 100755 index 0000000000000000000000000000000000000000..88fc8ea9390ea663bcbf21e5b4ecd5a66a64f89b --- /dev/null +++ b/data/syn_dataset_definitions/legacy/ClaudeRefined11/tobacco3482.yaml @@ -0,0 +1,44 @@ +name: "tobacco3482" +task: "CLASSIFICATION" +base_dataset_name: "tobacco3482" +documents_count: 1000 +valid_labels: + - ADVERTISEMENT + - EMAIL + - FORM + - LETTER + - MEMO + - NEWS_ARTICLE + - NOTE + - REPORT + - RESUME + - SCIENTIFIC + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "legal and corporate" + language: "English" + gt_type: | + document class labels: + * ADVERTISEMENT: Advertisement + * EMAIL: Email + * FORM: Form + * LETTER: Letter + * MEMO: Memo + * NEWS_ARTICLE: News article + * NOTE: Note/handwritten note + * REPORT: Report + * RESUME: Resume/CV + * SCIENTIFIC: Scientific publication + gt_format: 'JSON object {"label": ""}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: start | 38.61 USD +# ICVPR: end | 50.37 USD +# 936 samples @ 11.76 USD => 1.25 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-handwritten-sizes4.yaml b/data/syn_dataset_definitions/legacy/docvqa-handwritten-sizes4.yaml new file mode 100755 index 0000000000000000000000000000000000000000..896cd126f1e76400132082ab4923e7480408057c --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-handwritten-sizes4.yaml @@ -0,0 +1,20 @@ +name: "docvqa-handwritten-sizes4" +documents_count: 10 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_folder: "docvqa-handwritten-examples" +seed_images_count: 1 +seed_image_max_width: 512 +seed_image_quality: 80 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-pipelinetest.yaml b/data/syn_dataset_definitions/legacy/docvqa-pipelinetest.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b30d37b86b73d7c270de3c1a0ec3d68b00a943f4 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-pipelinetest.yaml @@ -0,0 +1,21 @@ +name: "docvqa-pipelinetest" +base_dataset_name: "ex_docvqa" +documents_count: 100 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +sampling_strategy: "proportional_cluster_size_sampling" \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test-alpha=-1.yaml b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=-1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8cf71906cc43e233c6c83624a69433bc2f0416d6 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=-1.yaml @@ -0,0 +1,22 @@ +name: "docvqa-test-alpha=-1" +base_dataset_name: "ex_docvqa" +documents_count: 1 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: -1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.5.yaml b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7f5d1ad9d170f9d172ab5ad7213be22453c67f88 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.5.yaml @@ -0,0 +1,22 @@ +name: "docvqa-test-alpha=0.5" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.75.yaml b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4dd7d39c594f76a82c01b9fd5f19cd05c419815e --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.75.yaml @@ -0,0 +1,22 @@ +name: "docvqa-test-alpha=0.75" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.yaml b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..830a206982556a4dbf750e0ebd95e84f423e5a1d --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=0.yaml @@ -0,0 +1,22 @@ +name: "docvqa-test-alpha=0" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test-alpha=1.yaml b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..435ac48602e6c965d48be81c6c18eddb2028ae16 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test-alpha=1.yaml @@ -0,0 +1,22 @@ +name: "docvqa-test-alpha=1" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-test.yaml b/data/syn_dataset_definitions/legacy/docvqa-test.yaml new file mode 100755 index 0000000000000000000000000000000000000000..c241265205f64c3e77f1773e02273d7ff3c49b85 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-test.yaml @@ -0,0 +1,21 @@ +name: "docvqa-test" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +prompt_template: "ClaudeRefined7" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-viselems.yaml b/data/syn_dataset_definitions/legacy/docvqa-viselems.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7c29e7c4bb4bfcbe802f90bad2e5a0f0d1fc9dbe --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-viselems.yaml @@ -0,0 +1,21 @@ +name: "docvqa-viselems" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +prompt_template: "ClaudeRefined10" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/docvqa-viselems2.yaml b/data/syn_dataset_definitions/legacy/docvqa-viselems2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9f8248735bf2f1929786613922dfdf31657d79ba --- /dev/null +++ b/data/syn_dataset_definitions/legacy/docvqa-viselems2.yaml @@ -0,0 +1,18 @@ +name: "docvqa-viselems2" +task: "QA" +base_dataset_name: "ex_docvqa" +documents_count: 50 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_count: 10 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/sroie-test.yaml b/data/syn_dataset_definitions/legacy/sroie-test.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9538dbc43640bde0d5c5158dd747e7fab6a7bce2 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/sroie-test.yaml @@ -0,0 +1,27 @@ +name: "sroie-test" +task: "KIE" +base_dataset_name: "sroie" +documents_count: 100 + +prompt_template: "ClaudeRefined11" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: 'JSON object {"COMPANY": "", "DATE": "", "ADDRESS": "", "TOTAL": ""}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/sroie_as_annotation.yaml b/data/syn_dataset_definitions/legacy/sroie_as_annotation.yaml new file mode 100755 index 0000000000000000000000000000000000000000..41cbee70ab5b361bbc639a9bfcf5a7c078c1c817 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/sroie_as_annotation.yaml @@ -0,0 +1,34 @@ +name: "sroie" +task: "KIE" +base_dataset_name: "sroie" +documents_count: 50 +valid_labels: + - COMPANY + - DATE + - ADDRESS + - TOTAL +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: | + Ensure every label is only present once and to annotate exact using spans, e.g. "ADDRESS" element should not contain other contact info. + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-authors-visual_elements-examples_seed_based.yaml b/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-authors-visual_elements-examples_seed_based.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a0761ea7cda99c190ae82ab93954b00eac58c881 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-authors-visual_elements-examples_seed_based.yaml @@ -0,0 +1,20 @@ +name: "syn_docvqa-handwritten-authors-visual_elements-examples_seed_based" +documents_count: 100 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined2" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_folder: "docvqa-handwritten-examples" +seed_images_count: 1 +seed_image_max_width: 512 +seed_image_quality: 80 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-examples_seed_based.yaml b/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-examples_seed_based.yaml new file mode 100755 index 0000000000000000000000000000000000000000..93c5fd84c5013b4e2277d28e46c8696efbfa5d1d --- /dev/null +++ b/data/syn_dataset_definitions/legacy/syn_docvqa-handwritten-examples_seed_based.yaml @@ -0,0 +1,29 @@ +name: "syn-docvqa-handwritten-examples-seed-based" +documents_count: 100 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined1" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "Also include **handwritten textfields**, if the type of document demands it: mark these simply with the HTML class 'handwritten', otherwise apply no specific styles or fonts and treat them as usual text spans. + Analyze the seed images to identify and replicate the primary structural elements, which may include: + * Headers, titles, and document identification + * Main content organization (tables, paragraphs, lists, visual elements) + * Data relationships and hierarchical information + * Labels, captions, and descriptive text + * Numerical data, dates, and reference information + * Visual elements like charts, diagrams, or structured layouts + * Footer information, signatures, or supplementary details + * Any other document-specific organizational patterns observed" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_folder: "docvqa-handwritten-examples" +seed_images_count: 1 +seed_image_max_width: 500 +seed_image_quality: 80 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/syn_docvqa_seed_based.yaml b/data/syn_dataset_definitions/legacy/syn_docvqa_seed_based.yaml new file mode 100755 index 0000000000000000000000000000000000000000..cf894d8a287cf033f6a104801e84785090a81587 --- /dev/null +++ b/data/syn_dataset_definitions/legacy/syn_docvqa_seed_based.yaml @@ -0,0 +1,28 @@ +name: "syn-docvqa-seed-based" +documents_count: 15000 # 10.194 Documents in DocVQA train, 39,461 QA pairs + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined1" +prompt_params: + num_solutions: 3 + doc_type: "business and administrative" + language: "English" + sections: + - "N/A - replicate structural elements observed in seed images" + background_requirements: "white background" + additional_requirements: "Analyze the seed images to identify and replicate the primary structural elements, which may include: + * Headers, titles, and document identification + * Main content organization (tables, paragraphs, lists, visual elements) + * Data relationships and hierarchical information + * Labels, captions, and descriptive text + * Numerical data, dates, and reference information + * Visual elements like charts, diagrams, or structured layouts + * Footer information, signatures, or supplementary details + * Any other document-specific organizational patterns observed" + gt_type: "Up to 4 questions about each document, with their answers taken **verbatim** from the document." + gt_format: '{"": "", "": "", ...}' + +seed_images_folder: "docvqa" +seed_images_count: 10 +seed_image_max_width: 500 +seed_image_quality: 80 \ No newline at end of file diff --git a/data/syn_dataset_definitions/legacy/syn_sroie_seed_based.yaml b/data/syn_dataset_definitions/legacy/syn_sroie_seed_based.yaml new file mode 100755 index 0000000000000000000000000000000000000000..10cd8bdc1b5e8d45448e7e165089d3e3d5f09cab --- /dev/null +++ b/data/syn_dataset_definitions/legacy/syn_sroie_seed_based.yaml @@ -0,0 +1,23 @@ +name: "syn-sroie-seed-based" +documents_count: 600 + +seed_type: "seed-based" # or "seed-free" +prompt_template: "ClaudeRefined1" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + sections: + - "company" + - "date" + - "address" + - "total" + background_requirements: "white background" + additional_requirements: "None" + gt_type: "keys and their values" + gt_format: '{"company": "company value", "date": "date value", "address": "address value", "total": "total value"}' + +seed_images_folder: "sroie" +seed_images_count: 10 +seed_image_max_width: 500 +seed_image_quality: 80 \ No newline at end of file diff --git a/data/syn_dataset_definitions/publaynet_alpha=0.5.yaml b/data/syn_dataset_definitions/publaynet_alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..4fe80d9ab6b792c2aa65c084788007cf63702c54 --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=0.5.yaml @@ -0,0 +1,33 @@ +name: "publaynet_alpha=0.5" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 0.5 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_alpha=0.5_v1.yaml b/data/syn_dataset_definitions/publaynet_alpha=0.5_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..c86b6bd8bd993711c1385a8dd4293dc3d80ad83d --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=0.5_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_alpha=0.5_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 0.5 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_alpha=0.75.yaml b/data/syn_dataset_definitions/publaynet_alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7f363a06a904e9e7d2ecad5b8677c07b886b000d --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=0.75.yaml @@ -0,0 +1,33 @@ +name: "publaynet_alpha=0.75" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 0.75 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_alpha=0.75_v1.yaml b/data/syn_dataset_definitions/publaynet_alpha=0.75_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..760d1911dd20a864b8d1d5a6d6c961f5440552b9 --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=0.75_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_alpha=0.75_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 0.75 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_alpha=1.0.yaml b/data/syn_dataset_definitions/publaynet_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0840bd9ac7b3e771cc622821f4e03a035262c5ee --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=1.0.yaml @@ -0,0 +1,39 @@ +name: "publaynet_alpha=1.0" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +label_mapping: + LE-TEXT: text + LE-TITLE: title + LE-TABLE: table + LE-FIGURE: figure + LE-LIST: list +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_alpha=1.0_v1.yaml b/data/syn_dataset_definitions/publaynet_alpha=1.0_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..25e71a4f983d90d6529e6984dbfa00bb4e5ce6c6 --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_alpha=1.0_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_alpha=1.0_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 # Should have been 5 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d016b2d8976c16a500fbecd8134112a551ca861a --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=0.5" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 0.5 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5_v1.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..39002a7736e3235bf95bebdfa4a18d87db2286db --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.5_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=0.5_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 0.5 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..eb43943862eae3251bfd507029df532e43731234 --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=0.75" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 0.75 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75_v1.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..7864fc37f1a1fa59d02c47a1f4ab7bbe41c35c0e --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=0.75_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=0.75_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 0.75 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..95f91c0efdcbc04256798d4faf3a14451d8cfe80 --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=1.0" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0_v1.yaml b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..14e9cf44dcacec4178f58fad1b2a5cc2ae27caee --- /dev/null +++ b/data/syn_dataset_definitions/publaynet_correct-sampling_alpha=1.0_v1.yaml @@ -0,0 +1,33 @@ +name: "publaynet_correct-sampling_alpha=1.0_v1" +task: "DLA" +dataloader_model_task_as: +base_dataset_name: "publaynet" +documents_count: 4500 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v1" +seed_images_count: 4 +hdbscan_min_cluster_size: 5 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/rvlcdip.yaml b/data/syn_dataset_definitions/rvlcdip.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e0897cbfe4dc63644d48d20d3429076c5909c1a3 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 10 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=0.5.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=0.5.yaml new file mode 100755 index 0000000000000000000000000000000000000000..099934af2e96f6783e6c62786b2607c635bce8c6 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=0.5.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=0.5" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=0.5_v1.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=0.5_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..5a9c73a4508b804dbd661dcbc02112494d74f3f6 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=0.5_v1.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=0.5_v1" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.5 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=0.75.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=0.75.yaml new file mode 100755 index 0000000000000000000000000000000000000000..306524e3ec99a18c72bd8a23370a2d2e117e50f3 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=0.75.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=0.75" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=0.75_v1.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=0.75_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..595c21385ab79a6724b4e54a04eb7b5d5b85e9f9 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=0.75_v1.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=0.75_v1" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 0.75 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=1.0.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3b32ab6f79eed8a14e2d624c629d99c2bf12cfc3 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=1.0.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=1.0" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/rvlcdip_alpha=1.0_v1.yaml b/data/syn_dataset_definitions/rvlcdip_alpha=1.0_v1.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d78111012868203f5be22fe285b3aa56d2cf7d56 --- /dev/null +++ b/data/syn_dataset_definitions/rvlcdip_alpha=1.0_v1.yaml @@ -0,0 +1,57 @@ +name: "rvlcdip_alpha=1.0_v1" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "rvlcdip" +documents_count: 4500 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v1" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/sroie.yaml b/data/syn_dataset_definitions/sroie.yaml new file mode 100755 index 0000000000000000000000000000000000000000..0ad1e348176c843296dffc5931517b4aa2d4c9fd --- /dev/null +++ b/data/syn_dataset_definitions/sroie.yaml @@ -0,0 +1,37 @@ +name: "sroie" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "sroie" +documents_count: 50 +valid_labels: + - COMPANY + - DATE + - ADDRESS + - TOTAL +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: '{"COMPANY": "", "DATE": "", "ADDRESS": "", "TOTAL": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/sroie_alpha=1.0.yaml b/data/syn_dataset_definitions/sroie_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..1122fc39ac9d8f51714077ce5b866ce536b4985e --- /dev/null +++ b/data/syn_dataset_definitions/sroie_alpha=1.0.yaml @@ -0,0 +1,37 @@ +name: "sroie_alpha=1.0" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "sroie" +documents_count: 1000 +valid_labels: + - COMPANY + - DATE + - ADDRESS + - TOTAL +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: '{"COMPANY": "", "DATE": "", "ADDRESS": "", "TOTAL": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/sroie_test.yaml b/data/syn_dataset_definitions/sroie_test.yaml new file mode 100755 index 0000000000000000000000000000000000000000..3b50a0e0bc5fd826fa1f4f45205a9773b78eaddf --- /dev/null +++ b/data/syn_dataset_definitions/sroie_test.yaml @@ -0,0 +1,37 @@ +name: "sroie_test" +task: "KIE" +dataloader_model_task_as: +base_dataset_name: "sroie" +documents_count: 10 +valid_labels: + - COMPANY + - DATE + - ADDRESS + - TOTAL +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + keys and their values + * "COMPANY": The company name. + * "DATE": The date on the receipt. + * "ADDRESS": The address of the company. + * "TOTAL": The total amount. + gt_format: '{"COMPANY": "", "DATE": "", "ADDRESS": "", "TOTAL": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: 22.10.2025 | 11.21 USD +# ICVPR: 23.10.2025 | 38.61 USD +# 1950 samples @ 27.4 USD => 1.4 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/templates/cord.yaml b/data/syn_dataset_definitions/templates/cord.yaml new file mode 100755 index 0000000000000000000000000000000000000000..fb069bd03f06f336deeb77050c121a49cdb96388 --- /dev/null +++ b/data/syn_dataset_definitions/templates/cord.yaml @@ -0,0 +1,200 @@ +name: "cord" +task: "KIE" +base_dataset_name: "cord" +documents_count: 10 +valid_labels: + - MENU_NM + - MENU_NUM + - MENU_UNITPRICE + - MENU_CNT + - MENU_DISCOUNTPRICE + - MENU_PRICE + - MENU_ITEMSUBTOTAL + - MENU_VATYN + - MENU_ETC + - MENU_SUB_NM + - MENU_SUB_UNITPRICE + - MENU_SUB_CNT + - MENU_SUB_PRICE + - MENU_SUB_ETC + - VOID_MENU_NM + - VOID_MENU_PRICE + - SUB_TOTAL_SUBTOTAL_PRICE + - SUB_TOTAL_DISCOUNT_PRICE + - SUB_TOTAL_SERVICE_PRICE + - SUB_TOTAL_OTHERSVC_PRICE + - SUB_TOTAL_TAX_PRICE + - SUB_TOTAL_ETC + - TOTAL_TOTAL_PRICE + - TOTAL_TOTAL_ETC + - TOTAL_CASHPRICE + - TOTAL_CHANGEPRICE + - TOTAL_CREDITCARDPRICE + - TOTAL_EMONEYPRICE + - TOTAL_MENUTYPE_CNT + - TOTAL_MENUQTY_CNT +valid_secondary_labels: + - MENU_1 + - MENU_2 + - MENU_3 + - MENU_4 + - MENU_5 + - MENU_6 + - MENU_7 + - MENU_8 + - MENU_9 + - MENU_10 + - MENU_11 + - MENU_12 + - MENU_13 + - MENU_14 + - MENU_15 + - MENU_16 + - MENU_17 + - MENU_18 + - MENU_19 + - MENU_20 + - MENU_21 + - MENU_22 + - MENU_23 + - MENU_24 + - MENU_25 + - MENU_26 + - MENU_27 + - MENU_28 + - MENU_29 + - MENU_30 + - MENU_31 + - MENU_32 + - MENU_33 + - MENU_34 + - MENU_35 + - MENU_36 + - MENU_37 + - MENU_38 + - MENU_39 + - MENU_40 + - MENU_41 + - MENU_42 + - MENU_43 + - MENU_44 + - MENU_45 + - MENU_46 + - MENU_47 + - MENU_48 + - MENU_49 + - MENU_50 + - MENU_51 + - MENU_52 + - MENU_53 + - MENU_54 + - MENU_55 + - MENU_56 + - MENU_57 + - MENU_58 + - MENU_59 + - MENU_60 + - MENU_61 + - MENU_62 + - MENU_63 + - MENU_64 + - MENU_65 + - MENU_66 + - MENU_67 + - MENU_68 + - MENU_69 + - MENU_70 + - MENU_71 + - MENU_72 + - MENU_73 + - MENU_74 + - MENU_75 + - MENU_76 + - MENU_77 + - MENU_78 + - MENU_79 + - MENU_80 + - MENU_81 + - MENU_82 + - MENU_83 + - MENU_84 + - MENU_85 + - MENU_86 + - MENU_87 + - MENU_88 + - MENU_89 + - MENU_90 + - MENU_91 + - MENU_92 + - MENU_93 + - MENU_94 + - MENU_95 + - MENU_96 + - MENU_97 + - MENU_98 + - MENU_99 + - MENU_100 + - VOID_MENU + - VOID_MENU_1 # the LLM shouldn't do this but does + - VOID_MENU_2 + - VOID_MENU_3 + - VOID_MENU_4 + - VOID_MENU_5 + - VOID_MENU_6 + - VOID_MENU_7 + - VOID_MENU_8 + - VOID_MENU_9 + - VOID_MENU_10 + - GENERIC + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 3 + doc_type: "receipt" + language: "English" + gt_type: | + (if applicable, provide as plaintext values from the document) + // Menu items (multiple menu items are allowed) + * "MENU_NM": The menu item name. + * "MENU_NUM": The menu item number or identifier. + * "MENU_UNITPRICE": The price per unit of the menu item. + * "MENU_CNT": The quantity or count of the menu item. + * "MENU_DISCOUNTPRICE": The discount amount applied to the menu item. + * "MENU_PRICE": The final price of the menu item. + * "MENU_ITEMSUBTOTAL": The subtotal for this menu item line. + * "MENU_VATYN": The VAT indicator (yes/no) for the menu item. + * "MENU_ETC": Other miscellaneous menu item information. + * "MENU_SUB_NM": The name of a sub-item or modifier. + * "MENU_SUB_UNITPRICE": The price per unit of the sub-item. + * "MENU_SUB_CNT": The quantity of the sub-item. + * "MENU_SUB_PRICE": The price of the sub-item. + * "MENU_SUB_ETC": Other sub-item information. + // Menu items that were canceled + * "VOID_MENU_NM": The name of a cancelled or voided item. + * "VOID_MENU_PRICE": The price of the cancelled item. + // Generic receipt data + * "SUB_TOTAL_SUBTOTAL_PRICE": The subtotal before additional charges. + * "SUB_TOTAL_DISCOUNT_PRICE": The total discount amount. + * "SUB_TOTAL_SERVICE_PRICE": The service charge or fee. + * "SUB_TOTAL_OTHERSVC_PRICE": Other service charges. + * "SUB_TOTAL_TAX_PRICE": The tax amount. + * "SUB_TOTAL_ETC": Other subtotal information. + * "TOTAL_TOTAL_PRICE": The final total amount on the receipt. + * "TOTAL_TOTAL_ETC": Other total-related information. + * "TOTAL_CASHPRICE": The amount paid in cash. + * "TOTAL_CHANGEPRICE": The change given back to the customer. + * "TOTAL_CREDITCARDPRICE": The amount paid by credit card. + * "TOTAL_EMONEYPRICE": The amount paid by electronic money or digital payment. + * "TOTAL_MENUTYPE_CNT": The count of different menu item types. + * "TOTAL_MENUQTY_CNT": The total quantity of all items ordered. + gt_format: | + Group individual menu items in groups using the menu item enumerator class MENU_ and a sub-field class from the list above (e.g. "MENU_1 MENU_NM", "MENU_1 MENU_CNT", "MENU_2 MENU_NM", ...). + For void/canceled menu items use the class "VOID_MENU" instead of the enumeration. + For generic receipt data use the class "GENERIC". + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/templates/publaynet.yaml b/data/syn_dataset_definitions/templates/publaynet.yaml new file mode 100755 index 0000000000000000000000000000000000000000..885158bef50feba059b2ca84e94d7d338f833ac4 --- /dev/null +++ b/data/syn_dataset_definitions/templates/publaynet.yaml @@ -0,0 +1,33 @@ +name: "publaynet" +task: "DLA" +base_dataset_name: "publaynet" +documents_count: 20 +valid_labels: + - LE-TEXT + - LE-TITLE + - LE-TABLE + - LE-FIGURE + - LE-LIST +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "annotation" +prompt_params: + num_solutions: 2 + doc_type: "single A4 pages out of one and two column scientific article" + language: "English" + gt_type: | + * "LE-TEXT": Contains regular body text including paragraphs, abstracts, authors, affiliations, keywords, footnotes, footer, references, and captions for figures and tables. + * "LE-TITLE": Comprises all document titles and headings, article titles as well as standalone section or subsection headings that appear on their own line rather than inline with text. + * "LE-TABLE": Denotes the main body content of tables, excluding captions and labels. + * "LE-FIGURE": Indicates the main visual content of figures and illustrations, with multi-panel figures annotated as complete units rather than individual sub-figures. + * "LE-LIST": Represents enumerated or bulleted list structures, with nested lists annotated as single unified objects. + gt_format: + +seed_selection_strategy: "v2" +seed_images_count: 4 +hdbscan_min_cluster_size: 10 +embedding_type: image +alpha: 1 +max_seed_pool: -1 diff --git a/data/syn_dataset_definitions/templates/rvlcdip.yaml b/data/syn_dataset_definitions/templates/rvlcdip.yaml new file mode 100755 index 0000000000000000000000000000000000000000..45e33017660a53598959979e9f09e489a388774a --- /dev/null +++ b/data/syn_dataset_definitions/templates/rvlcdip.yaml @@ -0,0 +1,54 @@ +name: "rvlcdip" +task: "CLASSIFICATION" +base_dataset_name: "rvlcdip" +documents_count: 10 +valid_labels: + - letter + - form + - email + - handwritten + - advertisement + - scientific report + - scientific publication + - specification + - file folder + - news article + - budget + - invoice + - presentation + - questionnaire + - resume + - memo +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "business correspondence and corporate" + language: "English" + gt_type: | + document class label + * letter + * form + * email + * handwritten + * advertisement + * scientific report + * scientific publication + * specification + * file folder + * news article + * budget + * invoice + * presentation + * questionnaire + * resume + * memo + gt_format: 'JSON object {"label": ""}' + +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 \ No newline at end of file diff --git a/data/syn_dataset_definitions/tobacco3482_alpha=1.0.yaml b/data/syn_dataset_definitions/tobacco3482_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..9c89ba154534e79281ee9fb5c7bb90ff7a3958f3 --- /dev/null +++ b/data/syn_dataset_definitions/tobacco3482_alpha=1.0.yaml @@ -0,0 +1,60 @@ +name: "tobacco3482_alpha=1.0" +task: "CLASSIFICATION" +dataloader_model_task_as: +base_dataset_name: "tobacco3482" +documents_count: 5500 +valid_labels: + - ADVERTISEMENT + - EMAIL + - FORM + - LETTER + - MEMO + - NEWS_ARTICLE + - NOTE + - REPORT + - RESUME + - SCIENTIFIC +label_mapping: + ADVERTISEMENT: ADVE + EMAIL: Email + FORM: Form + LETTER: Letter + MEMO: Memo + NEWS_ARTICLE: News + NOTE: Note + REPORT: Report + RESUME: Resume + SCIENTIFIC: Scientific + +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "legal and corporate" + language: "English" + gt_type: | + document class labels: + * ADVERTISEMENT: Advertisement + * EMAIL: Email + * FORM: Form + * LETTER: Letter + * MEMO: Memo + * NEWS_ARTICLE: News article + * NOTE: Note/handwritten note + * REPORT: Report + * RESUME: Resume/CV + * SCIENTIFIC: Scientific publication + gt_format: 'JSON object {"label": ""}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 10 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 + +# ICVPR: start | 38.61 USD +# ICVPR: end | 50.37 USD +# 936 samples @ 11.76 USD => 1.25 ct/doc \ No newline at end of file diff --git a/data/syn_dataset_definitions/wtq_alpha=1.0.yaml b/data/syn_dataset_definitions/wtq_alpha=1.0.yaml new file mode 100755 index 0000000000000000000000000000000000000000..d03cc747e95624906ca2e0bdb4c8e7f5abf3dba9 --- /dev/null +++ b/data/syn_dataset_definitions/wtq_alpha=1.0.yaml @@ -0,0 +1,30 @@ +name: "wtq_alpha=1.0" +task: "QA" +dataloader_model_task_as: +base_dataset_name: "ex_wiki" +documents_count: 1600 # 1600 (1400 + 200 margin of error) +valid_labels: +label_mapping: +valid_secondary_labels: + +prompt_template: "ClaudeRefined12" +prompt_task: "json" +prompt_params: + num_solutions: 3 + doc_type: "semi-structures table" + language: "English" + gt_type: | + Multiple complex question-answer pairs in everyday language that can be answered from the associated table, with their answers taken **verbatim** from the document. + Common Question Types: + * Lookup: Finding specific cell values ("What is the capital of France?") + * Aggregation: Counting, summing, averaging ("How many players scored over 20 points?") + * Comparison: Finding max/min ("Which country has the largest population?") + * Reasoning: Requiring multiple steps ("What team did the highest scorer play for?") + gt_format: '{"": "", "": "", ...}' + +seed_selection_strategy: "v2" +seed_images_count: 6 +hdbscan_min_cluster_size: 5 +embedding_type: combined +alpha: 1 +max_seed_pool: -1 diff --git a/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM (1).png b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM (1).png new file mode 100755 index 0000000000000000000000000000000000000000..aa97cb337fdcd413404851c8e4d2cb7f97a32e39 --- /dev/null +++ b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM (1).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8862d479ae51472629b63424e6786a6ee0affd0b46c96dff3cc2489d6fdfa85e +size 1210373 diff --git a/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM.png b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM.png new file mode 100755 index 0000000000000000000000000000000000000000..785f4fca8497adc0fc39aed20fc6ded7461fb82c --- /dev/null +++ b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_10PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bf61a520590255d3b96c005b62d52a60b8135fbd3efa6a68a3b8289a865e9217 +size 1435185 diff --git a/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM (1).png b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM (1).png new file mode 100755 index 0000000000000000000000000000000000000000..70d255d6e0cc15d9d7e46880b04e124dae035e47 --- /dev/null +++ b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM (1).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:030bd85435de1b77e07a5ce579686fa807e57195914d84c438de469fb2506948 +size 1537276 diff --git a/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM.png b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM.png new file mode 100755 index 0000000000000000000000000000000000000000..9ac0c82944c8810570cadd00c423d117bf5e11ff --- /dev/null +++ b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_11PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f871d775be655d255086af8ea04730c235d4040ce4e7503de98f16205ad8a373 +size 1693736 diff --git a/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_12PM.png b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_12PM.png new file mode 100755 index 0000000000000000000000000000000000000000..bc618c88575a5f082debc36aeb90ee8c884b3516 --- /dev/null +++ b/data/visual_element_prefabs/figure/Generated Image October 27, 2025 - 9_12PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1177692955eed1930228849a22197ffa6e93d7c393a82e40dfe584d16b0bcecf +size 1470095 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_13PM.png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_13PM.png new file mode 100755 index 0000000000000000000000000000000000000000..c372c616b1709eee531374257ed7913b108b2de2 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_13PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1de46e8129861a4e753e3a6c9b10b00c8d32c3e69a3b5ff78bc2fbf6b0d86b6f +size 1562926 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM (1).png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM (1).png new file mode 100755 index 0000000000000000000000000000000000000000..4b3680af82196261833234f0c5d28ec8e4662fb1 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM (1).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04f9cb6b08eff876ba35311dfb09a1ef7994219f09997064138570d59acf7ded +size 1195767 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM.png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM.png new file mode 100755 index 0000000000000000000000000000000000000000..95547d755a56866ad5afe728506dc399800be2d1 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_14PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c6fc569379b9e99d61554940cbea022af613b1cba910d6731d23e83f9929403f +size 1344638 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM (2).png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM (2).png new file mode 100755 index 0000000000000000000000000000000000000000..f4ae0402577db8aaa501c6f0518ef1b32ae50217 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM (2).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a71156800daddbf8fd1f7bfc8d5827afc24df0b080750567bcba865569d14fed +size 1439178 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM.png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM.png new file mode 100755 index 0000000000000000000000000000000000000000..e0534c57bebd1df59c618b1dcba988d9a4a23388 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_16PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d8f133f7ec17a2f0eaf58b96c941c0d420b282edc07673ff6f5b03e8dcdd7c7f +size 1261549 diff --git a/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_20PM.png b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_20PM.png new file mode 100755 index 0000000000000000000000000000000000000000..8ac33ecef7a424b0e7c93abbe737a68bd144ddf0 --- /dev/null +++ b/data/visual_element_prefabs/logo/Generated Image October 27, 2025 - 9_20PM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:205f1444edf957d3301ff43ff3fb8bbfc79210c3b84de1322269cbf77dc7fa71 +size 1697350 diff --git a/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_36AM (2).png b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_36AM (2).png new file mode 100755 index 0000000000000000000000000000000000000000..75f45acc15292c3628d3d0f621f01d4b5d28a42f --- /dev/null +++ b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_36AM (2).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b3467d58ae9d8f2b4bdca834915ac2245096d1b673aaffb1d504246cd9bc67c9 +size 2127546 diff --git a/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM (1).png b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM (1).png new file mode 100755 index 0000000000000000000000000000000000000000..2c36cf358e7a7732a638f6f8cc436cd5985c50d8 --- /dev/null +++ b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM (1).png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a91e2d887b5f3bc49ea5c98c86e11bbdb806234b6716e2970a5926562e8e6be5 +size 2180085 diff --git a/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM.png b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM.png new file mode 100755 index 0000000000000000000000000000000000000000..49003d4dba70cac7bb02ed72dddc89af277e0842 --- /dev/null +++ b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_38AM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1cb3a5912c28948e49a59d1b61886c9ba113371531a6e25de50efec70b075c74 +size 2096141 diff --git a/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_39AM.png b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_39AM.png new file mode 100755 index 0000000000000000000000000000000000000000..0b5981fde9eca21d21e95bfef7d8cb2c2e54f9b4 --- /dev/null +++ b/data/visual_element_prefabs/photo/Generated Image October 28, 2025 - 12_39AM.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00e880cf6d316a7f46f8505b8e351ca8d21e360d93ca738f339bbd5f27d979b1 +size 2241792 diff --git a/data/visual_element_prefabs/photo/photo1.jpg b/data/visual_element_prefabs/photo/photo1.jpg new file mode 100755 index 0000000000000000000000000000000000000000..c1aa252ccf87453a9a0f3bbbcd333453aa33dd11 --- /dev/null +++ b/data/visual_element_prefabs/photo/photo1.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d9f21fa89f133ca73b40e9b6051b032b7ca0a69ff69a6f0ec160fda62bbbfacd +size 547596 diff --git a/data/visual_element_prefabs/photo/photo2.jpg b/data/visual_element_prefabs/photo/photo2.jpg new file mode 100755 index 0000000000000000000000000000000000000000..2b71d7099b95ab6dd2873d01c9f4e29ee383a5b9 --- /dev/null +++ b/data/visual_element_prefabs/photo/photo2.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c55f123f45626687bea3d54bb12353447c45cce7264a13f9250610fa506209cd +size 590629 diff --git a/data/visual_element_prefabs/photo/photo3.jpg b/data/visual_element_prefabs/photo/photo3.jpg new file mode 100755 index 0000000000000000000000000000000000000000..56cfdfdc993cf167c2c49b2eadc5ddfbbd24c2db --- /dev/null +++ b/data/visual_element_prefabs/photo/photo3.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:632ae0175f03eaaf19315a250aa6fb1af62eb4e331e3185967a3348bbd04394a +size 550545 diff --git a/data/visual_element_prefabs/photo/photo4.jpg b/data/visual_element_prefabs/photo/photo4.jpg new file mode 100755 index 0000000000000000000000000000000000000000..fbd37c4d57816e3e183e4da832ef2362cba5a83c --- /dev/null +++ b/data/visual_element_prefabs/photo/photo4.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c298f261a7b24aadf96b63fc4cf560a1b69c236f884a58d11cf8e3dc8eed64d +size 552578 diff --git a/data/visual_element_prefabs/photo/photo5.jpg b/data/visual_element_prefabs/photo/photo5.jpg new file mode 100755 index 0000000000000000000000000000000000000000..da5fd64d53b6a2be84f4f9f1df4088b2f7c34433 --- /dev/null +++ b/data/visual_element_prefabs/photo/photo5.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7df040dd747bc9b53e195a71a06171199c57a57b85e10c84c6caaf6613c4defe +size 527274 diff --git a/deploy.sh b/deploy.sh new file mode 100755 index 0000000000000000000000000000000000000000..a79c88db58c289ddb028194da1b432c3c63e8472 --- /dev/null +++ b/deploy.sh @@ -0,0 +1,193 @@ +#!/bin/bash +# ============================================ +# DocGenie Deployment Helper Script +# ============================================ +# Quick deployment script for Railway + RunPod + +set -e # Exit on error + +echo "🚀 DocGenie Deployment Helper" +echo "==============================" +echo "" + +# Colors +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Function to print colored messages +print_success() { + echo -e "${GREEN}✓ $1${NC}" +} + +print_error() { + echo -e "${RED}✗ $1${NC}" +} + +print_info() { + echo -e "${YELLOW}ℹ $1${NC}" +} + +# Check prerequisites +echo "Checking prerequisites..." + +# Check if Docker is installed +if ! command -v docker &> /dev/null; then + print_error "Docker is not installed. Please install Docker first." + exit 1 +fi +print_success "Docker installed" + +# Check if .env exists +if [ ! -f "api/.env" ]; then + print_error "api/.env file not found. Please create it first." + exit 1 +fi +print_success "Environment file found" + +# Menu +echo "" +echo "Select deployment option:" +echo "1) Build Handwriting Service Docker image" +echo "2) Push Handwriting Service to Docker Hub" +echo "3) Deploy API to Railway" +echo "4) Run local test environment (docker-compose)" +echo "5) Full deployment (Handwriting + API)" +echo "0) Exit" +echo "" +read -p "Enter option (0-5): " option + +case $option in + 1) + echo "" + print_info "Building Handwriting Service Docker image..." + + # Build image + cd handwriting_service + docker buildx build --platform linux/amd64 \ + -t docgenie-handwriting:latest \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + . + + print_success "Image built successfully" + print_info "Tag: docgenie-handwriting:latest" + ;; + + 2) + echo "" + read -p "Enter your Docker Hub username: " docker_username + + print_info "Tagging image for Docker Hub..." + docker tag docgenie-handwriting:latest ${docker_username}/docgenie-handwriting:latest + + print_info "Pushing to Docker Hub..." + docker push ${docker_username}/docgenie-handwriting:latest + + print_success "Image pushed successfully" + print_info "Deploy this on RunPod: ${docker_username}/docgenie-handwriting:latest" + ;; + + 3) + echo "" + print_info "Deploying API to Railway..." + + # Check if Railway CLI is installed + if ! command -v railway &> /dev/null; then + print_error "Railway CLI not installed. Installing..." + npm i -g @railway/cli + fi + + # Deploy + railway up + + print_success "API deployed to Railway" + print_info "View logs: railway logs" + print_info "View URL: railway open" + ;; + + 4) + echo "" + print_info "Starting local test environment..." + print_info "This will start: Redis, API, Worker, Handwriting Service" + + # Check if GPU is available + if command -v nvidia-smi &> /dev/null; then + print_info "GPU detected, using CUDA" + docker-compose up + else + print_info "No GPU detected, using CPU for handwriting service" + DEVICE=cpu docker-compose up + fi + ;; + + 5) + echo "" + print_info "Full deployment starting..." + + # Step 1: Build handwriting image + print_info "Step 1/4: Building Handwriting Service..." + cd handwriting_service + docker buildx build --platform linux/amd64 \ + -t docgenie-handwriting:latest \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + . + cd .. + print_success "Handwriting image built" + + # Step 2: Push to Docker Hub + echo "" + read -p "Enter your Docker Hub username: " docker_username + print_info "Step 2/4: Pushing to Docker Hub..." + docker tag docgenie-handwriting:latest ${docker_username}/docgenie-handwriting:latest + docker push ${docker_username}/docgenie-handwriting:latest + print_success "Image pushed" + + # Step 3: Deploy to RunPod (manual) + echo "" + print_info "Step 3/4: Deploy to RunPod (manual step)" + print_info "1. Go to https://runpod.io → Serverless → New Endpoint" + print_info "2. Use image: ${docker_username}/docgenie-handwriting:latest" + print_info "3. Select GPU: RTX 4090 or A40" + print_info "4. Set port: 8080" + print_info "5. Set env: DEVICE=cuda" + read -p "Press Enter when RunPod deployment is complete..." + + # Step 4: Get RunPod URL and deploy API + echo "" + read -p "Enter your RunPod endpoint URL: " runpod_url + + print_info "Step 4/4: Deploying API to Railway..." + + # Set HANDWRITING_SERVICE_URL + export HANDWRITING_SERVICE_URL=$runpod_url + + # Deploy to Railway + if ! command -v railway &> /dev/null; then + print_error "Railway CLI not installed. Installing..." + npm i -g @railway/cli + fi + + railway up + + print_success "Full deployment complete!" + echo "" + print_info "Next steps:" + print_info "1. Set HANDWRITING_SERVICE_URL in Railway dashboard" + print_info "2. railway variables set HANDWRITING_SERVICE_URL=$runpod_url" + print_info "3. Test: curl https://your-domain.up.railway.app/health" + ;; + + 0) + echo "Goodbye!" + exit 0 + ;; + + *) + print_error "Invalid option" + exit 1 + ;; +esac + +echo "" +print_success "Done!" diff --git a/docgenie/__init__.py b/docgenie/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..56bc87e9e06dec489b2e81b897c8125ca713ec3a --- /dev/null +++ b/docgenie/__init__.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from enum import Enum +from pathlib import Path + +_root_path = Path(__file__).parent.parent.resolve() + + +# Project paths +class ENV: + # General + ROOT_DIR: Path = _root_path + DATA_DIR: Path = ROOT_DIR / "data" + + DATASETS_DIR: Path = ROOT_DIR / "data" / "datasets" + BASE_DATASETS_DIR: Path = DATASETS_DIR / "base_v2" + SYN_DATASETS_PREPARED_DIR: Path = DATASETS_DIR / "synthesized_prepared" + SYN_DATASETS_DIR: Path = DATASETS_DIR / "synthesized_datasets" + + VISUAL_ELEMENT_PREFABS_DIR: Path = DATA_DIR / "visual_element_prefabs" + + EMBEDDINGS_DIR: Path = DATA_DIR / "embeddings" + GT_EMBEDDINGS_DIR: Path = DATA_DIR / "gt_embeddings" + CLUSTERS_DIR: Path = DATA_DIR / "clusters" + CLUSTER_PLOTS: Path = DATA_DIR / "cluster_plots" + SYN_DATASET_STAT_PLOTS: Path = DATA_DIR / "syn_dataste_statistics_plots" + + ANALYZATION_DIR: Path = DATA_DIR / "analyzation" + GT_ANALYZATION_DIR: Path = ANALYZATION_DIR / "gt" + KIE_GT_ANALYZATION_DIR: Path = GT_ANALYZATION_DIR / "kie" + CLS_GT_ANALYZATION_DIR: Path = GT_ANALYZATION_DIR / "cls" + QA_GT_ANALYZATION_DIR: Path = GT_ANALYZATION_DIR / "qa" + DLA_GT_ANALYZATION_DIR: Path = GT_ANALYZATION_DIR / "dla" + + WEBAPP_CACHE_DIR: Path = DATA_DIR / "webapp_cache" + QA_GT_WEBAPP_CACHE_DIR: Path = WEBAPP_CACHE_DIR / "qa_gt" + + TEMP_DIR: Path = DATA_DIR / "temp" + + MODELS_DIR: Path = DATA_DIR / "models" + RUNS_DIR: Path = DATA_DIR / "runs" + + EXPORTS_DIR: Path = DATA_DIR / "exports" + + # Contains combined datasets (original and synthetic) + PREPARED_DATASETS_DIR: Path = DATASETS_DIR / "prepared" + + SYN_DATA_DEFINITIONS_DIR: Path = DATA_DIR / "syn_dataset_definitions" + PROMPT_TEMPLATES_DIR: Path = DATA_DIR / "prompt_templates" + SEED_IMAGES_DIR: Path = DATA_DIR / "seed-images" + + +ENV.BASE_DATASETS_DIR.mkdir(parents=True, exist_ok=True) +ENV.SYN_DATASETS_DIR.mkdir(parents=True, exist_ok=True) +ENV.SYN_DATASETS_PREPARED_DIR.mkdir(parents=True, exist_ok=True) +ENV.VISUAL_ELEMENT_PREFABS_DIR.mkdir(parents=True, exist_ok=True) +ENV.PREPARED_DATASETS_DIR.mkdir(parents=True, exist_ok=True) +ENV.EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) +ENV.CLUSTERS_DIR.mkdir(parents=True, exist_ok=True) +ENV.TEMP_DIR.mkdir(parents=True, exist_ok=True) +ENV.MODELS_DIR.mkdir(parents=True, exist_ok=True) +ENV.EXPORTS_DIR.mkdir(parents=True, exist_ok=True) +ENV.CLUSTER_PLOTS.mkdir(parents=True, exist_ok=True) +ENV.SYN_DATASET_STAT_PLOTS.mkdir(parents=True, exist_ok=True) +ENV.GT_EMBEDDINGS_DIR.mkdir(parents=True, exist_ok=True) +ENV.KIE_GT_ANALYZATION_DIR.mkdir(parents=True, exist_ok=True) +ENV.CLS_GT_ANALYZATION_DIR.mkdir(parents=True, exist_ok=True) +ENV.DLA_GT_ANALYZATION_DIR.mkdir(parents=True, exist_ok=True) +ENV.QA_GT_ANALYZATION_DIR.mkdir(parents=True, exist_ok=True) +ENV.QA_GT_WEBAPP_CACHE_DIR.mkdir(parents=True, exist_ok=True) + + +class LLM: + CLAUDE_SONNET_4 = "claude-sonnet-4-20250514" + CLAUDE_SONNET_4_5 = "claude-sonnet-4-5-20250929" + CLAUDE_HAIKU_4_5 = "claude-haiku-4-5-20251001" + TINYLLM_CLAUDE_SONNET_4 = "anthropic/claude-sonnet-4-20250514" + + +# Default values for generation +class GENERATION: + LLM = LLM.CLAUDE_SONNET_4_5 + MAX_TOKENS = 16384 + HANDWRITING_MODEL_CHECKPOINT = ENV.MODELS_DIR / "handwriting" / "latest.pt" diff --git a/docgenie/analyzation/clustering/cmds/generate_clusters.py b/docgenie/analyzation/clustering/cmds/generate_clusters.py new file mode 100755 index 0000000000000000000000000000000000000000..a4fdbe9de2a22450293148638779e9c958f0127a --- /dev/null +++ b/docgenie/analyzation/clustering/cmds/generate_clusters.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from pathlib import Path + +import pydantic.v1 as pydantic +import pydantic_argparse + +from docgenie import ENV +from docgenie.analyzation.clustering.core._metrics import calculate_cluster_statistics +from docgenie.analyzation.clustering.core._utilities import ( + EmbeddingType, + _save_clustering_metrics, +) +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +def main(cfg: ClusteringConfig): + """ + Generate clusters for all embedding types and save results. + """ + + import numpy as np + + from docgenie.analyzation.clustering.core._algorithms import ( + _read_and_cluster_embeddings, + ) + from docgenie.analyzation.clustering.core._metrics import ( + evaluate_clusters_unsupervised, + ) + from docgenie.analyzation.clustering.core._utilities import ( + _get_clustering_output_path, + ) + + logger.info(f"Clustering with config:\n{cfg}") + + for embedding_type in EmbeddingType.__members__.values(): + logger.info(f"Generating clusters for {embedding_type.value=}") + + # see if embeddings exist + embeddings_path = ( + Path(cfg.embeddings_dir) / cfg.dataset_name / (f"{embedding_type.value}.h5") + ) + if not embeddings_path.exists(): + logger.warning( + f"Embeddings not found for {cfg.dataset_name} at {embeddings_path}, skipping..." + ) + continue + + # save cluster labels + output_dir = Path(cfg.output_dir) / cfg.dataset_name / embedding_type.value + clusters_path = _get_clustering_output_path( + output_dir=output_dir, + intermediate_num_dims=cfg.intermediate_num_dims, + hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size, + hdbscan_metric=cfg.hdbscan_metric, + k_nn_n_neighbors=cfg.k_nn_n_neighbors, + do_knn=cfg.do_knn, + method=cfg.method, + ) + + if not Path(clusters_path).exists(): + outputs = _read_and_cluster_embeddings( + embeddings_dir=cfg.embeddings_dir, + dataset_name=cfg.dataset_name, + embedding_type=embedding_type, + intermediate_num_dims=cfg.intermediate_num_dims, + hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size, + hdbscan_metric=cfg.hdbscan_metric, + k_nn_n_neighbors=cfg.k_nn_n_neighbors, + seed=cfg.seed, + do_knn=cfg.do_knn, + cache_dir=output_dir, + method=cfg.method, + ) + + logger.info(f"Saving clusters to {clusters_path}...") + Path(clusters_path).parent.mkdir(parents=True, exist_ok=True) + np.save( + clusters_path, + outputs, + ) + cluster_labels = outputs["cluster_labels"] + num_noise = outputs.get("num_noise", 0) + embeddings_reduced_dim = outputs["embeddings_reduced_dim"] + else: + logger.info(f"Loading existing clusters from {clusters_path}...") + cluster_results = np.load(clusters_path, allow_pickle=True).item() + cluster_labels = cluster_results["cluster_labels"] + num_noise = cluster_results["num_noise"] + embeddings_reduced_dim = cluster_results["embeddings_reduced_dim"] + + # compute cluster statistics + cluster_stats = calculate_cluster_statistics( + embeddings_reduced_dim, cluster_labels + ) + cluster_stats.to_csv( + clusters_path.parent / clusters_path.name.replace(".npy", "_stats.csv"), + index=False, + ) + + # compute metrics + cluster_metrics, num_clusters = evaluate_clusters_unsupervised( + embeddings=embeddings_reduced_dim, cluster_labels=cluster_labels + ) + + # save metrics + _save_clustering_metrics( + output_dir=cfg.output_dir, + dataset_name=cfg.dataset_name, + hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size, + intermediate_num_dims=cfg.intermediate_num_dims, + hdbscan_metric=cfg.hdbscan_metric, + k_nn_n_neighbors=cfg.k_nn_n_neighbors, + method=cfg.method, + embedding_type=embedding_type, + embeddings=embeddings_reduced_dim, + cluster_metrics=cluster_metrics, + num_clusters=num_clusters, + num_noise=num_noise, + seed=cfg.seed, + do_knn=cfg.do_knn, + ) + + +class ClusteringConfig(pydantic.BaseModel): + """ + Configuration for clustering operations. + """ + + dataset_name: str + seed: int = 42 + hdbscan_min_cluster_size: int = 10 + intermediate_num_dims: int = 100 + hdbscan_metric: str = "euclidean" + do_knn: bool = True + k_nn_n_neighbors: int = 5 + embeddings_dir: str | Path = ENV.EMBEDDINGS_DIR + output_dir: str | Path = ENV.CLUSTERS_DIR + method: str = "hdbscan" # or "kmeans" + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=ClusteringConfig, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/clustering/cmds/generate_embeddings.py b/docgenie/analyzation/clustering/cmds/generate_embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..ee9a8ec2000831a99453c5e3d602775e9cf9917c --- /dev/null +++ b/docgenie/analyzation/clustering/cmds/generate_embeddings.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from pathlib import Path + +import pydantic.v1 as pydantic +import pydantic_argparse + +from docgenie import ENV +from docgenie.analyzation.clustering.core._embeddings import ( + _load_sample_ids_from_embeddings, + _save_embeddings, + embedding_extraction_with_cache, +) +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data._core._utilities import TaskType +from docgenie.data.interface import load_data_pipeline, load_preprocessed_data_pipeline +from docgenie.evaluation.utils import get_device +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +class GenerateEmbeddingsConfig(pydantic.BaseModel): + """ + Configuration for generating embeddings. + """ + + dataset_name: str + is_synth: bool = False + output_dir: str = ENV.EMBEDDINGS_DIR + kernel_size: int = 4 + split: str = "train" + batch_size: int = 16 + dataloader_num_workers: int = 8 + use_preprocessed: bool = False + verify_only: bool = False + is_synthetic: bool = False + + +def main(cfg: GenerateEmbeddingsConfig): + # setup data pipeline and dataloaders with preprocessing + # this will save preprocessed msgpacks + if cfg.use_preprocessed: + data_pipeline = load_preprocessed_data_pipeline( + dataset_name=cfg.dataset_name, + is_synthetic=cfg.is_synth, + task_type=TaskType.generate_embeddings, + split=cfg.split, + is_synthetic=cfg.is_synthetic, + ) + else: + data_pipeline = load_data_pipeline( + dataset_name=cfg.dataset_name, + is_synthetic=cfg.is_synth, + task_type=TaskType.generate_embeddings, + split=cfg.split, + is_synthetic=cfg.is_synthetic, + ) + + if cfg.verify_only: + output_dir = Path(cfg.output_dir) / cfg.dataset_name + sample_ids_per_type = {} + for embedding_type in list(EmbeddingType): + cache_file = Path(output_dir) / f"{embedding_type.value}.h5" + if not cache_file.exists(): + logger.warning( + f"Cache file {cache_file} does not exist. Please run the script " + "without --verify_only to generate embeddings." + ) + continue + sample_ids = _load_sample_ids_from_embeddings(cache_file) + logger.info( + f"Cache file {cache_file} exists with {len(sample_ids)} samples." + ) + sample_ids_per_type[embedding_type.value] = sample_ids + + # make sure sample ids are the same across all types + sample_ids = sample_ids_per_type[ + sample_ids_per_type.keys().__iter__().__next__() + ] + for embedding_type, ids in sample_ids_per_type.items(): + assert ids == sample_ids, ( + f"Sample IDs for {embedding_type} do not match those for " + f"{EmbeddingType.layout.value}" + ) + + logger.info(f"All cache files exist for dataset {cfg.dataset_name}.") + return + + # print dataset info + logger.info(data_pipeline.dataset) + + # setup dataloader + dataloader = data_pipeline.dataloader( + split=cfg.split, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.dataloader_num_workers, + ) + + # check whether batch in the dataset has ocr content + batch = next(iter(dataloader)) + has_ocr_content = batch.words is not None + + output_dir = Path(cfg.output_dir) / cfg.dataset_name + embeddings_per_type = {} + sample_ids_per_type = {} + for embedding_type in list(EmbeddingType): + if embedding_type == EmbeddingType.combined: + continue + if ( + embedding_type + in [EmbeddingType.layout, EmbeddingType.text, EmbeddingType.paper] + and not has_ocr_content + ): + logger.warning( + f"Skipping {embedding_type.value} embeddings for dataset {cfg.dataset_name} " + "as it does not have OCR content." + ) + continue + embeddings, sample_ids = embedding_extraction_with_cache( + dataloader=dataloader, + output_dir=output_dir, + embedding_type=embedding_type, + device=get_device(), + ) + embeddings_per_type[embedding_type.value] = embeddings + sample_ids_per_type[embedding_type.value] = sample_ids + logger.info( + f"Generated {embedding_type.value} embeddings for {len(sample_ids)} samples." + ) + + # make sure sample ids are the same across all types + sample_ids = sample_ids_per_type[sample_ids_per_type.keys().__iter__().__next__()] + print("Sample ids of first 10 samples: ", sample_ids[:10]) + for embedding_type, ids in sample_ids_per_type.items(): + assert ids == sample_ids, ( + f"Sample IDs for {embedding_type} do not match those for " + f"{EmbeddingType.layout.value}" + ) + + if not has_ocr_content: + logger.warning( + f"Skipping {EmbeddingType.combined.value} embeddings for dataset {cfg.dataset_name} " + "as it does not have OCR content." + ) + return + cache_file = Path(output_dir) / f"{EmbeddingType.combined.value}.h5" + if not cache_file.exists(): + import numpy as np + from sklearn.preprocessing import StandardScaler + + embeddings_per_type = { + k: StandardScaler().fit_transform(v) for k, v in embeddings_per_type.items() + } + + combined_embeddings = np.hstack( + [ + v + for k, v in embeddings_per_type.items() + if k + in [ + EmbeddingType.layout.value, + EmbeddingType.text.value, + EmbeddingType.image.value, + ] + ] + ) + + logger.info( + f"Generated {EmbeddingType.combined.value} embeddings for {len(sample_ids)} samples." + ) + _save_embeddings( + embeddings=combined_embeddings, + sample_ids=sample_ids, + file_path=Path(output_dir) / f"{EmbeddingType.combined.value}.h5", + ) + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=GenerateEmbeddingsConfig, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/clustering/cmds/generate_seeds.py b/docgenie/analyzation/clustering/cmds/generate_seeds.py new file mode 100755 index 0000000000000000000000000000000000000000..fb4920c6934e2dcd5e1eb36e4595b2d69a11df16 --- /dev/null +++ b/docgenie/analyzation/clustering/cmds/generate_seeds.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import pydantic.v1 as pydantic +import pydantic_argparse +import tqdm + +from docgenie import ENV +from docgenie.analyzation.clustering.core._embeddings import ( + _load_sample_ids_from_embeddings, +) +from docgenie.analyzation.clustering.core._utilities import ( + EmbeddingType, + _get_clustering_output_path, + _visualize_images_grid, +) +from docgenie.logging import get_logger + +if TYPE_CHECKING: + import numpy as np + + +logger = get_logger(__name__) + + +def alpha_cluster_sampling_create_pool( + cluster_labels: np.ndarray, + max_seed_pool: int = -1, +) -> np.ndarray: + """ + Create a pool of candidate seed images for LLM prompt construction. + + The pool is sampled **proportional to cluster sizes**, ensuring that + each cluster is represented at least once if possible. This prevents + small clusters from being entirely excluded from the pool. + + Args: + cluster_labels: np.ndarray of cluster labels for all samples. + max_seed_pool: int, maximum number of seed images to select for the pool. + - If -1 or larger than the dataset, the full dataset is used. + + Returns: + np.ndarray: indices of samples included in the pool. + """ + n_samples = len(cluster_labels) + unique_labels = np.unique(cluster_labels) + + # Use full dataset if max_seed_pool is -1 or larger than dataset + if max_seed_pool == -1 or max_seed_pool >= n_samples: + return np.arange(n_samples) + + # Step 1: guarantee one sample per cluster + guaranteed_indices = [ + np.random.choice(np.where(cluster_labels == label)[0]) + for label in unique_labels + ] + + remaining = max_seed_pool - len(guaranteed_indices) + if remaining <= 0: + # pool is smaller than number of clusters: return guaranteed samples + return np.array(guaranteed_indices) + + # Step 2: sample remaining indices proportional to cluster sizes + cluster_sizes = { + label: np.sum(cluster_labels == label).item() for label in unique_labels + } + cluster_prob = { + label: size / sum(cluster_sizes.values()) + for label, size in cluster_sizes.items() + } + doc_prob = np.array([cluster_prob[cluster_labels[i]] for i in range(n_samples)]) + + # Exclude guaranteed indices + available_indices = np.setdiff1d(np.arange(n_samples), guaranteed_indices) + available_prob = doc_prob[available_indices] + available_prob = available_prob / available_prob.sum() + + sampled_remaining = np.random.choice( + available_indices, size=remaining, replace=False, p=available_prob + ) + + pool_indices = np.concatenate([guaranteed_indices, sampled_remaining]) + return pool_indices + + +def alpha_cluster_sampling_pool( + cluster_labels: np.ndarray, + total_seeds: int, + pool_indices: np.ndarray, + alpha: float = 1.0, + seed_selection_strategy: str = "v1", +) -> tuple[list[int], list[int]]: + """ + Sample seeds from a pool using two-stage alpha-based cluster probabilities: + 1) Pick a cluster based on alpha weighting + 2) Pick a random sample from that cluster + + Args: + cluster_labels: np.ndarray of cluster labels for all samples + total_seeds: number of seeds to sample + pool_indices: available sample indices + alpha: exponent for cluster weighting + - alpha=1 -> proportional to cluster size + - alpha=0 -> uniform across clusters + - alpha<0 -> inverse-proportional to cluster size + + Returns: + Tuple of (sampled_indices, sampled_clusters) + """ + pool_labels = cluster_labels[pool_indices] + unique_labels = np.unique(pool_labels) + + # Compute cluster sizes in pool + cluster_sizes = { + label: np.sum(pool_labels == label).item() for label in unique_labels + } + + # Compute alpha-weighted cluster probabilities + cluster_probs = np.array([cluster_sizes[label] ** alpha for label in unique_labels]) + cluster_probs = cluster_probs / cluster_probs.sum() + + if seed_selection_strategy == "v1": + sampled_indices = [] + sampled_clusters = [] + + for _ in range(total_seeds): + # 1) Pick a cluster according to alpha probabilities + cluster = np.random.choice(unique_labels, p=cluster_probs) + + # 2) Pick a random sample from that cluster in the pool + cluster_pool_indices = pool_indices[pool_labels == cluster] + sample = np.random.choice(cluster_pool_indices) + + sampled_indices.append(int(sample)) + sampled_clusters.append(int(cluster)) + + return sampled_indices, sampled_clusters + + elif seed_selection_strategy == "v2": + sampled_indices = [] + sampled_clusters = [] + + cluster = np.random.choice(unique_labels, p=cluster_probs) + for _ in range(total_seeds): + cluster_pool_indices = pool_indices[pool_labels == cluster] + sample = np.random.choice(cluster_pool_indices) + + sampled_indices.append(int(sample)) + sampled_clusters.append(int(cluster)) + + return sampled_indices, sampled_clusters + else: + raise ValueError(f"Unknown seed selection strategy: {seed_selection_strategy}") + + +def generate_seeds_for_embedding_type( + cfg: GenerateSeedsConfig, embedding_type: EmbeddingType +) -> tuple[Path, Path]: + import random + + import numpy as np + import pandas as pd + + # set seed + np.random.seed(cfg.seed) + random.seed(cfg.seed) + + # get paths + output_dir = Path(cfg.clusters_dir) / cfg.dataset_name / embedding_type.value + embeddings_path = ( + Path(cfg.embeddings_dir) / cfg.dataset_name / f"{embedding_type.value}.h5" + ) + cluster_sample_ids = _load_sample_ids_from_embeddings(embeddings_path) + clusters_path = _get_clustering_output_path( + output_dir=output_dir, + intermediate_num_dims=cfg.intermediate_num_dims, + hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size, + hdbscan_metric=cfg.hdbscan_metric, + k_nn_n_neighbors=cfg.k_nn_n_neighbors, + do_knn=cfg.do_knn, + method=cfg.method, + ) + + # load data + # cluster_labels coresponds to the document indices in the dataset + cluster_results = np.load(clusters_path, allow_pickle=True).item() + cluster_labels = cluster_results["cluster_labels"] + logger.info(f"Cluster labels shape: {cluster_labels.shape}") + assert len(cluster_labels) == len(cluster_sample_ids), ( + "Mismatch in number of samples" + ) + + pool_indices = alpha_cluster_sampling_create_pool( + cluster_labels=cluster_labels, max_seed_pool=cfg.max_pool_size + ) + seed_samples = [] + seed_clusters = [] + for _ in tqdm.tqdm(range(cfg.total_seed_runs), desc="Sampling seeds"): + sampled_seeds, sampled_clusters = alpha_cluster_sampling_pool( + cluster_labels=cluster_labels, + total_seeds=cfg.total_seeds_per_run, + pool_indices=pool_indices, + alpha=cfg.alpha, + seed_selection_strategy=cfg.seed_selection_strategy, + ) + sampled_seed_ids = [cluster_sample_ids[i] for i in sampled_seeds] + seed_samples.append(sampled_seed_ids) + seed_clusters.append(sampled_clusters) + + # save the sampled seeds + seeds_output_path = Path(cfg.output_dir) / clusters_path.name.replace( + ".npy", + f"_alpha={cfg.alpha}_max-pool-size={cfg.max_pool_size}_strategy={cfg.seed_selection_strategy}_seeds.csv", + ) + dataframe = pd.DataFrame(seed_samples) + logger.info(f"Saving sampled seeds to {seeds_output_path}...") + dataframe.to_csv(seeds_output_path, index=False) + + # save the sampled clusters + clusters_output_path = Path(cfg.output_dir) / clusters_path.name.replace( + ".npy", + f"_alpha={cfg.alpha}_max-pool-size={cfg.max_pool_size}_strategy={cfg.seed_selection_strategy}_clusters.csv", + ) + dataframe = pd.DataFrame(seed_clusters) + logger.info(f"Saving sampled seeds to {clusters_output_path}...") + dataframe.to_csv(clusters_output_path, index=False) + + # also visualize the random 20 seed documents as an image grid + # load all seed documents into an image grid + if cfg.visualize_seeds: + from docgenie.data import load_dataset + + dataset = load_dataset(cfg.dataset_name, split="train") + seed_images = [] + for seed in sampled_seeds[: cfg.n_seeds_to_visualize]: + seed_images.append( + dataset.train.get_by_id(cluster_sample_ids[seed]).image.content + ) + vis_fname = seeds_output_path.parent / seeds_output_path.name.replace( + ".csv", ".png" + ) + _visualize_images_grid( + images=seed_images, + save_path=vis_fname, + ) + + return seeds_output_path, clusters_output_path + + +class GenerateSeedsConfig(pydantic.BaseModel): + """ + Configuration for generating clustering seeds. + """ + + # same as clustering config + dataset_name: str + seed: int = 42 + hdbscan_min_cluster_size: int = 10 + intermediate_num_dims: int = 100 + hdbscan_metric: str = "euclidean" + do_knn: bool = True + k_nn_n_neighbors: int = 5 + embeddings_dir: str | Path = ENV.EMBEDDINGS_DIR + clusters_dir: str | Path = ENV.CLUSTERS_DIR + output_dir: str | Path + method: str = "hdbscan" # or "kmeans" + seed_selection_strategy: str = "v1" + + # specific to seed generation + total_seed_runs: int = 10000 + total_seeds_per_run: int = 10 + visualize_seeds: bool = False + n_seeds_to_visualize: int = 20 + + # sampling strategy + max_pool_size: int = -1 # if -1, seeds are selected from complete dataset, otherwise a pool is generated via proportional sampling, where it is ensured that each cluster is selected at least once + """ + sampling exponent for clusters. + - alpha=1 -> proportional + - alpha=0 -> uniform + - alpha<0 -> inverse-proportional + """ + alpha: float = 0 + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=GenerateSeedsConfig, + ) + generate_seeds_for_embedding_type(parser.parse_typed_args(), EmbeddingType.combined) diff --git a/docgenie/analyzation/clustering/cmds/load_seed_samples.py b/docgenie/analyzation/clustering/cmds/load_seed_samples.py new file mode 100755 index 0000000000000000000000000000000000000000..98e942c1f6609c6d9cd5511873295d1e991778b2 --- /dev/null +++ b/docgenie/analyzation/clustering/cmds/load_seed_samples.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from pathlib import Path +import pydantic.v1 as pydantic +import pydantic_argparse + +from docgenie import ENV +from docgenie.analyzation.clustering.core._embeddings import ( + _load_sample_ids_from_embeddings, +) +from docgenie.analyzation.clustering.core._utilities import ( + EmbeddingType, + _get_clustering_output_path, +) +from docgenie.logging import get_logger + + +logger = get_logger(__name__) + + +def main(cfg: LoadSeedSamples): + import pandas as pd + from docgenie.data import load_dataset + + for embedding_type in EmbeddingType.__members__.values(): + output_dir = Path(cfg.output_dir) / cfg.dataset_name / embedding_type.value + embeddings_path = ( + Path(cfg.embeddings_dir) / cfg.dataset_name / f"{embedding_type.value}.h5" + ) + sample_ids = _load_sample_ids_from_embeddings(embeddings_path) + clusters_path = _get_clustering_output_path( + output_dir=output_dir, + intermediate_num_dims=cfg.intermediate_num_dims, + hdbscan_min_cluster_size=cfg.hdbscan_min_cluster_size, + hdbscan_metric=cfg.hdbscan_metric, + k_nn_n_neighbors=cfg.k_nn_n_neighbors, + do_knn=cfg.do_knn, + method=cfg.method, + ) + seeds_output_path = clusters_path.parent / clusters_path.name.replace( + ".npy", f"_strategy={cfg.sampling_strategy}_seeds.csv" + ) + + # load the sampled seeds + dataset = load_dataset(cfg.dataset_name, split="train") + seed_sample_indices = pd.read_csv(seeds_output_path) + for _, row in seed_sample_indices.iterrows(): + # get seed samples from first row + sampled_seeds = row.tolist() + seed_sample_ids = [sample_ids[int(i)] for i in sampled_seeds] + samples = [dataset.train.get_by_id(sid) for sid in seed_sample_ids] + print(f"Loaded {len(samples)} seed samples from {seeds_output_path}") + print("Example sample: ", samples[0]) + break + + +class LoadSeedSamples(pydantic.BaseModel): + # same as clustering config + dataset_name: str + seed: int = 42 + hdbscan_min_cluster_size: int = 10 + intermediate_num_dims: int = 100 + hdbscan_metric: str = "euclidean" + do_knn: bool = True + k_nn_n_neighbors: int = 5 + embeddings_dir: str | Path = ENV.EMBEDDINGS_DIR + output_dir: str | Path = ENV.CLUSTERS_DIR + method: str = "hdbscan" # or "kmeans" + sampling_strategy: str = "uniform_cluster_sampling" + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=LoadSeedSamples, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/clustering/compute_best_clusterings.py b/docgenie/analyzation/clustering/compute_best_clusterings.py new file mode 100755 index 0000000000000000000000000000000000000000..4941fa639fd50a5f495e8c4558b135050041fdd3 --- /dev/null +++ b/docgenie/analyzation/clustering/compute_best_clusterings.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +""" +Compute top N clustering configurations per dataset +from a single global metrics file. + +Example: + python compute_best_clusterings_all_in_one.py \ + --metrics compactness__silhouette_score balance__entropy \ + --directions max max \ + --top 5 +""" + +import argparse +import pandas as pd +import numpy as np +from sklearn.preprocessing import MinMaxScaler +from pathlib import Path + +from docgenie import ENV + + +# -------------------------------------------------------------------- +# CONFIG +# -------------------------------------------------------------------- +METRICS_FILE = ENV.CLUSTERS_DIR / "metrics-seed=42.csv" + + +# -------------------------------------------------------------------- +# FUNCTIONS +# -------------------------------------------------------------------- +valid_datasets = [ + "cord", + "doclaynet_4k", + "ex_docvqa", + "ex_klc", + "ex_wiki", + "funsd", + "icdar2019", + "publaynet", + "rvlcdip", + "sroie", + "tobacco3482", +] + + +def compute_best_per_dataset(df, metrics, directions, top_n=5, filter_datasets=False): + """Compute top N configs per dataset for selected metrics.""" + results = [] + + for dataset, group in df.groupby("dataset_name"): + if filter_datasets and dataset not in valid_datasets: + continue + + df_norm = group.copy() + scaler = MinMaxScaler() + + # normalize + direction handling + for metric, direction in zip(metrics, directions): + if metric not in group.columns: + raise ValueError( + f"Metric '{metric}' not found in columns: {list(group.columns)}" + ) + + # normed = scaler.fit_transform(group[[metric]].values) + normed = group[[metric]].values + if direction == "min": + normed = 1 - normed # flip so higher is better + df_norm[metric] = normed + + df_norm["final_score"] = df_norm[metrics].mean(axis=1) + top = df_norm.sort_values("final_score", ascending=False).head(top_n) + top["dataset_name"] = dataset + results.append(top) + + combined = pd.concat(results, ignore_index=True) + return combined + + +# Compute final embedding ranking +def compute_embedding_ranking(top_df, top_n, filter_datasets): + """Aggregate top N positions across datasets per embedding type.""" + ranking_list = [] + + for dataset, group in top_df.groupby("dataset_name"): + if filter_datasets and dataset not in valid_datasets: + continue + + # Sort by final_score descending + group_sorted = group.sort_values("final_score", ascending=False).reset_index() + # Assign position-based score + group_sorted["rank_score"] = ( + top_n - group_sorted.index + ) # top row = top_n, next = top_n-1 ... + ranking_list.append( + group_sorted[["embedding_type", "min_cluster_size", "rank_score"]] + ) + + # Combine all datasets + all_scores = pd.concat(ranking_list) + # Sum scores per embedding type + final_ranking = ( + all_scores.groupby(["embedding_type", "min_cluster_size"], as_index=False)[ + "rank_score" + ] + .sum() + .reset_index() + ) + final_ranking = final_ranking.sort_values("rank_score", ascending=False) + final_ranking["final_rank"] = range(1, len(final_ranking) + 1) + + return final_ranking + + +# -------------------------------------------------------------------- +# MAIN +# -------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser( + description="Compute top N clustering configurations per dataset." + ) + parser.add_argument( + "--metrics", nargs="+", required=True, help="Metrics to consider" + ) + parser.add_argument( + "--directions", + nargs="+", + required=True, + help="Directions for each metric (max/min)", + ) + + parser.add_argument( + "--filter", + action="store_true", + help="If set, only take into account used datasets", + ) + + parser.add_argument( + "--min-cluster-size", + type=int, + help="Only consider rows with this min_cluster_size", + ) + parser.add_argument("--top", type=int, default=5, help="Top N results per dataset") + parser.add_argument( + "--outfile", default="best_clusterings_summary.csv", help="Output CSV path" + ) + args = parser.parse_args() + + if len(args.metrics) != len(args.directions): + parser.error("Number of metrics and directions must match.") + + print(f"📂 Loading metrics from {METRICS_FILE}") + df = pd.read_csv(METRICS_FILE) + + # Apply filter if specified + if args.min_cluster_size is not None: + df = df[df["min_cluster_size"] == args.min_cluster_size] + if df.empty: + print(f"⚠️ No rows found with min_cluster_size = {args.min_cluster_size}") + return + + print(f"✅ Found {len(df)} rows across {df['dataset_name'].nunique()} datasets") + combined = compute_best_per_dataset( + df, args.metrics, args.directions, args.top, filter_datasets=args.filter + ) + + # Select main display columns + cols_to_show = [ + "dataset_name", + "embedding_type", + "min_cluster_size", + "intermediate_dims", + "method", + *args.metrics, + "final_score", + ] + cols_to_show = [c for c in cols_to_show if c in combined.columns] + + print("\n=== Top results per dataset ===") + for ds, g in combined.groupby("dataset_name"): + print(f"\n--- {ds} ---") + print(g[cols_to_show]) + + # Save combined summary + out_path = Path(args.outfile) + combined.to_csv(out_path, index=False) + print(f"\n✅ Summary saved to {out_path.resolve()}") + + final_ranking = compute_embedding_ranking( + combined, top_n=args.top, filter_datasets=args.filter + ) + + print("\n=== Final Ranking of Embedding Types ===") + print(final_ranking) + + +if __name__ == "__main__": + main() diff --git a/docgenie/analyzation/clustering/core/_algorithms.py b/docgenie/analyzation/clustering/core/_algorithms.py new file mode 100755 index 0000000000000000000000000000000000000000..38224262f5e79a5f51142d3f096ffc1e2db32aeb --- /dev/null +++ b/docgenie/analyzation/clustering/core/_algorithms.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from docgenie.analyzation.clustering.core._utilities import ( + EmbeddingType, +) +from docgenie.analyzation.clustering.core._embeddings import ( + _load_embeddings, +) +from docgenie.logging import get_logger + +if TYPE_CHECKING: + import numpy as np + import torch + +logger = get_logger(__name__) + + +def _normalized_embeddings( + embeddings: np.ndarray, +) -> np.ndarray: + import numpy as np + + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + return embeddings / norms + + +def _reduce_embeddings_dims( + embeddings: torch.Tensor, + intermediate_num_dims: int = None, + reduce_dim_metric: str = "euclidean", + seed: int = None, +): + import math + + import umap + + if intermediate_num_dims is None: + intermediate_num_dims = math.floor(math.sqrt(embeddings.shape[1])) + + if intermediate_num_dims < embeddings.shape[1]: + logger.info( + f"Reducing embedding dimensions from {embeddings.shape[1]} to {intermediate_num_dims=} before clustering..." + ) + umap_engine = umap.UMAP( + n_components=intermediate_num_dims, + metric=reduce_dim_metric, + n_jobs=-1, + verbose=False, + random_state=seed, + ) + return umap_engine.fit_transform(embeddings) + return embeddings + + +def _run_hdbscan( + embeddings: torch.Tensor, + hdbscan_min_cluster_size: int = 10, + hdbscan_metric: str = "euclidean", + seed: int = None, +): + import hdbscan + import numpy as np + + approx_min_span_tree = True + if seed is not None: + np.random.seed(seed) + approx_min_span_tree = False # otherwise not deterministic + + logger.info("Running HDBSCAN...") + clusterer = hdbscan.HDBSCAN( + min_cluster_size=hdbscan_min_cluster_size, + metric=hdbscan_metric, + core_dist_n_jobs=-1, + approx_min_span_tree=approx_min_span_tree, + algorithm="best", + prediction_data=True, + ) + cluster_labels = clusterer.fit_predict(embeddings) + soft_clusters = hdbscan.all_points_membership_vectors(clusterer) + return soft_clusters, cluster_labels + + +def _run_knn( + embeddings: torch.Tensor, + cluster_labels: np.ndarray, + k_nn_n_neighbors: int = 5, +): + import copy + + from sklearn.neighbors import KNeighborsClassifier + + # train k-NN classifier + noise_mask = cluster_labels == -1 + non_noise_mask = cluster_labels != -1 + X_non_noise = embeddings[non_noise_mask] + y_non_noise = cluster_labels[non_noise_mask] + knn = KNeighborsClassifier(n_neighbors=k_nn_n_neighbors, n_jobs=-1) + knn.fit(X_non_noise, y_non_noise) + + X_noise = embeddings[noise_mask] + predicted_labels = knn.predict(X_noise) + + # assign predicted labels back to noise points + cluster_labels = copy.deepcopy(cluster_labels) + cluster_labels[noise_mask] = predicted_labels + + return cluster_labels + + +def _get_cached_reduced_embeddings( + embeddings: np.ndarray, + intermediate_num_dims: int, + reduce_dim_metric: str, + seed: int, + cache_dir: str = None, +) -> np.ndarray: + """Get reduced embeddings from cache or compute and cache them.""" + import os + import pickle + + if cache_dir is None: + # Compute without caching + return _reduce_embeddings_dims( + embeddings=embeddings, + intermediate_num_dims=intermediate_num_dims, + reduce_dim_metric=reduce_dim_metric, + seed=seed, + ) + + # Create cache key from parameters + cache_key = f"{intermediate_num_dims}_{reduce_dim_metric}_{seed}" + cache_file = os.path.join(cache_dir, f"reduced_embeddings_{cache_key}.pkl") + + # Try to load from cache + if os.path.exists(cache_file): + logger.info(f"Loading reduced embeddings from cache: {cache_file}") + with open(cache_file, "rb") as f: + return pickle.load(f) + + # Compute and cache + os.makedirs(cache_dir, exist_ok=True) + reduced_embeddings = _reduce_embeddings_dims( + embeddings=embeddings, + intermediate_num_dims=intermediate_num_dims, + reduce_dim_metric=reduce_dim_metric, + seed=seed, + ) + + with open(cache_file, "wb") as f: + pickle.dump(reduced_embeddings, f) + logger.info(f"Cached reduced embeddings to: {cache_file}") + + return reduced_embeddings + + +# layoutlm CLS token, clip, text, combined +def _read_and_cluster_embeddings( + embeddings_dir: str, + dataset_name: str, + embedding_type: EmbeddingType, + intermediate_num_dims: int = None, + hdbscan_min_cluster_size: int = 10, + hdbscan_metric: str = "euclidean", + method: str = "hdbscan", + n_kmeans_clusters: int = 150, + k_nn_n_neighbors: int = 5, + seed: int = 42, + do_knn: bool = True, + cache_dir: str = None, +) -> dict: + """ + Read embeddings from H5PY file, reduce dimensions, and cluster them. + + This function first loads the embeddings from an H5PY file, normalizes them to unit length, + and then reduces their dimensions using UMAP if specified (by default we always use umap). + It then applies the chosen clustering algorithm (HDBSCAN or KMeans) to the reduced embeddings. Usually we only + use HDBSCAN currently with KNN, and optionally apply k-NN to label noise points. Without KNN, HDBSCAN returns + clusters with noise points associated a label of -1. KMeans is also supported as an alternative clustering method. + The function returns a dictionary containing the cluster labels, noise mask, number of noise points, + reduced embeddings, and soft cluster assignments. + + Args: + embeddings_dir (str): Directory where the embeddings H5PY file is located. + dataset_name (str): Name of the dataset (used to construct the file name). + embedding_type (EmbeddingType): Type of embeddings (layout, clip, text). + intermediate_num_dims (int, optional): Number of dimensions to reduce embeddings to before clustering. + If None, no dimensionality reduction is applied. Defaults to None. + hdbscan_min_cluster_size (int, optional): Minimum cluster size for HDBSCAN algorithm. + Defaults to 10. + hdbscan_metric (str, optional): Distance metric used by HDBSCAN algorithm. + Defaults to "euclidean". + method (str, optional): The clustering method to use ("hdbscan" or "kmeans"). Defaults to "hdbscan". + n_kmeans_clusters (int, optional): Number of clusters for KMeans algorithm. + Only used if method is "kmeans". Defaults to 150. + k_nn_n_neighbors (int, optional): Number of neighbors for k-NN algorithm. + Only used if method is "hdbscan" and do_knn is True. Defaults to 5. + seed (int, optional): Random seed for reproducibility. Defaults to 42. + do_knn (bool, optional): Whether to apply k-nearest neighbors processing. + Only used if method is "hdbscan". Defaults to True. + cache_dir (str, optional): Directory to cache reduced embeddings. + If None, no caching is done. Defaults to None. + """ + import numpy as np + import torch + from pathlib import Path + + # read the embeddings + embeddings, _ = _load_embeddings( + file_path=Path(embeddings_dir) / dataset_name / f"{embedding_type.value}.h5" + ) + embeddings = torch.from_numpy(embeddings) + + # normalize the embeddings + embeddings = _normalized_embeddings(embeddings) + + # we also reduce embeddings to embeddings_2d for visualization + # we only run it to cache the embeddings + _get_cached_reduced_embeddings( + embeddings=embeddings, + intermediate_num_dims=2, + reduce_dim_metric=hdbscan_metric, + seed=seed, + cache_dir=cache_dir, + ) + + # reduce embedding dimensions for clustering + embeddings_reduced_dim = _get_cached_reduced_embeddings( + embeddings=embeddings, + intermediate_num_dims=intermediate_num_dims, + reduce_dim_metric=hdbscan_metric, + seed=seed, + cache_dir=cache_dir, + ) + + # convert embeddings to double + embeddings_reduced_dim = embeddings_reduced_dim.astype(np.double) + + # normalize reduced embeddings + embeddings_reduced_dim = _normalized_embeddings(embeddings_reduced_dim) + + if method == "hdbscan": + # step 1: run the clustering algorithm on the embeddings + soft_clusters, cluster_labels = _run_hdbscan( + embeddings=embeddings_reduced_dim, + hdbscan_min_cluster_size=hdbscan_min_cluster_size, + hdbscan_metric=hdbscan_metric, + seed=seed, + ) + + # step 2: train k-NN on non-noise points + # select points that are not labeled as noise + num_noise = np.sum(cluster_labels == -1) + noise_mask = cluster_labels == -1 + + logger.info("Number of noise points: %d", num_noise) + + # return if not using k-NN to label noise points + if do_knn and num_noise > 0: + cluster_labels = _run_knn( + embeddings=embeddings_reduced_dim, + cluster_labels=cluster_labels, + k_nn_n_neighbors=k_nn_n_neighbors, + ) + + return { + "cluster_labels": cluster_labels, + "noise_mask": noise_mask, + "num_noise": num_noise, + "embeddings_reduced_dim": embeddings_reduced_dim, + "soft_clusters": soft_clusters, + } + elif method == "kmeans": + from sklearn.cluster import KMeans + + kmeans = KMeans(n_clusters=n_kmeans_clusters, random_state=seed, n_init="auto") + cluster_labels = kmeans.fit_predict(embeddings_reduced_dim) + soft_clusters = np.zeros((len(cluster_labels), n_kmeans_clusters)) + soft_clusters[np.arange(len(cluster_labels)), cluster_labels] = 1.0 + return { + "cluster_labels": cluster_labels, + "num_noise": 0, + "embeddings_reduced_dim": embeddings_reduced_dim, + "soft_clusters": soft_clusters, + } + else: + raise ValueError(f"Unknown clustering method: {method}") diff --git a/docgenie/analyzation/clustering/core/_embeddings.py b/docgenie/analyzation/clustering/core/_embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..6514a464fb332b1d1abf39855f706da451032a13 --- /dev/null +++ b/docgenie/analyzation/clustering/core/_embeddings.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING, Callable + +import tqdm + +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data._core._data_types import DocumentInstanceModelInput +from docgenie.logging import get_logger + +if TYPE_CHECKING: + import numpy as np + from torch.utils.data import DataLoader + +logger = get_logger(__name__) + + +def _iterate_dataset( + model_fn: Callable, + embedding_fn: Callable, + dataloader: "DataLoader", + device: str = "cpu", +): + """Inner function that actually generates the embeddings.""" + import torch + + model = model_fn() + model.to(device) + model.eval() + + sample_ids = [] + embeddings = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader, desc="Extracting embeddings"): + batch: DocumentInstanceModelInput + batch = batch.select_first_overflow_samples() + batch = batch.to(device) + + token_bboxes = batch.token_bboxes + if token_bboxes is not None: + if token_bboxes.min() >= 0 and token_bboxes.max() <= 1.0: + # if bboxes are normalized to [0, 1], convert to [0, 1000] as expected by layoutlmv3 + token_bboxes = (token_bboxes * 1000).long() + else: + logger.warning( + f"Token bboxes must be in the range [0, 1], but got min {token_bboxes.min()} and max {token_bboxes.max()}" + ) + token_bboxes = (token_bboxes.clip(0, 1.0) * 1000).long() + + # assert check + assert token_bboxes.min() >= 0 and token_bboxes.max() <= 1000, ( + f"Token bboxes must be in the range [0, 1000], but got min {token_bboxes.min()} and max {token_bboxes.max()}" + ) + + # make sure if image is normlized 0-1 as in layoutlm we renormalize using clip stats + assert batch.image.min() >= -1.1 and batch.image.max() <= 1.1, ( + f"Image pixel values must be in the range [0, 1], but got min {batch.image.min()} and max {batch.image.max()}" + ) + + # make inputs + inputs = dict( + input_ids=batch.token_ids, + bbox=token_bboxes, + attention_mask=batch.attention_mask, + pixel_values=batch.image, + words=batch.words, + ) + + embeddings.append(embedding_fn(model, inputs)) + + # in our preprocessed dataset indices are always unqiue + # but sample_ids may not be always unique in some rare cases + sample_ids.extend(batch.sample_id) + + embeddings = torch.cat(embeddings, dim=0) + return embeddings.cpu().numpy(), sample_ids + + +def _extract_layoutlm_embeddings( + dataloader: "DataLoader", + device: str = "cpu", +): + """Inner function that actually generates the embeddings.""" + + def model_fn(): + from transformers import ( + LayoutLMv3Model, + ) + + model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base") + model.to(device) + model.eval() + return model + + def embedding_fn(model, inputs): + outputs = model( + input_ids=inputs["input_ids"], + bbox=inputs["bbox"], + attention_mask=inputs["attention_mask"], + pixel_values=inputs["pixel_values"], + ) + return outputs.last_hidden_state[:, 0, :] + + embeddings, sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + return embeddings, sample_ids + + +def _extract_text_embeddings( + dataloader: "DataLoader", + device: str = "cpu", +): + """Inner function that actually generates the embeddings.""" + + def model_fn(): + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("all-mpnet-base-v2") + model.to(device) + model.eval() + return model + + def embedding_fn(model, inputs): + sentences = [" ".join(words_per_sample) for words_per_sample in inputs["words"]] + return model.encode(sentences, convert_to_tensor=True) + + embeddings, sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + return embeddings, sample_ids + + +def _extract_image_embeddings( + dataloader: "DataLoader", + device: str = "cpu", +): + """Inner function that actually generates the embeddings.""" + OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] + OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + + def model_fn(): + from transformers import ( + CLIPModel, + ) + + model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + model.to(device) + model.eval() + return model + + def embedding_fn(model, inputs): + from torchvision.transforms.functional import normalize + + # make sure if image is normlized 0-1 as in layoutlm we renormalize using clip stats + inputs["pixel_values"] = inputs["pixel_values"] * 0.5 + 0.5 # -1 to 1 to [0, 1] + inputs["pixel_values"] = normalize( + inputs["pixel_values"], mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD + ) + outputs = model.get_image_features(pixel_values=inputs["pixel_values"]) + return outputs.cpu() + + embeddings, sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + return embeddings, sample_ids + + +def _extract_paper_embeddings( + dataloader: "DataLoader", + device: str = "cpu", + paper_embedding_kernel_size: int = 4, +): + """Inner function that actually generates the embeddings.""" + + def model_fn(): + from transformers import ( + LayoutLMv3Model, + ) + + model = LayoutLMv3Model.from_pretrained("microsoft/layoutlmv3-base") + model.to(device) + model.eval() + return model + + def embedding_fn(model, inputs): + import torch + from torch import nn + + # do layoutlmv3 forward + outputs = model( + input_ids=inputs["input_ids"], + bbox=inputs["bbox"], + attention_mask=inputs["attention_mask"], + pixel_values=inputs["pixel_values"], + ) + + # get last last_hidden_state + last_hidden_state_batch = outputs.last_hidden_state + + # now apply paper embedding logic + pad_token_id = model.config.pad_token_id + num_image_tokens = (model.config.input_size // model.config.patch_size) ** 2 + embeddings = [] + for idx in range(last_hidden_state_batch.shape[0]): + last_hidden_state = last_hidden_state_batch[idx, :, :] # (L, D) + Lt = (inputs["input_ids"][idx] != pad_token_id).sum() # its a 1D tensor + text_embedding = last_hidden_state[:Lt, :] # (Lt, D) + # image_embedding_with_padding = last_hidden_state[Lt:, :] # (Lv, D) + image_embedding = last_hidden_state[-num_image_tokens:, :] # (Lv, D) + + # Step 1: Mean pooling of text embeddings + vt = text_embedding.mean(dim=0) # shape: (D,) + + # Step 2: 1D max-pooling on image embeddings to reduce feature dimension + # Reshape Hv to (Lv, 1, D) to apply 1D max-pooling along the feature dimension + Hv_reshaped = image_embedding.unsqueeze(1) # (Lv, 1, D) + maxpool = nn.MaxPool1d( + kernel_size=paper_embedding_kernel_size, + stride=paper_embedding_kernel_size, + ) + Hv_pooled = maxpool(Hv_reshaped) # (Lv, 1, N), N < D + Hv_pooled = Hv_pooled.squeeze(1) # shape: (Lv, N) + + # Step 3: Mean pooling of pooled image embeddings + vv = Hv_pooled.mean(dim=0) # shape: (N,) + + # Step 4: Concatenate text and pooled image embeddings + v = torch.cat([vt, vv], dim=0) # shape: (D + N,) + embeddings.append(v) + return torch.stack(embeddings, dim=0) # (B, D + N) + + embeddings, sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + return embeddings, sample_ids + + +def embedding_extraction_with_cache( + dataloader: "DataLoader", + output_dir: str | Path, + embedding_type: EmbeddingType, + device: str = "cpu", + cache_outputs: bool = True, +): + """Generic cacher function that handles caching logic for any embedding type.""" + cache_file = Path(output_dir) / f"{embedding_type.value}.h5" + if cache_outputs and cache_file.exists(): + logger.info( + f"Loading cached {embedding_type.value} embeddings from {cache_file}" + ) + return _load_embeddings(cache_file) + + # Generate new embeddings using the provided extraction function + if embedding_type == EmbeddingType.layout: + extraction_func = _extract_layoutlm_embeddings + embeddings, sample_ids = extraction_func(dataloader, device) + elif embedding_type == EmbeddingType.text: + extraction_func = _extract_text_embeddings + embeddings, sample_ids = extraction_func(dataloader, device) + elif embedding_type == EmbeddingType.image: + extraction_func = _extract_image_embeddings + embeddings, sample_ids = extraction_func(dataloader, device) + elif embedding_type == EmbeddingType.paper: + extraction_func = _extract_paper_embeddings + embeddings, sample_ids = extraction_func(dataloader, device) + else: + raise ValueError(f"Unsupported embedding type: {embedding_type}") + + if cache_outputs: + assert len(sample_ids) == embeddings.shape[0], ( + f"Number of sample IDs ({len(sample_ids)}) must match number of embeddings ({embeddings.shape[0]})" + ) + assert len(set(sample_ids)) == len(sample_ids), "Sample IDs must be unique" + _save_embeddings( + embeddings=embeddings, + sample_ids=sample_ids, + file_path=Path(output_dir) / f"{embedding_type.value}.h5", + ) + return _load_embeddings(cache_file) + + return embeddings, sample_ids + + +def _save_embeddings(embeddings: "np.ndarray", sample_ids: list[str], file_path: Path): + import h5py + + file_path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(file_path, "w") as f: + f.create_dataset("embeddings", data=embeddings) + f.create_dataset("sample_ids", data=sample_ids) + + +def _load_embeddings(file_path: Path): + import h5py + + with h5py.File(file_path, "r") as f: + sample_ids = f["sample_ids"][:] + embeddings = f["embeddings"][:] + return embeddings, [ + s.decode("utf-8") if isinstance(s, bytes) else s for s in sample_ids + ] + + +def _load_sample_ids_from_embeddings(file_path: Path): + import h5py + + with h5py.File(file_path, "r") as f: + sample_ids = f["sample_ids"][:] + return [ # decode and remove the index suffx + s.decode("utf-8") if isinstance(s, bytes) else s for s in sample_ids + ] diff --git a/docgenie/analyzation/clustering/core/_metrics.py b/docgenie/analyzation/clustering/core/_metrics.py new file mode 100755 index 0000000000000000000000000000000000000000..38fa58d85eee9d945cf4095b49bfd99c1493bfbe --- /dev/null +++ b/docgenie/analyzation/clustering/core/_metrics.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np + import pandas as pd + + +# Distance / Connectivity +def _normalized_connectivity(X, labels, n_neighbors=10): + """ + Normalized connectivity metric: measures if each point's nearest neighbors + are in the same cluster. 0 = perfect connectivity, 1 = worst. + + Parameters: + - X: data points (n_samples x n_features) + - labels: cluster labels + - n_neighbors: number of neighbors to consider + + Returns: + - normalized connectivity score (0-1) + """ + from sklearn.neighbors import NearestNeighbors + + n_samples = X.shape[0] + nbrs = NearestNeighbors(n_neighbors=n_neighbors + 1).fit(X) + distances, indices = nbrs.kneighbors(X) + + # Exclude self from neighbors + indices = indices[:, 1:] + score = 0 + for i in range(n_samples): + for j in indices[i]: + if labels[i] != labels[j]: + score += 1 / n_neighbors # penalize different cluster + + # Maximum possible score is n_samples (each point has all neighbors in other clusters) + max_score = n_samples + normalized_score = score / max_score + return normalized_score + + +# Compactness / Separation +def _cluster_compactness_scores(embeddings, labels): + """ + Compute compactness scores for clusters using various metrics. + """ + from sklearn.metrics import ( + calinski_harabasz_score, + davies_bouldin_score, + silhouette_score, + ) + + return { + "silhouette_score": silhouette_score(embeddings, labels), + "calinski_harabasz_score": calinski_harabasz_score(embeddings, labels), + "davies_bouldin_score": davies_bouldin_score(embeddings, labels), + } + + +# Balance / Size Equity +def _cluster_balance_scores(cluster_sizes): + """ + Compute balance scores for clusters using various metrics. + """ + import numpy as np + import scipy + + sizes = np.array(cluster_sizes) + entropy = scipy.stats.entropy(sizes) + norm_entropy = entropy / np.log(len(sizes)) + + # Coefficient of variation + cv = sizes.std() / sizes.mean() + mmr = sizes.min() / sizes.max() + + # Gini coefficient + sorted_sizes = np.sort(sizes) + n = len(sizes) + gini = ( + 2 * np.sum((np.arange(1, n + 1)) * sorted_sizes) / (n * sorted_sizes.sum()) + ) - (n + 1) / n + + return { + "entropy": norm_entropy.item(), + "coefficient_of_variation": cv.item(), + "min-to-max-ratio": mmr.item(), + "gini-coefficient": gini.item(), + } + + +def evaluate_clusters_unsupervised( + embeddings: np.ndarray, cluster_labels: np.ndarray +) -> tuple[dict[str, float], int]: + """ + Evaluate clustering quality using unsupervised metrics. + """ + import numpy as np + import torch + + if isinstance(embeddings, torch.Tensor): + embeddings = embeddings.numpy() + + unique_entries, counts = np.unique(cluster_labels, return_counts=True) + result = dict() + result["connectivity"] = { + "normalized_connectivity": _normalized_connectivity( + X=embeddings, + labels=cluster_labels, + n_neighbors=int(embeddings.shape[0] * 0.01), + ), + } + result["compactness"] = _cluster_compactness_scores( + embeddings=embeddings, labels=cluster_labels + ) + result["balance"] = _cluster_balance_scores(counts) + return result, len(unique_entries) + + +def calculate_cluster_statistics( + embeddings: np.ndarray, cluster_labels: np.ndarray +) -> "pd.DataFrame": + """ + Calculate statistics for each cluster, including size and variance. + Variance is computed as the average pairwise cosine distance within the cluster. + """ + import numpy as np + import pandas as pd + from sklearn.metrics.pairwise import cosine_similarity + + unique_clusters = set(cluster_labels) + cluster_stats = [] + for cluster_id in unique_clusters: + cluster_mask = cluster_labels == cluster_id + cluster_embeddings = embeddings[cluster_mask] + cluster_size = len(cluster_embeddings) + sim_matrix = cosine_similarity(cluster_embeddings) + cosine_distances = 1 - sim_matrix[np.triu_indices_from(sim_matrix, k=1)] + cosine_diversity = np.mean(cosine_distances) + cluster_stats.append( + { + "cluster_id": cluster_id, + "size": cluster_size, + "variance": cosine_diversity, + } + ) + return pd.DataFrame(cluster_stats) diff --git a/docgenie/analyzation/clustering/core/_utilities.py b/docgenie/analyzation/clustering/core/_utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..d04d9dbd0b3e9ca154d10f170b3c5eeeaa7556a5 --- /dev/null +++ b/docgenie/analyzation/clustering/core/_utilities.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +import enum +from pathlib import Path +from typing import TYPE_CHECKING + +import pandas as pd + +from docgenie.logging import get_logger + +if TYPE_CHECKING: + import numpy as np + from PIL.Image import Image + +logger = get_logger(__name__) + + +if TYPE_CHECKING: + import torch + + +class EmbeddingType(str, enum.Enum): + """ + Enum for different types of embeddings used in DocGenie. + """ + + layout = "layout" + image = "image" + text = "text" + combined = "combined" + paper = "paper_kernel=4" + + +def _glob_clustering_output_paths(output_dir: str | Path): + """ + List all clustering output files in the specified directory. + + Args: + output_dir (str | Path): The directory to search for clustering output files. + This must point to the `output_directory/dataset_name/embedding_type` level. + """ + output_path = Path(output_dir) + return list( + output_path.glob("method=*_clusters_ind=*_hmcs=*_hm=*_do_knn=*_knn=*.npy") + ) + + +def _get_clustering_output_path( + output_dir: str | Path, + intermediate_num_dims: int, + hdbscan_min_cluster_size: int = 10, + hdbscan_metric: str = "euclidean", + do_knn: bool = True, + k_nn_n_neighbors: int = 5, + method: str = "hdbscan", +): + """ + Generate a standardized file path for clustering output results. + + This function creates a descriptive filename that encodes all the clustering + parameters used, allowing for easy identification and retrieval of clustering + results based on the specific configuration. + + Args: + output_dir (str | Path): The base directory where clustering results will be saved. + This must point to the `output_directory/dataset_name/embedding_type` level. + intermediate_num_dims (int): The number of dimensions used in intermediate processing. + hdbscan_min_cluster_size (int, optional): Minimum cluster size for HDBSCAN algorithm. + Defaults to 10. + hdbscan_metric (str, optional): Distance metric used by HDBSCAN algorithm. + Defaults to "euclidean". + do_knn (bool, optional): Whether to apply k-nearest neighbors processing. + Defaults to True. + k_nn_n_neighbors (int, optional): Number of neighbors for k-NN algorithm. + Defaults to 5. + method (str, optional): The clustering method being used. Defaults to "hdbscan". + + Returns: + Path: A Path object pointing to the clustering output file with encoded parameters + in the filename format: method={method}_clusters_ind={intermediate_num_dims}_ + hmcs={hdbscan_min_cluster_size}_hm={hdbscan_metric}_do_knn={do_knn}_ + knn={k_nn_n_neighbors}.npy + """ + return ( + output_dir + / f"method={method}_clusters_ind={intermediate_num_dims}_hmcs={hdbscan_min_cluster_size}_hm={hdbscan_metric}_do_knn={do_knn}_knn={k_nn_n_neighbors}.npy" + ) + + +def _save_clustering_metrics( + output_dir: str | Path, + dataset_name: str, + hdbscan_min_cluster_size: int, + intermediate_num_dims: int, + hdbscan_metric: str, + k_nn_n_neighbors: int, + method: str, + embedding_type: "EmbeddingType", + embeddings: "np.ndarray", + cluster_metrics: dict, + num_clusters: int, + num_noise: int, + seed: int, + do_knn: bool = True, +) -> None: + import hashlib + + import torch + + cnt = embeddings.shape[0] + noise_percent = num_noise / float(cnt) + noise_percent = ( + noise_percent.item() + if isinstance(noise_percent, torch.Tensor) + else noise_percent + ) + metrics_row = { + "dataset_name": dataset_name, + "embedding_type": embedding_type.value, + "min_cluster_size": hdbscan_min_cluster_size, + "intermediate_dims": intermediate_num_dims, + "hdbscan_metric": hdbscan_metric, + "k_nn_n_neighbors": k_nn_n_neighbors, + "num_clusters": num_clusters, + "num_noise": num_noise, + "noise_percent": noise_percent, + "do_knn": do_knn, + "method": method, + } + + # Add cluster metrics to the row + for cat, items in cluster_metrics.items(): + for k, v in items.items(): + metrics_row[f"{cat}__{k}"] = v + + # Generate unique hash based on configuration parameters only (excluding results) + config_items = { + "dataset_name": dataset_name, + "embedding_type": embedding_type.value, + "min_cluster_size": hdbscan_min_cluster_size, + "intermediate_dims": intermediate_num_dims, + "hdbscan_metric": hdbscan_metric, + "k_nn_n_neighbors": k_nn_n_neighbors, + "seed": seed, + "do_knn": do_knn, + "method": method, + } + row_hash = hashlib.md5(str(sorted(config_items.items())).encode()).hexdigest() + metrics_row["row_hash"] = row_hash + + # Save metrics + metrics_path = Path(output_dir) / f"metrics-seed={seed}.csv" + if metrics_path.exists(): + df = pd.read_csv(metrics_path) + df = df[df["row_hash"] != row_hash] + df = pd.concat([df, pd.DataFrame([metrics_row])], ignore_index=True) + else: + df = pd.DataFrame([metrics_row]) + + logger.info(f"Saving clustering metrics to {metrics_path}...") + df.to_csv(metrics_path, index=False) + + +def _visualize_images_grid( + images: list[np.ndarray | "Image"], + save_path: str | Path, + nrow: int = 8, + title: str | None = None, + figsize: tuple[int, int] = (12, 8), + dpi: int = 150, +) -> None: + """ + Create and save an image grid using torchvision's make_grid utility. + + Args: + images: List of numpy arrays or PIL images to arrange in grid + save_path: Path where the grid image will be saved + nrow: Number of images displayed in each row of the grid + title: Optional title for the saved image + figsize: Figure size for matplotlib + dpi: DPI for saved image + """ + import matplotlib.pyplot as plt + import numpy as np + import torch + import torchvision.transforms as transforms + from torchvision.transforms.functional import resize + from torchvision.utils import make_grid + + # Convert inputs to tensors + tensor_images = [] + for img in images: + if isinstance(img, np.ndarray): + # Handle different numpy array formats + if img.ndim == 2: # Grayscale + img = np.expand_dims(img, axis=0) # Add channel dimension + elif img.ndim == 3 and img.shape[2] == 3: # RGB with channels last + img = np.transpose(img, (2, 0, 1)) # Convert to channels first + elif img.ndim == 3 and img.shape[0] in [1, 3]: # Already channels first + pass + else: + raise ValueError(f"Unsupported numpy array shape: {img.shape}") + + tensor = torch.from_numpy(img).float() + else: # PIL Image + transform = transforms.ToTensor() + tensor = transform(img) + + tensor = resize(tensor, size=(512, 512)) # Resize to fixed size + tensor_images.append(tensor) + + # Stack all tensors + batch_tensor = torch.stack(tensor_images) + + # Create grid + grid = make_grid( + batch_tensor, + nrow=nrow, + ) + + # Convert to numpy for matplotlib (channels last) + grid_np = grid.permute(1, 2, 0).numpy() + + # Create matplotlib figure + fig, ax = plt.subplots(figsize=figsize) + ax.imshow(grid_np) + ax.axis("off") + + if title: + ax.set_title(title, fontsize=16, pad=20) + + # Save the figure + plt.tight_layout() + plt.savefig(save_path, dpi=dpi, bbox_inches="tight", pad_inches=0.1) + plt.close() + + logger.info(f"Image grid saved to {save_path}") + + +def _load_pdfs_to_pil_images(pdf_paths: list[str | Path]) -> list["Image"]: + """ + Loads a list of PDF document paths to PIL Images by rendering the first page of each PDF as PNG. + + Args: + pdf_paths: List of paths to PDF files + + Returns: + List of PIL Image objects, one for each PDF's first page + """ + from pdf2image import convert_from_path + + pil_images = [] + + for pdf_path in pdf_paths: + try: + # Convert first page of PDF to PIL Image + images = convert_from_path(str(pdf_path), first_page=1, last_page=1, dpi=72) + + if images: + pil_images.append(images[0]) + else: + logger.warning(f"No images converted from PDF: {pdf_path}") + + except Exception as e: + logger.error(f"Failed to convert PDF {pdf_path}: {e}") + continue + + return pil_images diff --git a/docgenie/analyzation/clustering/utils.py b/docgenie/analyzation/clustering/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..d127a9763d81c60e8f49e8268d7cd4142b61954d --- /dev/null +++ b/docgenie/analyzation/clustering/utils.py @@ -0,0 +1,24 @@ + + +import h5py +import numpy as np +from tqdm import tqdm + +from docgenie import ENV + + +def read_embeddings_numpy(dataset_name: str, embeddings_type: str, kernel_size: int = None) -> np.ndarray: + all_embeddings = [] + fname = f'{dataset_name}_{embeddings_type}' + if embeddings_type == 'paper': + fname += f'_kernel={kernel_size}' + + fpath = ENV.EMBEDDINGS_DIR / f'{fname}.h5' + with h5py.File(fpath, "r") as f: + for id_ in tqdm(sorted(f.keys())): + emb = f[id_][:] # load tensor in numpy format + all_embeddings.append(emb) + + # Vertically stack along the first dimension + X = np.vstack(all_embeddings) + return X \ No newline at end of file diff --git a/docgenie/analyzation/clustering/webapp/__init__.py b/docgenie/analyzation/clustering/webapp/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..ebe0d18b611b4c8fdc6378618de71d7bb7fabd12 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/__init__.py @@ -0,0 +1,11 @@ +""" +Document clustering visualization web application. + +This package provides an interactive Dash web application for visualizing +document clustering results with scatter plots, cluster analysis, and +document preview capabilities. +""" + +from .app import create_app, main + +__all__ = ["create_app", "main"] diff --git a/docgenie/analyzation/clustering/webapp/_deprecated/visualize_clusters.py b/docgenie/analyzation/clustering/webapp/_deprecated/visualize_clusters.py new file mode 100755 index 0000000000000000000000000000000000000000..55876078d5b3624db4488e8366f9869467f6de30 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/_deprecated/visualize_clusters.py @@ -0,0 +1,634 @@ +""" +Document Clustering Visualization Dashboard + +A refactored modular version of the clustering visualization tool. +This file serves as the main entry point and maintains backward compatibility. +""" + +from .app import main + +if __name__ == "__main__": + main() +from flask import Response + +# from flask import send_from_directory +from plotly.subplots import make_subplots + +from docgenie import ENV +from docgenie.analyzation.clustering.core._utilities import ( + EmbeddingType, + _get_clustering_output_path, +) +from docgenie.data import load_dataset + +# -------------------------- +# Dash app + server route to serve PDFs +# -------------------------- +app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) +server = app.server + + +@server.route("/image/") +def serve_image(index): + global dataset + image = dataset.train[int(index)].image.content + img_io = io.BytesIO() + image.save(img_io, "PNG") + img_io.seek(0) + return Response(img_io.getvalue(), mimetype="image/png") + + +@server.route("/cluster_grid/") +def serve_cluster_grid(indices_list): + """Create and serve a grid image from multiple document PDFs.""" + import io + + from flask import Response + from PIL import Image, ImageDraw, ImageFont + + try: + print(f"Creating grid for doc IDs: {indices_list}") + # Parse document IDs from comma-separated string + indices_list = indices_list.split(",")[:12] # Limit to 12 for performance + + # Grid dimensions + cols = min(4, len(indices_list)) + rows = (len(indices_list) + cols - 1) // cols + + # Image dimensions + thumb_width, thumb_height = 200, 280 + grid_width = cols * thumb_width + (cols - 1) * 10 # 10px spacing + grid_height = rows * thumb_height + (rows - 1) * 10 + + # Create grid image + grid_img = Image.new("RGB", (grid_width, grid_height), "white") + for i, index in enumerate(indices_list): + row = i // cols + col = i % cols + x = col * (thumb_width + 10) + y = row * (thumb_height + 10) + + image = dataset.train[int(index)].image.content + + try: + # Resize to thumbnail + image.thumbnail( + (thumb_width, thumb_height - 30), Image.Resampling.LANCZOS + ) + + # Paste thumbnail into grid + grid_img.paste(image, (x, y)) + + # Add document ID label + draw = ImageDraw.Draw(grid_img) + try: + font = ImageFont.truetype( + "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12 + ) + except Exception: + font = ImageFont.load_default() + + text_y = y + image.height + 5 + draw.text((x, text_y), index, fill="black", font=font) + except Exception: + # Draw placeholder for failed PDF + draw = ImageDraw.Draw(grid_img) + draw.rectangle( + [x, y, x + thumb_width, y + thumb_height - 30], + outline="gray", + fill="lightgray", + ) + draw.text((x + 10, y + 10), f"Error loading\n{index}", fill="black") + + # Convert PIL Image to bytes for response + img_io = io.BytesIO() + grid_img.save(img_io, "PNG") + img_io.seek(0) + + return Response(img_io.getvalue(), mimetype="image/png") + + except Exception as e: + return f"Error creating grid: {str(e)}", 500 + + +# -------------------------- +# Cluster Analysis Functions +# -------------------------- +def create_cluster_visualization( + cluster_df: pd.DataFrame, + dataset_name: str, + cluster_labels: np.ndarray, +) -> go.Figure: + """Create a comprehensive visualization of cluster statistics with clickable clusters.""" + fig = make_subplots( + rows=4, + cols=1, + subplot_titles=( + "Cluster Sizes", + "Cluster Variances", + "Size vs Variance", + "Distribution", + ), + specs=[ + [{"type": "bar"}], + [{"type": "bar"}], + [{"type": "scatter"}], + [{"type": "histogram"}], + ], + ) + + # Prepare custom data for click events + cluster_indices = {} + for cluster_id in cluster_df["cluster_id"]: + indices = np.where(cluster_labels == cluster_id)[0].tolist() + cluster_indices[cluster_id] = indices + + # Plot 1: Cluster sizes (clickable) + fig.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["size"], + name="Size", + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + hovertemplate="Cluster %{x}
Size: %{y}
Click to view images", + ), + row=1, + col=1, + ) + + # Plot 2: Cluster variances + fig.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["variance"], + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + name="Variance", + ), + row=2, + col=1, + ) + + # Plot 3: Size vs Variance scatter (clickable) + fig.add_trace( + go.Scatter( + x=cluster_df["size"], + y=cluster_df["variance"], + mode="markers", + text=cluster_df["cluster_id"], + name="Clusters", + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + hovertemplate="Cluster %{text}
Size: %{x}
Variance: %{y}
Click to view images", + ), + row=3, + col=1, + ) + + # Plot 4: Size distribution + fig.add_trace( + go.Histogram(x=cluster_df["size"], name="Size Distribution"), row=4, col=1 + ) + + # Add JavaScript for click handling + fig.update_layout( + title_text=f"Cluster Analysis for {dataset_name}", showlegend=False, height=1200 + ) + + fig.update_layout( + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=40, r=40, t=40, b=40), + ) + _update_subplot_axes(fig) + + return fig + + +def _update_subplot_axes(fig: go.Figure) -> None: + """Update axes labels for all subplots.""" + fig.update_xaxes(title_text="Cluster ID", row=1, col=1) + fig.update_yaxes(title_text="Size", row=1, col=1) + fig.update_xaxes(title_text="Cluster ID", row=2, col=1) + fig.update_yaxes(title_text="Variance", row=2, col=1) + fig.update_xaxes(title_text="Size", row=3, col=1) + fig.update_yaxes(title_text="Variance", row=3, col=1) + fig.update_xaxes(title_text="Size", row=4, col=1) + fig.update_yaxes(title_text="Count", row=4, col=1) + + +# -------------------------- +# Globals +# -------------------------- +embedding_sources = [ + "paper_kernel=4", + "layout", + "image", + "text", + "combined", +] # example embedding models +intermediate_options = [100] +min_cluster_size_options = [5, 10] +dataset_options = [ + { + "label": name, + "value": name, + } + for name in os.listdir(ENV.CLUSTERS_DIR) +] +dataset_name = "tobacco3482" +dataset = load_dataset( + dataset_name=dataset_name, + split="train", +) + +seed = 42 +metric = "euclidean" +k_nn_n_neighbors = 5 +do_knn = False +labels = None +df = None + + +# -------------------------- +# Callbacks +# -------------------------- +@app.callback( + [ + Input("dataset-dropdown", "value"), + ], +) +def update_dataset(new_dataset_name): + global dataset, dataset_name + dataset_name = new_dataset_name + dataset = load_dataset( + dataset_name=dataset_name, + split="train", + ) + + +# -------------------------- +# Callbacks +# -------------------------- +@app.callback( + [ + Output("scatter", "figure"), + Output("cluster-analysis", "figure"), + ], + [ + Input("dataset-dropdown", "value"), + Input("intermediate-dropdown", "value"), + Input("min-cluster-size-dropdown", "value"), + Input("embedding-dropdown", "value"), + Input("method-dropdown", "value"), + # Input("cluster-size-bar", "clickData"), + ], +) +def update_scatter( + dataset_name, + intermediate_dims, + min_cluster_size, + embedding_src, + method, + # bar_click, +): + global labels, df + + output_dir = ENV.CLUSTERS_DIR / dataset_name / embedding_src + clusters_path = _get_clustering_output_path( + output_dir=output_dir, + intermediate_num_dims=intermediate_dims, + hdbscan_min_cluster_size=1 if method == "kmeans" else min_cluster_size, + hdbscan_metric=metric, + k_nn_n_neighbors=k_nn_n_neighbors, + method=method, + ) + # Create cache key from parameters + cluster_data = np.load(clusters_path, allow_pickle=True).item() + labels = cluster_data["cluster_labels"] + soft_clusters = cluster_data["soft_clusters"] + noise_mask = cluster_data.get("noise_mask", np.array([False] * len(labels))) + + cluster_stats = pd.read_csv( + clusters_path.parent / clusters_path.name.replace(".npy", "_stats.csv") + ) + + emb_2d_path = output_dir / f"reduced_embeddings_2_{metric}_{seed}.pkl" + # Try to load from cache + if not os.path.exists(emb_2d_path): + raise ValueError(f"2D embeddings not found: {emb_2d_path}") + with open(emb_2d_path, "rb") as f: + emb_2d = pickle.load(f) + + x, y = emb_2d[:, 0], emb_2d[:, 1] + df = pd.DataFrame( + { + "doc_id": np.arange(len(labels)), + "x": x, + "y": y, + "label": labels, + "prob": np.max(soft_clusters, axis=1), + "index": np.arange(len(labels)), + "noise_mask": noise_mask, + } + ) + + # Optional: filter by cluster if user clicked a bar + # if bar_click and "points" in bar_click: + # cluster_id = int(bar_click["points"][0]["x"]) + # df = df[df["label"] == cluster_id] + # # df_noise = df_noise[df_noise["label"] == cluster_id] + + # Create main (non-noise) scatter + fig = px.scatter( + df, + x="x", + y="y", + color="label", + hover_data={"index": True, "label": True, "doc_id": True}, + title=f"Embeddings ({embedding_src}) — Click a point to view its PDF", + ) + + fig.update_traces(marker=dict(size=7), customdata=df["index"]) + fig.update_layout( + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=20, r=20, t=40, b=20), + legend_title="Cluster", + ) + + # Cluster size bar chart + counts = pd.Series(labels).value_counts().sort_index() + df_counts = counts.reset_index() + df_counts.columns = ["Cluster", "Count"] + + # Create cluster analysis visualization + cluster_analysis_fig = create_cluster_visualization( + cluster_stats, dataset_name, labels + ) + + return fig, cluster_analysis_fig + + +@app.callback( + [ + Output("pdf-viewer", "src"), + Output("pdf-viewer", "hidden"), + Output("doc-info", "children"), + ], + [ + Input("scatter", "clickData"), + Input("cluster-analysis", "clickData"), + ], + prevent_initial_call=False, +) +def display_pdfs(scatter_click, cluster_click): + from dash import callback_context + + # Check which input triggered the callback + ctx = callback_context + if not ctx.triggered: + return "", True, "Click a point or cluster to view documents" + + trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] + + # Handle cluster click - show multiple documents in grid + if trigger_id == "cluster-analysis" and cluster_click: + try: + # Get cluster indices from customdata + cluster_indices = cluster_click["points"][0]["customdata"] + if not cluster_indices: + return "", True, "No documents in this cluster" + + # Limit to first 12 documents for performance + display_indices = cluster_indices[:12] + cluster_id = labels[display_indices[0]] + + # Create comma-separated list of document IDs for the grid endpoint + doc_indices_list = [str(idx) for idx in display_indices] + doc_ids_str = ",".join(doc_indices_list) + + # Add timestamp to force browser refresh + import time + + timestamp = int(time.time() * 1000) # milliseconds + + return ( + "", + True, + html.Div( + [ + html.H5( + f"Cluster {cluster_id} ({len(cluster_indices)} documents)" + ), + html.Img( + src=f"/cluster_grid/{doc_ids_str}?t={timestamp}", + style={ + "width": "100%", + "max-height": "600px", + "object-fit": "contain", + "border": "1px solid #ddd", + "border-radius": "4px", + }, + ), + ] + ), + ) + except Exception as e: + print(f"Error displaying cluster: {e}") + return "", True, f"Error displaying cluster: {e}" + + # Handle single document click from scatter plot + if trigger_id == "scatter" and scatter_click: + try: + idx = int(scatter_click["points"][0]["pointIndex"]) + return ( + f"/image/{idx}", + False, + html.Div( + [ + html.P(f"Index: {idx}"), + html.P(f"Cluster: {labels[idx]}"), + html.P(f"DocID: {idx}"), + ] + ), + ) + except Exception as e: + return "", True, f"Error selecting point: {e}" + + return "", True, "Click a point or cluster to view documents" + + +# -------------------------- +# Layout +# -------------------------- +app.layout = html.Div( + [ + html.Div( + [ + # Dataset + dbc.Row( + [ + dbc.Col( + [ + html.Label("Dataset", className="fw-bold"), + html.Div( + "Choose the dataset to analyze", + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id="dataset-dropdown", + options=dataset_options, + value="rvlcdip", + clearable=False, + ), + width=5, + ), + ] + ), + # Embedding Source + dbc.Row( + [ + dbc.Col( + [ + html.Label("Embedding source", className="fw-bold"), + html.Div( + "Which embedding model to use", + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id="embedding-dropdown", + options=[ + {"label": src, "value": src} + for src in embedding_sources + ], + value=embedding_sources[0], + clearable=False, + ), + width=5, + ), + ] + ), + # Intermediate dimensions + dbc.Row( + [ + dbc.Col( + [ + html.Label( + "Intermediate dimensions", className="fw-bold" + ), + html.Div( + "Projection size before clustering", + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id="intermediate-dropdown", + options=[ + {"label": str(d), "value": d} + for d in intermediate_options + ], + value=intermediate_options[0], + clearable=False, + ), + width=5, + ), + ] + ), + # Minimum cluster size + dbc.Row( + [ + dbc.Col( + [ + html.Label("Minimum cluster size", className="fw-bold"), + html.Div( + "Smallest allowed cluster size", + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id="min-cluster-size-dropdown", + options=[ + {"label": str(d), "value": d} + for d in min_cluster_size_options + ], + value=min_cluster_size_options[0], + clearable=False, + ), + width=5, + ), + ] + ), + # method + dbc.Row( + [ + dbc.Col( + [ + html.Label("Clustering method", className="fw-bold"), + html.Div( + "Which clustering algorithm to use", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id="method-dropdown", + options=[ + {"label": "k-means", "value": "kmeans"}, + {"label": "HDBSCAN", "value": "hdbscan"}, + ], + value="hdbscan", + clearable=False, + ), + width=5, + ), + ] + ), + ], + style={"gap": "15px"}, + ), + html.Div( + [ + dcc.Graph(id="scatter", style={"height": "700px"}), + # dcc.Graph(id="cluster-size-bar", style={"height": "400px"}), + dcc.Graph(id="cluster-analysis", style={"height": "800px"}), + ], + style={"width": "65%", "display": "inline-block", "verticalAlign": "top"}, + ), + html.Div( + [ + html.H4("Selected document"), + html.Div(id="doc-info", children="Click a point to open its PDF"), + html.Iframe( + id="pdf-viewer", + src="", + style={"width": "100%", "height": "700px"}, + hidden=True, + ), + ], + style={ + "width": "34%", + "display": "inline-block", + "paddingLeft": "10px", + "verticalAlign": "top", + }, + ), + ] +) + +# -------------------------- +if __name__ == "__main__": + app.run(debug=True, port=8055) diff --git a/docgenie/analyzation/clustering/webapp/_deprecated/visualize_metrics.py b/docgenie/analyzation/clustering/webapp/_deprecated/visualize_metrics.py new file mode 100755 index 0000000000000000000000000000000000000000..12bb3dc98fb5b14b25caaa689b89c59e0bdc515a --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/_deprecated/visualize_metrics.py @@ -0,0 +1,266 @@ +import dash +import dash_bootstrap_components as dbc +import pandas as pd +from dash import Input, Output, State, dash_table, dcc, html +from sklearn.preprocessing import MinMaxScaler + +from docgenie import ENV + +csv_fpath = ENV.CLUSTERS_DIR / "metrics-seed=42.csv" +df = pd.read_csv(csv_fpath) + +# Available metrics and their default optimization direction +METRICS = { + "num_clusters": "min", + "noise_percent": "min", + "connectivity__normalized_connectivity": "max", + "compactness__silhouette_score": "max", + "compactness__calinski_harabasz_score": "max", + "compactness__davies_bouldin_score": "min", + "balance__entropy": "max", + "balance__coefficient_of_variation": "min", + "balance__min-to-max-ratio": "max", + "balance__gini-coefficient": "min", +} + +METRIC_DESCRIPTIONS = { + "noise_percent (min)": "Proportion of points labeled as noise by HDBSCAN.", + "connectivity__normalized_connectivity (max)": "How connected clusters are (higher = more connected).", + "compactness__silhouette_score (max)": "Silhouette score (higher = better cluster separation).", + "compactness__calinski_harabasz_score (max)": "Calinski-Harabasz index (higher = better defined clusters).", + "compactness__davies_bouldin_score (min)": "Davies-Bouldin index (lower = better clustering).", + "balance__entropy (max)": "Entropy of cluster size distribution (higher = more balanced).", + "balance__coefficient_of_variation (min)": "Coefficient of variation of cluster sizes (lower = more balanced).", + "balance__min-to-max-ratio (max)": "Ratio of smallest to largest cluster size (higher = more balanced).", + "balance__gini-coefficient (min)": "Gini coefficient of cluster sizes (lower = more balanced).", +} + +# === Dash app === +app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) + + +@app.callback(Input("kernel-size-dropdown", "value")) +def update_direction_selectors(kernel_size): + global df + csv_fpath = ENV.CLUSTERS_DIR / "metrics-seed=42.csv" + df = pd.read_csv(csv_fpath) + print(f"Read {csv_fpath}") + + +app.layout = dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + html.H2( + "Clustering Evaluation Dashboard", className="text-center my-3" + ) + ) + ] + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.Alert( + [ + html.H5( + "How embeddings and clustering are created", + className="fw-bold", + ), + html.Ol( + [ + html.Li( + [ + "Embeddings are created akin to ", + html.A( + "Unsupervised Document and Template Clustering using Multimodal Embeddings", + href="https://arxiv.org/pdf/2506.12116", + target="_blank", + ), + ":", + html.Br(), + "Get mean of all text tokens, concatenate with image embedding. Image embedding is concatenation of all image patch tokens and then applying a kernel.", + ] + ), + html.Li( + "Embeddings are clustered in 2 stages: first HDBSCAN, the points labeled as noise (no cluster membership) are then assigned to identified clusters via k-NN" + ), + ] + ), + ], + color="light", + className="shadow-sm mb-4", + ) + ] + ) + ] + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.Card( + [ + dbc.CardHeader("Metric Selection"), + dbc.CardBody( + [ + html.Label("Choose metrics to evaluate:"), + dcc.Checklist( + id="metric-checklist", + options=[ + {"label": m, "value": m} + for m in METRICS.keys() + ], + value=[], + className="mb-3", + ), + html.Div(id="direction-selectors"), + dbc.Button( + "Compute Best Results", + id="compute-btn", + color="primary", + className="mt-3", + ), + ] + ), + ], + className="mb-4", + ) + ], + width=4, + ), + dbc.Col( + [ + dbc.Card( + [ + dbc.CardHeader("Top Results"), + dbc.CardBody( + [ + dash_table.DataTable( + id="results-table", + page_size=10, + style_table={"overflowX": "auto"}, + style_cell={ + "textAlign": "left", + "padding": "8px", + "font-family": "monospace", + }, + style_header={ + "fontWeight": "bold", + "backgroundColor": "#f8f9fa", + }, + style_data_conditional=[ + { + "if": {"state": "active"}, + "backgroundColor": "#e9ecef", + "border": "1px solid #adb5bd", + }, + ], + ) + ] + ), + ] + ) + ], + width=8, + ), + ] + ), + ], + fluid=True, +) + + +@app.callback( + Output("direction-selectors", "children"), Input("metric-checklist", "value") +) +def update_direction_selectors(selected_metrics): + """Show dropdowns for choosing min/max and a description for each selected metric.""" + controls = [] + for m in selected_metrics: + description = METRIC_DESCRIPTIONS.get(m, "") + controls.append( + dbc.Card( + [ + dbc.CardBody( + [ + dbc.Row( + [ + dbc.Col( + [ + html.Label(m, className="fw-bold"), + html.Div( + description, + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id={ + "type": "direction-dropdown", + "metric": m, + }, + options=[ + {"label": "Maximize", "value": "max"}, + {"label": "Minimize", "value": "min"}, + ], + value=METRICS[m], + clearable=False, + ), + width=5, + ), + ] + ) + ] + ) + ], + className="mb-2", + ) + ) + return controls + + +@app.callback( + Output("results-table", "data"), + Output("results-table", "columns"), + Input("compute-btn", "n_clicks"), + State("metric-checklist", "value"), + State({"type": "direction-dropdown", "metric": dash.ALL}, "value"), + State({"type": "direction-dropdown", "metric": dash.ALL}, "id"), +) +def compute_best_results(n_clicks, selected_metrics, directions, ids): + if n_clicks == 0 or not selected_metrics: + return [], [] + + # Map metrics to directions + metric_directions = {i["metric"]: d for i, d in zip(ids, directions)} + + # Copy for normalization but keep original df for output + df_norm = df.copy() + + for col in selected_metrics: + scaler = MinMaxScaler() + values = df[[col]].values + normed = scaler.fit_transform(values) + if metric_directions[col] == "min": + normed = 1 - normed # flip so higher is better + df_norm[col] = normed + + df_norm["final_score"] = df_norm[selected_metrics].mean(axis=1) + + # Select top rows + best_idx = df_norm.sort_values("final_score", ascending=False).index + best = df.loc[best_idx].copy() + best["final_score"] = df_norm.loc[best_idx, "final_score"] + + # Convert to table + columns = [{"name": c, "id": c} for c in best.columns] + data = best.to_dict("records") + return data, columns + + +if __name__ == "__main__": + app.run(debug=True, port=8052) diff --git a/docgenie/analyzation/clustering/webapp/app.py b/docgenie/analyzation/clustering/webapp/app.py new file mode 100755 index 0000000000000000000000000000000000000000..fb49b1249f3d71db94ff8ab8e15c1ec02302dff2 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/app.py @@ -0,0 +1,37 @@ +import dash +import dash_bootstrap_components as dbc + +from docgenie.analyzation.clustering.webapp.config import settings +from docgenie.analyzation.clustering.webapp.components import create_app_layout +from docgenie.analyzation.clustering.webapp.callbacks import register_callbacks +from docgenie.analyzation.clustering.webapp.server_routes import setup_server_routes +from docgenie.analyzation.clustering.webapp.data_manager import data_manager + + +def create_app(): + """Create and configure the Dash application.""" + app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP]) + + # Setup Flask server routes + setup_server_routes(app.server) + + # Load initial dataset + data_manager.load_dataset(settings.default_dataset) + + # Set layout + app.layout = create_app_layout() + + # Register callbacks + register_callbacks(app) + + return app + + +def main(): + """Main entry point.""" + app = create_app() + app.run(debug=settings.debug, port=settings.port) + + +if __name__ == "__main__": + main() diff --git a/docgenie/analyzation/clustering/webapp/callbacks.py b/docgenie/analyzation/clustering/webapp/callbacks.py new file mode 100755 index 0000000000000000000000000000000000000000..ee2200eb0b15fd0848f4eede43d78dbc32fe32fa --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/callbacks.py @@ -0,0 +1,457 @@ +import time +import numpy as np +from dash import Input, Output, html, callback_context, dcc +import dash_bootstrap_components as dbc +import dash +import dash_bootstrap_components as dbc +import pandas as pd +from dash import Input, Output, State, dash_table, dcc, html +import plotly.graph_objects as go +from .utils.save_utils import get_graph_save_path, save_plotly_figure +from sklearn.preprocessing import MinMaxScaler +from .data_manager import data_manager +from .visualizations import ( + create_scatter_plot, + create_cluster_analysis_plot, + generate_individual_cluster_plots, +) +from .config import settings + + +def register_callbacks(app): + """Register all Dash callbacks.""" + + @app.callback( + [Input("dataset-dropdown", "value")], + ) + def update_dataset(dataset_name: str): + """Update global dataset when dropdown changes.""" + data_manager.load_dataset(dataset_name) + + @app.callback( + [ + Output("scatter", "figure"), + Output("cluster-analysis", "figure"), + ], + [ + Input("dataset-dropdown", "value"), + Input("intermediate-dropdown", "value"), + Input("min-cluster-size-dropdown", "value"), + Input("embedding-dropdown", "value"), + Input("method-dropdown", "value"), + ], + ) + def update_visualizations( + dataset_name: str, + intermediate_dims: int, + min_cluster_size: int, + embedding_src: str, + method: str, + ): + """Update scatter plot and cluster analysis when parameters change.""" + # Get cluster data + cluster_data = data_manager.get_cluster_data( + dataset_name, embedding_src, intermediate_dims, min_cluster_size, method + ) + + labels = cluster_data["cluster_data"]["cluster_labels"] + soft_clusters = cluster_data["cluster_data"]["soft_clusters"] + noise_mask = cluster_data["cluster_data"].get( + "noise_mask", np.array([False] * len(labels)) + ) + cluster_stats = cluster_data["cluster_stats"] + emb_2d = cluster_data["emb_2d"] + sample_ids = cluster_data["sample_ids"] + + # Create scatter plot dataframe + df = data_manager.create_scatter_dataframe( + labels, emb_2d, soft_clusters, sample_ids, noise_mask + ) + + # Create visualizations + scatter_fig = create_scatter_plot( + df, + embedding_src, + dataset_name, + min_cluster_size, + len(set(cluster_data["cluster_data"]["cluster_labels"])), + ) + cluster_analysis_fig = create_cluster_analysis_plot( + cluster_stats, dataset_name, labels + ) + + return scatter_fig, cluster_analysis_fig + + @app.callback( + [ + Output("pdf-viewer", "src"), + Output("pdf-viewer", "hidden"), + Output("doc-info", "children"), + ], + [ + Input("scatter", "clickData"), + Input("cluster-analysis", "clickData"), + Input("dataset-dropdown", "value"), + Input("intermediate-dropdown", "value"), + Input("min-cluster-size-dropdown", "value"), + Input("embedding-dropdown", "value"), + Input("method-dropdown", "value"), + ], + prevent_initial_call=False, + ) + def display_documents( + scatter_click: dict, + cluster_click: dict, + dataset_name: str, + intermediate_dims: int, + min_cluster_size: int, + embedding_src: str, + method: str, + ): + """Handle document display for both single and cluster clicks.""" + ctx = callback_context + if not ctx.triggered: + return "", True, "Click a point or cluster to view documents" + + trigger_id = ctx.triggered[0]["prop_id"].split(".")[0] + + cluster_data = data_manager.get_cluster_data( + dataset_name, embedding_src, intermediate_dims, min_cluster_size, method + ) + labels = cluster_data["cluster_data"]["cluster_labels"] + sample_ids = cluster_data["sample_ids"] + # Handle cluster click - show grid of documents + if trigger_id == "cluster-analysis" and cluster_click: + return _handle_cluster_click(cluster_click, labels, sample_ids) + + # Handle single document click + if trigger_id == "scatter" and scatter_click: + return _handle_scatter_click(scatter_click, labels, sample_ids) + + return "", True, "Click a point or cluster to view documents" + + @app.callback( + Output("direction-selectors", "children"), Input("metric-checklist", "value") + ) + def update_direction_selectors(selected_metrics: list): + """Show dropdowns for choosing min/max and a description for each selected metric.""" + controls = [] + for m in selected_metrics: + description = settings.metrics_list[m]["description"] + controls.append( + dbc.Card( + [ + dbc.CardBody( + [ + dbc.Row( + [ + dbc.Col( + [ + html.Label(m, className="fw-bold"), + html.Div( + description, + className="text-muted small mb-2", + ), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id={ + "type": "direction-dropdown", + "metric": m, + }, + options=[ + { + "label": "Maximize", + "value": "max", + }, + { + "label": "Minimize", + "value": "min", + }, + ], + value=settings.metrics_list[m][ + "direction" + ], + clearable=False, + ), + width=5, + ), + ] + ) + ] + ) + ], + className="mb-2", + ) + ) + return controls + + @app.callback( + Output("results-table", "data"), + Output("results-table", "columns"), + Input("compute-btn", "n_clicks"), + State("metric-checklist", "value"), + State({"type": "direction-dropdown", "metric": dash.ALL}, "value"), + State({"type": "direction-dropdown", "metric": dash.ALL}, "id"), + ) + def compute_best_results(n_clicks, selected_metrics, directions, ids): + if n_clicks == 0 or not selected_metrics: + return [], [] + + # Map metrics to directions + metric_directions = {i["metric"]: d for i, d in zip(ids, directions)} + + # Copy for normalization but keep original df for output + df = data_manager.metrics.copy() + df_norm = df.copy() + + for col in selected_metrics: + scaler = MinMaxScaler() + values = df[[col]].values + normed = scaler.fit_transform(values) + if metric_directions[col] == "min": + normed = 1 - normed # flip so higher is better + df_norm[col] = normed + + df_norm["final_score"] = df_norm[selected_metrics].mean(axis=1) + + # Select top rows + best_idx = df_norm.sort_values("final_score", ascending=False).index + best = df.loc[best_idx].copy() + best["final_score"] = df_norm.loc[best_idx, "final_score"] + + # Convert to table + columns = [{"name": c, "id": c} for c in best.columns] + data = best.to_dict("records") + return data, columns + + @app.callback( + Output("embedding-overview-table", "data"), + Output("embedding-overview-table", "columns"), + Input("dataset-dropdown", "value"), + Input("min-cluster-size-dropdown", "value"), + ) + def update_embedding_overview(dataset_name, min_cluster_size): + """Compute summary metrics (num clusters, silhouette, entropy) for all embeddings.""" + # Placeholder for results + rows = [] + + for embedding_src in settings.embedding_sources: + try: + # Load cluster data for each embedding + cluster_data = data_manager.get_cluster_data( + dataset_name, + embedding_src, + settings.default_intermediate, + min_cluster_size, + settings.default_method, + ) + + labels = cluster_data["cluster_data"]["cluster_labels"] + + # Number of clusters (excluding noise if labeled as -1) + valid_labels = labels[labels >= 0] + n_clusters = len(np.unique(valid_labels)) + + # Silhouette score (skip if only 1 cluster) + if n_clusters > 1: + from sklearn.metrics import silhouette_score + + emb = cluster_data["emb_2d"] # or full embeddings if available + sil = silhouette_score(emb, labels) + else: + sil = np.nan + + # Entropy of cluster distribution + from scipy.stats import entropy + + cluster_sizes = np.bincount(valid_labels) + probs = cluster_sizes / cluster_sizes.sum() + ent = entropy(probs) + + rows.append( + dict( + embedding=embedding_src, + n_clusters=n_clusters, + silhouette=round(sil, 3) if not np.isnan(sil) else "—", + entropy=round(ent, 3), + ) + ) + + except Exception as e: + rows.append( + dict( + embedding=embedding_src, + n_clusters="Error", + silhouette="Error", + entropy=str(e), + ) + ) + + df = pd.DataFrame(rows) + columns = [{"name": c.replace("_", " ").title(), "id": c} for c in df.columns] + return df.to_dict("records"), columns + + @app.callback( + Output("save-feedback", "children"), + Input("save-all-graphs-btn", "n_clicks"), + [ + State("scatter", "figure"), + State("cluster-analysis", "figure"), + State("dataset-dropdown", "value"), + State("embedding-dropdown", "value"), + State("min-cluster-size-dropdown", "value"), + State("intermediate-dropdown", "value"), + State("method-dropdown", "value"), + ], + prevent_initial_call=True, + ) + def save_all_graphs( + n_clicks, + scatter_fig, + cluster_fig, + dataset_name, + embedding_src, + min_cluster_size, + intermediate_dims, + method, + ): + """Save scatter plot and the cluster-analysis subplots separately.""" + if not dataset_name or not embedding_src: + return dbc.Alert("Missing dataset or embedding source.", color="danger") + + saved_paths = [] + errors = [] + ext = "pdf" + + cluster_data = data_manager.get_cluster_data( + dataset_name, embedding_src, intermediate_dims, min_cluster_size, method + ) + + if scatter_fig: + try: + nclusters = len(set(cluster_data["cluster_data"]["cluster_labels"])) + fig = go.Figure(scatter_fig) + path = get_graph_save_path( + dataset_name, + "scatter", + embedding_src, + min_cluster_size, + ext=ext, + nclusters=nclusters, + ) + saved = save_plotly_figure(fig, path, fmt=ext) + saved_paths.append(saved) + except Exception as e: + errors.append(f"Scatter: {e}") + + try: + cluster_data = data_manager.get_cluster_data( + dataset_name, embedding_src, intermediate_dims, min_cluster_size, method + ) + nclusters = len(set(cluster_data["cluster_data"]["cluster_labels"])) + cluster_stats = cluster_data["cluster_stats"] + + """This function is generating cluster analysis figurs sepratley + because in the webapp there is on plot having further subplots + if I save that which I tried it looks odd like text is overlapping + so I created separate plots for it and save that separately.""" + per_plot_figs = generate_individual_cluster_plots( + cluster_stats, dataset_name=dataset_name + ) + + for plot_name, fig in per_plot_figs.items(): + try: + path = get_graph_save_path( + dataset_name, + plot_name, + embedding_src, + min_cluster_size, + ext=ext, + ) + saved = save_plotly_figure(fig, path, fmt=ext) + saved_paths.append(saved) + except Exception as e_plot: + errors.append(f"{plot_name}: {e_plot}") + except Exception as e: + errors.append(f"Cluster subplots generation failed: {e}") + + if errors and saved_paths: + msg = "
".join(errors) + return dbc.Alert( + f"Some graphs saved, some failed:
{msg}", + color="warning", + dismissable=True, + ) + elif errors and not saved_paths: + msg = "
".join(errors) + return dbc.Alert( + f"Failed to save graphs:
{msg}", color="danger", dismissable=True + ) + else: + return dbc.Alert( + f"Saved {len(saved_paths)} files:
" + "
".join(saved_paths), + color="success", + dismissable=True, + ) + + +def _handle_cluster_click(cluster_click: dict, labels, sample_ids): + """Handle cluster visualization clicks.""" + try: + cluster_indices = cluster_click["points"][0]["customdata"] + if not cluster_indices: + return "", True, "No documents in this cluster" + + # Limit documents for performance + display_indices = cluster_indices[:12] + + # Create grid URL + doc_ids_str = ",".join(str(sample_ids[idx]) for idx in display_indices) + timestamp = int(time.time() * 1000) + cluster_id = labels[display_indices[0]] + + return ( + "", + True, + html.Div( + [ + html.H5(f"Cluster {cluster_id} ({len(cluster_indices)} documents)"), + html.Img( + src=f"/cluster_grid/{doc_ids_str}?t={timestamp}", + style={ + "width": "100%", + "max-height": "600px", + "object-fit": "contain", + "border": "1px solid #ddd", + "border-radius": "4px", + }, + ), + ] + ), + ) + except Exception as e: + print(f"Error displaying cluster: {e}") + return "", True, f"Error displaying cluster: {e}" + + +def _handle_scatter_click(scatter_click, labels, sample_ids): + """Handle scatter plot clicks.""" + try: + idx = int(scatter_click["points"][0]["pointIndex"]) + sample_id = str(sample_ids[idx]) + return ( + f"/image/{sample_id}", + False, + html.Div( + [ + html.P(f"Index: {idx}"), + html.P(f"Cluster: {labels[idx]}"), + html.P(f"DocID: {idx}"), + ] + ), + ) + except Exception as e: + return "", True, f"Error selecting point: {e}" diff --git a/docgenie/analyzation/clustering/webapp/components.py b/docgenie/analyzation/clustering/webapp/components.py new file mode 100755 index 0000000000000000000000000000000000000000..25fa43932d11274c7ab74f78ff22eb265c80b23d --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/components.py @@ -0,0 +1,313 @@ +import dash_bootstrap_components as dbc +from dash import dcc, html + +from .config import settings +from dash import dash_table, dcc, html + + +def create_control_panel(): + """Create the main control panel with all dropdowns.""" + return html.Div( + [ + _create_dropdown_row( + "Dataset", + "Choose the dataset to analyze", + "dataset-dropdown", + [{"label": ds, "value": ds} for ds in settings.dataset_options], + settings.default_dataset, + ), + _create_dropdown_row( + "Embedding source", + "Which embedding model to use", + "embedding-dropdown", + [{"label": src, "value": src} for src in settings.embedding_sources], + settings.default_embedding, + ), + _create_dropdown_row( + "Intermediate dimensions", + "Projection size before clustering", + "intermediate-dropdown", + [{"label": str(d), "value": d} for d in settings.intermediate_options], + settings.default_intermediate, + ), + _create_dropdown_row( + "Minimum cluster size", + "Smallest allowed cluster size", + "min-cluster-size-dropdown", + [ + {"label": str(d), "value": d} + for d in settings.min_cluster_size_options + ], + settings.default_min_cluster_size, + ), + _create_dropdown_row( + "Clustering method", + "Which clustering algorithm to use", + "method-dropdown", + [ + {"label": "HDBSCAN", "value": "hdbscan"}, + ], + settings.default_method, + ), + ], + style={"gap": "15px"}, + ) + + +def _create_dropdown_row(label, description, dropdown_id, options, value): + """Create a standardized dropdown row.""" + return dbc.Row( + [ + dbc.Col( + [ + html.Label(label, className="fw-bold"), + html.Div(description, className="text-muted small mb-2"), + ], + width=7, + ), + dbc.Col( + dcc.Dropdown( + id=dropdown_id, + options=options, + value=value, + clearable=False, + ), + width=5, + ), + ] + ) + + +def create_visualization_panel(): + """Create main visualization panel with Save Graphs button.""" + return html.Div( + [ + # Top Row: Control Buttons + dbc.Row( + [ + dbc.Col( + dbc.Button( + "Save Graphs", + id="save-all-graphs-btn", + color="primary", + className="me-2", + style={"width": "100%"}, + ), + width=3, + ), + dbc.Col(html.Div(id="save-feedback", style={"marginTop": "5px"}), width=9), + ], + className="mb-3", + ), + + # Graphs + dcc.Graph(id="scatter", style={"height": "700px"}), + dcc.Graph(id="cluster-analysis", style={"height": "800px"}), + ], + style={"width": "65%", "display": "inline-block", "verticalAlign": "top"}, + ) + + + +def create_document_viewer(): + """Create the document viewer panel.""" + return html.Div( + [ + html.H4("Selected document"), + html.Div(id="doc-info", children="Click a point to open its document"), + html.Iframe( + id="pdf-viewer", + src="", + style={"width": "100%", "height": "700px"}, + hidden=True, + ), + ], + style={ + "width": "34%", + "display": "inline-block", + "paddingLeft": "10px", + "verticalAlign": "top", + }, + ) + + +def create_metrics_viewer(): + """Create the metrics evaluation panel.""" + return html.Div( + [ + dbc.Container( + [ + dbc.Row( + [ + dbc.Col( + html.H2( + "Clustering Evaluation Dashboard", + className="text-center my-3", + ) + ) + ] + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.Alert( + [ + html.H5( + "How embeddings and clustering are created", + className="fw-bold", + ), + html.Ol( + [ + html.Li( + [ + "Embeddings are created akin to ", + html.A( + "Unsupervised Document and Template Clustering using Multimodal Embeddings", + href="https://arxiv.org/pdf/2506.12116", + target="_blank", + ), + ":", + html.Br(), + "Get mean of all text tokens, concatenate with image embedding. Image embedding is concatenation of all image patch tokens and then applying a kernel.", + ] + ), + html.Li( + "Embeddings are clustered in 2 stages: first HDBSCAN, the points labeled as noise (no cluster membership) are then assigned to identified clusters via k-NN" + ), + ] + ), + ], + color="light", + className="shadow-sm mb-4", + ) + ] + ) + ] + ), + dbc.Row( + [ + dbc.Col( + [ + dbc.Card( + [ + dbc.CardHeader("Metric Selection"), + dbc.CardBody( + [ + html.Label( + "Choose metrics to evaluate:" + ), + dcc.Checklist( + id="metric-checklist", + options=[ + {"label": m, "value": m} + for m in settings.metrics_list.keys() + ], + value=[], + className="mb-3", + ), + html.Div(id="direction-selectors"), + dbc.Button( + "Compute Best Results", + id="compute-btn", + color="primary", + className="mt-3", + ), + ] + ), + ], + className="mb-4", + ) + ], + width=4, + ), + dbc.Col( + [ + dbc.Card( + [ + dbc.CardHeader("Top Results"), + dbc.CardBody( + [ + dash_table.DataTable( + id="results-table", + page_size=10, + style_table={ + "overflowX": "auto" + }, + style_cell={ + "textAlign": "left", + "padding": "8px", + "font_family": "monospace", + }, + style_header={ + "fontWeight": "bold", + "backgroundColor": "#f8f9fa", + }, + style_data_conditional=[ + { + "if": { + "state": "active" + }, + "backgroundColor": "#e9ecef", + "border": "1px solid #adb5bd", + }, + ], + ) + ] + ), + ] + ) + ], + width=8, + ), + ] + ), + ], + fluid=True, + ) + ] + ) + + +def create_overview_panel(): + """Overview table: shows per-embedding metrics for selected dataset.""" + return html.Div( + [ + html.H4("Embedding Overview", className="mt-4"), + html.Div( + [ + html.P( + "Shows summary metrics for each embedding method on the selected dataset:", + className="text-muted small", + ), + dash_table.DataTable( + id="embedding-overview-table", + style_table={"overflowX": "auto"}, + style_cell={ + "textAlign": "left", + "padding": "8px", + "font_family": "monospace", + }, + style_header={ + "fontWeight": "bold", + "backgroundColor": "#f8f9fa", + }, + ), + ] + ), + ], + style={"marginTop": "30px"}, + ) + + +def create_app_layout(): + """Create the complete app layout.""" + return html.Div( + [ + create_control_panel(), + create_overview_panel(), + create_visualization_panel(), + create_document_viewer(), + create_metrics_viewer(), + ] + ) diff --git a/docgenie/analyzation/clustering/webapp/config.py b/docgenie/analyzation/clustering/webapp/config.py new file mode 100755 index 0000000000000000000000000000000000000000..da90e5cfc2dd8da1949aa5705767c40fdd34a880 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/config.py @@ -0,0 +1,102 @@ +import os +from typing import List +from pydantic_settings import BaseSettings +from docgenie import ENV +from pathlib import Path + + +class AppSettings(BaseSettings): + # App Config + debug: bool = True + port: int = 8055 + graphs_base_dir: Path = ENV.CLUSTER_PLOTS + external_stylesheets: List[str] = [ + "https://cdn.jsdelivr.net/npm/bootstrap@5.1.3/dist/css/bootstrap.min.css" + ] + + # Clustering Options + embedding_sources: List[str] = [ + "paper_kernel=4", + "layout", + "image", + "text", + "combined", + ] + intermediate_options: List[int] = [100] + min_cluster_size_options: List[int] = [5, 10] + dataset_options: List[str] = sorted([d for d in os.listdir(ENV.CLUSTERS_DIR)]) + + # Default Values + default_dataset: str = "tobacco3482" + default_embedding: str = "paper_kernel=4" + default_intermediate: int = 100 + default_min_cluster_size: int = 5 + default_method: str = "hdbscan" + + # Clustering Params + seed: int = 42 + metric: str = "euclidean" + k_nn_n_neighbors: int = 5 + do_knn: bool = False + + # Grid Config + max_images: int = 12 + thumb_width: int = 200 + thumb_height: int = 280 + spacing: int = 10 + max_cols: int = 4 + + # Metric Descriptions and Optimization Direction + @property + def metrics_list(self) -> dict: + return { + "num_clusters": { + "direction": "min", + "description": "Total number of clusters formed (excluding noise).", + }, + "noise_percent": { + "direction": "min", + "description": "Proportion of points labeled as noise by HDBSCAN.", + }, + "connectivity__normalized_connectivity": { + "direction": "max", + "description": "How connected clusters are (higher = more connected).", + }, + "compactness__silhouette_score": { + "direction": "max", + "description": "Silhouette score (higher = better cluster separation).", + }, + "compactness__calinski_harabasz_score": { + "direction": "max", + "description": "Calinski-Harabasz index (higher = better defined clusters).", + }, + "compactness__davies_bouldin_score": { + "direction": "min", + "description": "Davies-Bouldin index (lower = better clustering).", + }, + "balance__entropy": { + "direction": "max", + "description": "Entropy of cluster size distribution (higher = more balanced).", + }, + "balance__coefficient_of_variation": { + "direction": "min", + "description": "Coefficient of variation of cluster sizes (lower = more balanced).", + }, + "balance__min-to-max-ratio": { + "direction": "max", + "description": "Ratio of smallest to largest cluster size (higher = more balanced).", + }, + "balance__gini-coefficient": { + "direction": "min", + "description": "Gini coefficient of cluster sizes (lower = more balanced).", + }, + } + + # Load metrics CSV + @property + def metrics_csv_path(self) -> str: + return str(ENV.CLUSTERS_DIR / f"metrics-seed={self.seed}.csv") + + +# Initialize settings +settings = AppSettings() diff --git a/docgenie/analyzation/clustering/webapp/data_manager.py b/docgenie/analyzation/clustering/webapp/data_manager.py new file mode 100755 index 0000000000000000000000000000000000000000..f222f1bf1698ac588b2b8c06121ffbd581faf5a0 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/data_manager.py @@ -0,0 +1,123 @@ +import os +import pickle +import numpy as np +import pandas as pd +from typing import Dict, Tuple, Optional + +from docgenie import ENV +from docgenie.analyzation.clustering.core._embeddings import ( + _load_sample_ids_from_embeddings, +) +from docgenie.data import load_dataset +from docgenie.analyzation.clustering.core._utilities import _get_clustering_output_path +from .config import settings +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +class DataManager: + """Manages dataset and clustering data loading.""" + + def __init__(self): + self.dataset = None + self.dataset_name = None + self.metrics = None + self.cluster_data_cache = {} + + def load_dataset(self, dataset_name: str): + """Load dataset and update internal state.""" + if self.dataset_name != dataset_name: + self.dataset = load_dataset(dataset_name=dataset_name, split="train") + self.dataset_name = dataset_name + self.metrics = pd.read_csv(settings.metrics_csv_path) + self.metrics = self.metrics[self.metrics["dataset_name"] == dataset_name] + + def get_cluster_data( + self, + dataset_name: str, + embedding_src: str, + intermediate_dims: int, + min_cluster_size: int, + method: str, + ) -> Dict: + """Load clustering results with caching.""" + cache_key = ( + dataset_name, + embedding_src, + intermediate_dims, + min_cluster_size, + method, + ) + + if cache_key not in self.cluster_data_cache: + output_dir = ENV.CLUSTERS_DIR / dataset_name / embedding_src + sample_ids = _load_sample_ids_from_embeddings( + file_path=ENV.EMBEDDINGS_DIR / dataset_name / f"{embedding_src}.h5" + ) + logger.info("Loading clustering results from %s", output_dir) + clusters_path = _get_clustering_output_path( + output_dir=output_dir, + intermediate_num_dims=intermediate_dims, + hdbscan_min_cluster_size=1 if method == "kmeans" else min_cluster_size, + hdbscan_metric=settings.metric, + k_nn_n_neighbors=settings.k_nn_n_neighbors, + method=method, + ) + + # Load cluster data + cluster_data = np.load(clusters_path, allow_pickle=True).item() + + # Load cluster statistics + stats_path = clusters_path.parent / clusters_path.name.replace( + ".npy", "_stats.csv" + ) + cluster_stats = pd.read_csv(stats_path) + + # Load 2D embeddings + emb_2d_path = ( + output_dir + / f"reduced_embeddings_2_{settings.metric}_{settings.seed}.pkl" + ) + if not os.path.exists(emb_2d_path): + raise ValueError(f"2D embeddings not found: {emb_2d_path}") + + with open(emb_2d_path, "rb") as f: + emb_2d = pickle.load(f) + + self.cluster_data_cache[cache_key] = { + "sample_ids": sample_ids, + "cluster_data": cluster_data, + "cluster_stats": cluster_stats, + "emb_2d": emb_2d, + } + + return self.cluster_data_cache[cache_key] + + def create_scatter_dataframe( + self, + labels: np.ndarray, + emb_2d: np.ndarray, + soft_clusters: np.ndarray, + sample_ids: np.ndarray, + noise_mask: Optional[np.ndarray] = None, + ) -> pd.DataFrame: + """Create DataFrame for scatter plot visualization.""" + if noise_mask is None: + noise_mask = np.array([False] * len(labels)) + + x, y = emb_2d[:, 0], emb_2d[:, 1] + return pd.DataFrame( + { + "doc_id": sample_ids, + "x": x, + "y": y, + "label": labels, + "prob": np.max(soft_clusters, axis=1), + "index": np.arange(len(labels)), + "noise_mask": noise_mask, + } + ) + + +# Global instance +data_manager = DataManager() diff --git a/docgenie/analyzation/clustering/webapp/server_routes.py b/docgenie/analyzation/clustering/webapp/server_routes.py new file mode 100755 index 0000000000000000000000000000000000000000..b7d5540d1592e3ceb4d7c9cca63ed5a3abfbba8e --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/server_routes.py @@ -0,0 +1,118 @@ +import io +from flask import Response +from PIL import Image, ImageDraw, ImageFont + +from .data_manager import data_manager +from .config import settings +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +def setup_server_routes(server): + """Setup Flask server routes.""" + + @server.route("/image/") + def serve_image(sample_id: str): + """Serve individual document image.""" + try: + sample = data_manager.dataset.train.get_by_id(sample_id) + assert sample.sample_id == sample_id, ( + f"Mismatched sample ID, found {sample.sample_id}, expected {sample_id}" + ) + image = sample.image.content + img_io = io.BytesIO() + image.save(img_io, "PNG") + img_io.seek(0) + return Response(img_io.getvalue(), mimetype="image/png") + except (ValueError, IndexError) as e: + return f"Invalid sample ID: {str(e)}", 404 + except Exception as e: + return f"Error serving image: {str(e)}", 500 + + @server.route("/cluster_grid/") + def serve_cluster_grid(sample_ids): + """Create and serve a grid image from multiple document images.""" + try: + sample_ids = _parse_ids(sample_ids) + grid_img = _create_grid_image(sample_ids) + return _image_response(grid_img) + except Exception as e: + return f"Error creating grid: {str(e)}", 500 + + +def _parse_ids(sample_ids): + """Parse and validate sample_ids list.""" + return sample_ids.split(",")[: settings.max_images] + + +def _create_grid_image(sample_ids): + """Create a grid image from document sample_ids.""" + cols = min(settings.max_cols, len(sample_ids)) + rows = (len(sample_ids) + cols - 1) // cols + + grid_width = cols * settings.thumb_width + (cols - 1) * settings.spacing + grid_height = rows * settings.thumb_height + (rows - 1) * settings.spacing + + grid_img = Image.new("RGB", (grid_width, grid_height), "white") + + for i, sample_id in enumerate(sample_ids): + _add_thumbnail_to_grid(grid_img, sample_id, i, cols) + + return grid_img + + +def _add_thumbnail_to_grid(grid_img, sample_id, position, cols): + """Add a single thumbnail to the grid.""" + row = position // cols + col = position % cols + x = col * (settings.thumb_width + settings.spacing) + y = row * (settings.thumb_height + settings.spacing) + + try: + sample = data_manager.dataset.train.get_by_id(sample_id) + assert sample.sample_id == sample_id, ( + f"Mismatched sample ID, found {sample.sample_id}, expected {sample_id}" + ) + image = sample.image.content + image.thumbnail( + (settings.thumb_width, settings.thumb_height - 30), Image.Resampling.LANCZOS + ) + grid_img.paste(image, (x, y)) + _add_label(grid_img, sample_id, x, y + image.height + 5) + except Exception as e: + logger.exception(f"Error loading image for sample ID {sample_id}") + _draw_error_placeholder(grid_img, sample_id, x, y) + + +def _add_label(grid_img, text, x, y): + """Add text label to the grid.""" + draw = ImageDraw.Draw(grid_img) + font = _get_font() + draw.text((x, y), text, fill="black", font=font) + + +def _get_font(): + """Get font for text rendering.""" + try: + return ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12) + except Exception: + return ImageFont.load_default() + + +def _draw_error_placeholder(grid_img, index, x, y): + """Draw error placeholder when image fails to load.""" + draw = ImageDraw.Draw(grid_img) + draw.rectangle( + [x, y, x + settings.thumb_width, y + settings.thumb_height - 30], + outline="gray", + fill="lightgray", + ) + draw.text((x + 10, y + 10), f"Error loading\n{index}", fill="black") + + +def _image_response(image): + """Convert PIL Image to Flask Response.""" + img_io = io.BytesIO() + image.save(img_io, "PNG") + img_io.seek(0) + return Response(img_io.getvalue(), mimetype="image/png") diff --git a/docgenie/analyzation/clustering/webapp/utils/save_utils.py b/docgenie/analyzation/clustering/webapp/utils/save_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..3d79e6f9d8ee18f0883b0a24073e04dd17e62644 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/utils/save_utils.py @@ -0,0 +1,50 @@ +import os +from pathlib import Path +import plotly.graph_objects as go +import plotly.io as pio +from datetime import datetime + +from ..config import settings + + +def ensure_dir(path: Path): + """Ensure directory exists.""" + path.mkdir(parents=True, exist_ok=True) + + +def get_graph_save_path( + dataset_name: str, + graph_type: str, + embedding_src: str, + min_cluster_size: int, + ext: str = "png", + nclusters: int | None = None, +) -> Path: + """ + Construct structured path to save graph image. + + Example: + graphs/tobacco3482/scatter/tobacco3482_scatter_paper_kernel=4_min5.png + """ + base_dir = Path(settings.graphs_base_dir) + dataset_dir = base_dir / dataset_name / graph_type + ensure_dir(dataset_dir) + + filename = f"{dataset_name}_{graph_type}_{embedding_src=}_{min_cluster_size=}_{nclusters=}.{ext}" + + return dataset_dir / filename + + +def save_plotly_figure( + fig: go.Figure, save_path: Path, fmt: str = "png", scale: int = 2 +): + """ + Save Plotly figure to disk. + Requires `kaleido` to be installed. + """ + ensure_dir(save_path.parent) + try: + pio.write_image(fig, str(save_path), format=fmt, scale=scale) + return str(save_path) + except Exception as e: + raise RuntimeError(f"Error saving figure to {save_path}: {e}") diff --git a/docgenie/analyzation/clustering/webapp/visualizations.py b/docgenie/analyzation/clustering/webapp/visualizations.py new file mode 100755 index 0000000000000000000000000000000000000000..17231b8642480166c8635313acdf454dc78d4ef7 --- /dev/null +++ b/docgenie/analyzation/clustering/webapp/visualizations.py @@ -0,0 +1,262 @@ +import numpy as np +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go +from plotly.subplots import make_subplots +from pathlib import Path + + +def map_embedding_name_to_final_name(embedding_name: str): + match embedding_name: + case "layout": + return "layoutlm" + case "image": + return "clip" + case "text": + return "sentence" + case "paper_kernel=4": + return "pooled" + case "combined": + return "combined" + + +def create_scatter_plot( + df: pd.DataFrame, + embedding_src: str, + dataset_name: str, + min_cluster_size: int, + n_cluster: int, +) -> go.Figure: + """Create interactive scatter plot of document embeddings.""" + embedding_src = map_embedding_name_to_final_name(embedding_src) + + # # Force categorical colors if labels are numeric + # df = df.copy() # Avoid modifying original + # df["label"] = df["label"].astype(str) + + fig = px.scatter( + df, + x="x", + y="y", + color="label", + # labels={"label": ""}, + hover_data={"index": True, "label": True, "doc_id": True}, + # title=f"{dataset_name}: '{embedding_src}' Embeddings, κ={min_cluster_size}, {n_cluster} Clusters", + ) + margin = 0 + + fig.update_traces(marker=dict(size=7, showscale=False), customdata=df["index"]) + fig.update_layout( + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=margin, r=margin, t=margin, b=margin), + # legend_title="Cluster", + showlegend=False, + coloraxis_showscale=False, + ) + + fig.update_xaxes(visible=False) + fig.update_yaxes(visible=False) + + return fig + + +"""This function is used to display the analysis plots as subplot i.e. one figure containing all plots""" + + +def create_cluster_analysis_plot( + cluster_df: pd.DataFrame, + dataset_name: str, + cluster_labels: np.ndarray, +) -> go.Figure: + """Create comprehensive cluster analysis visualization with clickable clusters.""" + fig = make_subplots( + rows=4, + cols=1, + subplot_titles=( + "Cluster Sizes", + "Cluster Variances", + "Size vs Variance", + "Distribution", + ), + specs=[ + [{"type": "bar"}], + [{"type": "bar"}], + [{"type": "scatter"}], + [{"type": "histogram"}], + ], + ) + + # Prepare cluster indices for click events + cluster_indices = {} + for cluster_id in cluster_df["cluster_id"]: + indices = np.where(cluster_labels == cluster_id)[0].tolist() + cluster_indices[cluster_id] = indices + + # Plot 1: Cluster sizes (clickable) + fig.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["size"], + name="Size", + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + hovertemplate="Cluster %{x}
Size: %{y}
Click to view images", + ), + row=1, + col=1, + ) + + # Plot 2: Cluster variances + fig.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["variance"], + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + name="Variance", + ), + row=2, + col=1, + ) + + # Plot 3: Size vs Variance scatter (clickable) + fig.add_trace( + go.Scatter( + x=cluster_df["size"], + y=cluster_df["variance"], + mode="markers", + text=cluster_df["cluster_id"], + name="Clusters", + customdata=[cluster_indices[cid] for cid in cluster_df["cluster_id"]], + hovertemplate="Cluster %{text}
Size: %{x}
Variance: %{y}
Click to view images", + ), + row=3, + col=1, + ) + + # Plot 4: Size distribution + fig.add_trace( + go.Histogram(x=cluster_df["size"], name="Size Distribution"), + row=4, + col=1, + ) + + fig.update_layout( + title_text=f"Cluster Analysis for {dataset_name}", + showlegend=False, + height=1200, + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=40, r=40, t=40, b=40), + ) + + _update_subplot_axes(fig) + return fig + + +def _update_subplot_axes(fig: go.Figure) -> None: + """Update axes labels for all subplots.""" + fig.update_xaxes(title_text="Cluster ID", row=1, col=1) + fig.update_yaxes(title_text="Size", row=1, col=1) + fig.update_xaxes(title_text="Cluster ID", row=2, col=1) + fig.update_yaxes(title_text="Variance", row=2, col=1) + fig.update_xaxes(title_text="Size", row=3, col=1) + fig.update_yaxes(title_text="Variance", row=3, col=1) + fig.update_xaxes(title_text="Size", row=4, col=1) + fig.update_yaxes(title_text="Count", row=4, col=1) + + +"""This function is used to save cluster analysis plots separately not as a single plot""" + + +def generate_individual_cluster_plots(cluster_df, dataset_name: str) -> dict: + """ + Given cluster_df (DataFrame with columns 'cluster_id', 'size', 'variance'), + return a dict of plot_name -> go.Figure, one per subplot: + - cluster_sizes + - cluster_variances + - size_vs_variance + - distribution + + Note: dataset_name is unused for plotting but kept for potential titles. + """ + plots = {} + + # Ensure expected columns exist + if not {"cluster_id", "size", "variance"}.issubset(cluster_df.columns): + raise ValueError( + "cluster_df must contain 'cluster_id', 'size', and 'variance' columns" + ) + + # Cluster Sizes (bar) + fig_sizes = go.Figure() + fig_sizes.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["size"], + name="Size", + ) + ) + fig_sizes.update_layout( + title_text=f"Cluster Sizes{' — ' + dataset_name if dataset_name else ''}", + xaxis_title="Cluster ID", + yaxis_title="Size", + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=20, r=20, t=40, b=20), + ) + plots["cluster_sizes"] = fig_sizes + + # Cluster Variances (bar) + fig_var = go.Figure() + fig_var.add_trace( + go.Bar( + x=cluster_df["cluster_id"], + y=cluster_df["variance"], + name="Variance", + ) + ) + fig_var.update_layout( + title_text=f"Cluster Variances{' — ' + dataset_name if dataset_name else ''}", + xaxis_title="Cluster ID", + yaxis_title="Variance", + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=20, r=20, t=40, b=20), + ) + plots["cluster_variances"] = fig_var + + # Size vs Variance (scatter) + fig_sv = go.Figure() + fig_sv.add_trace( + go.Scatter( + x=cluster_df["size"], + y=cluster_df["variance"], + mode="markers", + text=cluster_df["cluster_id"], + name="Size vs Variance", + ) + ) + fig_sv.update_layout( + title_text=f"Size vs Variance{' — ' + dataset_name if dataset_name else ''}", + xaxis_title="Size", + yaxis_title="Variance", + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=20, r=20, t=40, b=20), + ) + plots["size_vs_variance"] = fig_sv + + # Distribution (histogram) + fig_dist = go.Figure() + fig_dist.add_trace(go.Histogram(x=cluster_df["size"], name="Size Distribution")) + fig_dist.update_layout( + title_text=f"Distribution{' — ' + dataset_name if dataset_name else ''}", + xaxis_title="Size", + yaxis_title="Count", + plot_bgcolor="white", + paper_bgcolor="white", + margin=dict(l=20, r=20, t=40, b=20), + ) + plots["distribution"] = fig_dist + + return plots diff --git a/docgenie/analyzation/gt/cls/cls_qa_analysis.py b/docgenie/analyzation/gt/cls/cls_qa_analysis.py new file mode 100755 index 0000000000000000000000000000000000000000..f2248c523a40abc42fe24d7fc2e35945907021a0 --- /dev/null +++ b/docgenie/analyzation/gt/cls/cls_qa_analysis.py @@ -0,0 +1,381 @@ +import argparse +import json +import numpy as np +from collections import Counter +from scipy.stats import entropy +import matplotlib.pyplot as plt +import seaborn as sns + +from docgenie import ENV +from docgenie.analyzation.gt.webapp import get_base_dataset_name +from docgenie.data.interfaces.dataset import load_dataset +from docgenie.generation.models._syndatadef import SynDatasetDefinition + +# Set seaborn style for CVPR-quality plots +sns.set_theme(style="whitegrid", context="paper", palette="colorblind") +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.size"] = 10 + + +def extract_labels(dataset, label_mapping: dict[str, str] = None): + """Extract classification labels from dataset.""" + labels = [] + for sample in dataset.train: + if sample.annotations: + label = sample.annotations[0].label.name + + if label_mapping is not None and len(label_mapping) > 0: + label = label_mapping[label] + + labels.append(label) + return labels + + +def compute_distribution(labels): + """Compute label distribution.""" + counter = Counter(labels) + total = sum(counter.values()) + distribution = {k: v / total for k, v in sorted(counter.items())} + return counter, distribution + + +def compare_distributions(real_dist, synth_dist): + """Compare two distributions using various metrics.""" + # Align keys + all_labels = sorted(set(real_dist.keys()) | set(synth_dist.keys())) + + real_probs = np.array([real_dist.get(k, 0) for k in all_labels]) + synth_probs = np.array([synth_dist.get(k, 0) for k in all_labels]) + + # KL divergence (add small epsilon to avoid log(0)) + epsilon = 1e-10 + kl_div = entropy(real_probs + epsilon, synth_probs + epsilon) + + # Total Variation Distance + tvd = 0.5 * np.sum(np.abs(real_probs - synth_probs)) + + # L2 distance + l2_dist = np.linalg.norm(real_probs - synth_probs) + + return { + "kl_divergence": float(kl_div), + "total_variation_distance": float(tvd), + "l2_distance": float(l2_dist), + } + + +def plot_absolute_counts(real_counter, synth_counter, synth_dataset_name, save_path): + """Plot absolute class counts comparison.""" + all_labels = sorted(set(real_counter.keys()) | set(synth_counter.keys())) + real_counts = [real_counter.get(k, 0) for k in all_labels] + synth_counts = [synth_counter.get(k, 0) for k in all_labels] + + x = np.arange(len(all_labels)) + width = 0.35 + + fig, ax = plt.subplots(figsize=(10, 6)) + + bars1 = ax.bar( + x - width / 2, + real_counts, + width, + label="Real", + color=sns.color_palette("colorblind")[0], + alpha=0.85, + edgecolor="black", + linewidth=0.5, + ) + bars2 = ax.bar( + x + width / 2, + synth_counts, + width, + label="Synthetic", + color=sns.color_palette("colorblind")[1], + alpha=0.85, + edgecolor="black", + linewidth=0.5, + ) + + ax.set_xlabel("Class Label", fontsize=11, fontweight="bold") + ax.set_ylabel("Count", fontsize=11, fontweight="bold") + ax.set_title( + f"{synth_dataset_name} Absolute Class Counts Comparison", + fontsize=12, + fontweight="bold", + pad=15, + ) + ax.set_xticks(x) + + # Rotate labels if there are many or if they're long + max_label_len = max(len(str(label)) for label in all_labels) + rotation = 45 if len(all_labels) > 8 or max_label_len > 10 else 0 + ha = "right" if rotation > 0 else "center" + + ax.set_xticklabels(all_labels, rotation=rotation, ha=ha) + ax.legend(frameon=True, loc="upper right", fontsize=10) + ax.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") + print(f"Absolute counts plot saved to {save_path}") + plt.close() + + +def plot_distribution(real_dist, synth_dist, synth_dataset_name, save_path): + """Plot normalized class distribution comparison.""" + all_labels = sorted(set(real_dist.keys()) | set(synth_dist.keys())) + real_probs = [real_dist.get(k, 0) for k in all_labels] + synth_probs = [synth_dist.get(k, 0) for k in all_labels] + + x = np.arange(len(all_labels)) + width = 0.35 + + fig, ax = plt.subplots(figsize=(10, 6)) + + bars1 = ax.bar( + x - width / 2, + real_probs, + width, + label="Real", + color=sns.color_palette("colorblind")[0], + alpha=0.85, + edgecolor="black", + linewidth=0.5, + ) + bars2 = ax.bar( + x + width / 2, + synth_probs, + width, + label="Synthetic", + color=sns.color_palette("colorblind")[1], + alpha=0.85, + edgecolor="black", + linewidth=0.5, + ) + + ax.set_xticks(x) + + # Rotate labels if there are many or if they're long + max_label_len = max(len(str(label)) for label in all_labels) + rotation = 45 if len(all_labels) > 8 or max_label_len > 10 else 0 + ha = "right" if rotation > 0 else "center" + + ax.set_xticklabels(all_labels, rotation=rotation, ha=ha) + ax.legend(frameon=True, loc="upper right", fontsize=10) + ax.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + + # Format y-axis as percentage + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{y:.1%}")) + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") + print(f"Distribution plot saved to {save_path}") + plt.close() + + +def plot_difference_heatmap(real_dist, synth_dist, synth_dataset_name, save_path): + """Plot heatmap showing per-class differences.""" + all_labels = sorted(set(real_dist.keys()) | set(synth_dist.keys())) + differences = [ + (synth_dist.get(k, 0) - real_dist.get(k, 0)) * 100 for k in all_labels + ] + + fig, ax = plt.subplots(figsize=(10, max(6, len(all_labels) * 0.4))) + + # Create diverging colormap centered at 0 + cmap = sns.diverging_palette(250, 10, as_cmap=True) + + # Create heatmap data + data = np.array(differences).reshape(-1, 1) + + sns.heatmap( + data, + annot=True, + fmt=".2f", + cmap=cmap, + center=0, + yticklabels=all_labels, + xticklabels=["Diff (%)"], + cbar_kws={"label": "Percentage Point Difference"}, + linewidths=0.5, + linecolor="gray", + ax=ax, + ) + + ax.set_title( + f"{synth_dataset_name} Per-Class Distribution Difference (Synthetic - Real)", + fontsize=12, + fontweight="bold", + pad=15, + ) + ax.set_ylabel("Class Label", fontsize=11, fontweight="bold") + + plt.tight_layout() + plt.savefig(save_path, dpi=300, bbox_inches="tight", facecolor="white") + print(f"Difference heatmap saved to {save_path}") + plt.close() + + +def save_metrics( + metrics, real_counter, synth_counter, real_dist, synth_dist, save_path +): + """Save all metrics and distributions to JSON.""" + all_labels = sorted(set(real_counter.keys()) | set(synth_counter.keys())) + + output = { + "metrics": metrics, + "class_statistics": { + label: { + "real_count": real_counter.get(label, 0), + "synth_count": synth_counter.get(label, 0), + "real_proportion": float(real_dist.get(label, 0)), + "synth_proportion": float(synth_dist.get(label, 0)), + "difference_percentage_points": float( + (synth_dist.get(label, 0) - real_dist.get(label, 0)) * 100 + ), + } + for label in all_labels + }, + "summary": { + "total_real_samples": sum(real_counter.values()), + "total_synth_samples": sum(synth_counter.values()), + "num_classes": len(all_labels), + }, + } + + with open(save_path, "w") as f: + json.dump(output, f, indent=2) + + print(f"Metrics saved to {save_path}") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="CLS GT Comparison", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "synthdataset", + type=str, + help="Name of the synthetic dataset", + ) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # Configuration + synth_dataset_name = parse_args().synthdataset + + # Load datasets + base_dataset_name = get_base_dataset_name(synth_dataset_name) + print( + f"Loading datasets: {base_dataset_name} (real) vs {synth_dataset_name} (synthetic)" + ) + + base_dataset = load_dataset(base_dataset_name, is_synthetic=False) + synth_dataset = load_dataset(synth_dataset_name, is_synthetic=True) + + deffile = ENV.SYN_DATA_DEFINITIONS_DIR / f"{synth_dataset_name}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file(deffile) + label_mapping = dsdef.label_mapping + + # Extract labels + real_labels = extract_labels(base_dataset, label_mapping=None) + synth_labels = extract_labels(synth_dataset, label_mapping=label_mapping) + + print(f"\nDataset sizes:") + print(f" Real: {len(real_labels)} samples") + print(f" Synthetic: {len(synth_labels)} samples") + + # Compute distributions + real_counter, real_dist = compute_distribution(real_labels) + synth_counter, synth_dist = compute_distribution(synth_labels) + + # Print distributions + print("\n" + "=" * 80) + print("CLASS DISTRIBUTION COMPARISON") + print("=" * 80) + print( + f"{'Class':<20} {'Real Count':<15} {'Real %':<12} {'Synth Count':<15} {'Synth %':<12} {'Diff %':<10}" + ) + print("-" * 80) + + all_labels = sorted(set(real_counter.keys()) | set(synth_counter.keys())) + for label in all_labels: + real_count = real_counter.get(label, 0) + synth_count = synth_counter.get(label, 0) + real_pct = real_dist.get(label, 0) * 100 + synth_pct = synth_dist.get(label, 0) * 100 + diff_pct = synth_pct - real_pct + + print( + f"{label:<20} {real_count:<15} {real_pct:<12.2f} {synth_count:<15} {synth_pct:<12.2f} {diff_pct:+.2f}" + ) + + # Compare distributions + metrics = compare_distributions(real_dist, synth_dist) + + print("\n" + "=" * 80) + print("DISTRIBUTION SIMILARITY METRICS") + print("=" * 80) + print(f"KL Divergence (Real || Synth): {metrics['kl_divergence']:.4f}") + print( + f"Total Variation Distance: {metrics['total_variation_distance']:.4f}" + ) + print(f"L2 Distance: {metrics['l2_distance']:.4f}") + print("=" * 80) + print("\nInterpretation:") + print(" - Lower values indicate more similar distributions") + print(" - KL divergence: 0 = identical, higher = more different") + print(" - TVD: [0, 1], 0 = identical, 1 = completely different") + + # Create output directory + output_dir = ENV.CLS_GT_ANALYZATION_DIR / synth_dataset_name + output_dir.mkdir(parents=True, exist_ok=True) + + # Generate and save plots + print("\n" + "=" * 80) + print("GENERATING PLOTS") + print("=" * 80) + + # plot_absolute_counts( + # real_counter, + # synth_counter, + # synth_dataset_name=synth_dataset_name, + # save_path=output_dir / "absolute_counts.pdf", + # ) + + plot_distribution( + real_dist, + synth_dist, + synth_dataset_name=synth_dataset_name, + save_path=output_dir / f"{synth_dataset_name}_distribution.pdf", + ) + + # plot_difference_heatmap( + # real_dist, + # synth_dist, + # synth_dataset_name=synth_dataset_name, + # save_path=output_dir / "difference_heatmap.pdf", + # ) + + # Save metrics to JSON + save_metrics( + metrics, + real_counter, + synth_counter, + real_dist, + synth_dist, + save_path=output_dir / "metrics.json", + ) + + print("\n" + "=" * 80) + print(f"All outputs saved to: {output_dir}") + print("=" * 80) diff --git a/docgenie/analyzation/gt/dla/dla_gt_analysis.py b/docgenie/analyzation/gt/dla/dla_gt_analysis.py new file mode 100755 index 0000000000000000000000000000000000000000..834110e8b5b12bd5b349244e32c309b0f954b156 --- /dev/null +++ b/docgenie/analyzation/gt/dla/dla_gt_analysis.py @@ -0,0 +1,885 @@ +#!/usr/bin/env python3 +""" +Compare Document Layout Analysis Ground Truth between Synthetic and Real Datasets +For CVPR paper on synthesis of document understanding datasets +""" + +import argparse +import os +from pathlib import Path +from collections import defaultdict +import atria_core +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from typing import Dict, List, Tuple +from dataclasses import dataclass +from scipy.stats import entropy, wasserstein_distance +from scipy.spatial.distance import jensenshannon + +from docgenie import ENV +from docgenie.analyzation.gt.webapp import get_base_dataset_name +from docgenie.data.interfaces.dataset import load_dataset +from docgenie.generation.models._syndatadef import SynDatasetDefinition + +# Set seaborn style for CVPR-quality figures +sns.set_theme(style="whitegrid", context="paper", palette="colorblind") +plt.rcParams["figure.dpi"] = 300 +plt.rcParams["savefig.dpi"] = 300 +plt.rcParams["font.size"] = 9 +plt.rcParams["axes.labelsize"] = 10 +plt.rcParams["axes.titlesize"] = 11 +plt.rcParams["xtick.labelsize"] = 8 +plt.rcParams["ytick.labelsize"] = 8 +plt.rcParams["legend.fontsize"] = 9 +plt.rcParams["figure.titlesize"] = 12 + + +@dataclass +class DatasetMetrics: + """Container for computed dataset metrics""" + + bbox_sizes: Dict[str, List[Tuple[float, float]]] # label -> [(width, height), ...] + bbox_areas: Dict[str, List[float]] # label -> [area, ...] + bbox_aspect_ratios: Dict[str, List[float]] # label -> [aspect_ratio, ...] + centroids: Dict[str, List[Tuple[float, float]]] # label -> [(x, y), ...] + region_counts: Dict[str, List[int]] # label -> [count_per_doc, ...] + page_coverages: List[float] # coverage per document + pairwise_distances: Dict[ + Tuple[str, str], List[float] + ] # (label1, label2) -> [distances, ...] + adjacency_counts: Dict[Tuple[str, str], int] # (label1, label2) -> count + + +def compute_bbox_metrics( + bbox_abs: List[float], + img_width: int, + img_height: int, + bbox_norm: List[float] = None, +) -> Dict: + """Compute metrics for a single bounding box + + Args: + bbox_abs: Bounding box in absolute coordinates [x1, y1, x2, y2] + img_width: Image width + img_height: Image height + bbox_norm: Bounding box in normalized coordinates [x1, y1, x2, y2] (0-1 range) + """ + x1, y1, x2, y2 = bbox_abs + width = x2 - x1 + height = y2 - y1 + area = width * height + aspect_ratio = width / height if height > 0 else 0 + + # Use normalized coords for centroid if provided, otherwise compute from absolute + if bbox_norm is not None: + x1_n, y1_n, x2_n, y2_n = bbox_norm + centroid_x = (x1_n + x2_n) / 2 + centroid_y = (y1_n + y2_n) / 2 + else: + centroid_x = (x1 + x2) / 2 / img_width + centroid_y = (y1 + y2) / 2 / img_height + + return { + "size": (width, height), + "area": area, + "aspect_ratio": aspect_ratio, + "centroid": (centroid_x, centroid_y), + "norm_area": area / (img_width * img_height), + } + + +def compute_pairwise_distances( + bboxes: List[Dict], labels: List[str] +) -> Dict[Tuple[str, str], List[float]]: + """Compute pairwise distances between region centroids""" + distances = defaultdict(list) + + for i in range(len(bboxes)): + for j in range(i + 1, len(bboxes)): + label_pair = tuple(sorted([labels[i], labels[j]])) + c1 = bboxes[i]["centroid"] + c2 = bboxes[j]["centroid"] + dist = np.sqrt((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) + distances[label_pair].append(dist) + + return distances + + +def compute_adjacency( + bboxes: List[Dict], labels: List[str], threshold: float = 0.1 +) -> Dict[Tuple[str, str], int]: + """Compute adjacency matrix (regions within threshold distance)""" + adjacency = defaultdict(int) + + for i in range(len(bboxes)): + for j in range(i + 1, len(bboxes)): + c1 = bboxes[i]["centroid"] + c2 = bboxes[j]["centroid"] + dist = np.sqrt((c1[0] - c2[0]) ** 2 + (c1[1] - c2[1]) ** 2) + + if dist < threshold: + label_pair = tuple(sorted([labels[i], labels[j]])) + adjacency[label_pair] += 1 + + return adjacency + + +def extract_metrics(dataset, label_mapping: dict[str, str] = None) -> DatasetMetrics: + """Extract all metrics from a dataset""" + bbox_sizes = defaultdict(list) + bbox_areas = defaultdict(list) + bbox_aspect_ratios = defaultdict(list) + centroids = defaultdict(list) + region_counts = defaultdict(list) + page_coverages = [] + all_pairwise_distances = defaultdict(list) + all_adjacency_counts = defaultdict(int) + + # Debug: check first sample + first_sample = True + + for sample in dataset.train: + img_width = sample.image.width + img_height = sample.image.height + + # Get annotations + annotation = sample.annotations[0] # LayoutAnalysisAnnotation + if type(annotation).__name__ == "ClassificationAnnotation": + annotation = sample.annotations[1] # LayoutAnalysisAnnotation + + labels = annotation.annotated_objects.label.name + bboxes_list = annotation.annotated_objects.bbox.value + + # Check if bboxes are normalized + is_normalized = annotation.annotated_objects.bbox.normalized + + if first_sample: + print(f" First sample debug info:") + print(f" Image size: {img_width}x{img_height}") + print(f" Bboxes normalized flag: {is_normalized}") + if len(bboxes_list) > 0: + print(f" First bbox value: {bboxes_list[0]}") + print(f" Number of bboxes: {len(bboxes_list)}") + first_sample = False + + if label_mapping is not None and len(label_mapping) > 0: + labels = [label_mapping[l] for l in labels] + + # Count regions per label + label_count = defaultdict(int) + doc_bboxes = [] + doc_labels = [] + total_area = 0 + + for label, bbox_value in zip(labels, bboxes_list): + # Handle normalized vs absolute coordinates + if is_normalized: + # Bboxes are already normalized (0-1), convert to absolute for metrics + x1, y1, x2, y2 = bbox_value + x1_abs = x1 * img_width + y1_abs = y1 * img_height + x2_abs = x2 * img_width + y2_abs = y2 * img_height + bbox_abs = [x1_abs, y1_abs, x2_abs, y2_abs] + bbox_norm = bbox_value # Already normalized for centroid + else: + # Bboxes are in absolute coordinates + bbox_abs = bbox_value + x1, y1, x2, y2 = bbox_value + # Normalize for centroid calculation + bbox_norm = [ + x1 / img_width, + y1 / img_height, + x2 / img_width, + y2 / img_height, + ] + + metrics = compute_bbox_metrics(bbox_abs, img_width, img_height, bbox_norm) + + # Store per-label metrics + bbox_sizes[label].append(metrics["size"]) + bbox_areas[label].append(metrics["area"]) + bbox_aspect_ratios[label].append(metrics["aspect_ratio"]) + centroids[label].append(metrics["centroid"]) + + label_count[label] += 1 + doc_bboxes.append(metrics) + doc_labels.append(label) + total_area += metrics["norm_area"] + + # Region counts per document + for label, count in label_count.items(): + region_counts[label].append(count) + + # Page coverage + page_coverages.append(min(total_area, 1.0)) # Cap at 1.0 + + # Pairwise distances + doc_distances = compute_pairwise_distances(doc_bboxes, doc_labels) + for pair, dists in doc_distances.items(): + all_pairwise_distances[pair].extend(dists) + + # Adjacency + doc_adjacency = compute_adjacency(doc_bboxes, doc_labels) + for pair, count in doc_adjacency.items(): + all_adjacency_counts[pair] += count + + return DatasetMetrics( + bbox_sizes=dict(bbox_sizes), + bbox_areas=dict(bbox_areas), + bbox_aspect_ratios=dict(bbox_aspect_ratios), + centroids=dict(centroids), + region_counts=dict(region_counts), + page_coverages=page_coverages, + pairwise_distances=dict(all_pairwise_distances), + adjacency_counts=dict(all_adjacency_counts), + ) + + +def compute_distribution_metrics( + real_data: np.ndarray, synth_data: np.ndarray, bins: int = 50 +) -> Dict[str, float]: + """Compute various distribution comparison metrics""" + # Create histograms with same bins for both + min_val = min(real_data.min(), synth_data.min()) + max_val = max(real_data.max(), synth_data.max()) + bin_edges = np.linspace(min_val, max_val, bins + 1) + + real_hist, _ = np.histogram(real_data, bins=bin_edges, density=True) + synth_hist, _ = np.histogram(synth_data, bins=bin_edges, density=True) + + # Normalize histograms to sum to 1 for probability distributions + real_hist = real_hist / (real_hist.sum() + 1e-10) + synth_hist = synth_hist / (synth_hist.sum() + 1e-10) + + # Add small epsilon to avoid log(0) + real_hist = real_hist + 1e-10 + synth_hist = synth_hist + 1e-10 + + # KL Divergence (synth || real) + kl_div = entropy(synth_hist, real_hist) + + # Jensen-Shannon Divergence (symmetric) + js_div = jensenshannon(real_hist, synth_hist) ** 2 + + # Wasserstein Distance (Earth Mover's Distance) + w_dist = wasserstein_distance(real_data, synth_data) + + # Mean and std comparison + mean_diff = abs(np.mean(synth_data) - np.mean(real_data)) + std_diff = abs(np.std(synth_data) - np.std(real_data)) + + return { + "kl_divergence": kl_div, + "js_divergence": js_div, + "wasserstein_distance": w_dist, + "mean_difference": mean_diff, + "std_difference": std_diff, + "real_mean": np.mean(real_data), + "synth_mean": np.mean(synth_data), + "real_std": np.std(real_data), + "synth_std": np.std(synth_data), + } + + +# def plot_size_distribution_comparison( +# real_metrics: DatasetMetrics, +# synth_metrics: DatasetMetrics, +# output_dir: Path, +# dataset_name: str, +# ): +# """Plot comprehensive size distribution comparison for all classes""" +# all_labels = sorted( +# set(real_metrics.bbox_areas.keys()) | set(synth_metrics.bbox_areas.keys()) +# ) + +# # Compute metrics for each label +# area_metrics = {} +# width_metrics = {} +# height_metrics = {} +# aspect_ratio_metrics = {} + +# for label in all_labels: +# if label in real_metrics.bbox_areas and label in synth_metrics.bbox_areas: +# # Area metrics +# real_areas = np.array(real_metrics.bbox_areas[label]) +# synth_areas = np.array(synth_metrics.bbox_areas[label]) +# area_metrics[label] = compute_distribution_metrics(real_areas, synth_areas) + +# # Width and height metrics +# real_sizes = np.array(real_metrics.bbox_sizes[label]) +# synth_sizes = np.array(synth_metrics.bbox_sizes[label]) +# width_metrics[label] = compute_distribution_metrics( +# real_sizes[:, 0], synth_sizes[:, 0] +# ) +# height_metrics[label] = compute_distribution_metrics( +# real_sizes[:, 1], synth_sizes[:, 1] +# ) + +# # Aspect ratio metrics +# real_ar = np.array(real_metrics.bbox_aspect_ratios[label]) +# synth_ar = np.array(synth_metrics.bbox_aspect_ratios[label]) +# aspect_ratio_metrics[label] = compute_distribution_metrics( +# real_ar, synth_ar +# ) + +# # Create comprehensive comparison plot +# fig, axes = plt.subplots(2, 2, figsize=(12, 10)) +# fig.suptitle( +# f"{dataset_name}: Size Distribution Comparison", fontsize=14, fontweight="bold" +# ) + +# labels_list = list(area_metrics.keys()) +# x_pos = np.arange(len(labels_list)) + +# # Plot 1: KL Divergence comparison +# kl_areas = [area_metrics[l]["kl_divergence"] for l in labels_list] +# kl_widths = [width_metrics[l]["kl_divergence"] for l in labels_list] +# kl_heights = [height_metrics[l]["kl_divergence"] for l in labels_list] +# kl_ar = [aspect_ratio_metrics[l]["kl_divergence"] for l in labels_list] + +# width = 0.2 +# axes[0, 0].bar(x_pos - 1.5 * width, kl_areas, width, label="Area", alpha=0.8) +# axes[0, 0].bar(x_pos - 0.5 * width, kl_widths, width, label="Width", alpha=0.8) +# axes[0, 0].bar(x_pos + 0.5 * width, kl_heights, width, label="Height", alpha=0.8) +# axes[0, 0].bar(x_pos + 1.5 * width, kl_ar, width, label="Aspect Ratio", alpha=0.8) +# axes[0, 0].set_ylabel("KL Divergence") +# axes[0, 0].set_title("KL Divergence (Synth || Real)", fontweight="bold") +# axes[0, 0].set_xticks(x_pos) +# axes[0, 0].set_xticklabels([l[:20] for l in labels_list], rotation=45, ha="right") +# axes[0, 0].legend() +# axes[0, 0].grid(axis="y", alpha=0.3) + +# # Plot 2: Jensen-Shannon Divergence +# js_areas = [area_metrics[l]["js_divergence"] for l in labels_list] +# js_widths = [width_metrics[l]["js_divergence"] for l in labels_list] +# js_heights = [height_metrics[l]["js_divergence"] for l in labels_list] +# js_ar = [aspect_ratio_metrics[l]["js_divergence"] for l in labels_list] + +# axes[0, 1].bar(x_pos - 1.5 * width, js_areas, width, label="Area", alpha=0.8) +# axes[0, 1].bar(x_pos - 0.5 * width, js_widths, width, label="Width", alpha=0.8) +# axes[0, 1].bar(x_pos + 0.5 * width, js_heights, width, label="Height", alpha=0.8) +# axes[0, 1].bar(x_pos + 1.5 * width, js_ar, width, label="Aspect Ratio", alpha=0.8) +# axes[0, 1].set_ylabel("JS Divergence") +# axes[0, 1].set_title("Jensen-Shannon Divergence", fontweight="bold") +# axes[0, 1].set_xticks(x_pos) +# axes[0, 1].set_xticklabels([l[:20] for l in labels_list], rotation=45, ha="right") +# axes[0, 1].legend() +# axes[0, 1].grid(axis="y", alpha=0.3) + +# # Plot 3: Wasserstein Distance +# w_areas = [area_metrics[l]["wasserstein_distance"] for l in labels_list] +# w_widths = [width_metrics[l]["wasserstein_distance"] for l in labels_list] +# w_heights = [height_metrics[l]["wasserstein_distance"] for l in labels_list] + +# axes[1, 0].bar(x_pos - width, w_areas, width, label="Area", alpha=0.8) +# axes[1, 0].bar(x_pos, w_widths, width, label="Width", alpha=0.8) +# axes[1, 0].bar(x_pos + width, w_heights, width, label="Height", alpha=0.8) +# axes[1, 0].set_ylabel("Wasserstein Distance") +# axes[1, 0].set_title("Wasserstein Distance (Earth Mover)", fontweight="bold") +# axes[1, 0].set_xticks(x_pos) +# axes[1, 0].set_xticklabels([l[:20] for l in labels_list], rotation=45, ha="right") +# axes[1, 0].legend() +# axes[1, 0].grid(axis="y", alpha=0.3) + +# # Plot 4: Mean and Std differences for area +# mean_diffs = [area_metrics[l]["mean_difference"] for l in labels_list] +# std_diffs = [area_metrics[l]["std_difference"] for l in labels_list] + +# ax4_twin = axes[1, 1].twinx() +# p1 = axes[1, 1].bar( +# x_pos - width / 2, mean_diffs, width, label="Mean Diff", alpha=0.8, color="C0" +# ) +# p2 = ax4_twin.bar( +# x_pos + width / 2, std_diffs, width, label="Std Diff", alpha=0.8, color="C1" +# ) +# axes[1, 1].set_ylabel("Mean Difference (pixels²)", color="C0") +# ax4_twin.set_ylabel("Std Difference (pixels²)", color="C1") +# axes[1, 1].set_title("Area: Mean & Std Differences", fontweight="bold") +# axes[1, 1].set_xticks(x_pos) +# axes[1, 1].set_xticklabels([l[:20] for l in labels_list], rotation=45, ha="right") +# axes[1, 1].tick_params(axis="y", labelcolor="C0") +# ax4_twin.tick_params(axis="y", labelcolor="C1") +# axes[1, 1].grid(axis="y", alpha=0.3) + +# # Add combined legend for plot 4 +# lines = [p1, p2] +# labels = ["Mean Diff", "Std Diff"] +# axes[1, 1].legend(lines, labels, loc="upper left") + +# plt.tight_layout() +# plt.savefig( +# output_dir / "size_distribution_comparison.png", dpi=300, bbox_inches="tight" +# ) +# plt.close() + +# # Create detailed overlay histograms for each class +# n_labels = len(labels_list) +# print(labels_list) +# n_cols = 3 +# n_rows = (n_labels + n_cols - 1) // n_cols + +# fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows)) +# fig.suptitle( +# f"{dataset_name}: Area Distribution Overlays", fontsize=14, fontweight="bold" +# ) +# axes = axes.flatten() if n_labels > 1 else [axes] + +# for idx, label in enumerate(labels_list): +# if label in real_metrics.bbox_areas and label in synth_metrics.bbox_areas: +# real_areas = np.array(real_metrics.bbox_areas[label]) +# synth_areas = np.array(synth_metrics.bbox_areas[label]) + +# axes[idx].hist( +# real_areas, +# bins=40, +# alpha=0.5, +# label="Real", +# density=True, +# color="C0", +# edgecolor="black", +# linewidth=0.5, +# ) +# axes[idx].hist( +# synth_areas, +# bins=40, +# alpha=0.5, +# label="Synth", +# density=True, +# color="C1", +# edgecolor="black", +# linewidth=0.5, +# ) + +# axes[idx].set_xlabel("Area (pixels²)") +# axes[idx].set_ylabel("Density") +# axes[idx].set_title(f"{label}", fontweight="bold") +# axes[idx].legend() +# axes[idx].grid(axis="y", alpha=0.3) + +# # Add metrics text +# metrics_text = ( +# f"KL: {area_metrics[label]['kl_divergence']:.3f}\n" +# f"JS: {area_metrics[label]['js_divergence']:.3f}" +# ) +# axes[idx].text( +# 0.98, +# 0.98, +# metrics_text, +# transform=axes[idx].transAxes, +# verticalalignment="top", +# horizontalalignment="right", +# bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), +# fontsize=8, +# ) + +# # Hide unused subplots +# for idx in range(n_labels, len(axes)): +# axes[idx].axis("off") + +# plt.tight_layout() +# plt.savefig( +# output_dir / "area_distributions_overlay.png", dpi=300, bbox_inches="tight" +# ) +# plt.close() + + +def plot_spatial_heatmaps( + real_metrics: DatasetMetrics, + synth_metrics: DatasetMetrics, + output_dir: Path, + dataset_name: str, +): + """Plot 2D heatmaps showing complete region coverage (location and size)""" + all_labels = sorted( + set(real_metrics.centroids.keys()) | set(synth_metrics.centroids.keys()) + ) + + n_classes = len(all_labels) + n_cols = min(3, n_classes) + n_rows = (n_classes + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(6 * n_cols, 4 * n_rows)) + + # Handle different subplot array shapes based on ACTUAL subplot dimensions + # We always have n_cols * 2 columns (real + synth for each class) + if n_rows == 1 and n_cols == 1: + # Single class: 1 row, 2 columns → axes is 1D array of length 2 + axes = axes.reshape(1, -1) + elif n_rows == 1: + # Multiple classes, single row → axes is 1D array + axes = axes.reshape(1, -1) + elif n_cols == 1: + # Multiple rows, single column of classes → axes is 2D + # Already correct shape: (n_rows, 2) + pass + # else: axes is already 2D with correct shape + + for idx, label in enumerate(all_labels): + row = idx // n_cols + col_base = (idx % n_cols) * 2 + + # Real data coverage map + if label in real_metrics.bbox_sizes and len(real_metrics.bbox_sizes[label]) > 0: + bbox_sizes = real_metrics.bbox_sizes[label] + bbox_areas = real_metrics.bbox_areas[label] + centroids = real_metrics.centroids[label] + + # Create high-resolution coverage grid + grid_size = 200 + coverage = np.zeros((grid_size, grid_size)) + + # For each region, fill in the coverage area based on size + for (width, height), area, (cx, cy) in zip( + bbox_sizes, bbox_areas, centroids + ): + typical_img_size = 2000.0 + width_norm = min(width / typical_img_size, 0.5) + height_norm = min(height / typical_img_size, 0.5) + + x_start = np.clip(cx - width_norm / 2, 0, 1) + x_end = np.clip(cx + width_norm / 2, 0, 1) + y_start = np.clip(cy - height_norm / 2, 0, 1) + y_end = np.clip(cy + height_norm / 2, 0, 1) + + x_start_idx = int(x_start * grid_size) + x_end_idx = min(int(x_end * grid_size) + 1, grid_size) + y_start_idx = int(y_start * grid_size) + y_end_idx = min(int(y_end * grid_size) + 1, grid_size) + + coverage[y_start_idx:y_end_idx, x_start_idx:x_end_idx] += 1 + + print(f"\n {label} - Real coverage:") + print(f" Total regions: {len(bbox_sizes)}") + print(f" Coverage area: {np.sum(coverage > 0) / (grid_size**2):.2%}") + print(f" Max overlap: {coverage.max():.0f} regions") + + im1 = axes[row, col_base].imshow( + coverage, + origin="upper", + cmap="YlOrRd", + extent=[0, 1, 0, 1], + aspect="auto", + interpolation="bilinear", + ) + axes[row, col_base].set_title( + f"{label}\n(Real, n={len(bbox_sizes)})", fontweight="bold", fontsize=9 + ) + axes[row, col_base].set_xlabel("Normalized X", fontsize=8) + axes[row, col_base].set_ylabel("Normalized Y", fontsize=8) + axes[row, col_base].grid(True, alpha=0.3, linewidth=0.5) + plt.colorbar(im1, ax=axes[row, col_base], label="Overlap") + else: + axes[row, col_base].text( + 0.5, 0.5, "No data", ha="center", va="center", fontsize=12 + ) + axes[row, col_base].set_title( + f"{label}\n(Real, n=0)", fontweight="bold", fontsize=9 + ) + axes[row, col_base].set_xlim(0, 1) + axes[row, col_base].set_ylim(0, 1) + axes[row, col_base].set_xlabel("Normalized X", fontsize=8) + axes[row, col_base].set_ylabel("Normalized Y", fontsize=8) + + # Synthetic data coverage map + if ( + label in synth_metrics.bbox_sizes + and len(synth_metrics.bbox_sizes[label]) > 0 + ): + bbox_sizes = synth_metrics.bbox_sizes[label] + bbox_areas = synth_metrics.bbox_areas[label] + centroids = synth_metrics.centroids[label] + + grid_size = 200 + coverage = np.zeros((grid_size, grid_size)) + + for (width, height), area, (cx, cy) in zip( + bbox_sizes, bbox_areas, centroids + ): + typical_img_size = 2000.0 + width_norm = min(width / typical_img_size, 0.5) + height_norm = min(height / typical_img_size, 0.5) + + x_start = np.clip(cx - width_norm / 2, 0, 1) + x_end = np.clip(cx + width_norm / 2, 0, 1) + y_start = np.clip(cy - height_norm / 2, 0, 1) + y_end = np.clip(cy + height_norm / 2, 0, 1) + + x_start_idx = int(x_start * grid_size) + x_end_idx = min(int(x_end * grid_size) + 1, grid_size) + y_start_idx = int(y_start * grid_size) + y_end_idx = min(int(y_end * grid_size) + 1, grid_size) + + coverage[y_start_idx:y_end_idx, x_start_idx:x_end_idx] += 1 + + print(f"\n {label} - Synth coverage:") + print(f" Total regions: {len(bbox_sizes)}") + print(f" Coverage area: {np.sum(coverage > 0) / (grid_size**2):.2%}") + print(f" Max overlap: {coverage.max():.0f} regions") + + im2 = axes[row, col_base + 1].imshow( + coverage, + origin="upper", + cmap="YlOrRd", + extent=[0, 1, 0, 1], + aspect="auto", + interpolation="bilinear", + ) + axes[row, col_base + 1].set_title( + f"{label}\n(Synth, n={len(bbox_sizes)})", fontweight="bold", fontsize=9 + ) + axes[row, col_base + 1].set_xlabel("Normalized X", fontsize=8) + axes[row, col_base + 1].set_ylabel("Normalized Y", fontsize=8) + axes[row, col_base + 1].grid(True, alpha=0.3, linewidth=0.5) + plt.colorbar(im2, ax=axes[row, col_base + 1], label="Overlap") + else: + axes[row, col_base + 1].text( + 0.5, 0.5, "No data", ha="center", va="center", fontsize=12 + ) + axes[row, col_base + 1].set_title( + f"{label}\n(Synth, n=0)", fontweight="bold", fontsize=9 + ) + axes[row, col_base + 1].set_xlim(0, 1) + axes[row, col_base + 1].set_ylim(0, 1) + axes[row, col_base + 1].set_xlabel("Normalized X", fontsize=8) + axes[row, col_base + 1].set_ylabel("Normalized Y", fontsize=8) + + # Hide empty subplots + for idx in range(n_classes, n_rows * n_cols): + row = idx // n_cols + col_base = (idx % n_cols) * 2 + axes[row, col_base].axis("off") + axes[row, col_base + 1].axis("off") + + plt.tight_layout() + plt.savefig( + output_dir / f"{dataset_name}_spatial_heatmaps_grid.pdf", + dpi=300, + bbox_inches="tight", + ) + plt.close() + + +def plot_region_counts( + real_metrics: DatasetMetrics, + synth_metrics: DatasetMetrics, + output_dir: Path, + dataset_name: str, +): + """Plot distribution of region counts per document in a grid layout""" + all_labels = sorted( + set(real_metrics.region_counts.keys()) | set(synth_metrics.region_counts.keys()) + ) + + n_classes = len(all_labels) + n_cols = min(4, n_classes) + n_rows = (n_classes + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows)) + + # Handle different subplot array shapes + if n_classes == 1: + axes = np.array([axes]) + else: + axes = axes.flatten() + + for idx, label in enumerate(all_labels): + ax = axes[idx] + + # Get data + real_counts = real_metrics.region_counts.get(label, []) + synth_counts = synth_metrics.region_counts.get(label, []) + + if len(real_counts) > 0 or len(synth_counts) > 0: + # Determine data range + all_counts = real_counts + synth_counts + min_count = min(all_counts) if all_counts else 0 + max_count = max(all_counts) if all_counts else 1 + + # Create bins with equal width + # Use integer bins for count data (each count gets its own bin) + bin_width = 1 + bins = np.arange(min_count - 0.5, max_count + 1.5, bin_width) + + # If there are too many bins, increase bin width + if len(bins) > 30: + bin_width = max(1, int(np.ceil((max_count - min_count) / 30))) + bins = np.arange(min_count - 0.5, max_count + 1.5, bin_width) + + # Plot histograms with explicit bin edges for consistent bar width + if len(real_counts) > 0: + ax.hist( + real_counts, + bins=bins, + alpha=0.6, + label="Real", + density=True, + color="C0", + edgecolor="black", + linewidth=0.5, + ) + if len(synth_counts) > 0: + ax.hist( + synth_counts, + bins=bins, + alpha=0.6, + label="Synth", + density=True, + color="C1", + edgecolor="black", + linewidth=0.5, + ) + + ax.set_xlabel("Count per Document", fontsize=9) + ax.set_ylabel("Density", fontsize=9) + ax.set_title(f"{label}", fontweight="bold", fontsize=10) + ax.legend(fontsize=8) + ax.grid(axis="y", alpha=0.3) + ax.tick_params(labelsize=8) + + # Add statistics as text + if len(real_counts) > 0 and len(synth_counts) > 0: + stats_text = ( + f"Real: μ={np.mean(real_counts):.1f}\n" + f"Synth: μ={np.mean(synth_counts):.1f}" + ) + ax.text( + 0.98, + 0.98, + stats_text, + transform=ax.transAxes, + verticalalignment="top", + horizontalalignment="right", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + fontsize=7, + ) + else: + ax.text( + 0.5, + 0.5, + "No data", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=10, + ) + ax.set_title(f"{label}", fontweight="bold", fontsize=10) + + # Hide unused subplots + for idx in range(n_classes, len(axes)): + axes[idx].axis("off") + + plt.tight_layout() + plt.savefig( + output_dir / f"{dataset_name}_region_counts.pdf", dpi=300, bbox_inches="tight" + ) + plt.close() + + +def save_summary_stats( + real_metrics: DatasetMetrics, + synth_metrics: DatasetMetrics, + output_dir: Path, + dataset_name: str, +): + """Save summary statistics to text file""" + with open(output_dir / "summary_stats.txt", "w") as f: + f.write("=" * 80 + "\n") + f.write(f"DATASET COMPARISON SUMMARY: {dataset_name}\n") + f.write("=" * 80 + "\n\n") + + # Page coverage + f.write("PAGE COVERAGE:\n") + f.write( + f" Real - Mean: {np.mean(real_metrics.page_coverages):.3f}, " + f"Std: {np.std(real_metrics.page_coverages):.3f}, " + f"Median: {np.median(real_metrics.page_coverages):.3f}\n" + ) + f.write( + f" Synth - Mean: {np.mean(synth_metrics.page_coverages):.3f}, " + f"Std: {np.std(synth_metrics.page_coverages):.3f}, " + f"Median: {np.median(synth_metrics.page_coverages):.3f}\n\n" + ) + + # Region counts + f.write("AVERAGE REGION COUNTS PER DOCUMENT:\n") + all_labels = sorted( + set(real_metrics.region_counts.keys()) + | set(synth_metrics.region_counts.keys()) + ) + for label in all_labels: + real_mean = np.mean(real_metrics.region_counts.get(label, [0])) + synth_mean = np.mean(synth_metrics.region_counts.get(label, [0])) + real_std = np.std(real_metrics.region_counts.get(label, [0])) + synth_std = np.std(synth_metrics.region_counts.get(label, [0])) + f.write( + f" {label:30s} - Real: {real_mean:6.2f}±{real_std:5.2f}, " + f"Synth: {synth_mean:6.2f}±{synth_std:5.2f}\n" + ) + + f.write("\n" + "=" * 80 + "\n") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="DLA GT Comparison", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "synthdataset", + type=str, + help="Name of the synthetic dataset", + ) + + args = parser.parse_args() + return args + + +def main(): + """Main comparison workflow""" + args = parse_args() + synth_dataset_name = args.synthdataset + + # Setup output directory + output_dir = ENV.DLA_GT_ANALYZATION_DIR / synth_dataset_name + output_dir.mkdir(parents=True, exist_ok=True) + + print("Loading datasets...") + base_dataset_name = get_base_dataset_name(synth_dataset_name) + base_dataset = load_dataset(base_dataset_name, is_synthetic=False) + print(base_dataset.metadata.dataset_labels) + synth_dataset = load_dataset(synth_dataset_name, is_synthetic=True) + + deffile = ENV.SYN_DATA_DEFINITIONS_DIR / f"{synth_dataset_name}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file(deffile) + label_mapping = dsdef.label_mapping + + print("Extracting metrics from real dataset...") + real_metrics = extract_metrics(base_dataset, label_mapping=None) + + print("Extracting metrics from synthetic dataset...") + synth_metrics = extract_metrics(synth_dataset, label_mapping=label_mapping) + + print("Generating visualizations...") + + # print(" - Size distribution comparison...") + # plot_size_distribution_comparison( + # real_metrics, synth_metrics, output_dir, synth_dataset_name + # ) + + print(" - Spatial heatmaps...") + plot_spatial_heatmaps(real_metrics, synth_metrics, output_dir, synth_dataset_name) + + print(" - Region count distributions...") + plot_region_counts(real_metrics, synth_metrics, output_dir, synth_dataset_name) + + print("Saving summary statistics...") + save_summary_stats(real_metrics, synth_metrics, output_dir, synth_dataset_name) + + print(f"\nAnalysis complete! Results saved to: {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/docgenie/analyzation/gt/embeddings_qa.py b/docgenie/analyzation/gt/embeddings_qa.py new file mode 100755 index 0000000000000000000000000000000000000000..59287f63adb067e7984566dcd1f453bd5359990f --- /dev/null +++ b/docgenie/analyzation/gt/embeddings_qa.py @@ -0,0 +1,263 @@ +""" +TODO: include answers in QA GT embeddings? +""" + +from __future__ import annotations +import h5py +import argparse +from pathlib import Path +from typing import TYPE_CHECKING, Callable, TypeVar +import numpy as np +import tqdm +from docgenie import ENV +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data._core._data_types import DocumentInstanceModelInput +from docgenie.data.interfaces.dataset import load_dataset +from docgenie.logging import get_logger +from atria_core.types.data_instance.base import ( + BaseDataInstance, +) +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data.interfaces.data_pipeline import ( + load_preprocessed_data_pipeline, +) +from typing import Literal +from docgenie.data._core._utilities import TaskType +from docgenie.data.interface import load_transform +from docgenie.generation.models import ( + SyntheticDatasetFileStructure, + SynDatasetDefinition, +) + +import numpy as np +from torch.utils.data import DataLoader + +T_BaseDataInstance = TypeVar("T_BaseDataInstance", bound=BaseDataInstance) + + +logger = get_logger(__name__) + + +def _iterate_dataset( + model_fn: Callable, + embedding_fn: Callable, + dataloader: "DataLoader", + device: str = "cuda", +): + """Inner function that actually generates the embeddings.""" + import torch + + model = model_fn() + model.to(device) + model.eval() + print("Model is on:", next(model.parameters()).device) + + sample_ids = [] + doc_ids = [] + questions = [] + answers = [] + embeddings = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader, desc="Extracting embeddings"): + embeddings.append(embedding_fn(model, batch)) + sample_ids.extend(batch["sample_ids"]) + doc_ids.extend(batch["doc_ids"]) + questions.extend(batch["questions"]) + answers.extend(batch["answers"]) + + embeddings = torch.cat(embeddings, dim=0) + return embeddings.cpu().numpy(), questions, answers, sample_ids, doc_ids + + +def _extract_text_embeddings(dataloader: "DataLoader", device: str = "cuda"): + """Inner function that actually generates the embeddings.""" + + def model_fn(): + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("all-mpnet-base-v2") + model.to(device) + model.eval() + return model + + def embedding_fn(model, inputs): + sentences = [qa_question for qa_question in inputs["questions"]] + return model.encode(sentences, convert_to_tensor=True) + + question_embeddings, questions, answers, sample_ids, doc_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + def qa_embedding_fn(model, inputs): + sentences = [ + f"Question: {q} Answer: {a}" + for q, a in zip(inputs["questions"], inputs["answers"]) + ] + return model.encode(sentences, convert_to_tensor=True) + + qa_embeddings, questions, answers, sample_ids, doc_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=qa_embedding_fn, + dataloader=dataloader, + device=device, + ) + + return question_embeddings, qa_embeddings, questions, answers, sample_ids, doc_ids + + +def extract_embeddings( # MsgpackDatasetReader[T_BaseDataInstance] | None + dataloader: "DataLoader", + output_dir: Path, + device: str = "cuda", +): + q_path = output_dir / f"Q.h5" + qa_path = output_dir / f"QA.h5" + + if q_path.exists() and qa_path.exists(): + print(f"Found existing QA embeddings at {q_path} and {qa_path} - SKIPPING") + else: + extraction_func = _extract_text_embeddings + question_embeddings, qa_embeddings, questions, answers, sample_ids, doc_ids = ( + extraction_func(dataloader, device) + ) + + _save_embeddings( + embeddings=question_embeddings, + questions=questions, + answers=answers, + sample_ids=sample_ids, + document_ids=doc_ids, + file_path=q_path, + ) + + _save_embeddings( + embeddings=qa_embeddings, + questions=questions, + answers=answers, + sample_ids=sample_ids, + document_ids=doc_ids, + file_path=Path(output_dir) / f"QA.h5", + ) + + +def _save_embeddings( + embeddings: "np.ndarray", + questions: list[str], + answers: list[str], + sample_ids: list[str], + document_ids: list[str], + file_path: Path, +): + file_path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(file_path, "w") as f: + f.create_dataset("embeddings", data=embeddings) + f.create_dataset("questions", data=questions) + f.create_dataset("answers", data=answers) + f.create_dataset("sample_ids", data=sample_ids) + f.create_dataset("document_ids", data=document_ids) + + +def load_qa_embeddings(dataset_name: str, embedding_type: Literal["Q", "QA"]): + file_path: Path = ENV.GT_EMBEDDINGS_DIR / dataset_name / f"{embedding_type}.h5" + print(f"Loading embeddings from {file_path}") + + def decode_str_collection(col): + return [s.decode("utf-8") if isinstance(s, bytes) else s for s in col] + + with h5py.File(file_path, "r") as f: + embeddings = f["embeddings"][:] + questions = f["questions"][:] + answers = f["answers"][:] + sample_ids = f["sample_ids"][:] + doc_ids = f["document_ids"][:] + + return ( + embeddings, + decode_str_collection(questions), + decode_str_collection(answers), + decode_str_collection(sample_ids), + decode_str_collection(doc_ids), + ) + + +def collate_fn_extract_questions(batch): + all_questions = [] + all_answers = [] + all_doc_ids = [] + all_sample_ids = [] + + for doc in batch: + for a in doc.annotations: + for qa in a.qa_pairs: + all_questions.append(qa.question_text) + all_answers.append(qa.answer_text[0]) + all_doc_ids.append(doc.sample_id) + all_sample_ids.append(f"{doc.sample_id}_{qa.id}") + + return { + "questions": all_questions, + "answers": all_answers, + "sample_ids": all_sample_ids, + "doc_ids": all_doc_ids, + } + + +def main(dataset_name: str, is_synth: bool): + if is_synth: + ymal_file = ENV.SYN_DATA_DEFINITIONS_DIR / f"{dataset_name}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file( + yaml_path=ymal_file + ) + # prepare_synthetic_dataset(dsdef=dsdef) + + # data_pipeline = load_preprocessed_data_pipeline( + # dataset_name=dataset_name, + # # task_type=TaskType.generate_embeddings, + # is_synthetic=is_synth, + # ) + # train_dataloader = data_pipeline.train_dataloader(batch_size=512, num_workers=2) + + dataset = load_dataset(dataset_name=dataset_name, is_synthetic=is_synth) + train_dataloader = DataLoader( + dataset=dataset.train, + batch_size=512, + num_workers=0, + collate_fn=collate_fn_extract_questions, + ) + + output_dir = ENV.GT_EMBEDDINGS_DIR / dataset_name + extract_embeddings( + dataloader=train_dataloader, + output_dir=output_dir, + ) + + # embeddings, questions, answers, sample_ids, doc_ids = load_qa_embeddings( + # dataset_name=dataset_name, embedding_type="Q" + # ) + # for e, q, a, s, d in zip(embeddings, questions, answers, sample_ids, doc_ids): + # print(e, q, a, s, d) + # input() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate GT embeddings") + parser.add_argument( + "dataset", + type=str, + help="Name of the dataset (e.g., docvqa, mysynthetic, pubtabnet)", + ) + + parser.add_argument( + "--is_synth", + action="store_true", + help="If set, determines that the dataset is a synthetic dataset", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(dataset_name=args.dataset, is_synth=args.is_synth) diff --git a/docgenie/analyzation/gt/embeddings_qa_using_datapipeline.py b/docgenie/analyzation/gt/embeddings_qa_using_datapipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..2cef91050ee115c0df3dfd26879a9a5fef9177ac --- /dev/null +++ b/docgenie/analyzation/gt/embeddings_qa_using_datapipeline.py @@ -0,0 +1,291 @@ +""" +TODO: include answers in QA GT embeddings? +""" + +from __future__ import annotations +import h5py +import argparse +from pathlib import Path +from typing import TYPE_CHECKING, Callable, TypeVar +import numpy as np +import tqdm +from docgenie import ENV +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data._core._data_types import DocumentInstanceModelInput +from docgenie.logging import get_logger +from atria_core.types.data_instance.base import ( + BaseDataInstance, +) +from docgenie.analyzation.clustering.core._utilities import EmbeddingType +from docgenie.data.interfaces.synthetic_data import ( + prepare_synthetic_dataset, +) +from docgenie.data.interfaces.data_pipeline import ( + load_preprocessed_data_pipeline, +) +from typing import Literal +from docgenie.data._core._utilities import TaskType +from docgenie.data.interface import load_transform +from docgenie.generation.models import ( + SyntheticDatasetFileStructure, + SynDatasetDefinition, +) +from docgenie.data._core._dataset import Dataset +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader + +T_BaseDataInstance = TypeVar("T_BaseDataInstance", bound=BaseDataInstance) +if TYPE_CHECKING: + import numpy as np + from torch.utils.data import DataLoader + +logger = get_logger(__name__) + + +def _iterate_dataset( + model_fn: Callable, + embedding_fn: Callable, + dataloader: "DataLoader", + device: str = "cuda", +): + """Inner function that actually generates the embeddings.""" + import torch + + model = model_fn() + model.to(device) + model.eval() + print("Model is on:", next(model.parameters()).device) + + sample_ids = [] + embeddings = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataloader, desc="Extracting embeddings"): + batch_dict = batch.to_dict() + batch: DocumentInstanceModelInput + batch = batch.select_first_overflow_samples() + batch = batch.to(device) + + token_bboxes = batch.token_bboxes + if token_bboxes is not None: + if token_bboxes.min() >= 0 and token_bboxes.max() <= 1.0: + # if bboxes are normalized to [0, 1], convert to [0, 1000] as expected by layoutlmv3 + token_bboxes = (token_bboxes * 1000).long() + else: + logger.warning( + f"Token bboxes must be in the range [0, 1], but got min {token_bboxes.min()} and max {token_bboxes.max()}" + ) + token_bboxes = (token_bboxes.clip(0, 1.0) * 1000).long() + + # assert check + assert token_bboxes.min() >= 0 and token_bboxes.max() <= 1000, ( + f"Token bboxes must be in the range [0, 1000], but got min {token_bboxes.min()} and max {token_bboxes.max()}" + ) + + # make sure if image is normlized 0-1 as in layoutlm we renormalize using clip stats + assert batch.image.min() >= -1.1 and batch.image.max() <= 1.1, ( + f"Image pixel values must be in the range [0, 1], but got min {batch.image.min()} and max {batch.image.max()}" + ) + + # make inputs + inputs = dict( + qa_answers=batch.qa_answers, + qa_question=batch.qa_question, + sample_ids=batch.sample_id, + ) + + embeddings.append(embedding_fn(model, inputs)) + + # in our preprocessed dataset indices are always unqiue + # but sample_ids may not be always unique in some rare cases + sample_ids.extend(batch.sample_id) + + embeddings = torch.cat(embeddings, dim=0) + return embeddings.cpu().numpy(), sample_ids + + +def _extract_text_embeddings( + dataloader: "DataLoader", + device: str = "cuda", +): + """Inner function that actually generates the embeddings.""" + + def model_fn(): + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer("all-mpnet-base-v2") + model.to(device) + model.eval() + return model + + print("Extracting embeddings only for Questions.................") + + def embedding_fn(model, inputs): + sentences = [qa_question for qa_question in inputs["qa_question"]] + return model.encode(sentences, convert_to_tensor=True) + + question_embeddings, question_sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=embedding_fn, + dataloader=dataloader, + device=device, + ) + + print("Extracting embeddings for both Questions and Answers...............") + + def qa_embedding_fn(model, inputs): + """I asked gpt and It said this type of approach is common in SBERT/Text-encoders""" + sentences = [ + f"Question: {q} Answer: {a}" + for q, a in zip(inputs["qa_question"], inputs["qa_answers"]) + ] + return model.encode(sentences, convert_to_tensor=True) + + qa_embeddings, qa_sample_ids = _iterate_dataset( + model_fn=model_fn, + embedding_fn=qa_embedding_fn, + dataloader=dataloader, + device=device, + ) + + return dict( + question_embeddings=question_embeddings, + question_sample_ids=question_sample_ids, + qa_embeddings=qa_embeddings, + qa_sample_ids=qa_sample_ids, + ) + + +def embedding_extraction_with_cache( # MsgpackDatasetReader[T_BaseDataInstance] | None + dataloader: "DataLoader", + output_dir: str | Path, + embedding_type: EmbeddingType, + device: str = "cuda", + cache_outputs: bool = True, + load_embeddings: Literal[ + "question_only", "QA" + ] = "question_only", # used to load embeddings from chache +): + """By default it returns question only embeddings""" + """Generic cacher function that handles caching logic for any embedding type.""" + if load_embeddings == "QA": + cache_file = Path(output_dir) / f"QA_{embedding_type.value}.h5" + elif load_embeddings == "question_only": + cache_file = Path(output_dir) / f"Q_{embedding_type.value}.h5" + + if cache_outputs and cache_file.exists(): + logger.info( + f"Loading cached {load_embeddings}_{embedding_type.value} embeddings from {cache_file}" + ) + return _load_embeddings(cache_file) + + extraction_func = _extract_text_embeddings + all_embeddings = extraction_func(dataloader, device) + + # Question only embeddings + question_embeddings = all_embeddings["question_embeddings"] + question_sample_ids = all_embeddings["question_sample_ids"] + + # Question + Answer embeddings + qa_embeddings = all_embeddings["qa_embeddings"] + qa_sample_ids = all_embeddings["qa_sample_ids"] + + if cache_outputs: + """Checking that embeddings and sample_ids have same length""" + assert len(question_sample_ids) == question_embeddings.shape[0], logger.warning( + f"[Error in Questuion only Embedding] Number of sample IDs ({len(question_sample_ids)}) must match number of embeddings ({question_embeddings.shape[0]})" + ) + + assert len(qa_sample_ids) == qa_embeddings.shape[0], logger.warning( + f"[Error in QA Embedding] Number of sample IDs ({len(qa_sample_ids)}) must match number of embeddings ({qa_embeddings.shape[0]})" + ) + """Checking that sample_ids are unique""" + assert len(set(question_sample_ids)) == len(question_sample_ids), ( + logger.warning( + "[ERROR in Question only Embedding] Sample IDs must be unique" + ) + ) + assert len(set(qa_sample_ids)) == len(qa_sample_ids), logger.warning( + "[ERROR in QA Embedding] Sample IDs must be unique" + ) + """Saving question only embeddings""" + _save_embeddings( + embeddings=question_embeddings, + sample_ids=question_sample_ids, + file_path=Path(output_dir) / f"Q_{embedding_type.value}.h5", + ) + """Saving QA only embeddings""" + _save_embeddings( + embeddings=qa_embeddings, + sample_ids=qa_sample_ids, + file_path=Path(output_dir) / f"QA_{embedding_type.value}.h5", + ) + return _load_embeddings(cache_file) + + return question_embeddings, question_sample_ids + + +def _save_embeddings(embeddings: "np.ndarray", sample_ids: list[str], file_path: Path): + import h5py + + file_path.parent.mkdir(parents=True, exist_ok=True) + with h5py.File(file_path, "w") as f: + f.create_dataset("embeddings", data=embeddings) + f.create_dataset("sample_ids", data=sample_ids) + + +def _load_embeddings(file_path: Path): + import h5py + + print(f"Loading embeddings from {file_path}") + + with h5py.File(file_path, "r") as f: + sample_ids = f["sample_ids"][:] + embeddings = f["embeddings"][:] + return embeddings, [ + s.decode("utf-8") if isinstance(s, bytes) else s for s in sample_ids + ] + + +def main(dataset_name: str, is_synth: bool): + if is_synth: + ymal_file = ENV.SYN_DATA_DEFINITIONS_DIR / f"{dataset_name}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file( + yaml_path=ymal_file + ) + prepare_synthetic_dataset(dsdef=dsdef) + + data_pipeline = load_preprocessed_data_pipeline( + dataset_name=dataset_name, + # task_type=TaskType.generate_embeddings, + is_synthetic=is_synth, + ) + + train_dataloader = data_pipeline.train_dataloader(batch_size=512, num_workers=2) + + output_dir = ENV.GT_EMBEDDINGS_DIR / dataset_name + embedding, sample_ids = embedding_extraction_with_cache( + dataloader=train_dataloader, + output_dir=output_dir, + embedding_type=EmbeddingType.text, + ) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate GT embeddings") + parser.add_argument( + "--dataset", + type=str, + required=True, + help="Name of the dataset (e.g., docvqa, mysynthetic, pubtabnet)", + ) + + parser.add_argument( + "--is_synth", + action="store_true", + help="If set, determines that the dataset is a synthetic dataset", + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + main(dataset_name=args.dataset, is_synth=args.is_synth) diff --git a/docgenie/analyzation/gt/kie/kie_gt_analysis.py b/docgenie/analyzation/gt/kie/kie_gt_analysis.py new file mode 100755 index 0000000000000000000000000000000000000000..83ace6ae524fd883ed9b5711fa9454519f91c848 --- /dev/null +++ b/docgenie/analyzation/gt/kie/kie_gt_analysis.py @@ -0,0 +1,568 @@ +""" +Compare KIE Ground Truth between Synthetic and Real Datasets +For CVPR paper on synthesis of document understanding datasets +""" + +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from collections import defaultdict, Counter +from typing import List, Dict, Tuple +import pandas as pd +from scipy import stats + +from docgenie import ENV +from docgenie.analyzation.gt.webapp import get_base_dataset_name +from docgenie.data.interfaces.dataset import load_dataset +from docgenie.generation.models._syndatadef import SynDatasetDefinition + + +# Set publication-quality style +sns.set_style("whitegrid") +sns.set_context("paper", font_scale=1.3) +plt.rcParams["font.family"] = "serif" +plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"] +plt.rcParams["figure.dpi"] = 300 +plt.rcParams["savefig.dpi"] = 300 +plt.rcParams["savefig.bbox"] = "tight" + + +def parse_bio_tags_to_entities( + word_labels_names: List[str], +) -> List[Tuple[str, int, int]]: + """ + Parse BIO tags to extract complete entities. + + Args: + word_labels_names: List of BIO tags (e.g., ['B-HEADER', 'I-HEADER', 'B-QUESTION']) + + Returns: + List of tuples (entity_class, start_idx, end_idx) + """ + entities = [] + current_entity = None + current_start = None + + for idx, label in enumerate(word_labels_names): + if label.startswith("B-"): + # Save previous entity if exists + if current_entity is not None: + entities.append((current_entity, current_start, idx - 1)) + + # Start new entity + current_entity = label[2:] # Remove 'B-' prefix + current_start = idx + elif label.startswith("I-"): + # Continue current entity (if it matches) + entity_class = label[2:] # Remove 'I-' prefix + if current_entity is None or current_entity != entity_class: + # Start new entity if no current or mismatch + if current_entity is not None: + entities.append((current_entity, current_start, idx - 1)) + current_entity = entity_class + current_start = idx + else: + # 'O' tag or other - end current entity + if current_entity is not None: + entities.append((current_entity, current_start, idx - 1)) + current_entity = None + current_start = None + + # Don't forget last entity + if current_entity is not None: + entities.append((current_entity, current_start, len(word_labels_names) - 1)) + + return entities + + +def get_entity_spatial_info( + entities: List[Tuple[str, int, int]], word_bboxes: List[List[float]] +) -> Dict[str, List[Tuple[float, float]]]: + """ + Extract spatial information (centers) for each entity class. + + Args: + entities: List of (entity_class, start_idx, end_idx) + word_bboxes: List of normalized bboxes in XYXY format + + Returns: + Dict mapping entity_class to list of (x_center, y_center) positions + """ + spatial_info = defaultdict(list) + + for entity_class, start_idx, end_idx in entities: + # Get all bboxes for this entity + entity_bboxes = word_bboxes[start_idx : end_idx + 1] + + # Calculate entity center (average of all word centers) + x_centers = [(bbox[0] + bbox[2]) / 2 for bbox in entity_bboxes] + y_centers = [(bbox[1] + bbox[3]) / 2 for bbox in entity_bboxes] + + entity_x_center = np.mean(x_centers) + entity_y_center = np.mean(y_centers) + + spatial_info[entity_class].append((entity_x_center, entity_y_center)) + + return spatial_info + + +def analyze_dataset( + dataset, is_synthetic=False, label_mapping: dict[str, str] = None +) -> Dict: + """ + Analyze a dataset and extract statistics. + + Returns: + Dictionary with various statistics + """ + stats_dict = { + "entity_counts": defaultdict(int), + "entity_counts_per_sample": defaultdict(list), + "spatial_distributions": defaultdict(list), + "entity_lengths": defaultdict(list), # Number of words per entity + "total_samples": 0, + "total_entities": 0, + } + + for sample in dataset.train: + stats_dict["total_samples"] += 1 + + # Get word labels and bboxes + annotation = sample.annotations[0] # EntityLabelingAnnotation + word_labels_names = annotation.word_labels.name + word_bboxes = sample.content.word_bboxes.value + + # Parse entities + entities = parse_bio_tags_to_entities(word_labels_names) + + if label_mapping is not None and len(label_mapping) > 0: + entities = [(label_mapping[e], _s, _e) for (e, _s, _e) in entities] + + # Count entities per class in this sample + sample_entity_counts = Counter([e[0] for e in entities]) + + for entity_class, count in sample_entity_counts.items(): + stats_dict["entity_counts"][entity_class] += count + stats_dict["entity_counts_per_sample"][entity_class].append(count) + + # Add zeros for missing classes in this sample + all_classes = set(stats_dict["entity_counts"].keys()) + for entity_class in all_classes: + if entity_class not in sample_entity_counts: + stats_dict["entity_counts_per_sample"][entity_class].append(0) + + # Get spatial info + spatial_info = get_entity_spatial_info(entities, word_bboxes) + for entity_class, positions in spatial_info.items(): + stats_dict["spatial_distributions"][entity_class].extend(positions) + + # Entity lengths + for entity_class, start_idx, end_idx in entities: + length = end_idx - start_idx + 1 + stats_dict["entity_lengths"][entity_class].append(length) + + stats_dict["total_entities"] += len(entities) + + return stats_dict + + +def plot_entity_distribution_comparison( + real_stats: Dict, synth_stats: Dict, output_prefix: str +): + """ + Plot comparison of entity class distributions. + """ + all_classes = sorted( + set( + list(real_stats["entity_counts"].keys()) + + list(synth_stats["entity_counts"].keys()) + ) + ) + + real_counts = [real_stats["entity_counts"].get(cls, 0) for cls in all_classes] + synth_counts = [synth_stats["entity_counts"].get(cls, 0) for cls in all_classes] + + # Normalize to percentages + real_total = sum(real_counts) + synth_total = sum(synth_counts) + real_pcts = [c / real_total * 100 for c in real_counts] + synth_pcts = [c / synth_total * 100 for c in synth_counts] + + # Create DataFrame for seaborn + df_pct = pd.DataFrame( + { + "Entity Class": all_classes * 2, + "Percentage": real_pcts + synth_pcts, + "Dataset": ["Real"] * len(all_classes) + ["Synthetic"] * len(all_classes), + } + ) + + # Plot with seaborn + fig, ax = plt.subplots(1, 1, figsize=(7, 5)) + + # Color palette + palette = sns.color_palette("Set2", 2) + + # Percentages + sns.barplot( + data=df_pct, + x="Entity Class", + y="Percentage", + hue="Dataset", + palette=palette, + ax=ax, + alpha=0.85, + ) + ax.set_xlabel("") + ax.set_ylabel("") + ax.tick_params(axis="x", rotation=90) + plt.setp(ax.xaxis.get_majorticklabels(), rotation=90, ha="center") + + # Format y-axis to show percentage symbol + ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: f"{int(y)}%")) + + ax.legend(frameon=True, loc="upper right", fontsize=11) + ax.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax.set_axisbelow(True) + + plt.tight_layout() + plt.savefig( + ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_distribution_comparison.pdf", + dpi=300, + bbox_inches="tight", + ) + print( + f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_distribution_comparison.pdf" + ) + plt.close() + + +def plot_spatial_heatmaps(real_stats: Dict, synth_stats: Dict, output_prefix: str): + """ + Plot spatial heatmaps showing where entities appear on the page. + """ + all_classes = sorted( + set( + list(real_stats["spatial_distributions"].keys()) + + list(synth_stats["spatial_distributions"].keys()) + ) + ) + + n_classes = len(all_classes) + n_cols = min(4, n_classes) + n_rows = (n_classes + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols * 2, figsize=(6 * n_cols, 4 * n_rows)) + if n_rows == 1 and n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes.reshape(1, -1) + elif n_cols == 1: + axes = axes.reshape(-1, 1) + + # Use better colormap + cmap = sns.color_palette("rocket_r", as_cmap=True) + + for idx, entity_class in enumerate(all_classes): + row = idx // n_cols + col_base = (idx % n_cols) * 2 + + # Real data heatmap + real_positions = real_stats["spatial_distributions"].get(entity_class, []) + if len(real_positions) > 0: + x_coords = [pos[0] for pos in real_positions] + y_coords = [pos[1] for pos in real_positions] + + # Create 2D histogram (heatmap) + heatmap, xedges, yedges = np.histogram2d( + x_coords, y_coords, bins=20, range=[[0, 1], [0, 1]] + ) + + im1 = axes[row, col_base].imshow( + heatmap.T, + origin="upper", + cmap=cmap, + extent=[0, 1, 0, 1], + aspect="auto", + interpolation="bilinear", + ) + axes[row, col_base].set_title( + f"{entity_class}\n(Real, n={len(real_positions)})", + fontweight="bold", + fontsize=12, + pad=10, + ) + axes[row, col_base].set_xlabel("X Position", fontsize=11) + axes[row, col_base].set_ylabel("Y Position", fontsize=11) + cbar1 = plt.colorbar(im1, ax=axes[row, col_base], fraction=0.046, pad=0.04) + cbar1.ax.tick_params(labelsize=9) + else: + axes[row, col_base].text( + 0.5, 0.5, "No data", ha="center", va="center", fontsize=12 + ) + axes[row, col_base].set_title( + f"{entity_class}\n(Real, n=0)", fontweight="bold", fontsize=12, pad=10 + ) + axes[row, col_base].set_xticks([]) + axes[row, col_base].set_yticks([]) + + # Synthetic data heatmap + synth_positions = synth_stats["spatial_distributions"].get(entity_class, []) + if len(synth_positions) > 0: + x_coords = [pos[0] for pos in synth_positions] + y_coords = [pos[1] for pos in synth_positions] + + heatmap, xedges, yedges = np.histogram2d( + x_coords, y_coords, bins=20, range=[[0, 1], [0, 1]] + ) + + im2 = axes[row, col_base + 1].imshow( + heatmap.T, + origin="upper", + cmap=cmap, + extent=[0, 1, 0, 1], + aspect="auto", + interpolation="bilinear", + ) + axes[row, col_base + 1].set_title( + f"{entity_class}\n(Synth, n={len(synth_positions)})", + fontweight="bold", + fontsize=12, + pad=10, + ) + axes[row, col_base + 1].set_xlabel("X Position", fontsize=11) + axes[row, col_base + 1].set_ylabel("Y Position", fontsize=11) + cbar2 = plt.colorbar( + im2, ax=axes[row, col_base + 1], fraction=0.046, pad=0.04 + ) + cbar2.ax.tick_params(labelsize=9) + else: + axes[row, col_base + 1].text( + 0.5, 0.5, "No data", ha="center", va="center", fontsize=12 + ) + axes[row, col_base + 1].set_title( + f"{entity_class}\n(Synth, n=0)", fontweight="bold", fontsize=12, pad=10 + ) + axes[row, col_base + 1].set_xticks([]) + axes[row, col_base + 1].set_yticks([]) + + # Hide empty subplots + for idx in range(n_classes, n_rows * n_cols): + row = idx // n_cols + col_base = (idx % n_cols) * 2 + axes[row, col_base].axis("off") + axes[row, col_base + 1].axis("off") + + plt.tight_layout() + plt.savefig( + ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_spatial_heatmaps.pdf", + dpi=300, + bbox_inches="tight", + ) + print(f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_spatial_heatmaps.pdf") + plt.close() + + +def plot_entity_length_comparison( + real_stats: Dict, synth_stats: Dict, output_prefix: str +): + """ + Compare distribution of entity lengths (number of words per entity). + """ + all_classes = sorted( + set( + list(real_stats["entity_lengths"].keys()) + + list(synth_stats["entity_lengths"].keys()) + ) + ) + + n_classes = len(all_classes) + n_cols = min(3, n_classes) + n_rows = (n_classes + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows)) + if n_classes == 1: + axes = np.array([axes]) + axes = axes.flatten() + + # Color palette + colors = sns.color_palette("Set2", 2) + + for idx, entity_class in enumerate(all_classes): + real_lengths = real_stats["entity_lengths"].get(entity_class, []) + synth_lengths = synth_stats["entity_lengths"].get(entity_class, []) + + ax = axes[idx] + + if len(real_lengths) > 0 or len(synth_lengths) > 0: + # Plot histograms + bins = ( + np.arange( + 1, + max(max(real_lengths, default=1), max(synth_lengths, default=1)) + + 2, + ) + - 0.5 + ) + + ax.hist( + real_lengths, + bins=bins, + alpha=0.7, + label="Real", + density=True, + color=colors[0], + edgecolor="black", + linewidth=0.5, + ) + ax.hist( + synth_lengths, + bins=bins, + alpha=0.7, + label="Synthetic", + density=True, + color=colors[1], + edgecolor="black", + linewidth=0.5, + ) + + # Add statistics + real_mean = np.mean(real_lengths) if len(real_lengths) > 0 else 0 + synth_mean = np.mean(synth_lengths) if len(synth_lengths) > 0 else 0 + + ax.axvline( + real_mean, color=colors[0], linestyle="--", linewidth=2.5, alpha=0.8 + ) + ax.axvline( + synth_mean, color=colors[1], linestyle="--", linewidth=2.5, alpha=0.8 + ) + + ax.set_title( + f"{entity_class}\nReal μ={real_mean:.2f}, Synth μ={synth_mean:.2f}", + fontweight="bold", + fontsize=12, + pad=10, + ) + ax.set_xlabel("Entity Length (words)", fontsize=11, fontweight="bold") + ax.set_ylabel("Density", fontsize=11, fontweight="bold") + ax.legend(frameon=True, loc="best", fontsize=10) + ax.grid(axis="y", alpha=0.3, linestyle="--", linewidth=0.5) + ax.set_axisbelow(True) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + else: + ax.text( + 0.5, + 0.5, + "No data", + ha="center", + va="center", + transform=ax.transAxes, + fontsize=12, + ) + ax.set_title(f"{entity_class}", fontweight="bold", fontsize=12, pad=10) + ax.set_xticks([]) + ax.set_yticks([]) + + # Hide empty subplots + for idx in range(n_classes, len(axes)): + axes[idx].axis("off") + + plt.tight_layout() + plt.savefig( + ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_entity_length_comparison.png", + dpi=300, + bbox_inches="tight", + ) + print( + f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_entity_length_comparison.png" + ) + plt.close() + + +def compute_statistical_tests(real_stats: Dict, synth_stats: Dict) -> pd.DataFrame: + """ + Compute statistical tests to compare distributions. + """ + all_classes = sorted( + set( + list(real_stats["entity_counts"].keys()) + + list(synth_stats["entity_counts"].keys()) + ) + ) + + results = [] + + for entity_class in all_classes: + real_counts_per_sample = real_stats["entity_counts_per_sample"].get( + entity_class, [] + ) + synth_counts_per_sample = synth_stats["entity_counts_per_sample"].get( + entity_class, [] + ) + + # Mann-Whitney U test (non-parametric) + if len(real_counts_per_sample) > 0 and len(synth_counts_per_sample) > 0: + statistic, p_value = stats.mannwhitneyu( + real_counts_per_sample, synth_counts_per_sample, alternative="two-sided" + ) + else: + statistic, p_value = np.nan, np.nan + + # KS test for spatial distributions + real_spatial = real_stats["spatial_distributions"].get(entity_class, []) + synth_spatial = synth_stats["spatial_distributions"].get(entity_class, []) + + ks_x, ks_y = np.nan, np.nan + if len(real_spatial) > 0 and len(synth_spatial) > 0: + real_x = [pos[0] for pos in real_spatial] + synth_x = [pos[0] for pos in synth_spatial] + real_y = [pos[1] for pos in real_spatial] + synth_y = [pos[1] for pos in synth_spatial] + + ks_x, _ = stats.ks_2samp(real_x, synth_x) + ks_y, _ = stats.ks_2samp(real_y, synth_y) + + results.append( + { + "Entity Class": entity_class, + "Real Count": real_stats["entity_counts"].get(entity_class, 0), + "Synth Count": synth_stats["entity_counts"].get(entity_class, 0), + "Real Mean/Sample": np.mean(real_counts_per_sample) + if real_counts_per_sample + else 0, + "Synth Mean/Sample": np.mean(synth_counts_per_sample) + if synth_counts_per_sample + else 0, + "Mann-Whitney p-value": p_value, + "KS Stat (X)": ks_x, + "KS Stat (Y)": ks_y, + } + ) + + return pd.DataFrame(results) + + +def print_summary_statistics(real_stats: Dict, synth_stats: Dict): + """ + Print summary statistics. + """ + print("\n" + "=" * 80) + print("SUMMARY STATISTICS") + print("=" * 80) + + print(f"\nReal Dataset:") + print(f" Total samples: {real_stats['total_samples']}") + print(f" Total entities: {real_stats['total_entities']}") + print( + f" Avg entities/sample: {real_stats['total_entities'] / real_stats['total_samples']:.2f}" + ) + + print(f"\nSynthetic Dataset:") + print(f" Total samples: {synth_stats['total_samples']}") + print(f" Total entities: {synth_stats['total_entities']}") + print( + f" Avg entities/sample: {synth_stats['total_entities'] / synth_stats['total_samples']:.2f}" + ) + + print("\n" + "=" * 80) diff --git a/docgenie/analyzation/gt/kie/kie_gt_analysis_full.py b/docgenie/analyzation/gt/kie/kie_gt_analysis_full.py new file mode 100755 index 0000000000000000000000000000000000000000..6dcbd659cd79f225a567510be7ce6e80201a71d4 --- /dev/null +++ b/docgenie/analyzation/gt/kie/kie_gt_analysis_full.py @@ -0,0 +1,211 @@ +""" +Complete KIE GT Comparison Pipeline +Example usage with all metrics and visualizations +""" + +import argparse +from docgenie import ENV +from docgenie.analyzation.gt.kie.kie_gt_analysis import ( + analyze_dataset, + plot_entity_distribution_comparison, + plot_spatial_heatmaps, + plot_entity_length_comparison, + compute_statistical_tests, + print_summary_statistics, +) + +from docgenie.analyzation.gt.kie.kie_gt_analysis_utils import ( + compute_jensen_shannon_divergence, + compute_spatial_coverage_metrics, + plot_entity_co_occurrence_matrix, + plot_document_level_statistics, + generate_latex_table, + comprehensive_analysis, +) +from docgenie.analyzation.gt.webapp import get_base_dataset_name +from docgenie.data.interfaces.dataset import load_dataset +from docgenie.generation.models._syndatadef import SynDatasetDefinition + + +def full_comparison_pipeline( + synth_dataset_name: str, +): + """ + Complete comparison pipeline with all metrics and visualizations. + + Args: + synth_dataset_name: Name of synthetic dataset + get_base_dataset_name_func: Function to get base dataset name + load_dataset_func: Function to load datasets + output_prefix: Prefix for output files + """ + + print("=" * 80) + print("COMPLETE KIE GROUND TRUTH COMPARISON PIPELINE") + print("=" * 80) + + # ========== STEP 1: Load Datasets ========== + print("\n[1/6] Loading datasets...") + base_dataset_name = get_base_dataset_name(synth_dataset_name) + print(f" Base dataset: {base_dataset_name}") + print(f" Synthetic dataset: {synth_dataset_name}") + + base_dataset = load_dataset(base_dataset_name, is_synthetic=False) + synth_dataset = load_dataset(synth_dataset_name, is_synthetic=True) + print(" ✓ Datasets loaded") + + deffile = ENV.SYN_DATA_DEFINITIONS_DIR / f"{synth_dataset_name}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file(deffile) + label_mapping = dsdef.label_mapping + + # ========== STEP 2: Analyze Datasets ========== + print("\n[2/6] Analyzing datasets...") + print(" Analyzing real dataset...") + real_stats = analyze_dataset(base_dataset, is_synthetic=False, label_mapping=None) + print( + f" ✓ {real_stats['total_samples']} samples, {real_stats['total_entities']} entities" + ) + + print(" Analyzing synthetic dataset...") + synth_stats = analyze_dataset( + synth_dataset, is_synthetic=True, label_mapping=label_mapping + ) + print( + f" ✓ {synth_stats['total_samples']} samples, {synth_stats['total_entities']} entities" + ) + + # ========== STEP 3: Summary Statistics ========== + print("\n[3/6] Computing summary statistics...") + print_summary_statistics(real_stats, synth_stats) + + # ========== STEP 4: Statistical Tests ========== + print("\n[4/6] Running statistical tests...") + stats_df = compute_statistical_tests(real_stats, synth_stats) + print("\n" + stats_df.to_string(index=False)) + stats_df.to_csv( + f"{ENV.KIE_GT_ANALYZATION_DIR / synth_dataset_name}_statistics.csv", index=False + ) + print( + f"\n ✓ Saved: {ENV.KIE_GT_ANALYZATION_DIR / synth_dataset_name}_statistics.csv" + ) + + # Divergence metrics + divergence = compute_jensen_shannon_divergence(real_stats, synth_stats) + print(f"\n Distribution Similarity:") + print(f" Jensen-Shannon Divergence: {divergence['overall_js_divergence']:.4f}") + print(f" (Lower is better, 0 = identical, 1 = completely different)") + + # Spatial metrics + spatial_metrics = compute_spatial_coverage_metrics(real_stats, synth_stats) + print(f"\n Spatial Distribution (Centroid Distances):") + for entity_class, metrics in spatial_metrics.items(): + print(f" {entity_class}: {metrics['centroid_distance']:.4f}") + + # ========== STEP 5: Generate Visualizations ========== + print("\n[5/6] Generating visualizations...") + + print(" Creating distribution comparison plots...") + plot_entity_distribution_comparison(real_stats, synth_stats, synth_dataset_name) + + print(" Creating spatial heatmaps...") + plot_spatial_heatmaps(real_stats, synth_stats, synth_dataset_name) + + print(" Creating entity length comparison plots...") + plot_entity_length_comparison(real_stats, synth_stats, synth_dataset_name) + + print(" Creating co-occurrence matrices...") + plot_entity_co_occurrence_matrix( + real_stats, + synth_stats, + base_dataset.train, + synth_dataset.train, + synth_dataset_name, + ) + + print(" Creating document-level statistics...") + plot_document_level_statistics(real_stats, synth_stats, synth_dataset_name) + + # ========== STEP 6: Generate Paper Materials ========== + print("\n[6/6] Generating paper materials...") + generate_latex_table(stats_df, divergence, synth_dataset_name) + + # ========== Summary Report ========== + print("\n" + "=" * 80) + print("COMPARISON COMPLETE!") + print("=" * 80) + print("\nGenerated files:") + print(f" • {synth_dataset_name}_statistics.csv") + print(f" • {synth_dataset_name}_distribution_comparison.png") + print(f" • {synth_dataset_name}_spatial_heatmaps.png") + print(f" • {synth_dataset_name}_entity_length_comparison.png") + print(f" • {synth_dataset_name}_cooccurrence_matrix.png") + print(f" • {synth_dataset_name}_document_statistics.png") + print(f" • {synth_dataset_name}_table.tex") + + print("\nKey Findings:") + print( + f" • Dataset sizes: {real_stats['total_samples']} real vs {synth_stats['total_samples']} synthetic" + ) + print( + f" • Total entities: {real_stats['total_entities']} real vs {synth_stats['total_entities']} synthetic" + ) + print( + f" • Distribution similarity (JS): {divergence['overall_js_divergence']:.4f}" + ) + print( + f" • Average entities/doc: {real_stats['total_entities'] / real_stats['total_samples']:.2f} real vs " + f"{synth_stats['total_entities'] / synth_stats['total_samples']:.2f} synthetic" + ) + + # Identify classes with significant differences + sig_diff_classes = stats_df[stats_df["Mann-Whitney p-value"] < 0.05][ + "Entity Class" + ].tolist() + if sig_diff_classes: + print( + f" • Significantly different classes (p<0.05): {', '.join(sig_diff_classes)}" + ) + else: + print(f" • No significantly different classes detected (p<0.05)") + + print("\n" + "=" * 80) + + return { + "real_stats": real_stats, + "synth_stats": synth_stats, + "statistics_df": stats_df, + "divergence": divergence, + "spatial_metrics": spatial_metrics, + } + + +def parse_args(): + parser = argparse.ArgumentParser( + description="KIE GT Comparison", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "synthdataset", + type=str, + help="Name of the synthetic dataset", + ) + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # Run the complete pipeline + args = parse_args() + synth_dataset_name = args.synthdataset + + results = full_comparison_pipeline( + synth_dataset_name=synth_dataset_name, + ) + + # Access results for further analysis + print("\nResults dictionary contains:") + print(f" - real_stats: {list(results['real_stats'].keys())}") + print(f" - synth_stats: {list(results['synth_stats'].keys())}") + print(f" - statistics_df: {results['statistics_df'].shape}") diff --git a/docgenie/analyzation/gt/kie/kie_gt_analysis_utils.py b/docgenie/analyzation/gt/kie/kie_gt_analysis_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..5efb76186a3c5e42561983cbc1df92d9778b3d54 --- /dev/null +++ b/docgenie/analyzation/gt/kie/kie_gt_analysis_utils.py @@ -0,0 +1,428 @@ +""" +Supplementary analysis utilities for KIE GT comparison +Additional metrics for CVPR paper +""" + +import numpy as np +import matplotlib.pyplot as plt +from collections import defaultdict +from typing import Dict, List, Tuple +import seaborn as sns +from scipy.spatial.distance import jensenshannon +from scipy.stats import entropy + +from docgenie import ENV + + +def compute_jensen_shannon_divergence( + real_stats: Dict, synth_stats: Dict +) -> Dict[str, float]: + """ + Compute Jensen-Shannon divergence for entity distributions. + Lower is better (0 = identical distributions, 1 = completely different). + """ + all_classes = sorted( + set( + list(real_stats["entity_counts"].keys()) + + list(synth_stats["entity_counts"].keys()) + ) + ) + + real_counts = np.array( + [real_stats["entity_counts"].get(cls, 0) for cls in all_classes] + ) + synth_counts = np.array( + [synth_stats["entity_counts"].get(cls, 0) for cls in all_classes] + ) + + # Normalize to probabilities + real_probs = ( + real_counts / real_counts.sum() if real_counts.sum() > 0 else real_counts + ) + synth_probs = ( + synth_counts / synth_counts.sum() if synth_counts.sum() > 0 else synth_counts + ) + + # JS divergence + js_div = jensenshannon(real_probs, synth_probs) + + return { + "overall_js_divergence": js_div, + "overall_kl_divergence": entropy(real_probs, synth_probs), + } + + +def compute_spatial_coverage_metrics(real_stats: Dict, synth_stats: Dict) -> Dict: + """ + Compute spatial coverage metrics comparing how well synthetic data + covers the spatial distribution of real data. + """ + results = {} + + for entity_class in real_stats["spatial_distributions"].keys(): + real_pos = real_stats["spatial_distributions"].get(entity_class, []) + synth_pos = synth_stats["spatial_distributions"].get(entity_class, []) + + if len(real_pos) == 0 or len(synth_pos) == 0: + continue + + real_x = np.array([p[0] for p in real_pos]) + real_y = np.array([p[1] for p in real_pos]) + synth_x = np.array([p[0] for p in synth_pos]) + synth_y = np.array([p[1] for p in synth_pos]) + + # Mean absolute difference in centroids + real_centroid = (real_x.mean(), real_y.mean()) + synth_centroid = (synth_x.mean(), synth_y.mean()) + centroid_distance = np.sqrt( + (real_centroid[0] - synth_centroid[0]) ** 2 + + (real_centroid[1] - synth_centroid[1]) ** 2 + ) + + # Standard deviation comparison + real_std = (real_x.std(), real_y.std()) + synth_std = (synth_x.std(), synth_y.std()) + std_diff = (abs(real_std[0] - synth_std[0]), abs(real_std[1] - synth_std[1])) + + results[entity_class] = { + "centroid_distance": centroid_distance, + "std_x_diff": std_diff[0], + "std_y_diff": std_diff[1], + "real_centroid": real_centroid, + "synth_centroid": synth_centroid, + "real_std": real_std, + "synth_std": synth_std, + } + + return results + + +def plot_entity_co_occurrence_matrix( + real_stats: Dict, synth_stats: Dict, samples_real, samples_synth, output_prefix: str +): + """ + Plot co-occurrence matrices showing which entities appear together in documents. + Useful for understanding document structure preservation. + """ + + # Build co-occurrence matrices + def build_cooccurrence(samples): + from itertools import combinations + + co_occur = defaultdict(int) + all_classes = set() + + for sample in samples: + annotation = sample.annotations[0] + word_labels_names = annotation.word_labels.name + + # Get unique entity classes in this sample + entities_in_sample = set() + for label in word_labels_names: + if label.startswith("B-") or label.startswith("I-"): + entity_class = label[2:] + entities_in_sample.add(entity_class) + all_classes.add(entity_class) + + # Count co-occurrences + for pair in combinations(sorted(entities_in_sample), 2): + co_occur[pair] += 1 + + return co_occur, sorted(all_classes) + + real_cooccur, real_classes = build_cooccurrence(samples_real) + synth_cooccur, synth_classes = build_cooccurrence(samples_synth) + + all_classes = sorted(set(real_classes + synth_classes)) + n = len(all_classes) + + # Build matrices + real_matrix = np.zeros((n, n)) + synth_matrix = np.zeros((n, n)) + + class_to_idx = {cls: idx for idx, cls in enumerate(all_classes)} + + for (cls1, cls2), count in real_cooccur.items(): + i, j = class_to_idx[cls1], class_to_idx[cls2] + real_matrix[i, j] = count + real_matrix[j, i] = count + + for (cls1, cls2), count in synth_cooccur.items(): + i, j = class_to_idx[cls1], class_to_idx[cls2] + synth_matrix[i, j] = count + synth_matrix[j, i] = count + + # Normalize + real_matrix = ( + real_matrix / real_matrix.sum() if real_matrix.sum() > 0 else real_matrix + ) + synth_matrix = ( + synth_matrix / synth_matrix.sum() if synth_matrix.sum() > 0 else synth_matrix + ) + + # Plot + fig, axes = plt.subplots(1, 3, figsize=(20, 6)) + + # Real + sns.heatmap( + real_matrix, + annot=False, + cmap="Blues", + ax=axes[0], + xticklabels=all_classes, + yticklabels=all_classes, + cbar_kws={"label": "Frequency"}, + ) + axes[0].set_title("Real Data Co-occurrence", fontsize=14, fontweight="bold") + axes[0].set_xticklabels(axes[0].get_xticklabels(), rotation=45, ha="right") + axes[0].set_yticklabels(axes[0].get_yticklabels(), rotation=0) + + # Synthetic + sns.heatmap( + synth_matrix, + annot=False, + cmap="Oranges", + ax=axes[1], + xticklabels=all_classes, + yticklabels=all_classes, + cbar_kws={"label": "Frequency"}, + ) + axes[1].set_title("Synthetic Data Co-occurrence", fontsize=14, fontweight="bold") + axes[1].set_xticklabels(axes[1].get_xticklabels(), rotation=45, ha="right") + axes[1].set_yticklabels(axes[1].get_yticklabels(), rotation=0) + + # Difference + diff_matrix = np.abs(real_matrix - synth_matrix) + sns.heatmap( + diff_matrix, + annot=False, + cmap="Reds", + ax=axes[2], + xticklabels=all_classes, + yticklabels=all_classes, + cbar_kws={"label": "Abs Difference"}, + ) + axes[2].set_title("Absolute Difference", fontsize=14, fontweight="bold") + axes[2].set_xticklabels(axes[2].get_xticklabels(), rotation=45, ha="right") + axes[2].set_yticklabels(axes[2].get_yticklabels(), rotation=0) + + plt.tight_layout() + plt.savefig( + ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_cooccurrence_matrix.png", + dpi=300, + bbox_inches="tight", + ) + print( + f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_cooccurrence_matrix.png" + ) + plt.close() + + +def plot_document_level_statistics( + real_stats: Dict, synth_stats: Dict, output_prefix: str +): + """ + Plot document-level statistics (entities per document, etc.). + """ + # Compute entities per document + real_entities_per_doc = [] + synth_entities_per_doc = [] + + for entity_class in real_stats["entity_counts_per_sample"].keys(): + real_entities_per_doc.extend( + real_stats["entity_counts_per_sample"][entity_class] + ) + + for entity_class in synth_stats["entity_counts_per_sample"].keys(): + synth_entities_per_doc.extend( + synth_stats["entity_counts_per_sample"][entity_class] + ) + + # Aggregate by document + n_docs_real = real_stats["total_samples"] + n_docs_synth = synth_stats["total_samples"] + + # Reshape data properly - sum across entity types per document + # We need to reorganize the per_sample data + all_classes = sorted( + set( + list(real_stats["entity_counts_per_sample"].keys()) + + list(synth_stats["entity_counts_per_sample"].keys()) + ) + ) + + # Get max document count + max_docs = max( + max( + [ + len(real_stats["entity_counts_per_sample"].get(cls, [])) + for cls in all_classes + ], + default=0, + ), + max( + [ + len(synth_stats["entity_counts_per_sample"].get(cls, [])) + for cls in all_classes + ], + default=0, + ), + ) + + real_total_per_doc = np.zeros(n_docs_real) + synth_total_per_doc = np.zeros(n_docs_synth) + + for entity_class in all_classes: + real_counts = real_stats["entity_counts_per_sample"].get(entity_class, []) + synth_counts = synth_stats["entity_counts_per_sample"].get(entity_class, []) + + real_total_per_doc[: len(real_counts)] += np.array(real_counts) + synth_total_per_doc[: len(synth_counts)] += np.array(synth_counts) + + # Plot + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Histogram + axes[0].hist(real_total_per_doc, bins=20, alpha=0.6, label="Real", density=True) + axes[0].hist( + synth_total_per_doc, bins=20, alpha=0.6, label="Synthetic", density=True + ) + axes[0].axvline( + real_total_per_doc.mean(), + color="blue", + linestyle="--", + linewidth=2, + label=f"Real μ={real_total_per_doc.mean():.1f}", + ) + axes[0].axvline( + synth_total_per_doc.mean(), + color="orange", + linestyle="--", + linewidth=2, + label=f"Synth μ={synth_total_per_doc.mean():.1f}", + ) + axes[0].set_xlabel("Total Entities per Document", fontsize=12) + axes[0].set_ylabel("Density", fontsize=12) + axes[0].set_title( + "Distribution of Entities per Document", fontsize=14, fontweight="bold" + ) + axes[0].legend() + axes[0].grid(axis="y", alpha=0.3) + + # Cumulative distribution + real_sorted = np.sort(real_total_per_doc) + synth_sorted = np.sort(synth_total_per_doc) + real_cdf = np.arange(1, len(real_sorted) + 1) / len(real_sorted) + synth_cdf = np.arange(1, len(synth_sorted) + 1) / len(synth_sorted) + + axes[1].plot(real_sorted, real_cdf, label="Real", linewidth=2) + axes[1].plot(synth_sorted, synth_cdf, label="Synthetic", linewidth=2) + axes[1].set_xlabel("Total Entities per Document", fontsize=12) + axes[1].set_ylabel("Cumulative Probability", fontsize=12) + axes[1].set_title("Cumulative Distribution", fontsize=14, fontweight="bold") + axes[1].legend() + axes[1].grid(alpha=0.3) + + plt.tight_layout() + plt.savefig( + ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_document_statistics.png", + dpi=300, + bbox_inches="tight", + ) + print( + f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_document_statistics.png" + ) + plt.close() + + +def generate_latex_table(stats_df, divergence_metrics, output_prefix: str): + """ + Generate LaTeX table for paper. + """ + latex_str = "\\begin{table}[t]\n" + latex_str += "\\centering\n" + latex_str += "\\caption{Comparison of Entity Distributions between Real and Synthetic Datasets}\n" + latex_str += "\\label{tab:entity_comparison}\n" + latex_str += "\\begin{tabular}{lrrrrr}\n" + latex_str += "\\toprule\n" + latex_str += ( + "Entity Class & Real & Synth & Real (\\%) & Synth (\\%) & p-value \\\\\n" + ) + latex_str += "\\midrule\n" + + total_real = stats_df["Real Count"].sum() + total_synth = stats_df["Synth Count"].sum() + + for _, row in stats_df.iterrows(): + entity = row["Entity Class"] + real_count = int(row["Real Count"]) + synth_count = int(row["Synth Count"]) + real_pct = (real_count / total_real * 100) if total_real > 0 else 0 + synth_pct = (synth_count / total_synth * 100) if total_synth > 0 else 0 + p_val = row["Mann-Whitney p-value"] + + p_str = f"{p_val:.4f}" if not np.isnan(p_val) else "---" + if not np.isnan(p_val) and p_val < 0.001: + p_str = "$<$0.001" + + latex_str += f"{entity} & {real_count} & {synth_count} & {real_pct:.1f} & {synth_pct:.1f} & {p_str} \\\\\n" + + latex_str += "\\midrule\n" + latex_str += f"Total & {total_real} & {total_synth} & 100.0 & 100.0 & --- \\\\\n" + latex_str += "\\bottomrule\n" + latex_str += "\\end{tabular}\n" + latex_str += "\\end{table}\n" + + # Add divergence metrics as separate note + latex_str += "\n% Divergence Metrics:\n" + latex_str += f"% JS Divergence: {divergence_metrics['overall_js_divergence']:.4f}\n" + latex_str += f"% KL Divergence: {divergence_metrics['overall_kl_divergence']:.4f}\n" + + with open(ENV.KIE_GT_ANALYZATION_DIR / f"{output_prefix}_table.tex", "w") as f: + f.write(latex_str) + + print(f"Saved: {ENV.KIE_GT_ANALYZATION_DIR / output_prefix}_table.tex") + print("\nLaTeX Table Preview:") + print(latex_str) + + +def comprehensive_analysis( + synth_dataset_name: str, + real_stats: Dict, + synth_stats: Dict, + real_samples, + synth_samples, + output_prefix: str, +): + """ + Run all supplementary analyses. + """ + print("\n" + "=" * 80) + print("ADVANCED ANALYSIS") + print("=" * 80) + + # JS divergence + print("\nComputing distribution divergence metrics...") + divergence = compute_jensen_shannon_divergence(real_stats, synth_stats) + print(f" Jensen-Shannon Divergence: {divergence['overall_js_divergence']:.4f}") + print(f" KL Divergence: {divergence['overall_kl_divergence']:.4f}") + + # Spatial coverage + print("\nComputing spatial coverage metrics...") + spatial_metrics = compute_spatial_coverage_metrics(real_stats, synth_stats) + print(" Centroid distances:") + for entity_class, metrics in spatial_metrics.items(): + print(f" {entity_class}: {metrics['centroid_distance']:.4f}") + + # Generate additional plots + print("\nGenerating additional visualizations...") + plot_entity_co_occurrence_matrix( + real_stats, synth_stats, real_samples, synth_samples, output_prefix + ) + plot_document_level_statistics(real_stats, synth_stats, output_prefix) + + print("\n" + "=" * 80) + + +if __name__ == "__main__": + ... diff --git a/docgenie/analyzation/gt/qa/qa_gt_analysis old.py b/docgenie/analyzation/gt/qa/qa_gt_analysis old.py new file mode 100755 index 0000000000000000000000000000000000000000..acf0c79e84e72fddd2354c895a40b3ad3a25112a --- /dev/null +++ b/docgenie/analyzation/gt/qa/qa_gt_analysis old.py @@ -0,0 +1,589 @@ +""" +Compare QA Ground Truth between Synthetic and Real Document Understanding Datasets + +This script compares question-answer pairs from synthetic and real datasets using: +1. Question type distribution analysis +2. Embedding similarity metrics (KL/JS divergence, MMD) +3. UMAP projection overlay quantification + +For CVPR paper on synthesis of document understanding datasets. +""" + +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import entropy +from scipy.spatial.distance import jensenshannon, cdist +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.neighbors import KernelDensity +from typing import Literal, Tuple, Dict, List +import re +from collections import Counter + +from docgenie import ENV +from docgenie.analyzation.gt.embeddings_qa import load_qa_embeddings +from docgenie.analyzation.gt.webapp import get_base_dataset_name + + +def extract_question_type(question: str) -> str: + """ + Extract question type based on starting word. + + Args: + question: Question string + + Returns: + Question type: 'who', 'what', 'when', 'where', 'why', 'how', or 'other' + """ + question = question.lower().strip() + + # Common question starters + if question.startswith("who"): + return "who" + elif question.startswith("what"): + return "what" + elif question.startswith("when"): + return "when" + elif question.startswith("where"): + return "where" + elif question.startswith("why"): + return "why" + elif question.startswith("how"): + return "how" + else: + return "other" + + +def compute_question_type_distribution(questions: List[str]) -> Dict[str, float]: + """ + Compute distribution of question types. + + Args: + questions: List of question strings + + Returns: + Dictionary mapping question type to ratio + """ + types = [extract_question_type(q) for q in questions] + counter = Counter(types) + total = len(types) + + distribution = { + "who": counter.get("who", 0) / total, + "what": counter.get("what", 0) / total, + "when": counter.get("when", 0) / total, + "where": counter.get("where", 0) / total, + "why": counter.get("why", 0) / total, + "how": counter.get("how", 0) / total, + "other": counter.get("other", 0) / total, + } + + return distribution + + +def compute_cosine_similarity_histogram( + emb1: np.ndarray, emb2: np.ndarray, n_bins: int = 50 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Compute histograms of cosine similarities within each dataset. + + Args: + emb1: Embeddings from dataset 1 (N1 x D) + emb2: Embeddings from dataset 2 (N2 x D) + n_bins: Number of histogram bins + + Returns: + hist1, hist2, bins1, bins2 + """ + # Compute pairwise cosine similarities within each dataset + # Sample if datasets are too large + max_samples = 5000 + if len(emb1) > max_samples: + idx1 = np.random.choice(len(emb1), max_samples, replace=False) + emb1_sample = emb1[idx1] + else: + emb1_sample = emb1 + + if len(emb2) > max_samples: + idx2 = np.random.choice(len(emb2), max_samples, replace=False) + emb2_sample = emb2[idx2] + else: + emb2_sample = emb2 + + # Compute cosine similarities + sim1 = cosine_similarity(emb1_sample) + sim2 = cosine_similarity(emb2_sample) + + # Get upper triangle (excluding diagonal) to avoid self-similarities + triu_idx = np.triu_indices_from(sim1, k=1) + sim1_values = sim1[triu_idx] + + triu_idx = np.triu_indices_from(sim2, k=1) + sim2_values = sim2[triu_idx] + + # Compute histograms on same bins + bins = np.linspace(-1, 1, n_bins + 1) + hist1, _ = np.histogram(sim1_values, bins=bins, density=True) + hist2, _ = np.histogram(sim2_values, bins=bins, density=True) + + # Normalize to get probability distributions + hist1 = hist1 / hist1.sum() + hist2 = hist2 / hist2.sum() + + return hist1, hist2, bins[:-1], bins + + +def compute_kl_divergence( + p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10 +) -> float: + """ + Compute KL divergence KL(P||Q). + + Args: + p: Probability distribution P + q: Probability distribution Q + epsilon: Small value to avoid log(0) + + Returns: + KL divergence value + """ + p = np.array(p) + epsilon + q = np.array(q) + epsilon + p = p / p.sum() + q = q / q.sum() + + return entropy(p, q) + + +def compute_js_divergence(p: np.ndarray, q: np.ndarray) -> float: + """ + Compute Jensen-Shannon divergence. + + Args: + p: Probability distribution P + q: Probability distribution Q + + Returns: + JS divergence value (0 to 1) + """ + return jensenshannon(p, q) + + +def compute_mmd_rbf(X: np.ndarray, Y: np.ndarray, gamma: float = None) -> float: + """ + Compute Maximum Mean Discrepancy with RBF kernel. + + Args: + X: Samples from distribution P (N1 x D) + Y: Samples from distribution Q (N2 x D) + gamma: RBF kernel bandwidth (if None, uses median heuristic) + + Returns: + MMD^2 value + """ + # Sample if datasets are too large + max_samples = 2000 + if len(X) > max_samples: + X = X[np.random.choice(len(X), max_samples, replace=False)] + if len(Y) > max_samples: + Y = Y[np.random.choice(len(Y), max_samples, replace=False)] + + # Use median heuristic for gamma if not provided + if gamma is None: + XY = np.vstack([X, Y]) + dists = cdist(XY, XY) + gamma = 1.0 / (2 * np.median(dists[dists > 0]) ** 2) + + def rbf_kernel(X, Y, gamma): + """RBF kernel matrix.""" + XX = np.sum(X**2, axis=1)[:, np.newaxis] + YY = np.sum(Y**2, axis=1)[np.newaxis, :] + XY = X @ Y.T + dists_sq = XX + YY - 2 * XY + return np.exp(-gamma * dists_sq) + + K_XX = rbf_kernel(X, X, gamma) + K_YY = rbf_kernel(Y, Y, gamma) + K_XY = rbf_kernel(X, Y, gamma) + + m = len(X) + n = len(Y) + + # MMD^2 estimator + mmd_sq = (K_XX.sum() - np.trace(K_XX)) / (m * (m - 1)) + mmd_sq += (K_YY.sum() - np.trace(K_YY)) / (n * (n - 1)) + mmd_sq -= 2 * K_XY.mean() + + return mmd_sq + + +def compute_umap_overlay_metrics( + umap_real: np.ndarray, umap_synth: np.ndarray +) -> Dict[str, float]: + """ + Quantify the quality of UMAP projection overlay. + + Args: + umap_real: UMAP 2D projections of real data (N1 x 2) + umap_synth: UMAP 2D projections of synthetic data (N2 x 2) + + Returns: + Dictionary with overlay quality metrics + """ + metrics = {} + + # 1. Wasserstein distance (Earth Mover's Distance) + from scipy.stats import wasserstein_distance + + # Compute 1D Wasserstein on each dimension + w_dist_x = wasserstein_distance(umap_real[:, 0], umap_synth[:, 0]) + w_dist_y = wasserstein_distance(umap_real[:, 1], umap_synth[:, 1]) + metrics["wasserstein_x"] = w_dist_x + metrics["wasserstein_y"] = w_dist_y + metrics["wasserstein_avg"] = (w_dist_x + w_dist_y) / 2 + + # 2. 2D Wasserstein (using optimal transport if available) + try: + import ot + + # Normalize to uniform weights + a = np.ones(len(umap_real)) / len(umap_real) + b = np.ones(len(umap_synth)) / len(umap_synth) + M = ot.dist(umap_real, umap_synth, metric="euclidean") + w_dist_2d = ot.emd2(a, b, M) + metrics["wasserstein_2d"] = w_dist_2d + except ImportError: + print("Note: Python Optimal Transport (POT) not available for 2D Wasserstein") + + # 3. KL divergence of 2D density estimates + # Estimate densities using KDE + kde_real = KernelDensity(bandwidth=0.5, kernel="gaussian") + kde_synth = KernelDensity(bandwidth=0.5, kernel="gaussian") + + kde_real.fit(umap_real) + kde_synth.fit(umap_synth) + + # Create grid for density evaluation + x_min = min(umap_real[:, 0].min(), umap_synth[:, 0].min()) + x_max = max(umap_real[:, 0].max(), umap_synth[:, 0].max()) + y_min = min(umap_real[:, 1].min(), umap_synth[:, 1].min()) + y_max = max(umap_real[:, 1].max(), umap_synth[:, 1].max()) + + x_grid = np.linspace(x_min, x_max, 50) + y_grid = np.linspace(y_min, y_max, 50) + X_grid, Y_grid = np.meshgrid(x_grid, y_grid) + grid_points = np.column_stack([X_grid.ravel(), Y_grid.ravel()]) + + # Evaluate densities + log_dens_real = kde_real.score_samples(grid_points) + log_dens_synth = kde_synth.score_samples(grid_points) + + dens_real = np.exp(log_dens_real) + dens_synth = np.exp(log_dens_synth) + + # Normalize + dens_real = dens_real / dens_real.sum() + dens_synth = dens_synth / dens_synth.sum() + + # Compute KL and JS divergence + metrics["kl_divergence_2d"] = compute_kl_divergence(dens_real, dens_synth) + metrics["js_divergence_2d"] = compute_js_divergence(dens_real, dens_synth) + + # 4. Chamfer distance (average nearest neighbor distance) + from scipy.spatial import distance_matrix + + # Real to Synth + dists_r2s = distance_matrix(umap_real, umap_synth) + chamfer_r2s = dists_r2s.min(axis=1).mean() + + # Synth to Real + dists_s2r = distance_matrix(umap_synth, umap_real) + chamfer_s2r = dists_s2r.min(axis=0).mean() + + metrics["chamfer_real_to_synth"] = chamfer_r2s + metrics["chamfer_synth_to_real"] = chamfer_s2r + metrics["chamfer_symmetric"] = (chamfer_r2s + chamfer_s2r) / 2 + + # 5. Coverage metric (what fraction of real data is "covered" by synth) + # Define "coverage" as having a synthetic point within threshold distance + threshold = np.percentile(dists_r2s.min(axis=1), 95) # Adaptive threshold + coverage = (dists_r2s.min(axis=1) < threshold).mean() + metrics["coverage"] = coverage + + return metrics + + +def plot_comparison_results( + real_dist: Dict[str, float], + synth_dist: Dict[str, float], + hist1: np.ndarray, + hist2: np.ndarray, + bins: np.ndarray, + umap_real: np.ndarray, + umap_synth: np.ndarray, + save_path: str = None, +): + """ + Create visualization of comparison results. + + Args: + real_dist: Question type distribution for real data + synth_dist: Question type distribution for synthetic data + hist1: Cosine similarity histogram for real data + hist2: Cosine similarity histogram for synthetic data + bins: Histogram bins + umap_real: UMAP projections of real data + umap_synth: UMAP projections of synthetic data + save_path: Path to save figure (if None, shows plot) + """ + fig, axes = plt.subplots(2, 2, figsize=(15, 12)) + + # 1. Question type distribution + ax = axes[0, 0] + question_types = list(real_dist.keys()) + x = np.arange(len(question_types)) + width = 0.35 + + ax.bar( + x - width / 2, + [real_dist[t] for t in question_types], + width, + label="Real", + alpha=0.8, + ) + ax.bar( + x + width / 2, + [synth_dist[t] for t in question_types], + width, + label="Synthetic", + alpha=0.8, + ) + + ax.set_xlabel("Question Type") + ax.set_ylabel("Ratio") + ax.set_title("Question Type Distribution") + ax.set_xticks(x) + ax.set_xticklabels(question_types) + ax.legend() + ax.grid(axis="y", alpha=0.3) + + # 2. Cosine similarity histograms + ax = axes[0, 1] + bin_centers = (bins[:-1] + bins[1:]) / 2 + ax.plot(bin_centers, hist1, label="Real", alpha=0.7, linewidth=2) + ax.plot(bin_centers, hist2, label="Synthetic", alpha=0.7, linewidth=2) + ax.fill_between(bin_centers, hist1, alpha=0.3) + ax.fill_between(bin_centers, hist2, alpha=0.3) + + ax.set_xlabel("Cosine Similarity") + ax.set_ylabel("Density") + ax.set_title("Pairwise Cosine Similarity Distribution") + ax.legend() + ax.grid(alpha=0.3) + + # 3. UMAP overlay - side by side + ax = axes[1, 0] + ax.scatter( + umap_real[:, 0], umap_real[:, 1], c="blue", alpha=0.3, s=10, label="Real" + ) + ax.scatter( + umap_synth[:, 0], umap_synth[:, 1], c="red", alpha=0.3, s=10, label="Synthetic" + ) + ax.set_xlabel("UMAP 1") + ax.set_ylabel("UMAP 2") + ax.set_title("UMAP Projection Overlay") + ax.legend() + ax.grid(alpha=0.3) + + # 4. UMAP density contours + ax = axes[1, 1] + from scipy.stats import gaussian_kde + + # Compute KDE for contours + kde_real = gaussian_kde(umap_real.T) + kde_synth = gaussian_kde(umap_synth.T) + + # Create grid + x_min = min(umap_real[:, 0].min(), umap_synth[:, 0].min()) + x_max = max(umap_real[:, 0].max(), umap_synth[:, 0].max()) + y_min = min(umap_real[:, 1].min(), umap_synth[:, 1].min()) + y_max = max(umap_real[:, 1].max(), umap_synth[:, 1].max()) + + xx, yy = np.mgrid[x_min:x_max:100j, y_min:y_max:100j] + positions = np.vstack([xx.ravel(), yy.ravel()]) + + z_real = np.reshape(kde_real(positions).T, xx.shape) + z_synth = np.reshape(kde_synth(positions).T, xx.shape) + + ax.contour(xx, yy, z_real, colors="blue", alpha=0.6, levels=5) + ax.contour(xx, yy, z_synth, colors="red", alpha=0.6, levels=5, linestyles="dashed") + + ax.set_xlabel("UMAP 1") + ax.set_ylabel("UMAP 2") + ax.set_title("UMAP Density Contours (Real: solid, Synth: dashed)") + ax.grid(alpha=0.3) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"Figure saved to {save_path}") + else: + plt.show() + + +def compare_qa_datasets( + synth_dataset_name: str, + embedding_type: Literal["Q", "QA"], +) -> Dict[str, any]: + """ + Main function to compare QA datasets. + + Args: + synth_dataset_name: Name of synthetic dataset + embedding_type: Type of embeddings to use + load_qa_embeddings: Function to load embeddings + get_base_dataset_name: Function to get base dataset name + umap_real: Optional pre-computed UMAP projections for real data + umap_synth: Optional pre-computed UMAP projections for synthetic data + + Returns: + Dictionary with all comparison metrics + """ + print("=" * 80) + print("QA Dataset Comparison") + print("=" * 80) + + # Load data + print(f"\nLoading datasets...") + base_dataset_name = get_base_dataset_name(synth_dataset_name) + print(f"Base dataset: {base_dataset_name}") + print(f"Synthetic dataset: {synth_dataset_name}") + + real_emb, real_q, real_a, real_sample_ids, real_doc_ids = load_qa_embeddings( + base_dataset_name, embedding_type + ) + synth_emb, synth_q, synth_a, synth_sample_ids, synth_doc_ids = load_qa_embeddings( + synth_dataset_name, embedding_type + ) + + print(f"Real dataset: {len(real_q)} QA pairs") + print(f"Synthetic dataset: {len(synth_q)} QA pairs") + + results = {} + + # 1. Question type distribution + print("\n" + "-" * 80) + print("1. Question Type Distribution") + print("-" * 80) + + real_dist = compute_question_type_distribution(real_q) + synth_dist = compute_question_type_distribution(synth_q) + + print("\nReal data:") + for qtype, ratio in real_dist.items(): + print(f" {qtype:8s}: {ratio:6.2%}") + + print("\nSynthetic data:") + for qtype, ratio in synth_dist.items(): + print(f" {qtype:8s}: {ratio:6.2%}") + + # Compute distribution divergence + types_ordered = ["who", "what", "when", "where", "why", "how", "other"] + real_dist_arr = np.array([real_dist[t] for t in types_ordered]) + synth_dist_arr = np.array([synth_dist[t] for t in types_ordered]) + + qtype_kl = compute_kl_divergence(real_dist_arr, synth_dist_arr) + qtype_js = compute_js_divergence(real_dist_arr, synth_dist_arr) + + print(f"\nKL divergence (Real||Synth): {qtype_kl:.4f}") + print(f"JS divergence: {qtype_js:.4f}") + + results["question_type_real"] = real_dist + results["question_type_synth"] = synth_dist + results["question_type_kl"] = qtype_kl + results["question_type_js"] = qtype_js + + # 2. Cosine similarity histograms + print("\n" + "-" * 80) + print("2. Embedding Similarity Distribution") + print("-" * 80) + + hist_real, hist_synth, bin_centers, bins = compute_cosine_similarity_histogram( + real_emb, synth_emb, n_bins=50 + ) + + sim_kl = compute_kl_divergence(hist_real, hist_synth) + sim_js = compute_js_divergence(hist_real, hist_synth) + + print(f"KL divergence (Real||Synth): {sim_kl:.4f}") + print(f"JS divergence: {sim_js:.4f}") + + results["similarity_hist_kl"] = sim_kl + results["similarity_hist_js"] = sim_js + results["similarity_hist_real"] = hist_real + results["similarity_hist_synth"] = hist_synth + results["similarity_bins"] = bin_centers + + # 3. MMD + print("\n" + "-" * 80) + print("3. Maximum Mean Discrepancy (MMD)") + print("-" * 80) + + mmd_value = compute_mmd_rbf(real_emb, synth_emb) + print(f"MMD² (RBF kernel): {mmd_value:.6f}") + + results["mmd_rbf"] = mmd_value + + # 4. UMAP overlay metrics (if provided) + projection_method = "umap" + projection_cache_path = ( + ENV.QA_GT_WEBAPP_CACHE_DIR + / f"projection_{projection_method}_{synth_dataset_name}_{embedding_type}.npy" + ) + projection_2d = np.load(projection_cache_path) + umap_real = projection_2d[: len(real_emb)] + umap_synth = projection_2d[len(real_emb) :] + + if umap_real is not None and umap_synth is not None: + print("\n" + "-" * 80) + print("4. UMAP Projection Overlay Metrics") + print("-" * 80) + + overlay_metrics = compute_umap_overlay_metrics(umap_real, umap_synth) + + for metric_name, value in overlay_metrics.items(): + print(f"{metric_name:30s}: {value:.6f}") + + results["umap_overlay"] = overlay_metrics + + # Create visualization + plot_comparison_results( + real_dist, + synth_dist, + hist_real, + hist_synth, + bins, + umap_real, + umap_synth, + save_path=ENV.QA_GT_ANALYZATION_DIR / "qa_comparison.png", + ) + + print("\n" + "=" * 80) + print("Comparison complete!") + print("=" * 80) + + return results + + +# Example usage +if __name__ == "__main__": + synth_dataset_name = "wtq_alpha=1.0" + + # Run comparison + results = compare_qa_datasets( + synth_dataset_name=synth_dataset_name, embedding_type="Q" + ) + + print("\n\nResults summary available in 'results' dictionary:") + print(results) + print("Visualization saved to 'qa_comparison.png'") diff --git a/docgenie/analyzation/gt/qa/qa_gt_analysis.py b/docgenie/analyzation/gt/qa/qa_gt_analysis.py new file mode 100755 index 0000000000000000000000000000000000000000..6e2fda40fa97d51a5ae814fd23af66c58046e86f --- /dev/null +++ b/docgenie/analyzation/gt/qa/qa_gt_analysis.py @@ -0,0 +1,576 @@ +""" +Compare QA Ground Truth between Synthetic and Real Document Understanding Datasets + +This script compares question-answer pairs from synthetic and real datasets using: +1. Question type distribution analysis +2. Embedding similarity metrics (KL/JS divergence, MMD) +3. UMAP projection overlay quantification + +For CVPR paper on synthesis of document understanding datasets. +""" + +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import entropy +from scipy.spatial.distance import jensenshannon, cdist +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.neighbors import KernelDensity +from typing import Literal, Tuple, Dict, List +import re +from collections import Counter + +from docgenie import ENV +from docgenie.analyzation.gt.embeddings_qa import load_qa_embeddings +from docgenie.analyzation.gt.webapp import get_base_dataset_name + + +def extract_question_type(question: str) -> str: + """ + Extract question type based on starting word. + + Args: + question: Question string + + Returns: + Question type: 'who', 'what', 'when', 'where', 'why', 'how', or 'other' + """ + question = question.lower().strip() + + # Common question starters + if question.startswith("who"): + return "who" + elif question.startswith("what"): + return "what" + elif question.startswith("when"): + return "when" + elif question.startswith("where"): + return "where" + elif question.startswith("why"): + return "why" + elif question.startswith("how"): + return "how" + else: + return "other" + + +def compute_question_type_distribution(questions: List[str]) -> Dict[str, float]: + """ + Compute distribution of question types. + + Args: + questions: List of question strings + + Returns: + Dictionary mapping question type to ratio + """ + types = [extract_question_type(q) for q in questions] + counter = Counter(types) + total = len(types) + + distribution = { + "who": counter.get("who", 0) / total, + "what": counter.get("what", 0) / total, + "when": counter.get("when", 0) / total, + "where": counter.get("where", 0) / total, + "why": counter.get("why", 0) / total, + "how": counter.get("how", 0) / total, + "other": counter.get("other", 0) / total, + } + + return distribution + + +def compute_cosine_similarity_histogram( + emb1: np.ndarray, emb2: np.ndarray, n_bins: int = 50 +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Compute histograms of cosine similarities within each dataset. + + Args: + emb1: Embeddings from dataset 1 (N1 x D) + emb2: Embeddings from dataset 2 (N2 x D) + n_bins: Number of histogram bins + + Returns: + hist1, hist2, bins1, bins2 + """ + # Compute pairwise cosine similarities within each dataset + # Sample if datasets are too large + max_samples = 5000 + if len(emb1) > max_samples: + idx1 = np.random.choice(len(emb1), max_samples, replace=False) + emb1_sample = emb1[idx1] + else: + emb1_sample = emb1 + + if len(emb2) > max_samples: + idx2 = np.random.choice(len(emb2), max_samples, replace=False) + emb2_sample = emb2[idx2] + else: + emb2_sample = emb2 + + # Compute cosine similarities + sim1 = cosine_similarity(emb1_sample) + sim2 = cosine_similarity(emb2_sample) + + # Get upper triangle (excluding diagonal) to avoid self-similarities + triu_idx = np.triu_indices_from(sim1, k=1) + sim1_values = sim1[triu_idx] + + triu_idx = np.triu_indices_from(sim2, k=1) + sim2_values = sim2[triu_idx] + + # Compute histograms on same bins + bins = np.linspace(-1, 1, n_bins + 1) + hist1, _ = np.histogram(sim1_values, bins=bins, density=True) + hist2, _ = np.histogram(sim2_values, bins=bins, density=True) + + # Normalize to get probability distributions + hist1 = hist1 / hist1.sum() + hist2 = hist2 / hist2.sum() + + return hist1, hist2, bins[:-1], bins + + +def compute_kl_divergence( + p: np.ndarray, q: np.ndarray, epsilon: float = 1e-10 +) -> float: + """ + Compute KL divergence KL(P||Q). + + Args: + p: Probability distribution P + q: Probability distribution Q + epsilon: Small value to avoid log(0) + + Returns: + KL divergence value + """ + p = np.array(p) + epsilon + q = np.array(q) + epsilon + p = p / p.sum() + q = q / q.sum() + + return entropy(p, q) + + +def compute_js_divergence(p: np.ndarray, q: np.ndarray) -> float: + """ + Compute Jensen-Shannon divergence. + + Args: + p: Probability distribution P + q: Probability distribution Q + + Returns: + JS divergence value (0 to 1) + """ + return jensenshannon(p, q) + + +def compute_mmd_rbf(X: np.ndarray, Y: np.ndarray, gamma: float = None) -> float: + """ + Compute Maximum Mean Discrepancy with RBF kernel. + + Args: + X: Samples from distribution P (N1 x D) + Y: Samples from distribution Q (N2 x D) + gamma: RBF kernel bandwidth (if None, uses median heuristic) + + Returns: + MMD^2 value + """ + # Sample if datasets are too large + max_samples = 2000 + if len(X) > max_samples: + X = X[np.random.choice(len(X), max_samples, replace=False)] + if len(Y) > max_samples: + Y = Y[np.random.choice(len(Y), max_samples, replace=False)] + + # Use median heuristic for gamma if not provided + if gamma is None: + XY = np.vstack([X, Y]) + dists = cdist(XY, XY) + gamma = 1.0 / (2 * np.median(dists[dists > 0]) ** 2) + + def rbf_kernel(X, Y, gamma): + """RBF kernel matrix.""" + XX = np.sum(X**2, axis=1)[:, np.newaxis] + YY = np.sum(Y**2, axis=1)[np.newaxis, :] + XY = X @ Y.T + dists_sq = XX + YY - 2 * XY + return np.exp(-gamma * dists_sq) + + K_XX = rbf_kernel(X, X, gamma) + K_YY = rbf_kernel(Y, Y, gamma) + K_XY = rbf_kernel(X, Y, gamma) + + m = len(X) + n = len(Y) + + # MMD^2 estimator + mmd_sq = (K_XX.sum() - np.trace(K_XX)) / (m * (m - 1)) + mmd_sq += (K_YY.sum() - np.trace(K_YY)) / (n * (n - 1)) + mmd_sq -= 2 * K_XY.mean() + + return mmd_sq + + +def compute_umap_overlay_metrics( + umap_real: np.ndarray, umap_synth: np.ndarray +) -> Dict[str, float]: + """ + Quantify the quality of UMAP projection overlay. + + Args: + umap_real: UMAP 2D projections of real data (N1 x 2) + umap_synth: UMAP 2D projections of synthetic data (N2 x 2) + + Returns: + Dictionary with overlay quality metrics + """ + metrics = {} + + # 1. Wasserstein distance (Earth Mover's Distance) + from scipy.stats import wasserstein_distance + + # Compute 1D Wasserstein on each dimension + w_dist_x = wasserstein_distance(umap_real[:, 0], umap_synth[:, 0]) + w_dist_y = wasserstein_distance(umap_real[:, 1], umap_synth[:, 1]) + metrics["wasserstein_x"] = w_dist_x + metrics["wasserstein_y"] = w_dist_y + metrics["wasserstein_avg"] = (w_dist_x + w_dist_y) / 2 + + # 2. 2D Wasserstein (using optimal transport if available) + try: + import ot + + # Normalize to uniform weights + a = np.ones(len(umap_real)) / len(umap_real) + b = np.ones(len(umap_synth)) / len(umap_synth) + M = ot.dist(umap_real, umap_synth, metric="euclidean") + w_dist_2d = ot.emd2(a, b, M) + metrics["wasserstein_2d"] = w_dist_2d + except ImportError: + print("Note: Python Optimal Transport (POT) not available for 2D Wasserstein") + + # 3. KL divergence of 2D density estimates + # Estimate densities using KDE + kde_real = KernelDensity(bandwidth=0.5, kernel="gaussian") + kde_synth = KernelDensity(bandwidth=0.5, kernel="gaussian") + + kde_real.fit(umap_real) + kde_synth.fit(umap_synth) + + # Create grid for density evaluation + x_min = min(umap_real[:, 0].min(), umap_synth[:, 0].min()) + x_max = max(umap_real[:, 0].max(), umap_synth[:, 0].max()) + y_min = min(umap_real[:, 1].min(), umap_synth[:, 1].min()) + y_max = max(umap_real[:, 1].max(), umap_synth[:, 1].max()) + + x_grid = np.linspace(x_min, x_max, 50) + y_grid = np.linspace(y_min, y_max, 50) + X_grid, Y_grid = np.meshgrid(x_grid, y_grid) + grid_points = np.column_stack([X_grid.ravel(), Y_grid.ravel()]) + + # Evaluate densities + log_dens_real = kde_real.score_samples(grid_points) + log_dens_synth = kde_synth.score_samples(grid_points) + + dens_real = np.exp(log_dens_real) + dens_synth = np.exp(log_dens_synth) + + # Normalize + dens_real = dens_real / dens_real.sum() + dens_synth = dens_synth / dens_synth.sum() + + # Compute KL and JS divergence + metrics["kl_divergence_2d"] = compute_kl_divergence(dens_real, dens_synth) + metrics["js_divergence_2d"] = compute_js_divergence(dens_real, dens_synth) + + # 4. Chamfer distance (average nearest neighbor distance) + from scipy.spatial import distance_matrix + + # Real to Synth + dists_r2s = distance_matrix(umap_real, umap_synth) + chamfer_r2s = dists_r2s.min(axis=1).mean() + + # Synth to Real + dists_s2r = distance_matrix(umap_synth, umap_real) + chamfer_s2r = dists_s2r.min(axis=0).mean() + + metrics["chamfer_real_to_synth"] = chamfer_r2s + metrics["chamfer_synth_to_real"] = chamfer_s2r + metrics["chamfer_symmetric"] = (chamfer_r2s + chamfer_s2r) / 2 + + # 5. Coverage metric (what fraction of real data is "covered" by synth) + # Define "coverage" as having a synthetic point within threshold distance + threshold = np.percentile(dists_r2s.min(axis=1), 95) # Adaptive threshold + coverage = (dists_r2s.min(axis=1) < threshold).mean() + metrics["coverage"] = coverage + + return metrics + + +def plot_comparison_results( + real_dist: Dict[str, float], + synth_dist: Dict[str, float], + hist1: np.ndarray, + hist2: np.ndarray, + bins: np.ndarray, + umap_real: np.ndarray, + umap_synth: np.ndarray, + save_path: str = None, +): + """ + Create visualization of comparison results. + + Args: + real_dist: Question type distribution for real data + synth_dist: Question type distribution for synthetic data + hist1: Cosine similarity histogram for real data + hist2: Cosine similarity histogram for synthetic data + bins: Histogram bins + umap_real: UMAP projections of real data + umap_synth: UMAP projections of synthetic data + save_path: Path to save figure (if None, shows plot) + """ + + # 1. UMAP overlay + fig1, ax = plt.subplots(1, 1, figsize=(7, 6)) + ax.scatter( + umap_real[:, 0], umap_real[:, 1], c="blue", alpha=0.3, s=10, label="Real" + ) + ax.scatter( + umap_synth[:, 0], umap_synth[:, 1], c="red", alpha=0.3, s=10, label="Synthetic" + ) + ax.legend() + ax.grid(alpha=0.3) + ax.set_xticks([]) + ax.set_yticks([]) + + # Remove borders + for spine in ax.spines.values(): + spine.set_visible(False) + + plt.tight_layout() + + if save_path: + save_path = str(save_path) + # Create filename for UMAP plot + base_path = save_path.rsplit(".", 1)[0] + ext = save_path.rsplit(".", 1)[1] if "." in save_path else "pdf" + umap_path = f"{base_path}_umap.png" # PDFs have too many elements + plt.savefig(umap_path, dpi=300, bbox_inches="tight") + print(f"UMAP figure saved to {umap_path}") + else: + plt.show() + + plt.close(fig1) + + # 2. Question type distribution + fig2, ax = plt.subplots(1, 1, figsize=(7, 6)) + question_types = list(real_dist.keys()) + x = np.arange(len(question_types)) + width = 0.35 + + ax.bar( + x - width / 2, + [real_dist[t] for t in question_types], + width, + label="Real", + alpha=0.8, + color="blue", + ) + ax.bar( + x + width / 2, + [synth_dist[t] for t in question_types], + width, + label="Synthetic", + alpha=0.8, + color="red", + ) + + ax.set_xlabel("Question Type") + ax.set_ylabel("Ratio") + ax.set_xticks(x) + ax.set_xticklabels(question_types) + ax.legend() + ax.grid(axis="y", alpha=0.3) + + # Remove borders + for spine in ax.spines.values(): + spine.set_visible(False) + + plt.tight_layout() + + if save_path: + # Create filename for question types plot + base_path = save_path.rsplit(".", 1)[0] + ext = save_path.rsplit(".", 1)[1] if "." in save_path else "pdf" + qt_path = f"{base_path}_question_types.{ext}" + plt.savefig(qt_path, dpi=300, bbox_inches="tight") + print(f"Question types figure saved to {qt_path}") + else: + plt.show() + + plt.close(fig2) + + +def compare_qa_datasets( + synth_dataset_name: str, + embedding_type: Literal["Q", "QA"], +) -> Dict[str, any]: + """ + Main function to compare QA datasets. + + Args: + synth_dataset_name: Name of synthetic dataset + embedding_type: Type of embeddings to use + load_qa_embeddings: Function to load embeddings + get_base_dataset_name: Function to get base dataset name + umap_real: Optional pre-computed UMAP projections for real data + umap_synth: Optional pre-computed UMAP projections for synthetic data + + Returns: + Dictionary with all comparison metrics + """ + print("=" * 80) + print("QA Dataset Comparison") + print("=" * 80) + + # Load data + print(f"\nLoading datasets...") + base_dataset_name = get_base_dataset_name(synth_dataset_name) + print(f"Base dataset: {base_dataset_name}") + print(f"Synthetic dataset: {synth_dataset_name}") + + real_emb, real_q, real_a, real_sample_ids, real_doc_ids = load_qa_embeddings( + base_dataset_name, embedding_type + ) + synth_emb, synth_q, synth_a, synth_sample_ids, synth_doc_ids = load_qa_embeddings( + synth_dataset_name, embedding_type + ) + + print(f"Real dataset: {len(real_q)} QA pairs") + print(f"Synthetic dataset: {len(synth_q)} QA pairs") + + results = {} + + # 1. Question type distribution + print("\n" + "-" * 80) + print("1. Question Type Distribution") + print("-" * 80) + + real_dist = compute_question_type_distribution(real_q) + synth_dist = compute_question_type_distribution(synth_q) + + print("\nReal data:") + for qtype, ratio in real_dist.items(): + print(f" {qtype:8s}: {ratio:6.2%}") + + print("\nSynthetic data:") + for qtype, ratio in synth_dist.items(): + print(f" {qtype:8s}: {ratio:6.2%}") + + # Compute distribution divergence + types_ordered = ["who", "what", "when", "where", "why", "how", "other"] + real_dist_arr = np.array([real_dist[t] for t in types_ordered]) + synth_dist_arr = np.array([synth_dist[t] for t in types_ordered]) + + qtype_kl = compute_kl_divergence(real_dist_arr, synth_dist_arr) + qtype_js = compute_js_divergence(real_dist_arr, synth_dist_arr) + + print(f"\nKL divergence (Real||Synth): {qtype_kl:.4f}") + print(f"JS divergence: {qtype_js:.4f}") + + results["question_type_real"] = real_dist + results["question_type_synth"] = synth_dist + results["question_type_kl"] = qtype_kl + results["question_type_js"] = qtype_js + + # 2. Cosine similarity histograms + print("\n" + "-" * 80) + print("2. Embedding Similarity Distribution") + print("-" * 80) + + hist_real, hist_synth, bin_centers, bins = compute_cosine_similarity_histogram( + real_emb, synth_emb, n_bins=50 + ) + + sim_kl = compute_kl_divergence(hist_real, hist_synth) + sim_js = compute_js_divergence(hist_real, hist_synth) + + print(f"KL divergence (Real||Synth): {sim_kl:.4f}") + print(f"JS divergence: {sim_js:.4f}") + + results["similarity_hist_kl"] = sim_kl + results["similarity_hist_js"] = sim_js + results["similarity_hist_real"] = hist_real + results["similarity_hist_synth"] = hist_synth + results["similarity_bins"] = bin_centers + + # 3. MMD + print("\n" + "-" * 80) + print("3. Maximum Mean Discrepancy (MMD)") + print("-" * 80) + + mmd_value = compute_mmd_rbf(real_emb, synth_emb) + print(f"MMD² (RBF kernel): {mmd_value:.6f}") + + results["mmd_rbf"] = mmd_value + + # 4. UMAP overlay metrics (if provided) + projection_method = "umap" + projection_cache_path = ( + ENV.QA_GT_WEBAPP_CACHE_DIR + / f"projection_{projection_method}_{synth_dataset_name}_{embedding_type}.npy" + ) + projection_2d = np.load(projection_cache_path) + umap_real = projection_2d[: len(real_emb)] + umap_synth = projection_2d[len(real_emb) :] + + if umap_real is not None and umap_synth is not None: + print("\n" + "-" * 80) + print("4. UMAP Projection Overlay Metrics") + print("-" * 80) + + overlay_metrics = compute_umap_overlay_metrics(umap_real, umap_synth) + + for metric_name, value in overlay_metrics.items(): + print(f"{metric_name:30s}: {value:.6f}") + + results["umap_overlay"] = overlay_metrics + + # Create visualization + plot_comparison_results( + real_dist, + synth_dist, + hist_real, + hist_synth, + bins, + umap_real, + umap_synth, + save_path=ENV.QA_GT_ANALYZATION_DIR + / f"{synth_dataset_name}_qa_comparison.pdf", + ) + + print("\n" + "=" * 80) + print("Comparison complete!") + print("=" * 80) + + return results + + +# Example usage +if __name__ == "__main__": + synth_dataset_name = "wtq_alpha=1.0" + + # Run comparison + results = compare_qa_datasets( + synth_dataset_name=synth_dataset_name, embedding_type="Q" + ) + + print("\n\nResults summary available in 'results' dictionary:") + print(results) + print("Visualization saved to 'qa_comparison.png'") diff --git a/docgenie/analyzation/layoutfid/fid_calculator.py b/docgenie/analyzation/layoutfid/fid_calculator.py new file mode 100755 index 0000000000000000000000000000000000000000..912a54bac11906480355d79eff68a27d982c6df4 --- /dev/null +++ b/docgenie/analyzation/layoutfid/fid_calculator.py @@ -0,0 +1,241 @@ +import warnings +from pathlib import Path + +import numpy as np +import pandas as pd +import pydantic.v1 as pydantic +import pydantic_argparse +import torch +from PIL import Image +from scipy import linalg +from torch.utils.data import DataLoader +import tqdm +from transformers import AutoModel, AutoProcessor + +from docgenie.data._core._data_types import DocumentInstance +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader +from docgenie.data.interface import ( + load_dataset, + load_synthetic_dataset, +) +from docgenie.logging import get_logger + +import torchvision.transforms.functional as TF +from torch.nn.functional import adaptive_avg_pool2d +from pytorch_fid.inception import InceptionV3 + +logger = get_logger(__name__) + +warnings.filterwarnings("ignore") + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert ( + mu1.shape == mu2.shape + ), "Training and test mean vectors have different lengths" + assert ( + sigma1.shape == sigma2.shape + ), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + +def get_activations( + dataset, model, batch_size=50, dims=2048, device="cpu", num_workers=1 +): + model.eval() + + dataset.set_transform(lambda sample: TF.to_tensor(sample.image.content.convert("RGB").resize((1024, 1024)))) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) + + pred_arr = np.empty((len(dataset), dims)) + + start_idx = 0 + + for batch in tqdm.tqdm(dataloader): + batch = batch.to(device) + print('batch',batch.shape) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx : start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_activation_statistics( + dataset, model, batch_size=50, dims=2048, device="cpu", num_workers=1 +): + act = get_activations(dataset, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + +def calculate_fid_given_datasets(real_dataset, syn_dataset, batch_size, device, dims, num_workers=1): + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + model = InceptionV3([block_idx]).to(device) + m1, s1 = calculate_activation_statistics( + real_dataset, model, batch_size, dims, device, num_workers + ) + m2, s2 = calculate_activation_statistics( + syn_dataset, model, batch_size, dims, device, num_workers + ) + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +class FIDCalculatorConfig(pydantic.BaseModel): + """ + Configuration for clustering operations. + """ + + seed: int = 42 + real_dataset_name: str + synth_dataset_name: str + batch_size: int = 50 + limit_sizes_to_smallest: bool = True + + +def main( + cfg: FIDCalculatorConfig, +): + """Example usage of FID calculator.""" + + # load the results csv + output_df_path = Path("data/results/fid.csv") + + # load the results csv and check if row with same real and synth dataset exists + if output_df_path.exists(): + output_df = pd.read_csv(output_df_path) + existing_row = output_df[ + (output_df["real_dataset"] == cfg.real_dataset_name) + & (output_df["synth_dataset"] == cfg.synth_dataset_name) + ] + if not existing_row.empty: + logger.info( + f"FID already calculated for real dataset '{cfg.real_dataset_name}' and synthetic dataset '{cfg.synth_dataset_name}'. Skipping calculation." + ) + logger.info( + f"Existing FID Score: {existing_row['fid'].values[0]:.4f}" + ) + return + else: + output_df = pd.DataFrame( + columns=["real_dataset", "synth_dataset", "fid", "num_samples"] + ) + + # torch manual seed for reproducibility + torch.manual_seed(42) + + # logging config + logger.info("Calculating FID with config:") + logger.info(cfg.json(indent=4)) + + # load real dataset pipeline + real_dataset = load_dataset( + dataset_name=cfg.real_dataset_name, + create_train_val_splits=False, + ).train + + synth_dataset = load_synthetic_dataset( + dataset_name=cfg.synth_dataset_name, + ).train + + # assert datasets are not None + assert real_dataset is not None, "Real dataset train split is None" + assert synth_dataset is not None, "Synthetic dataset train split is None" + + # log dataset sizes + logger.info(f"Real dataset size: {len(real_dataset)}") + logger.info(f"Synthetic dataset size: {len(synth_dataset)}") + + # limit both datasets to smallest size + if cfg.limit_sizes_to_smallest: + real_size = len(real_dataset) # type: ignore + synth_size = len(synth_dataset) # type: ignore + + if real_size > synth_size: + logger.info( + f"Real dataset is bigger ({real_size} samples) than synthetic dataset ({synth_size} samples)." + ) + random_indices = torch.randperm(real_size)[:synth_size] + real_dataset.set_subset_indices(random_indices.tolist()) + else: + logger.info( + f"Synthetic dataset is bigger ({synth_size} samples) than real dataset ({real_size} samples)." + ) + random_indices = torch.randperm(synth_size)[:real_size] + synth_dataset.set_subset_indices(random_indices.tolist()) + + total_real_dataset_samples = len(real_dataset) # type: ignore + total_synth_dataset_samples = len(synth_dataset) # type: ignore + assert total_real_dataset_samples == total_synth_dataset_samples, ( + "FID calculation requires both datasets to have the same number of samples. " + f"Got {total_real_dataset_samples} real and {total_synth_dataset_samples} synthetic samples." + ) + + num_samples = total_real_dataset_samples + fid = calculate_fid_given_datasets(real_dataset, synth_dataset, cfg.batch_size, device="cuda", dims=2048) + logger.info(f"\FID Score: {fid:.4f} over {num_samples} samples") + + # append result to csv + new_row = { + "real_dataset": cfg.real_dataset_name, + "synth_dataset": cfg.synth_dataset_name, + "fid": fid, + "num_samples": len(real_dataset), + } + output_df = pd.concat([output_df, pd.DataFrame([new_row])], ignore_index=True) + output_df.to_csv("data/results/fid.csv", index=False) + logger.info("FID score saved to data/results/fid.csv") + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=FIDCalculatorConfig, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/layoutfid/layoutfid_from_embeddings.py b/docgenie/analyzation/layoutfid/layoutfid_from_embeddings.py new file mode 100755 index 0000000000000000000000000000000000000000..2d128a0a15ad9feff04616ce584e6c42cb324b9f --- /dev/null +++ b/docgenie/analyzation/layoutfid/layoutfid_from_embeddings.py @@ -0,0 +1,402 @@ +import warnings +from pathlib import Path + +import numpy as np +import pandas as pd +from docgenie import ENV +from docgenie.analyzation.clustering.core._embeddings import _load_embeddings +import pydantic.v1 as pydantic +import pydantic_argparse +import torch +from PIL import Image +from scipy import linalg +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModel, AutoProcessor + +from docgenie.data._core._data_types import DocumentInstance +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader +from docgenie.data.interface import ( + load_dataset, + load_synthetic_dataset, +) +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +warnings.filterwarnings("ignore") + + +class LayoutFIDCalculator: + """ + GPU-accelerated LayoutFID score calculator using LayoutLMv3 embeddings. + """ + + def __init__( + self, device: str = "cuda", model_name: str = "microsoft/layoutlmv3-base" + ): + """ + Initialize LayoutFID calculator. + + Args: + device: 'cuda' or 'cpu' + model_name: HuggingFace model identifier for LayoutLMv3 + """ + self.device = device if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {self.device}") + + # Load LayoutLMv3 model and processor + self.processor = AutoProcessor.from_pretrained(model_name, apply_ocr=False) + self.model = AutoModel.from_pretrained(model_name) + self.model.to(self.device) + self.model.eval() + + def _get_embeddings( + self, + dataset: MsgpackDatasetReader, + batch_size: int, + use_image_only: bool = False, + ) -> np.ndarray: + """ + Extract LayoutLMv3 embeddings for images. + + Args: + image_paths: List of paths to document images + batch_size: Batch size for processing + + Returns: + Embeddings array of shape (n_images, embedding_dim) + """ + + embeddings_list = [] + + with torch.no_grad(): + dataloader = DataLoader( + dataset, # type: ignore + batch_size=batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + collate_fn=lambda x: x, + ) + for batch in tqdm( + dataloader, + desc=f"Extracting embeddings batch_size=[{batch_size}]", + total=len(dataloader), + ): + batch: list[DocumentInstance] + + # get images, words, boxes from batch + words, word_bboxes, images = [], [], [] + for sample in batch: + assert sample.image is not None, "Sample image is None" + assert isinstance(sample.image.content, Image.Image), ( + "Sample image content is not PIL Image" + ) + images.append(sample.image.content.convert("RGB")) + if use_image_only: + continue + assert sample.content is not None, "Sample content is None" + assert sample.content.word_bboxes is not None, ( + "Sample word bboxes are None" + ) + + words.append(sample.content.words) + word_bboxes.append(sample.content.word_bboxes.value) + + # Process images with LayoutLMv3 processor + inputs = self.processor( + text=words, + boxes=word_bboxes, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + + # layoutlmv3 expects bboxes in range [0, 1000] + # we assume to get normalized bboxes in [0, 1] + # scale bboxes + # if ( + # inputs["bbox"].max() > 1.01 or inputs["bbox"].min() < -0.01 + # ): # 1.1 to account for any floating point precision issues + # raise ValueError( + # f"Expected normalized bounding boxes in range [0, 1], Got max value {inputs['bbox'].max()}" + # ) + + inputs["bbox"] = (inputs["bbox"].clip(0.0, 1.0) * 1000).long() + + # Move to device + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get model output + outputs = self.model(**inputs, output_hidden_states=True) + + # Use last hidden state (CLS token or mean pooling) + # Extract the [CLS] token representation (first token) + batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + embeddings_list.append(batch_embeddings) + + embeddings = np.concatenate(embeddings_list, axis=0) + return embeddings + + def _compute_statistics(self, embeddings: np.ndarray) -> tuple: + """ + Compute mean and covariance of embeddings. + + Args: + embeddings: Array of shape (n_samples, embedding_dim) + + Returns: + Tuple of (mean, covariance) + """ + mu = np.mean(embeddings, axis=0) + sigma = np.cov(embeddings.T) + + # Ensure sigma is 2D (handle 1D case) + if sigma.ndim == 1: + sigma = np.diag(sigma) + + return mu, sigma + + # def _compute_fid( # this works same as calculate_frechet_distance but i kept the original as its taken from well-known FID implementation + # https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py + # self, mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray + # ) -> float: + # """ + # Compute Fréchet Inception Distance. + + # Args: + # mu1, sigma1: Mean and covariance of real embeddings + # mu2, sigma2: Mean and covariance of generated embeddings + + # Returns: + # FID score + # """ + # # Euclidean distance between means + # diff = mu1 - mu2 + # diff_norm = np.sum(diff**2) + + # # Trace of covariance matrices + # trace_cov = np.trace(sigma1 + sigma2) + + # # Matrix square root of product of covariances + # # Using eigenvalue decomposition for numerical stability + # sqrt_cov_prod = self._sqrtm(sigma1 @ sigma2) + # trace_sqrt_prod = np.trace(sqrt_cov_prod) + + # # FID = ||µr - µg||^2 + Tr(Σr + Σg - 2√(ΣrΣg)) + # fid = diff_norm + trace_cov - 2 * trace_sqrt_prod + + # return float(np.real(fid)) + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, ( + "Training and test mean vectors have different lengths" + ) + assert sigma1.shape == sigma2.shape, ( + "Training and test covariances have different dimensions" + ) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + logger.info(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + @staticmethod + def _sqrtm(matrix: np.ndarray) -> np.ndarray: + """ + Compute matrix square root using eigenvalue decomposition. + More numerically stable than scipy.linalg.sqrtm for this use case. + """ + try: + # Use scipy's sqrtm for general case + sqrt_m = linalg.sqrtm(matrix) + # Return real part if imaginary component is negligible + if np.iscomplexobj(sqrt_m): + sqrt_m = np.real(sqrt_m) + return sqrt_m + except np.linalg.LinAlgError: + # Fallback: eigenvalue decomposition + eigvals, eigvecs = np.linalg.eigh(matrix) + eigvals = np.maximum(eigvals, 0) # Ensure non-negative + sqrt_m = eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T + return np.real(sqrt_m) + + def calculate_layoutfid( + self, + real_embeddings: "np.ndarray", + synth_embeddings: "np.ndarray", + limit_sizes_to_smallest: bool = True, + ) -> tuple[float, int]: + # limit both datasets to smallest size + if limit_sizes_to_smallest: + real_size = len(real_embeddings) # type: ignore + synth_size = len(synth_embeddings) # type: ignore + + # layout fix see which dataset is smaller in size + if real_size > synth_size: + logger.info( + f"Real embeddings is bigger ({real_size} samples) than synthetic dataset ({synth_size} samples)." + ) + random_indices = torch.randperm(real_size)[:synth_size] + real_embeddings = real_embeddings[random_indices.tolist()] + else: + logger.info( + f"Synthetic dataset is bigger ({synth_size} samples) than real dataset ({real_size} samples)." + ) + random_indices = torch.randperm(synth_size)[:real_size] + synth_embeddings = synth_embeddings[random_indices.tolist()] + + total_real_dataset_samples = len(real_embeddings) # type: ignore + total_synth_dataset_samples = len(synth_embeddings) # type: ignore + assert total_real_dataset_samples == total_synth_dataset_samples, ( + "FID calculation requires both datasets to have the same number of samples. " + f"Got {total_real_dataset_samples} real and {total_synth_dataset_samples} synthetic samples." + ) + logger.info("Calculating real statistics...") + mu_real, sigma_real = self._compute_statistics(real_embeddings) + logger.info("Calculating synthetic statistics...") + mu_gen, sigma_gen = self._compute_statistics(synth_embeddings) + layoutfid = self.calculate_frechet_distance( + mu_real, sigma_real, mu_gen, sigma_gen + ) + return layoutfid, real_embeddings.shape[0] + + +class LayoutFIDCalculatorConfig(pydantic.BaseModel): + """ + Configuration for clustering operations. + """ + + seed: int = 42 + real_dataset_name: str + synth_dataset_name: str + limit_sizes_to_smallest: bool = True + embedding_src: str = "layout" + + +def main( + cfg: LayoutFIDCalculatorConfig, +): + """Example usage of LayoutFID calculator.""" + + # load the results csv + output_df_path = Path("data/results/layout_fid_embeddings.csv") + + # load the results csv and check if row with same real and synth dataset exists + if output_df_path.exists(): + output_df = pd.read_csv(output_df_path) + existing_row = output_df[ + (output_df["real_dataset"] == cfg.real_dataset_name) + & (output_df["synth_dataset"] == cfg.synth_dataset_name) + & (output_df["embedding_src"] == cfg.embedding_src) + ] + if not existing_row.empty: + logger.info( + f"LayoutFID already calculated for real dataset '{cfg.real_dataset_name}' and synthetic dataset '{cfg.synth_dataset_name}'. Skipping calculation." + ) + logger.info( + f"Existing LayoutFID Score: {existing_row['layoutfid_score'].values[0]:.4f}" + ) + return + else: + output_df = pd.DataFrame( + columns=["real_dataset", "synth_dataset", "layoutfid_score", "num_samples"] + ) + + # torch manual seed for reproducibility + torch.manual_seed(42) + + # logging config + logger.info("Calculating LayoutFID with config:") + logger.info(cfg.json(indent=4)) + + # load the real embeddings + real_embeddings, _ = _load_embeddings( + file_path=ENV.EMBEDDINGS_DIR / cfg.real_dataset_name / f"{cfg.embedding_src}.h5" + ) + + # load the synthetic embeddings + synthetic_embeddings, _ = _load_embeddings( + file_path=ENV.EMBEDDINGS_DIR + / 'synth' + / cfg.synth_dataset_name + / f"{cfg.embedding_src}.h5", + ) + + # Initialize calculator + calculator = LayoutFIDCalculator(device="cuda") + + # Calculate LayoutFID + layoutfid_score, num_samples = calculator.calculate_layoutfid( + real_embeddings, + synthetic_embeddings, + limit_sizes_to_smallest=cfg.limit_sizes_to_smallest, + ) + logger.info(f"\nLayoutFID Score: {layoutfid_score:.4f} over {num_samples} samples") + + # append result to csv + new_row = { + "real_dataset": cfg.real_dataset_name, + "synth_dataset": cfg.synth_dataset_name, + "layoutfid_score": layoutfid_score, + "num_samples": len(real_embeddings), + "embedding_src": cfg.embedding_src, + } + output_df = pd.concat([output_df, pd.DataFrame([new_row])], ignore_index=True) + output_df_path.parent.mkdir(parents=True, exist_ok=True) + output_df.to_csv(output_df_path, index=False) + logger.info("LayoutFID score saved to data/results/layout_fid.csv") + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=LayoutFIDCalculatorConfig, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/layoutfid/layoutfidcalculator.py b/docgenie/analyzation/layoutfid/layoutfidcalculator.py new file mode 100755 index 0000000000000000000000000000000000000000..aab33d1bd61a2adab76c848ae78dbebcfc69d280 --- /dev/null +++ b/docgenie/analyzation/layoutfid/layoutfidcalculator.py @@ -0,0 +1,440 @@ +import warnings +from pathlib import Path + +import numpy as np +import pandas as pd +import pydantic.v1 as pydantic +import pydantic_argparse +import torch +from PIL import Image +from scipy import linalg +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModel, AutoProcessor + +from docgenie.data._core._data_types import DocumentInstance +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader +from docgenie.data.interface import ( + load_dataset, + load_synthetic_dataset, +) +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +warnings.filterwarnings("ignore") + + +class LayoutFIDCalculator: + """ + GPU-accelerated LayoutFID score calculator using LayoutLMv3 embeddings. + """ + + def __init__( + self, device: str = "cuda", model_name: str = "microsoft/layoutlmv3-base" + ): + """ + Initialize LayoutFID calculator. + + Args: + device: 'cuda' or 'cpu' + model_name: HuggingFace model identifier for LayoutLMv3 + """ + self.device = device if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {self.device}") + + # Load LayoutLMv3 model and processor + self.processor = AutoProcessor.from_pretrained(model_name, apply_ocr=False) + self.model = AutoModel.from_pretrained(model_name) + self.model.to(self.device) + self.model.eval() + + def _get_embeddings( + self, + dataset: MsgpackDatasetReader, + batch_size: int, + use_image_only: bool = False, + ) -> np.ndarray: + """ + Extract LayoutLMv3 embeddings for images. + + Args: + image_paths: List of paths to document images + batch_size: Batch size for processing + + Returns: + Embeddings array of shape (n_images, embedding_dim) + """ + + embeddings_list = [] + + with torch.no_grad(): + dataloader = DataLoader( + dataset, # type: ignore + batch_size=batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + collate_fn=lambda x: x, + ) + for batch in tqdm( + dataloader, + desc=f"Extracting embeddings batch_size=[{batch_size}]", + total=len(dataloader), + ): + batch: list[DocumentInstance] + + # get images, words, boxes from batch + words, word_bboxes, images = [], [], [] + for sample in batch: + assert sample.image is not None, "Sample image is None" + assert isinstance(sample.image.content, Image.Image), ( + "Sample image content is not PIL Image" + ) + images.append(sample.image.content.convert("RGB")) + if use_image_only: + words.append(["None"]) + word_bboxes.append([[0, 0, 0, 0]]) + + continue + assert sample.content is not None, "Sample content is None" + assert sample.content.word_bboxes is not None, ( + "Sample word bboxes are None" + ) + + words.append(sample.content.words) + word_bboxes.append(sample.content.word_bboxes.value) + + # Process images with LayoutLMv3 processor + inputs = self.processor( + text=words, + boxes=word_bboxes, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + + # layoutlmv3 expects bboxes in range [0, 1000] + # we assume to get normalized bboxes in [0, 1] + # scale bboxes + # if ( + # inputs["bbox"].max() > 1.01 or inputs["bbox"].min() < -0.01 + # ): # 1.1 to account for any floating point precision issues + # raise ValueError( + # f"Expected normalized bounding boxes in range [0, 1], Got max value {inputs['bbox'].max()}" + # ) + + inputs["bbox"] = (inputs["bbox"].clip(0.0, 1.0) * 1000).long() + + # Move to device + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + for key, value in inputs.items(): + assert isinstance(value, torch.Tensor), ( + f"Expected tensor for input '{key}', got {type(value)}" + ) + if value is not None: + print(f"Input '{key}' shape: {value.shape}") + else: + print(f"Input '{key}' is None") + + # Get model output + outputs = self.model(**inputs, output_hidden_states=True) + + # Use last hidden state (CLS token or mean pooling) + # Extract the [CLS] token representation (first token) + batch_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy() + embeddings_list.append(batch_embeddings) + + embeddings = np.concatenate(embeddings_list, axis=0) + return embeddings + + def _compute_statistics(self, embeddings: np.ndarray) -> tuple: + """ + Compute mean and covariance of embeddings. + + Args: + embeddings: Array of shape (n_samples, embedding_dim) + + Returns: + Tuple of (mean, covariance) + """ + mu = np.mean(embeddings, axis=0) + sigma = np.cov(embeddings.T) + + # Ensure sigma is 2D (handle 1D case) + if sigma.ndim == 1: + sigma = np.diag(sigma) + + return mu, sigma + + # def _compute_fid( # this works same as calculate_frechet_distance but i kept the original as its taken from well-known FID implementation + # https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py + # self, mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray + # ) -> float: + # """ + # Compute Fréchet Inception Distance. + + # Args: + # mu1, sigma1: Mean and covariance of real embeddings + # mu2, sigma2: Mean and covariance of generated embeddings + + # Returns: + # FID score + # """ + # # Euclidean distance between means + # diff = mu1 - mu2 + # diff_norm = np.sum(diff**2) + + # # Trace of covariance matrices + # trace_cov = np.trace(sigma1 + sigma2) + + # # Matrix square root of product of covariances + # # Using eigenvalue decomposition for numerical stability + # sqrt_cov_prod = self._sqrtm(sigma1 @ sigma2) + # trace_sqrt_prod = np.trace(sqrt_cov_prod) + + # # FID = ||µr - µg||^2 + Tr(Σr + Σg - 2√(ΣrΣg)) + # fid = diff_norm + trace_cov - 2 * trace_sqrt_prod + + # return float(np.real(fid)) + + def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, ( + "Training and test mean vectors have different lengths" + ) + assert sigma1.shape == sigma2.shape, ( + "Training and test covariances have different dimensions" + ) + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates" + ) % eps + logger.info(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + @staticmethod + def _sqrtm(matrix: np.ndarray) -> np.ndarray: + """ + Compute matrix square root using eigenvalue decomposition. + More numerically stable than scipy.linalg.sqrtm for this use case. + """ + try: + # Use scipy's sqrtm for general case + sqrt_m = linalg.sqrtm(matrix) + # Return real part if imaginary component is negligible + if np.iscomplexobj(sqrt_m): + sqrt_m = np.real(sqrt_m) + return sqrt_m + except np.linalg.LinAlgError: + # Fallback: eigenvalue decomposition + eigvals, eigvecs = np.linalg.eigh(matrix) + eigvals = np.maximum(eigvals, 0) # Ensure non-negative + sqrt_m = eigvecs @ np.diag(np.sqrt(eigvals)) @ eigvecs.T + return np.real(sqrt_m) + + def calculate_layoutfid( + self, + real_dataset: "MsgpackDatasetReader", + synth_dataset: "MsgpackDatasetReader", + batch_size: int = 32, + limit_sizes_to_smallest: bool = True, + use_image_only: bool = False, + ) -> tuple[float, int]: + """ + Calculate LayoutFID between real and generated document images. + + Args: + real_image_paths: List of paths to real document images (seed documents) + generated_image_paths: List of paths to generated document images + batch_size: Batch size for embedding extraction + + Returns: + LayoutFID score (lower is better) + """ + # limit both datasets to smallest size + if limit_sizes_to_smallest: + real_size = len(real_dataset) # type: ignore + synth_size = len(synth_dataset) # type: ignore + + # layout fix see which dataset is smaller in size + if real_size > synth_size: + logger.info( + f"Real dataset is bigger ({real_size} samples) than synthetic dataset ({synth_size} samples)." + ) + random_indices = torch.randperm(real_size)[:synth_size] + real_dataset.set_subset_indices(random_indices.tolist()) + else: + logger.info( + f"Synthetic dataset is bigger ({synth_size} samples) than real dataset ({real_size} samples)." + ) + random_indices = torch.randperm(synth_size)[:real_size] + synth_dataset.set_subset_indices(random_indices.tolist()) + + total_real_dataset_samples = len(real_dataset) # type: ignore + total_synth_dataset_samples = len(synth_dataset) # type: ignore + assert total_real_dataset_samples == total_synth_dataset_samples, ( + "FID calculation requires both datasets to have the same number of samples. " + f"Got {total_real_dataset_samples} real and {total_synth_dataset_samples} synthetic samples." + ) + + logger.info( + f"Extracting embeddings for {total_real_dataset_samples} real images..." + ) + real_embeddings = self._get_embeddings( + real_dataset, batch_size, use_image_only=use_image_only + ) + + logger.info( + f"Extracting embeddings for {total_synth_dataset_samples} generated images..." + ) + gen_embeddings = self._get_embeddings(synth_dataset, batch_size) + mu_real, sigma_real = self._compute_statistics(real_embeddings) + mu_gen, sigma_gen = self._compute_statistics(gen_embeddings) + layoutfid = self.calculate_frechet_distance( + mu_real, sigma_real, mu_gen, sigma_gen + ) + return layoutfid, real_embeddings.shape[0] + + +class LayoutFIDCalculatorConfig(pydantic.BaseModel): + """ + Configuration for clustering operations. + """ + + seed: int = 42 + real_dataset_name: str + synth_dataset_name: str + batch_size: int = 32 + limit_sizes_to_smallest: bool = True + use_image_only: bool = False + + +def main( + cfg: LayoutFIDCalculatorConfig, +): + """Example usage of LayoutFID calculator.""" + + # load the results csv + output_df_path = Path("data/results/layout_fid.csv") + + # load the results csv and check if row with same real and synth dataset exists + if output_df_path.exists(): + output_df = pd.read_csv(output_df_path) + existing_row = output_df[ + (output_df["real_dataset"] == cfg.real_dataset_name) + & (output_df["synth_dataset"] == cfg.synth_dataset_name) + ] + if not existing_row.empty: + logger.info( + f"LayoutFID already calculated for real dataset '{cfg.real_dataset_name}' and synthetic dataset '{cfg.synth_dataset_name}'. Skipping calculation." + ) + logger.info( + f"Existing LayoutFID Score: {existing_row['layoutfid_score'].values[0]:.4f}" + ) + return + else: + output_df = pd.DataFrame( + columns=["real_dataset", "synth_dataset", "layoutfid_score", "num_samples"] + ) + + # torch manual seed for reproducibility + torch.manual_seed(42) + + # logging config + logger.info("Calculating LayoutFID with config:") + logger.info(cfg.json(indent=4)) + + # load real dataset pipeline + real_dataset = load_dataset( + dataset_name=cfg.real_dataset_name, + create_train_val_splits=False, + ).train + + synth_dataset = load_synthetic_dataset( + dataset_name=cfg.synth_dataset_name, + ).train + + # assert datasets are not None + assert real_dataset is not None, "Real dataset train split is None" + assert synth_dataset is not None, "Synthetic dataset train split is None" + + # log dataset sizes + logger.info(f"Real dataset size: {len(real_dataset)}") + logger.info(f"Synthetic dataset size: {len(synth_dataset)}") + + # Initialize calculator + calculator = LayoutFIDCalculator(device="cuda") + + # Calculate LayoutFID + layoutfid_score, num_samples = calculator.calculate_layoutfid( + real_dataset, + synth_dataset, + batch_size=cfg.batch_size, + limit_sizes_to_smallest=cfg.limit_sizes_to_smallest, + use_image_only=cfg.use_image_only, + ) + logger.info(f"\nLayoutFID Score: {layoutfid_score:.4f} over {num_samples} samples") + + # append result to csv + new_row = { + "real_dataset": cfg.real_dataset_name, + "synth_dataset": cfg.synth_dataset_name, + "layoutfid_score": layoutfid_score, + "num_samples": len(real_dataset), + } + output_df = pd.concat([output_df, pd.DataFrame([new_row])], ignore_index=True) + output_df.to_csv("data/results/layout_fid.csv", index=False) + logger.info("LayoutFID score saved to data/results/layout_fid.csv") + + +if __name__ == "__main__": + parser = pydantic_argparse.ArgumentParser( + model=LayoutFIDCalculatorConfig, + ) + main(parser.parse_typed_args()) diff --git a/docgenie/analyzation/synth/analyze_policy_violations.py b/docgenie/analyzation/synth/analyze_policy_violations.py new file mode 100755 index 0000000000000000000000000000000000000000..30c9befef45345dfbaacca24570d33bff2b0227f --- /dev/null +++ b/docgenie/analyzation/synth/analyze_policy_violations.py @@ -0,0 +1,112 @@ +import argparse +from collections import Counter +import json +import matplotlib.pyplot as plt +from tqdm import tqdm + +from docgenie import ENV +from docgenie.generation.models._file import SyntheticDatasetFileStructure +from docgenie.generation.models._syndatadef import SynDatasetDefinition + + +def parse_args(): + parser = argparse.ArgumentParser( + description="DocGenie Synthetic Document Generator", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "SynDatasetDefinition", + type=str, + help="Filename without extension of the SynDatasetDefinition in data/syn_dataset_definitions", + ) + + args = parser.parse_args() + assert args.SynDatasetDefinition + print(args) + return args + + +def search_for_refusals(dsname): + deffile = ENV.SYN_DATA_DEFINITIONS_DIR / f"{dsname}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file(deffile) + dsfiles: SyntheticDatasetFileStructure = dsdef.get_file_structure() + + msg_ids_from_all_batches = set() + msg_id_to_batch_id = dict() + for prompt_batch_log_path in dsfiles.prompt_batches_directory.iterdir(): + prompt_batch_log = json.loads(prompt_batch_log_path.read_text(encoding="utf-8")) + for msg_id in prompt_batch_log["message_ids"]: + msg_ids_from_all_batches.add(msg_id) + msg_id_to_batch_id[msg_id] = prompt_batch_log["id"] + + missing_message_results = set() + refusals = set() + for msg_id in msg_ids_from_all_batches: + # Look for missing message results, as previously we didn't save them + msg_res_path = dsfiles.message_results_directory / f"{msg_id}.json" + if not msg_res_path.exists(): + missing_message_results.add(msg_id) + else: + msg_res = json.loads(msg_res_path.read_text(encoding="utf-8")) + if msg_res["error"] == "refusal": + refusals.add(msg_id) + + # Search seed images + all_refusals = missing_message_results | refusals + prompt_batch_log_lookup = dict() + problematic_seeds = list() + for msg_id in all_refusals: + batch_id = msg_id_to_batch_id[msg_id] + if batch_id not in prompt_batch_log_lookup: + prompt_batch_log_path = ( + dsfiles.prompt_batches_directory / f"{batch_id}.json" + ) + prompt_batch_log = json.loads( + prompt_batch_log_path.read_text(encoding="utf-8") + ) + prompt_batch_log_lookup[batch_id] = prompt_batch_log + + prompt_batch_log = prompt_batch_log_lookup[batch_id] + msg_seeds = prompt_batch_log["message_id_to_seed_docids"][msg_id] + + # Previously there was a bug, such that every message got ALL seeds of the batch saved as list of lists in message_id_to_seed_docids + # In newer versions, this is just a single list + is_buggy_lookup = all(isinstance(elem, list) for elem in msg_seeds) + if is_buggy_lookup: + # we need to retrive the correct sublist via index + i = prompt_batch_log["message_ids"].index(msg_id) + msg_seeds = msg_seeds[i] + else: + # msg_seeds is already in correct format + ... + problematic_seeds.extend(msg_seeds) + c = Counter(problematic_seeds) + sc = sorted(c.items(), key=lambda item: item[1], reverse=True) + for seed, cnt in sc[:3]: + print(f"{cnt=} {dsfiles.preprocessed_seed_images_directory / f'{seed}.jpg'}") + + return all_refusals + + +if __name__ == "__main__": + dsnames = [ + "cord_alpha=0.75", + "cord_alpha=1.0_v1", + "docvqa_alpha=0.5", + "docvqa_alpha=0.5_v1", + "docvqa_alpha=0.75", + "docvqa_alpha=0.75_v1", + "docvqa_alpha=1.0", + "docvqa_alpha=1.0_v1", + "publaynet_alpha=0.75", + "rvlcdip_alpha=0.5", + "rvlcdip_alpha=0.5_v1", + "rvlcdip_alpha=0.75", + "rvlcdip_alpha=0.75_v1", + "rvlcdip_alpha=1.0", + "rvlcdip_alpha=1.0_v1", + ] + for n in dsnames: + refusals = search_for_refusals(n) + print(f"{n} {len(refusals)=}") diff --git a/docgenie/analyzation/synth/select_hw_examples.py b/docgenie/analyzation/synth/select_hw_examples.py new file mode 100755 index 0000000000000000000000000000000000000000..620ab895fe9f1640b3dbd31910ecc29a765b196c --- /dev/null +++ b/docgenie/analyzation/synth/select_hw_examples.py @@ -0,0 +1,107 @@ +import argparse +from collections import Counter +import json +import matplotlib.pyplot as plt +from tqdm import tqdm + +from docgenie import ENV +from docgenie.generation.models._file import SyntheticDatasetFileStructure +from docgenie.generation.models._syndatadef import SynDatasetDefinition + + +def parse_args(): + parser = argparse.ArgumentParser( + description="DocGenie Synthetic Document Generator", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "SynDatasetDefinition", + type=str, + help="Filename without extension of the SynDatasetDefinition in data/syn_dataset_definitions", + ) + + args = parser.parse_args() + assert args.SynDatasetDefinition + print(args) + return args + + +def get_all_hw_sentence_imgs(dsname): + deffile = ENV.SYN_DATA_DEFINITIONS_DIR / f"{dsname}.yaml" + dsdef: SynDatasetDefinition = SynDatasetDefinition.from_file(deffile) + dsfiles: SyntheticDatasetFileStructure = dsdef.get_file_structure() + + images_path = dsfiles.handwritten_text_images_directory / "sentences" + if not images_path.exists(): + return + + for d in images_path.iterdir(): + if d.is_dir(): + for f in d.iterdir(): + yield f + + +if __name__ == "__main__": + datasets=[ + "cord_alpha=0.5", + "cord_alpha=0.5_v1", + "cord_alpha=0.75", + "cord_alpha=0.75_v1", + "cord_alpha=1.0", + "cord_alpha=1.0_v1", + "doclaynet4k_alpha=1.0_CLS", + "doclaynet4k_alpha=1.0_DLA", + "doclaynet_alpha=1.0_CLS", + "doclaynet_alpha=1.0_DLA", + "docvqa_alpha=0.5", + "docvqa_alpha=0.5_v1", + "docvqa_alpha=0.75", + "docvqa_alpha=0.75_v1", + "docvqa_alpha=1.0", + "docvqa_alpha=1.0_v1", + "funsd_alpha=1.0", + "icdar2019_alpha=1.0", + "kleister_alpha=1.0", + "publaynet_alpha=0.5", + "publaynet_alpha=0.5_v1", + "publaynet_alpha=0.75", + "publaynet_alpha=0.75_v1", + "publaynet_alpha=1.0", + "publaynet_alpha=1.0_v1", + "publaynet_correct-sampling_alpha=0.5", + "publaynet_correct-sampling_alpha=0.5_v1", + "publaynet_correct-sampling_alpha=0.75", + "publaynet_correct-sampling_alpha=0.75_v1", + "publaynet_correct-sampling_alpha=1.0", + "publaynet_correct-sampling_alpha=1.0_v1", + "rvlcdip_alpha=0.5", + "rvlcdip_alpha=0.5_v1", + "rvlcdip_alpha=0.75", + "rvlcdip_alpha=0.75_v1", + "rvlcdip_alpha=1.0", + "rvlcdip_alpha=1.0_v1", + "sroie_alpha=1.0", + "tobacco3482_alpha=1.0", + "wtq_alpha=1.0", + ] + + all_hw_sent_imgs = [] + for n in datasets: + hw_sent_imgs = get_all_hw_sentence_imgs(n) + all_hw_sent_imgs.extend(list(hw_sent_imgs)) + + import random + random.seed = 42 + random.shuffle(all_hw_sent_imgs) + + import shutil + import pathlib + NUM_IMGS = 200 + + f: pathlib.Path + for f in all_hw_sent_imgs[:NUM_IMGS]: + d = ENV.DATA_DIR / "hw_imgs" / f'{f.parent.stem}-{f.name}' + print(d) + input() + shutil.copy(f, d) diff --git a/docgenie/analyzation/utils.py b/docgenie/analyzation/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..bab0d9a1e9b9da7b713c4d82e4c152aba90b8bf7 --- /dev/null +++ b/docgenie/analyzation/utils.py @@ -0,0 +1,16 @@ +import pathlib +import h5py +import numpy as np +from tqdm import tqdm + + +def read_h5_numpy(path: pathlib.Path) -> np.ndarray: + all_embeddings = [] + all_ids = [] + with h5py.File(path, "r") as f: + for id_ in tqdm(sorted(f.keys())): + emb = f[id_][:] # load tensor in numpy format + all_embeddings.append(emb) + all_ids.append(id_) + + return all_embeddings, all_ids \ No newline at end of file diff --git a/docgenie/data/README.md b/docgenie/data/README.md new file mode 100755 index 0000000000000000000000000000000000000000..3fe47f7b8fb3fdc8450897d82eca5bc302977160 --- /dev/null +++ b/docgenie/data/README.md @@ -0,0 +1,83 @@ +# DocGenie +## Setup environment +```bash +uv sync +source .venv/bin/activate +``` + +## Run visualizations scripts for datasets for sanity check +```bash +# classification +uv run python docgenie/data/cmds/visualize.py --dataset-name tobacco3482 +uv run python docgenie/data/cmds/visualize.py --dataset-name rvlcdip + +# entity labeling +uv run python docgenie/data/cmds/visualize.py --dataset-name cord +uv run python docgenie/data/cmds/visualize.py --dataset-name sroie +uv run python docgenie/data/cmds/visualize.py --dataset-name funsd +uv run python docgenie/data/cmds/visualize.py --dataset-name wild_receipts +uv run python docgenie/data/cmds/visualize.py --dataset-name docile + +# extractive qa +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_docvqa # avg pages ~1 +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_deepform # avg pages ~5 +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_tabfact # avg pages ~1 +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_wiki # avg pages ~1 +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_infographics # avg pages ~1 +uv run python docgenie/data/cmds/visualize.py --dataset-name ex_klc # avg pages ~23 +``` + +## How to load a specific dataset without transforms +This script assumes that datasets are already prepared in the /path/to/datasets/ dir in msgpack format +The dataset preparation itself is managed using a separate atria_datasets library. +To keep docgenie code clean the two are separated. +```python +from docgenie.data import load_dataset +dataset = load_dataset(dataset_name, root_datasets_dir="/path/to/datasets/") + +# read samples or use dataset.train[0] +train_dataset = dataset.train # could be None, check for actual use +for sample in dataset.train: + print("Sample: ", sample) + +validation_dataset = dataset.validation # could be None, check for actual use +for sample in dataset.validation: + print("Sample: ", sample) + +test_dataset = dataset.test # could be None, check for actual use +for sample in dataset.test: + print("Sample: ", sample) +``` + +## How to load a specific dataset with task-specific transforms +This script assumes that datasets are already prepared in the /path/to/datasets/ dir in msgpack format +The dataset preparation itself is managed using a separate atria_datasets library. +To keep docgenie code clean the two are separated. +```python +from docgenie.data import load_data_pipeline + +# load sequence classification dataset pipeline +data_pipeline = load_data_pipeline( + dataset_name=dataset_name, +) + +# load tokenized batch from train dataloader +for batch in data_pipeline.train_dataloader: + print(batch) + +# load tokenized batch from validation dataloader +for batch in data_pipeline.validation_dataloader: + print(batch) + +# load tokenized batch from test dataloader +for batch in data_pipeline.test_dataloader: + print(batch) +``` + +## Run tests for data pipeline to make sure its correct +This can be run to test datasets. If tests fail this means something wrong with preparation +of that dataset. + +``` +uv run pytest ./tests/test_data_pipeline.py -q --tb=line +``` \ No newline at end of file diff --git a/docgenie/data/__init__.py b/docgenie/data/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..3665022392003e561972c94f52bc1eabc20705c6 --- /dev/null +++ b/docgenie/data/__init__.py @@ -0,0 +1,3 @@ +from ._transforms import * # noqa +from .interface import * # noqa +from .constants import * # noqa diff --git a/docgenie/data/_core/_data_pipeline.py b/docgenie/data/_core/_data_pipeline.py new file mode 100755 index 0000000000000000000000000000000000000000..26dbbd581d81fb3bc3775a228f2f2911ee8d5653 --- /dev/null +++ b/docgenie/data/_core/_data_pipeline.py @@ -0,0 +1,233 @@ +""" +A data pipeline that wraps around a dataset and provides dataloaders for training, validation, and testing. +""" + +from typing import TYPE_CHECKING, Callable + +from atria_core.utilities.repr import RepresentationMixin + +from docgenie.data._core._data_types import MMDetInput +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader +from docgenie.logging import get_logger + +from ._dataset import Dataset +from ._utilities import ( + auto_dataloader, + default_collate, +) + +if TYPE_CHECKING: + from torch.utils.data import DataLoader + + +logger = get_logger(__name__) + + +def mmdet_pseudo_collate(batch: list["MMDetInput"]): + """ + Default collate function for MMDetInput inputs. + + This function collates a batch of data instances into a single batch. It is used when + the `collate_fn` argument is not provided to the DataLoader. + + Args: + batch (List[MMDetInput]): A batch of data instances. + + Returns: + Any: The collated batch. + + Raises: + ValueError: If the batch is empty or not a list. + """ + from mmengine.dataset.utils import pseudo_collate + + return MMDetInput( + **pseudo_collate( + [ + { + "inputs": sample.inputs, + "data_samples": sample.data_samples, + } + for sample in batch + ] + ) + ) + + +class DataPipeline(RepresentationMixin): + def __init__( + self, + dataset: "Dataset", + # dataset split args + dataset_splitting_enabled: bool = False, + split_ratio: float = 0.9, + # collate_fn + collate_fn: str | None = "default_collate", + ): + self._dataset = dataset + self._sharded_storage_kwargs = {} + self._dataset_splitter = None + + # if dataset_splitting_enabled and self._dataset.validation is None: # just make sure to turn this off for now + # assert self._dataset.train is not None, ( + # "Dataset splitting enabled but no training dataset found." + # ) + # self._dataset_splitter = StandardSplitter( + # split_ratio=split_ratio, shuffle=True + # ) + # self._dataset.train, self._dataset.validation = self._dataset_splitter( + # self._dataset.train + # ) + + # logger.info("Dataset splitting enabled.") + # logger.info( + # f"Train set size: {self._dataset.train_size}, Validation set size: {self._dataset.validation_size}" + # ) + + if collate_fn == "default_collate": + self._collate_fn = default_collate + elif collate_fn == "mmdet_pseudo_collate": + self._collate_fn = mmdet_pseudo_collate + elif collate_fn == "identity": + self._collate_fn = lambda x: x + else: + raise ValueError(f"Invalid collate_fn: {collate_fn}") + + @property + def dataset(self): + return self._dataset + + @property + def dataset_metadata(self): + return self._dataset.metadata + + def set_transform( + self, transform: Callable, for_train: bool = True, for_eval: bool = True + ): + from torch.utils.data import ConcatDataset + + if for_train and self._dataset.train is not None: + if isinstance(self._dataset.train, ConcatDataset): + for ds in self._dataset.train.datasets: + ds.set_transform(transform) + else: + self._dataset.train.set_transform(transform) + + if for_eval and self._dataset.validation is not None: + self._dataset.validation.set_transform(transform) + + if for_eval and self._dataset.test is not None: + self._dataset.test.set_transform(transform) + + def dataloader( + self, + split: str, + batch_size: int = 1, + pin_memory: bool = True, + num_workers: int = 4, + shuffle: bool = True, + ): + if split == "train": + return self.train_dataloader( + batch_size=batch_size, + pin_memory=pin_memory, + num_workers=num_workers, + shuffle=shuffle, + ) + elif split == "validation": + return self.validation_dataloader( + batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers + ) + elif split == "test": + return self.test_dataloader( + batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers + ) + else: + raise ValueError(f"Invalid split name: {split}") + + def train_dataloader( + self, + batch_size: int = 1, + pin_memory: bool = True, + num_workers: int = 4, + shuffle: bool = True, + ) -> "DataLoader | None": + import ignite.distributed as idist + from torch.utils.data import RandomSampler, SequentialSampler + + if self._dataset.train is None: + return + + return auto_dataloader( + dataset=self._dataset.train, + collate_fn=self._collate_fn, + sampler=RandomSampler(self._dataset.train) + if shuffle + else SequentialSampler(self._dataset.train), + drop_last=idist.get_world_size() > 1, + batch_size=batch_size * idist.get_world_size(), + num_workers=num_workers, + pin_memory=pin_memory, + ) + + def validation_dataloader( + self, batch_size: int = 1, pin_memory: bool = True, num_workers: int = 4 + ) -> "DataLoader | None": + dataset = self._dataset.validation or self._dataset.test + if dataset is None: + return + + if self._dataset.validation is None: + logger.warning( + "No validation dataset found, using test dataset for validation." + ) + + return self._build_evaluation_dataloader( + dataset, + batch_size=batch_size, + pin_memory=pin_memory, + num_workers=num_workers, + ) + + def test_dataloader( + self, batch_size: int = 1, pin_memory: bool = True, num_workers: int = 4 + ) -> "DataLoader | None": + if self._dataset.test is None: + return None + return self._build_evaluation_dataloader( + self._dataset.test, + batch_size=batch_size, + pin_memory=pin_memory, + num_workers=num_workers, + ) + + def _build_evaluation_dataloader( + self, + dataset: "MsgpackDatasetReader", + batch_size: int = 1, + pin_memory: bool = True, + num_workers: int = 4, + ) -> "DataLoader": + if dataset is None: + return None + + import ignite.distributed as idist # type: ignore + from torch.utils.data import SequentialSampler # type: ignore + + if idist.get_world_size() > 1: + if len(dataset) % idist.get_world_size() != 0: + logger.warning( + "Enabling distributed evaluation with an eval dataset not divisible by process number. " + "This will slightly alter validation results as extra duplicate entries are added to achieve " + "equal num of samples per-process." + ) + return auto_dataloader( + dataset=dataset, + collate_fn=self._collate_fn, + shuffle=False, + drop_last=False, + sampler=SequentialSampler(dataset), + batch_size=batch_size * idist.get_world_size(), + pin_memory=pin_memory, + num_workers=num_workers, + ) diff --git a/docgenie/data/_core/_data_types.py b/docgenie/data/_core/_data_types.py new file mode 100755 index 0000000000000000000000000000000000000000..df0f8f2cf0e583d9214a44139de120c0b8c8412c --- /dev/null +++ b/docgenie/data/_core/_data_types.py @@ -0,0 +1,537 @@ +from __future__ import annotations + +import enum +from dataclasses import dataclass, field, fields, replace +from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar + +from atria_core.types import * +from mmdet.structures import DetDataSample +from pydantic import ConfigDict + +if TYPE_CHECKING: + from typing import Any + + import torch + + +class OverflowStrategy(str, enum.Enum): + select_first = "select_first" + select_all = "select_all" + select_random = "select_random" + + +if TYPE_CHECKING: + import torch + +T = TypeVar("T", bound="BaseModelInput") + + +@dataclass(frozen=True) +class MMDetInput: + model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") + inputs: list[Any] | Any + data_samples: DetDataSample + + def to(self, device: torch.device | str) -> "MMDetInput": + inputs = [tensor.to(device) for tensor in self.inputs] + return MMDetInput( + inputs=inputs, + data_samples=self.data_samples, + ) + + +@dataclass(frozen=True) +class BaseModelInput: + """ + Base class for model input dataclasses. + - Frozen (immutable) + - Prevents nested BaseModelInput instances + - Provides transform utilities (like .to(device)) + """ + + _is_batched: bool = field(default=False, repr=False, compare=False) + + def __post_init__(self): + # Disallow nested BaseModelInput instances + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, BaseModelInput): + raise TypeError( + f"Field '{f.name}' cannot be another BaseModelInput " + f"({type(value).__name__}). Nesting is not allowed." + ) + + def _map_tensors(self, fn: callable): + """ + Internal helper: apply a function to all torch.Tensor fields. + Returns a new instance with transformed fields. + """ + import torch + + updates = {} + for f in fields(self): + val = getattr(self, f.name) + if isinstance(val, torch.Tensor): + updates[f.name] = fn(val) + elif isinstance(val, list): + # If it's a list of tensors, map them too + updates[f.name] = [ + fn(v) if isinstance(v, torch.Tensor) else v for v in val + ] + else: + updates[f.name] = val + return replace(self, **updates) + + def to(self, device: torch.device | str): + """Move all tensor fields to a given device.""" + return self._map_tensors(lambda t: t.to(device)) + + def cpu(self): + """Move all tensor fields to CPU.""" + return self._map_tensors(lambda t: t.cpu()) + + def cuda(self): + """Move all tensor fields to CUDA.""" + return self._map_tensors(lambda t: t.cuda()) + + def numpy(self): + """Convert all tensor fields to numpy arrays.""" + return self._map_tensors( + lambda t: t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t + ) + + @classmethod + def batch(cls: Type[T], instances: list[T]) -> T: + """ + Batch a list of BaseModelInput instances into a single instance. + - Tensor fields are stacked along dim=0. + - Non-tensor fields become lists. + """ + import torch + + if not instances: + raise ValueError("Cannot batch an empty list of inputs.") + if not all(isinstance(x, cls) for x in instances): + raise TypeError(f"All elements must be instances of {cls.__name__}.") + + field_values = {} + for f in fields(instances[0]): + if f.name.startswith("_"): + field_values[f.name] = getattr(instances[0], f.name) + continue + + vals = [getattr(x, f.name) for x in instances] + if vals[0] is None: # we assume if any value is None, all are None + field_values[f.name] = None + continue + + if all(isinstance(v, torch.Tensor) for v in vals): + field_values[f.name] = torch.stack(vals, dim=0) + else: + field_values[f.name] = vals + + return cls(**field_values) + + def __repr__(self) -> str: + """ + Generates a developer-friendly string representation of the object. + + Returns: + str: A developer-friendly string representation of the object. + """ + + import torch + from rich.pretty import pretty_repr + + torch.set_printoptions(edgeitems=2, threshold=100) + + return pretty_repr(self, max_length=4, max_string=128, max_depth=3) + + def __str__(self) -> str: + """ + Generates a human-readable string representation of the object. + + Returns: + str: A human-readable string representation of the object. + """ + + import torch + from rich.pretty import pretty_repr + + torch.set_printoptions(edgeitems=2, threshold=100) + + return pretty_repr(self, max_length=4, max_string=128, max_depth=3) + + +@dataclass(frozen=True) +class DocumentInstanceModelInput(BaseModelInput): + tokenizer_config: dict | None = None + + # token level fields + token_ids: "torch.Tensor" = None + token_bboxes: Optional["torch.Tensor"] = None + token_type_ids: Optional["torch.Tensor"] = None + token_labels: Optional["torch.Tensor"] = None + attention_mask: "torch.Tensor" = None + word_ids: "torch.Tensor" = None + sequence_ids: "torch.Tensor" = None + overflow_to_sample_mapping: "torch.Tensor" = None + + # segment level fields + segment_index: "torch.Tensor" = None + segment_inner_token_rank: "torch.Tensor" = None + first_token_idxes: "torch.Tensor" = None + first_token_idxes_mask: "torch.Tensor" = None + + # sample level fields + index: Optional["torch.Tensor"] = ( + None # index is used to uniquely identify a sample in a batch + ) + sample_id: str = None + image: Optional["torch.Tensor"] = None + label: Optional["torch.Tensor"] = None + words: list[str] = None + + # extractive QA specific fields + question_id: int | None = None + qa_question: str | None = None + qa_answers: list[str] | None = None + token_answer_start: Optional["torch.Tensor"] = None + token_answer_end: Optional["torch.Tensor"] = None + + def select_overflow_samples_by_id(self, is_random: bool = False): + import torch + + assert self._is_batched, ( + "select_all_overflow_samples can only be called on batched inputs." + ) + + def _gather_idx_from_sequence_list( + samples_batch: list[torch.Tensor], + ) -> torch.Tensor | None: + if samples_batch is None: + return None + + resolved_samples_batch = [] + for sample_data in samples_batch: + if len(sample_data) == 1: + resolved_samples_batch.append(sample_data[0]) + else: + idx = ( + 0 + if not is_random + else torch.randint(0, sample_data.shape[0], (1,)).item() + ) + resolved_samples_batch.append(sample_data[idx]) + return torch.stack(resolved_samples_batch) + + token_ids = _gather_idx_from_sequence_list(self.token_ids) + token_type_ids = _gather_idx_from_sequence_list(self.token_type_ids) + token_bboxes = _gather_idx_from_sequence_list(self.token_bboxes) + token_labels = _gather_idx_from_sequence_list(self.token_labels) + attention_mask = _gather_idx_from_sequence_list(self.attention_mask) + word_ids = _gather_idx_from_sequence_list(self.word_ids) + sequence_ids = _gather_idx_from_sequence_list(self.sequence_ids) + overflow_to_sample_mapping = _gather_idx_from_sequence_list( + self.overflow_to_sample_mapping + ) + + # segment level fields + segment_index = _gather_idx_from_sequence_list(self.segment_index) + segment_inner_token_rank = _gather_idx_from_sequence_list( + self.segment_inner_token_rank + ) + first_token_idxes = _gather_idx_from_sequence_list(self.first_token_idxes) + first_token_idxes_mask = _gather_idx_from_sequence_list( + self.first_token_idxes_mask + ) + + # sample level fields remain unchanged + token_answer_start, token_answer_end = None, None + if self.token_answer_start is not None: + token_answer_start = _gather_idx_from_sequence_list(self.token_answer_start) + token_answer_end = _gather_idx_from_sequence_list(self.token_answer_end) + + return replace( + self, + token_ids=token_ids, + token_type_ids=token_type_ids, + token_bboxes=token_bboxes, + token_labels=token_labels, + attention_mask=attention_mask, + word_ids=word_ids, + sequence_ids=sequence_ids, + overflow_to_sample_mapping=overflow_to_sample_mapping, + token_answer_start=token_answer_start, + token_answer_end=token_answer_end, + # segment level fields + segment_index=segment_index, + segment_inner_token_rank=segment_inner_token_rank, + first_token_idxes=first_token_idxes, + first_token_idxes_mask=first_token_idxes_mask, + # stack tensors + image=self.image if self.image is None else torch.stack(self.image), + label=self.label if self.label is None else torch.stack(self.label), + # index=self.index if self.index is None else torch.tensor(self.index), + ) + + def resolve_sample_overflow( + self, overflow_strategy: OverflowStrategy = OverflowStrategy.select_all + ) -> DocumentInstanceModelInput: + if not isinstance(self.token_ids, list): + # already resolved + return self + + if overflow_strategy == OverflowStrategy.select_all: + return self.select_all_overflow_samples() + elif overflow_strategy == OverflowStrategy.select_first: + return self.select_first_overflow_samples() + elif overflow_strategy == OverflowStrategy.select_random: + return self.select_random_overflow_samples() + else: + raise ValueError(f"Unknown overflow strategy: {overflow_strategy}") + + def select_first_overflow_samples(self): + return self.select_overflow_samples_by_id(is_random=False) + + def select_random_overflow_samples(self): + return self.select_overflow_samples_by_id(is_random=True) + + def select_all_overflow_samples(self) -> tuple[bool, list[int], list[str]]: + import torch + + assert self._is_batched, ( + "select_all_overflow_samples can only be called on batched inputs." + ) + repeat_indices = [sample.shape[0] for sample in self.token_ids] + + # we concatenate all lists of overflowed samples into a single tensor + def _cat_tensor_fields(samples_list: list[torch.Tensor]) -> torch.Tensor | None: + if samples_list is not None: + return torch.cat(samples_list, dim=0) + return None + + # these are all fields that are already in overflowed format + token_ids = _cat_tensor_fields(self.token_ids) + token_bboxes = _cat_tensor_fields(self.token_bboxes) + token_type_ids = _cat_tensor_fields(self.token_type_ids) + token_labels = _cat_tensor_fields(self.token_labels) + attention_mask = _cat_tensor_fields(self.attention_mask) + word_ids = _cat_tensor_fields(self.word_ids) + sequence_ids = _cat_tensor_fields(self.sequence_ids) + overflow_to_sample_mapping = _cat_tensor_fields(self.overflow_to_sample_mapping) + + # segment level fields + segment_index = _cat_tensor_fields(self.segment_index) + segment_inner_token_rank = _cat_tensor_fields(self.segment_inner_token_rank) + first_token_idxes = _cat_tensor_fields(self.first_token_idxes) + first_token_idxes_mask = _cat_tensor_fields(self.first_token_idxes_mask) + + token_answer_start, token_answer_end = None, None + if self.token_answer_start is not None and self.token_answer_end is not None: + token_answer_start = _cat_tensor_fields(self.token_answer_start) + token_answer_end = _cat_tensor_fields(self.token_answer_end) + + # these are fields that are at sample level and need to be repeated in case of overflow + # sample level fields + index = self._repeat_field(self.index, repeat_indices) + sample_id = self._repeat_field(self.sample_id, repeat_indices) + image = self._repeat_field(self.image, repeat_indices) + label = self._repeat_field(self.label, repeat_indices) + words = self._repeat_field(self.words, repeat_indices) + + # extractive QA specific fields + question_id = self._repeat_field(self.question_id, repeat_indices) + qa_question = self._repeat_field(self.qa_question, repeat_indices) + qa_answers = self._repeat_field(self.qa_answers, repeat_indices) + + repeated_instance = replace( + self, + token_ids=token_ids, + token_bboxes=token_bboxes, + token_type_ids=token_type_ids, + token_labels=token_labels, + attention_mask=attention_mask, + word_ids=word_ids, + sequence_ids=sequence_ids, + overflow_to_sample_mapping=overflow_to_sample_mapping, + # segment level fields + segment_index=segment_index, + segment_inner_token_rank=segment_inner_token_rank, + first_token_idxes=first_token_idxes, + first_token_idxes_mask=first_token_idxes_mask, + # sample level fields + index=index, + sample_id=sample_id, + image=image, + label=label, + words=words, + question_id=question_id, + qa_question=qa_question, + qa_answers=qa_answers, + token_answer_start=token_answer_start, + token_answer_end=token_answer_end, + ) + + for key, value in repeated_instance.to_dict().items(): + if isinstance(value, list) and len(value) != sum(repeat_indices): + raise ValueError( + f"Field '{key}' length {len(value)} does not match expected {sum(repeat_indices)}" + ) + if isinstance(value, torch.Tensor) and value.size(0) != sum(repeat_indices): + raise ValueError( + f"Field '{key}' size {value.size(0)} does not match expected {sum(repeat_indices)}" + ) + return repeated_instance + + def _repeat_field(self, field_value: Any, repeat_indices: list[int]) -> Any: + import torch + + if isinstance(field_value, list): + if len(field_value) == 0: + return field_value + if len(field_value) != len(repeat_indices): + raise ValueError( + f"List length ({len(field_value)}) doesn't match repeat_indices length ({len(repeat_indices)})" + ) + repeated_list = [ + item + for item, count in zip(field_value, repeat_indices, strict=True) + for _ in range(count) + ] + + if isinstance(field_value[0], torch.Tensor): + return torch.stack(repeated_list, dim=0) + return repeated_list + + elif isinstance(field_value, torch.Tensor): + if field_value.size(0) != len(repeat_indices): + raise ValueError( + f"Tensor batch size ({field_value.size(0)}) doesn't match repeat_indices length ({len(repeat_indices)})" + ) + return field_value.repeat_interleave( + torch.tensor(repeat_indices, device=field_value.device), dim=0 + ) + + return field_value + + @classmethod + def batch(cls: DocumentInstanceModelInput, instances: list[T]) -> T: + if not instances: + raise ValueError("Cannot batch an empty list of inputs.") + if not all(isinstance(x, cls) for x in instances): + raise TypeError(f"All elements must be instances of {cls.__name__}.") + + field_values = {} + for f in fields(instances[0]): + if f.name == "_is_batched": + field_values[f.name] = True + continue + if f.name == "tokenizer_config": + # For tokenizer_config, we take from the first instance + field_values[f.name] = getattr(instances[0], f.name) + continue + if f.name.startswith("_"): + field_values[f.name] = getattr(instances[0], f.name) + continue + + vals = [getattr(x, f.name) for x in instances] + if vals[0] is None: # we assume if any value is None, all are None + field_values[f.name] = None + continue + + # we simply put all fields in a list and batch them later + # for example we can have sequences like following due to overflow mapping + # seq 1 -> token ids of size (2, 512) + # seq 2 -> token ids of size (1, 512) + # seq 3 -> token ids of size (4, 512) + field_values[f.name] = vals + + return cls(**field_values) + + def print_info(self): + import torch + + print("DocumentInstanceModelInput:") + for f in fields(self): + val = getattr(self, f.name) + if isinstance(val, torch.Tensor): + print(f" {f.name}: Tensor shape {val.shape}, dtype {val.dtype}") + elif isinstance(val, list): + if len(val) > 0 and isinstance(val[0], torch.Tensor): + shapes = [v.shape for v in val] + print(f" {f.name}: List of Tensors with shapes {shapes}") + else: + print(f" {f.name}: List of length {len(val)}") + else: + print(f" {f.name}: {type(val).__name__} value: {val}") + + def to_dict(self): + import torch + + result = {} + for f in fields(self): + val = getattr(self, f.name) + if isinstance(val, torch.Tensor): + result[f.name] = val.detach().cpu().numpy() + elif isinstance(val, list): + if len(val) > 0 and isinstance(val[0], torch.Tensor): + result[f.name] = [v.detach().cpu().numpy() for v in val] + else: + result[f.name] = val + else: + result[f.name] = val + + return result + + @classmethod + def from_dict(cls: Type[T], data: dict[str, Any]) -> T: + import numpy as np + import torch + + for key, value in data.items(): + if isinstance(value, np.ndarray): + data[key] = torch.tensor(value) + elif ( + isinstance(value, list) + and len(value) > 0 + and isinstance(value[0], np.ndarray) + ): + data[key] = [torch.tensor(v) for v in value] + else: + data[key] = value + + return cls(**data) + + +@dataclass(frozen=True) +class ConditionalGenerationModelInput(BaseModelInput): + index: Optional["torch.Tensor"] = None + sample_id: Optional[str] = None + input_ids: Optional["torch.Tensor"] = None + bbox: Optional["torch.Tensor"] = None + attention_mask: Optional["torch.Tensor"] = None + pixel_values: Optional["torch.Tensor"] = None + question_text: Optional[str] = None + target_text: Optional[str] = None + target_token_ids: Optional["torch.Tensor"] = None + words: Optional[list[str]] = None + word_labels: Optional[list[str]] = None + label: Optional["torch.Tensor"] = None + + +@dataclass(frozen=True) +class VLMModelInput(BaseModelInput): + index: Optional["torch.Tensor"] = None + sample_id: Optional[str] = None + input_ids: Optional["torch.Tensor"] = None + bbox: Optional["torch.Tensor"] = None + attention_mask: Optional["torch.Tensor"] = None + pixel_values: Optional["torch.Tensor"] = None + image_grid_thw: Optional["torch.Tensor"] = None + question_text: Optional[str] = None + target_text: Optional[str] = None + target_token_ids: Optional["torch.Tensor"] = None + words: Optional[list[str]] = None + word_labels: Optional[list[str]] = None + label: Optional["torch.Tensor"] = None diff --git a/docgenie/data/_core/_dataset.py b/docgenie/data/_core/_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..6b45751666c5071adefd52350fb41fa4b1c3df05 --- /dev/null +++ b/docgenie/data/_core/_dataset.py @@ -0,0 +1,114 @@ +""" +A simple dataset class that holds multiple split iterators. +""" + +from __future__ import annotations + +from typing import TypeVar + +from atria_core.types import DatasetMetadata +from atria_core.types.common import DatasetSplitType +from atria_core.types.data_instance.base import ( + BaseDataInstance, +) +from atria_core.utilities.repr import RepresentationMixin + +from docgenie.logging import get_logger + +from ._msgpack_dataset_reader import MsgpackDatasetReader +from ._utilities import TaskType + +logger = get_logger(__name__) + + +T_BaseDataInstance = TypeVar("T_BaseDataInstance", bound=BaseDataInstance) + + +class Dataset(RepresentationMixin): + def __init__( + self, + name: str, + split_iterators: dict, + metadata: DatasetMetadata, + task_type: TaskType, + ) -> None: + self._name = name + self._split_iterators: dict[DatasetSplitType, MsgpackDatasetReader] = ( + split_iterators + ) + self._metadata = metadata + self._task_type = task_type + + @property + def name(self) -> str: + """Dataset name.""" + return self._name + + @property + def task_type(self) -> TaskType: + """Dataset task type.""" + return self._task_type + + @task_type.setter + def task_type(self, value: TaskType) -> None: + self._task_type = value + + @property + def split_iterators( + self, + ) -> dict[DatasetSplitType, MsgpackDatasetReader]: + """Dictionary of split iterators.""" + return self._split_iterators + + @property + def train(self) -> MsgpackDatasetReader | None: + """Training split iterator. Returns None if training split is not available.""" + return self._split_iterators.get(DatasetSplitType.train, None) + + @property + def validation(self) -> MsgpackDatasetReader | None: + """Validation split iterator. Returns None if validation split is not available.""" + return self._split_iterators.get(DatasetSplitType.validation, None) + + @property + def test(self) -> MsgpackDatasetReader | None: + """Test split iterator. Returns None if test split is not available.""" + return self._split_iterators.get(DatasetSplitType.test, None) + + @train.setter + def train(self, value: MsgpackDatasetReader) -> None: + self._split_iterators[DatasetSplitType.train] = value + + @validation.setter + def validation(self, value: MsgpackDatasetReader) -> None: + self._split_iterators[DatasetSplitType.validation] = value + + @test.setter + def test(self, value: MsgpackDatasetReader) -> None: + self._split_iterators[DatasetSplitType.test] = value + + @property + def metadata(self) -> DatasetMetadata: + """Dataset metadata.""" + return self._metadata + + @property + def train_size(self) -> int: + """Length of the training split. Returns 0 if training split is not available.""" + if self.train is None: + return 0 + return len(self.train) + + @property + def validation_size(self) -> int: + """Length of the validation split. Returns 0 if validation split is not available.""" + if self.validation is None: + return 0 + return len(self.validation) + + @property + def test_size(self) -> int: + """Length of the test split. Returns 0 if test split is not available.""" + if self.test is None: + return 0 + return len(self.test) diff --git a/docgenie/data/_core/_dataset_factory.py b/docgenie/data/_core/_dataset_factory.py new file mode 100755 index 0000000000000000000000000000000000000000..ea8837e933e19abfc3bc5bb1cd3316fd46e4e96e --- /dev/null +++ b/docgenie/data/_core/_dataset_factory.py @@ -0,0 +1,177 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Callable + +import yaml +from atria_core.types import DatasetMetadata +from atria_core.types.common import DatasetSplitType +from atria_core.types.data_instance.base import ( + BaseDataInstance, +) +from atria_core.types.data_instance.document_instance import ( + DocumentInstance, +) + +from docgenie.data.constants import DatasetLoadConfig +from docgenie.logging import get_logger + +from ._dataset import Dataset +from ._msgpack_dataset_reader import ( + MsgpackDatasetReader, +) + +logger = get_logger(__name__) + + +class DatasetFactory: + """ + Factory class for creating and loading datasets from msgpack shard files. + + The DatasetFactory provides a centralized way to load datasets stored in a specific + directory structure with msgpack format. It automatically discovers available datasets + and configurations, validates paths, and creates appropriate data iterators for each split. + + Expected Directory Structure: + root_datasets_dir/ + └── dataset_name/ + └── storage/ + └── dataset_config_name/ + └── msgpack/ + ├── train/ + │ ├── shard_001.msgpack + │ └── shard_002.msgpack + ├── validation/ + │ └── shard_001.msgpack + └── test/ + └── shard_001.msgpack + + Usage: + # Basic usage with default DocumentInstance data model + dataset = DatasetFactory.load_dataset( + root_datasets_dir="/path/to/datasets", + dataset_name="my_dataset", + dataset_config_name="default" + + # With custom data model and output transformation + dataset = DatasetFactory.load_dataset( + root_datasets_dir="/path/to/datasets", + dataset_name="my_dataset", + dataset_config_name="processed", + data_model=CustomDataInstance, + output_transform=lambda x: preprocess(x) + + # Access splits + for sample in dataset.train: + # Process training samples + pass + + The factory handles: + - Automatic discovery of dataset splits (train, validation, test, etc.) + - Loading msgpack shard files for each split + - Data model instantiation and transformation + - Error handling with helpful messages about available datasets/configs + """ + + @classmethod + def get_preprocess_transform(self, preprocess_image_size: int) -> Callable: + def preprocess_transform(sample: BaseDataInstance) -> dict: + resized_image = sample.image.resize( + width=preprocess_image_size, height=preprocess_image_size + ) + return sample.model_copy(update={"image": resized_image}) + + return preprocess_transform + + @classmethod + def prepare_paths( + cls, root_datasets_dir: str | Path, dataset_name: str, dataset_config_name: str + ): + # construct paths + data_dir = Path(root_datasets_dir) / dataset_name / "storage" + metadata_file = data_dir / "metadata.yaml" + msgpack_dir = data_dir / dataset_config_name / "msgpack" + + if not data_dir.exists(): + raise ValueError( + f"Data directory {data_dir} does not exist. " + f"Please check the dataset {dataset_name} is prepared with config name {dataset_config_name}. " + ) + + assert metadata_file.exists(), f"Metadata file {metadata_file} does not exist. " + assert msgpack_dir.exists(), f"Data directory {msgpack_dir} does not exist. " + return metadata_file, msgpack_dir + + @classmethod + def load_metadata( + cls, + metadata_file: str | Path, + ) -> DatasetMetadata: + # load metadata + with open(metadata_file, "r") as f: + metadata = yaml.safe_load(f) + return DatasetMetadata(**metadata) + + @classmethod + def get_available_splits(cls, msgpack_dir: Path): + available_splits = [DatasetSplitType(x) for x in os.listdir(msgpack_dir)] + assert len(available_splits) > 0, ( + f"No splits found in {msgpack_dir}. Found {available_splits}" + ) + return available_splits + + @classmethod + def load_split_from_disk( + cls, + msgpack_dir: Path, + split: DatasetSplitType, + data_model: type[BaseDataInstance], + ) -> MsgpackDatasetReader: + # load msgpack files for this split + split_files = list((msgpack_dir / split.value).glob("*.msgpack")) + return MsgpackDatasetReader(msgpack_files=split_files, data_model=data_model) + + @classmethod + def load_dataset( + cls, + dataset_load_config: DatasetLoadConfig, + data_model: type[BaseDataInstance] = DocumentInstance, + split: str | None = None, + ) -> Dataset: + # get dataset name and config name + dataset_name, dataset_config_name = ( + dataset_load_config.dataset_name, + dataset_load_config.dataset_config_name, + ) + + # handle tuple config names + if isinstance(dataset_config_name, tuple): + dataset_name, dataset_config_name = dataset_config_name + + # construct paths + metadata_file, msgpack_dir = cls.prepare_paths( + dataset_load_config.root_datasets_dir, dataset_name, dataset_config_name + ) + + # load metadata + metadata = cls.load_metadata(metadata_file) + + # load split files + available_splits = cls.get_available_splits(msgpack_dir) + + # load split iterators + split_iterators = {} + for current_split in available_splits: + if split is not None and current_split.value != split: + continue + split_iterators[current_split] = cls.load_split_from_disk( + msgpack_dir, current_split, data_model + ) + + return Dataset( + name=dataset_name, + split_iterators=split_iterators, + metadata=metadata, + task_type=dataset_load_config.task_type, + ) diff --git a/docgenie/data/_core/_msgpack_dataset_reader.py b/docgenie/data/_core/_msgpack_dataset_reader.py new file mode 100755 index 0000000000000000000000000000000000000000..ae9f332fa5a92090231b2f2047dff21b6827767c --- /dev/null +++ b/docgenie/data/_core/_msgpack_dataset_reader.py @@ -0,0 +1,174 @@ +""" +Msgpack shard list dataset module taken from atria_datasets +""" + +from collections.abc import Sequence +from pathlib import Path +from typing import Any, Callable, TypeVar + +import numpy as np +from atria_core.types import BaseDataInstance +from datadings.reader import MsgpackReader as MsgpackFileReader + +from docgenie.data._core._data_types import DocumentInstanceModelInput +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +T_BaseDataInstance = TypeVar("T_BaseDataInstance", bound=BaseDataInstance) + + +class MsgpackDatasetReader(Sequence[Any]): + """ + A dataset class for reading Msgpack-based shard files. + + This class provides functionality for loading and iterating over datasets stored + in Msgpack-based shard files. It supports efficient indexing and cumulative size + calculations for handling multiple shards. + + Attributes: + _shard_files (list[str]): A list of Msgpack file path for each shard. + _cumulative_sizes (list[int]): Cumulative sizes of the shards for efficient indexing. + _total_size (int): The total number of samples across all shards. + """ + + def __init__( + self, + msgpack_files: list[str] | list[Path], + data_model: type, + transform: Callable | None = None, + ) -> None: + """ + Initializes the `MsgpackShardListDataset`. + + Args: + shard_files (List[DatasetShardInfo]): A list of shard metadata containing file URLs. + """ + logger.info(f"Loading dataset from files: {msgpack_files}") + self._msgpack_files = sorted(msgpack_files) + self._total_size: int = 0 + + cumulative_sizes: list[int] = [] + for f in self._msgpack_files: + with MsgpackFileReader(f) as reader: + self._total_size += len(reader) + cumulative_sizes.append(self._total_size) + self._cumulative_sizes = np.array(cumulative_sizes) + + self._data_model = data_model + self._transform = transform + self._subset_indices = None + self._msgpack_file_readers = [MsgpackFileReader(f) for f in self._msgpack_files] + self._data_dir = Path(self._msgpack_files[0]).parent + + @property + def data_dir(self) -> Path: + return self._data_dir + + def set_subset_indices(self, indices: list[int]) -> None: + """ + Sets the subset indices for the dataset. + + Args: + indices (List[int]): A list of indices to subset the dataset. + """ + self._subset_indices = indices + + def set_transform(self, transform: Callable) -> None: + """ + Sets the transform function for the dataset. + + Args: + transform (Callable): A function to transform each data instance. + """ + self._transform = transform + + def _transform_input(self, input: Any) -> BaseDataInstance: + if issubclass(self._data_model, BaseDataInstance): + if "total_num_pages" in input: + input.pop("total_num_pages") + data_instance: BaseDataInstance = self._data_model.model_validate(input) + + # assert that the transformed instance is of the expected data model type + assert isinstance(data_instance, self._data_model), ( + f"self._input_transform(sample) should return {self._data_model}, but got {type(data_instance)}" + ) + + # load the data instance from disk if not already loaded + data_instance.load() + + # yield the transformed data instance if output transform is enabled + if self._transform is not None: + data_instance = self._transform(data_instance) + return data_instance + elif issubclass(self._data_model, DocumentInstanceModelInput): + data_instance = self._data_model.from_dict(input) + if self._transform is not None: + data_instance = self._transform(data_instance) + return data_instance + else: + raise ValueError( + f"Unsupported data model type: {self._data_model}. Must be a subclass of BaseDataInstance or DocumentInstanceModelInput." + ) + + def get_by_id(self, sample_id: str) -> int: + for reader in self._msgpack_file_readers: + try: + sample_id = str(sample_id) + index = reader.find_index(sample_id.replace(".", "_")) + sample = reader[index] + sample.pop("key", None) + sample = self._transform_input(sample) + assert sample.sample_id == sample_id, ( # this should never happen + f"Sample ID mismatch: expected {sample_id} ({type(sample_id)}), got {sample.sample_id} ({type(sample.sample_id)})" + ) + return sample + except KeyError: + continue + raise ValueError(f"Sample ID {sample_id} not found in any shard.") + + def __getitem__(self, index: int) -> dict[str, Any]: # type: ignore[override] + """ + Retrieves a sample from the dataset by index. + + Args: + index (int): The index of the sample to retrieve. + + Returns: + Dict[str, Any]: The sample at the specified index. + """ + if self._subset_indices is not None: + index = self._subset_indices[index] + + shard_index = np.searchsorted(self._cumulative_sizes, index, side="right") + if shard_index == 0: + inner_index = index + else: + inner_index = index - self._cumulative_sizes[shard_index - 1] + sample = self._msgpack_file_readers[shard_index][inner_index] + sample.pop("key", None) + return self._transform_input(sample) + + def __len__(self) -> int: + """ + Returns the total number of samples in the dataset. + + Returns: + int: The total number of samples. + """ + if self._subset_indices is not None: + return len(self._subset_indices) + return self._total_size + + def close(self) -> None: + """ + Closes all shard file readers to release resources. + """ + for reader in self._msgpack_file_readers: + reader._close() + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}, " + f"total_size={self._total_size}, num_shards={len(self._msgpack_files)})" + ) diff --git a/docgenie/data/_core/_msgpack_dataset_writer.py b/docgenie/data/_core/_msgpack_dataset_writer.py new file mode 100755 index 0000000000000000000000000000000000000000..016c04c5023458c69b127d02ea12c20c5e92b8ef --- /dev/null +++ b/docgenie/data/_core/_msgpack_dataset_writer.py @@ -0,0 +1,108 @@ +""" +Defines interface for docgenie components to load datasets using DatasetFactory and log relevant information. +""" + +from __future__ import annotations + +from pathlib import Path + +import tqdm +from torch.utils.data import Dataset + +from docgenie.data._core._msgpack_dataset_reader import MsgpackDatasetReader +from docgenie.logging import get_logger + +from ._data_types import BaseDataInstance, DocumentInstance + +logger = get_logger(__name__) + + +class MsgpackDatasetWriter: + def __init__( + self, + dataset_reader: MsgpackDatasetReader | Dataset, + output_file: Path, + data_model: type | type[BaseDataInstance] = DocumentInstance, + ): + self._dataset_reader = dataset_reader + self._output_file = output_file + self._data_model = data_model + + def _get_dataloader(self): + import torch + + # setup dataloader + dataloader = torch.utils.data.DataLoader( + self._dataset_reader, + batch_size=16, + shuffle=False, + num_workers=0, + collate_fn=lambda x: x, + drop_last=False, + ) + return dataloader + + def write(self, force_overwrite: bool = False) -> MsgpackDatasetReader: + if force_overwrite: + logger.warning( + f"Force overwrite is enabled. Existing file at {self._output_file} will be deleted if it exists." + ) + self._output_file.unlink(missing_ok=True) + + if not self._output_file.exists(): + self._write() + return self.read() + + def read(self): + return MsgpackDatasetReader( + msgpack_files=[str(self._output_file)], + data_model=self._data_model, + ) + + def _write(self): + from datadings.writer import FileWriter + + try: + dataloader = self._get_dataloader() + total_sample = len(self._dataset_reader) + self._output_file.parent.mkdir(parents=True, exist_ok=True) + with FileWriter( + self._output_file, + overwrite=True, + ) as writer: + for batch in tqdm.tqdm( + dataloader, + desc=f"Preprocessing dataset to {self._output_file} with total samples {total_sample}", + ): + for sample_or_sample_list in batch: + sample_list = ( + [sample_or_sample_list] + if not isinstance(sample_or_sample_list, list) + else sample_or_sample_list + ) + for sample in sample_list: + try: + sample_dict = ( + sample.to_dict() + if hasattr(sample, "to_dict") + else sample.model_dump() + ) + writer.write( + { + "key": sample.sample_id, + **sample_dict, + } + ) + except ValueError as e: + logger.error( + f"[WriteError] Failed to write sample '{getattr(sample, 'sample_id', 'unknown')}': {e}" + ) + continue + except Exception as e: + logger.error(f"Error while writing preprocessed data: {e}") + self._output_file.unlink(missing_ok=True) + raise e + except KeyboardInterrupt as e: + logger.error("Preprocessing interrupted by user.") + self._output_file.unlink(missing_ok=True) + raise e diff --git a/docgenie/data/_core/_standard_splitter.py b/docgenie/data/_core/_standard_splitter.py new file mode 100755 index 0000000000000000000000000000000000000000..e9b2a41338ae63f58f9e02b7d27dc1754d68d291 --- /dev/null +++ b/docgenie/data/_core/_standard_splitter.py @@ -0,0 +1,140 @@ +""" +Dataset Splitter Module + +This module defines the `StandardSplitter` class, which provides utilities for splitting +datasets into training and validation subsets. It supports both sequential and random +splitting strategies, with configurable options for shuffle, and split ratio. + +Classes: + - StandardSplitter: A class for splitting datasets into training and validation subsets. + +Dependencies: + - copy: For deep copying datasets. + - typing: For type annotations. + - torch.utils.data: For dataset splitting utilities. + - atria_core.logger.logger: For logging utilities. + - atria_registry: For registering dataset splitters. + - atria_datasets.core.datasets.atria_dataset: For the base dataset class. + +Author: Your Name (your.email@example.com) +Date: 2025-04-07 +Version: 1.0.0 +License: MIT +""" + +from atria_core.utilities.repr import RepresentationMixin + +from ._msgpack_dataset_reader import MsgpackDatasetReader + + +class StandardSplitter(RepresentationMixin): + """ + A class for splitting datasets into training and validation subsets. + + This class provides methods for creating sequential and random splits of datasets. + It supports configurable options for shuffle, and split ratio. + + Attributes: + split_ratio (float): The ratio of the training split. Defaults to 0.8. + shuffle (bool): Whether to shuffle the dataset before splitting. Defaults to True. + """ + + def __init__(self, split_ratio: float = 0.8, shuffle: bool = True): + """ + Initializes the `StandardSplitter`. + + Args: + split_ratio (float): The ratio of the training split. Defaults to 0.8. + shuffle (bool): Whether to shuffle the dataset before splitting. Defaults to True. + """ + self.split_ratio = split_ratio + self.shuffle = shuffle + + def create_sequential_split( + self, train: "MsgpackDatasetReader" + ) -> tuple["MsgpackDatasetReader", "MsgpackDatasetReader"]: + """ + Creates a sequential split of the dataset. + + The dataset is split into training and validation subsets based on the split ratio, + without shuffling. + + Args: + train_dataset (AtriaDataset): The dataset to split. + + Returns: + Tuple[AtriaDataset, AtriaDataset]: The training and validation subsets. + """ + import copy + + dataset_size = len(train) + split_point = int(dataset_size * round(self.split_ratio, 2)) + + validation = copy.deepcopy(train) + train.set_subset_indices(list(range(split_point))) + validation.set_subset_indices(list(range(split_point))) + return train, validation + + def create_random_split( + self, train: "MsgpackDatasetReader" + ) -> tuple["MsgpackDatasetReader", "MsgpackDatasetReader"]: + """ + Creates a random split of the dataset. + + The dataset is split into training and validation subsets based on the split ratio, + with shuffling. + + Args: + train_dataset (AtriaDataset): The dataset to split. + + Returns: + Tuple[AtriaDataset, AtriaDataset]: The training and validation subsets. + """ + import copy + + from sklearn.model_selection import train_test_split + + assert train is not None, ( + "The dataset must have a 'train' split defined for sequential splitting." + ) + + train_dataset_size = len(train) + validation = copy.deepcopy(train) + train_subset, validation_subset = train_test_split( + list(range(train_dataset_size)), + test_size=1 - self.split_ratio, + random_state=42, + ) + train.set_subset_indices(list(train_subset)) + validation.set_subset_indices(list(validation_subset)) + return train, validation + + def __call__( + self, train_split: "MsgpackDatasetReader" + ) -> tuple["MsgpackDatasetReader", "MsgpackDatasetReader"]: + """ + Splits the dataset into training and validation subsets. + + The splitting strategy (sequential or random) is determined by the `shuffle` attribute. + + Args: + train_dataset (AtriaDataset): The dataset to split. + + Returns: + Tuple[AtriaDataset, AtriaDataset]: The training and validation subsets. + + Raises: + AssertionError: If the dataset is not an instance of `AtriaDataset` or if the + dataset size is unknown (e.g., in iterable mode). + """ + assert isinstance(train_split, MsgpackDatasetReader), ( + "The dataset must be a PyTorch or Hugging Face dataset." + ) + assert len(train_split) != "unknown", ( + "The dataset size is unknown. This means that the dataset is set up " + "in iterable mode and splitting is not supported." + ) + if self.shuffle: + return self.create_random_split(train_split) + else: + return self.create_sequential_split(train_split) diff --git a/docgenie/data/_core/_synth.py b/docgenie/data/_core/_synth.py new file mode 100755 index 0000000000000000000000000000000000000000..73286a103956a48377d264819965313af0b2b65e --- /dev/null +++ b/docgenie/data/_core/_synth.py @@ -0,0 +1,589 @@ +import json + +import cv2 +import fitz +import numpy as np +import textdistance as td +import tqdm +from PIL import Image as PILImageLoader +from torch.utils.data import Dataset + +from docgenie.generation.constants import IMAGE_RENDER_EXT +from docgenie.generation.models import ( + SynDatasetDefinition, + SyntheticDatasetFileStructure, +) +from docgenie.generation.models._consts import DatasetTask +from docgenie.generation.models._log import SynDocumentLog +from docgenie.generation.utils.bboxes import read_syn_dataset_bboxes +from docgenie.logging import get_logger + +from ._data_types import ( + AnnotatedObject, + AnnotatedObjectList, + BoundingBox, + BoundingBoxList, + ClassificationAnnotation, + DocumentContent, + DocumentInstance, + EntityLabelingAnnotation, + ExtractiveQAAnnotation, + ExtractiveQAPair, + Image, + Label, + LabelList, + LayoutAnalysisAnnotation, +) +from ._utilities import TaskType + +logger = get_logger(__name__) + + +def _compute_anls( + predictions: list[list[str]], gold_labels: list[list[str]], tau=0.5, rank=0 +): + res = [] + for i, (preds, golds) in enumerate(zip(predictions, gold_labels)): + max_s = 0 + for pred in preds: + for gold in golds: + dis = td.levenshtein.distance(pred.lower(), gold.lower()) + max_len = max(len(pred), len(gold)) + if max_len == 0: + s = 0 + else: + nl = dis / max_len + s = 1 - nl if nl < tau else 0 + max_s = max(s, max_s) + res.append(max_s) + return res, sum(res) / len(res) + + +def _compute_iou(box1, box2): + """Compute IoU between two bounding boxes in format [x1, y1, x2, y2]""" + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + if x2 <= x1 or y2 <= y1: + return 0.0 + + intersection = (x2 - x1) * (y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0.0 + + +def _foreground_bbox_clip( + image, + bboxes, + coords_are_inclusive=True, + min_area=10, + morph_kernel_size=3, + debug=False, + unnormalize=True, +) -> list: + if image is None: + raise ValueError("Image is None") + + gray = image if image.ndim == 2 else cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) + H, W = gray.shape + + refined = [] + debug_vis = image.copy() + + for i, box in enumerate(bboxes): + x1, y1, x2, y2 = box + + # Handle normalized input + if unnormalize: + x1, y1, x2, y2 = x1 * W, y1 * H, x2 * W, y2 * H + + # Convert to ints + x1, y1, x2, y2 = map(lambda v: int(round(v)), (x1, y1, x2, y2)) + + if coords_are_inclusive: + x2_slice, y2_slice = x2 + 1, y2 + 1 + else: + x2_slice, y2_slice = x2, y2 + + # Clip to image boundaries + x1c, y1c = max(0, min(W - 1, x1)), max(0, min(H - 1, y1)) + x2c, y2c = max(0, min(W, x2_slice)), max(0, min(H, y2_slice)) + + if x2c <= x1c or y2c <= y1c: + refined.append([x1c, y1c, x2c, y2c]) + continue + + crop = gray[y1c:y2c, x1c:x2c] + blur = cv2.GaussianBlur(crop, (5, 5), 0) + + mean_val = float(np.mean(blur)) + invert = mean_val > 127 + + # Apply Otsu threshold + if invert: + _, mask = cv2.threshold( + blur, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU + ) + else: + _, mask = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) + + # ---- REMOVE HORIZONTAL LINES ---- + # Tune these values depending on your document scale + # horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (mask.shape[1] // 8, 1)) + # detect_horizontal = cv2.morphologyEx(mask, cv2.MORPH_OPEN, horizontal_kernel, iterations=1) + + # Subtract detected lines from the mask + # mask = cv2.subtract(mask, detect_horizontal) + + # (Optional) Also remove very thin components (height < 3 px) + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + clean_mask = np.zeros_like(mask) + for i in range(1, num_labels): + x, y, w, h, area = stats[i] + if h > 3: # ignore 1–2 pixel tall components (likely lines) + clean_mask[labels == i] = 255 + mask = clean_mask + + # plt.figure(figsize=(12, 12)) + # plt.imshow(mask, cmap='gray') + # plt.axis('off') + # plt.show() + + # Morphological closing + if morph_kernel_size and morph_kernel_size > 1: + kernel = cv2.getStructuringElement( + cv2.MORPH_RECT, (morph_kernel_size, morph_kernel_size) + ) + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=1) + + # plt.figure(figsize=(12, 12)) + # plt.imshow(mask, cmap='gray') + # plt.axis('off') + # plt.show() + + # Remove small noise components + n_labels, labels, stats, _ = cv2.connectedComponentsWithStats( + mask, connectivity=8 + ) + keep_mask = np.zeros_like(mask, dtype=np.uint8) + + for label in range(1, n_labels): + if stats[label, cv2.CC_STAT_AREA] >= min_area: + keep_mask[labels == label] = 255 + + # If no foreground remains, keep original box + if np.count_nonzero(keep_mask) == 0: + refined.append([x1c, y1c, x2c, y2c]) + continue + + # Find tight bounds + ys, xs = np.where(keep_mask > 0) + y_min_local, y_max_local = int(ys.min()), int(ys.max()) + x_min_local, x_max_local = int(xs.min()), int(xs.max()) + + new_x1, new_y1 = x1c + x_min_local, y1c + y_min_local + new_x2, new_y2 = x1c + x_max_local, y1c + y_max_local + + new_x1, new_y1 = max(0, new_x1), max(0, new_y1) + new_x2, new_y2 = min(W - 1, new_x2), min(H - 1, new_y2) + + refined.append([new_x1, new_y1, new_x2, new_y2]) + + # --- Debug Visualization --- + if debug: + # Overlay mask in red channel + overlay = debug_vis.copy() + colored_mask = cv2.cvtColor(keep_mask, cv2.COLOR_GRAY2BGR) + colored_mask = cv2.resize(colored_mask, (x2c - x1c, y2c - y1c)) + overlay[y1c:y2c, x1c:x2c, 2] = np.maximum( + overlay[y1c:y2c, x1c:x2c, 2], colored_mask[:, :, 2] + ) + + debug_vis = cv2.addWeighted(debug_vis, 0.7, overlay, 0.3, 0) + + # Draw original bbox (yellow) and new bbox (green) + cv2.rectangle(debug_vis, (x1, y1), (x2, y2), (0, 255, 255), 1) + cv2.rectangle(debug_vis, (new_x1, new_y1), (new_x2, new_y2), (0, 255, 0), 2) + + # Label with index + cv2.putText( + debug_vis, + f"{i}", + (x1, max(10, y1 - 5)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + (255, 0, 255), + 1, + cv2.LINE_AA, + ) + + if debug: + return refined, debug_vis + return refined + + +class SynthesizedDataset(Dataset): + def __init__( + self, + dsdef: SynDatasetDefinition, + task_type: TaskType, + dataset_labels: list[str], + resize_images: bool = False, + clip_bboxes_to_foreground: bool = False, + ): + self.dataset_labels = dataset_labels + self.data = self._load_your_synthesized_data(dsdef) + self.task_type = task_type + self.resize_images = resize_images + self.clip_bboxes_to_foreground = clip_bboxes_to_foreground + + # remap dataset labels if cord + if dsdef.name.startswith("cord"): + self.dataset_labels = [x.replace(".", "_") for x in self.dataset_labels] + if dsdef.name.startswith("publaynet"): + self.dataset_labels = ["LE-" + x.upper() for x in self.dataset_labels] + if dsdef.name.startswith("doclaynet") and task_type == TaskType.layout_analysis: + self.dataset_labels = ["LE-" + x.upper() for x in self.dataset_labels] + if dsdef.name.startswith("icdar2019"): + self.dataset_labels = ["LE-" + x.upper() for x in self.dataset_labels] + if dsdef.name.startswith("tobacco3482"): + self.dataset_labels = [x.upper() for x in self.dataset_labels] + self.dataset_labels[self.dataset_labels.index("NEWS")] = "NEWS_ARTICLE" + self.dataset_labels[self.dataset_labels.index("ADVE")] = "ADVERTISEMENT" + + def _load_qa_gt(self, annotations: dict) -> dict: + qa_annotations = [] + for i, a in enumerate(annotations): + # if no answer is found we remove the sample + if len(a["answer_bbox_indices"]) == 0: + logger.warning( + f"No answer found for question id {i} in synthesized data. Skipping annotation." + ) + continue + + qa_annotation = { + "question_id": i, + "question": a["question"], + "answer_text": [a["answer"]], + "answer_start_indices": [a["answer_bbox_indices"][0]], + "answer_end_indices": [a["answer_bbox_indices"][-1]], + } + qa_annotations.append(qa_annotation) + + return {"qa_annotations": qa_annotations} + + def _load_kie_as_qa_gt( + self, annotations: dict, dsdef: SynDatasetDefinition + ) -> dict: + assert dsdef.prompt_task == "json", ( + "Modelling KIE tasks as QA in dataloader not implemented for annotation-type KIE." + ) + qa_annotations = [] + for i, a in enumerate(annotations["entities"]): + # if no answer is found we remove the sample + if len(a["bbox_indices"]) == 0: + logger.warning( + f"No answer found for KIE (modelled as QA) question id {i} in synthesized data. Skipping sample." + ) + continue + + qa_annotation = { + "question_id": i, + "question": a["key"], + "answer_text": [a["value"]], + "answer_start_indices": [a["bbox_indices"][0]], + "answer_end_indices": [a["bbox_indices"][-1]], + } + qa_annotations.append(qa_annotation) + + return {"qa_annotations": qa_annotations} + + def _load_classification_gt(self, annotations: dict) -> dict: + assert len(annotations) == 1 + return annotations # is already in correct format: {"label": "FORM"} + + def _load_kie_as_qa_gt( + self, annotations: dict, dsdef: SynDatasetDefinition + ) -> dict: + assert dsdef.prompt_task == "json", ( + "Modelling KIE tasks as QA in dataloader not implemented for annotation-type KIE." + ) + qa_annotations = [] + for i, a in enumerate(annotations["entities"]): + # if no answer is found we remove the sample + if len(a["bbox_indices"]) == 0: + logger.warning( + f"No answer found for KIE (modelled as QA) question id {i} in synthesized data. Skipping sample." + ) + continue + + qa_annotation = { + "question_id": i, + "question": a["key"], + "answer_text": [a["value"]], + "answer_start_indices": [a["bbox_indices"][0]], + "answer_end_indices": [a["bbox_indices"][-1]], + } + qa_annotations.append(qa_annotation) + + return {"qa_annotations": qa_annotations} + + def _load_kie_gt(self, annotations: dict) -> dict: + return {"word_labels": annotations["word_labels"]} + + def _load_dla_gt(self, annotations: dict) -> dict: + dla_annotations = [] + for i, a in enumerate(annotations): + dla_annotation = { + "label": a["label"], + "bbox": [a["x0"], a["y0"], a["x2"], a["y2"]], # already normalized + } + dla_annotations.append(dla_annotation) + + return {"annotations": dla_annotations} + + def _load_your_synthesized_data(self, dsdef: SynDatasetDefinition) -> list[dict]: + dsfiles: SyntheticDatasetFileStructure = dsdef.get_file_structure() + dslog_path = dsfiles.base_path / "dataset_log.json" + dslog: dict = json.loads(dslog_path.read_text(encoding="utf-8")) + valid_samples = dslog["valid_samples"]["items"] + + samples = list() + for docid in tqdm.tqdm( + valid_samples, desc="Loading synthesized dataset samples" + ): + doclog = SynDocumentLog( + document_id=docid, logdir=dsfiles.document_logs_directory + ) + + annotations_path = dsfiles.gt_directory / f"{docid}.json" + annotations = json.loads(annotations_path.read_text(encoding="utf-8")) + + sample_annotations = None + match dsdef.task: + case DatasetTask.QA.value: + sample_annotations = self._load_qa_gt(annotations=annotations) + case DatasetTask.CLASSIFICATION.value: + sample_annotations = self._load_classification_gt( + annotations=annotations + ) + case DatasetTask.KIE.value: + if dsdef.dataloader_model_task_as == DatasetTask.QA.value: + sample_annotations = self._load_kie_as_qa_gt( + annotations=annotations, dsdef=dsdef + ) + else: + sample_annotations = self._load_kie_gt(annotations=annotations) + case DatasetTask.DLA.value: + sample_annotations = self._load_dla_gt(annotations=annotations) + case _: + raise ValueError(f"Unknown synthetic dataset task: {dsdef.task}") + + # TODO: implement other tasks than QA + + word_bbox_path = dsfiles.get_final_normalized_bbox_path( + level="word", doc_id=docid + ) + word_bboxes_raw = read_syn_dataset_bboxes(word_bbox_path) + seg_bbox_path = dsfiles.get_final_normalized_bbox_path( + level="segment", doc_id=docid + ) + seg_bboxes_raw = read_syn_dataset_bboxes(seg_bbox_path) + + words = [b.text for b in word_bboxes_raw] + word_bboxes = [[b.x0, b.y0, b.x2, b.y2] for b in word_bboxes_raw] + segment_level_bboxes = [[b.x0, b.y0, b.x2, b.y2] for b in seg_bboxes_raw] + + if len(word_bboxes) == 0: + logger.warning( + f"No word bboxes found for document id {docid} in synthesized data. Skipping sample." + ) + continue + + if doclog.ocr_required: + image_file_path = dsfiles.img_directory / f"{docid}.{IMAGE_RENDER_EXT}" + else: + image_file_path = dsfiles.final_pdf_directory / f"{docid}.pdf" + + sample = { + "sample_id": docid, + "image_file_path": image_file_path, + "words": words, + "word_bboxes": word_bboxes, + "segment_level_bboxes": segment_level_bboxes, + } + sample.update(sample_annotations) + samples.append(sample) + return samples + + def _prepare_annotations(self, sample, image) -> list: + if self.task_type == TaskType.sequence_classification: + assert self.dataset_labels is not None, "Dataset labels must be provided." + return [ + ClassificationAnnotation( # assuming label is present as category map label to whichever classification category is output for synthesized data + label=Label( + name=sample["label"], + value=self.dataset_labels.index(sample["label"]), + ) + ) + ] + elif self.task_type == TaskType.token_classification: + # for token classification we use bio tagging. so we need to make sure label indices + # map back to riginal + assert self.dataset_labels is not None, "Dataset labels must be provided." + return [ + EntityLabelingAnnotation( + word_labels=LabelList.from_list( + [ + Label(value=self.dataset_labels.index(label), name=label) + for label in sample[ + "word_labels" + ] # here we assume word_labels are provided in synthesized data + ] + ) + ), + ] + + elif self.task_type == TaskType.extractive_qa: + qa_pairs = [] + for i, qa_annotation in enumerate(sample["qa_annotations"]): + qa_pair = ExtractiveQAPair( + id=qa_annotation["question_id"], # unique id if available + question_text=qa_annotation["question"], # question text + answer_start=qa_annotation[ + "answer_start_indices" + ], # start index answer in word tokens + answer_end=qa_annotation[ + "answer_end_indices" + ], # end index of answer in word tokens + answer_text=qa_annotation["answer_text"], # actual answer text + ) + qa_pairs.append(qa_pair) + return [ExtractiveQAAnnotation(qa_pairs=qa_pairs)] + + elif self.task_type == TaskType.layout_analysis: + assert self.dataset_labels is not None, "Dataset labels must be provided." + annotated_objects = [] + for annotation in sample["annotations"]: + label = annotation["label"] + assert label in self.dataset_labels, ( + f"Label {label} not in dataset labels. Found labels: {self.dataset_labels}" + ) + bbox = BoundingBox(value=annotation["bbox"], normalized=True) + + annotated_object = AnnotatedObject( + label=Label(value=self.dataset_labels.index(label), name=label), + bbox=bbox, + ) + annotated_objects.append(annotated_object) + + # convert to AnnotatedObjectList + annotated_objects = AnnotatedObjectList.from_list(annotated_objects) + + if self.clip_bboxes_to_foreground: + image = np.array(image) + refined_bboxes = _foreground_bbox_clip( + image, + annotated_objects.bbox.value, + coords_are_inclusive=False, + min_area=10, + morph_kernel_size=3, + unnormalize=annotated_objects.bbox.normalized, + ) + annotated_objects = annotated_objects.model_copy( + update={ + "bbox": BoundingBoxList(value=refined_bboxes).normalize( + image.shape[1], image.shape[0] + ) + } + ) + + return [ + LayoutAnalysisAnnotation(annotated_objects=annotated_objects), + ] + else: + raise ValueError(f"Unsupported task type: {self.task_type}") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + sample = self.data[idx] + + image_file_path = str(sample["image_file_path"]) + if image_file_path.endswith(IMAGE_RENDER_EXT): + image = PILImageLoader.open(image_file_path) + elif image_file_path.endswith(".pdf"): + doc = fitz.open(image_file_path) + page = doc[0] + mat = fitz.Matrix(1, 1) + pix = page.get_pixmap(matrix=mat) + image = PILImageLoader.frombytes( + "RGB", [pix.width, pix.height], pix.samples + ) + else: + raise ValueError(f"Unsupported image file format: {image_file_path}") + + image = Image(file_path=sample["image_file_path"], content=image) + word_bboxes = sample["word_bboxes"] + segment_level_bboxes = sample["segment_level_bboxes"] + + # remap segment level bboxes to word level if counts mismatch + if len(word_bboxes) != len(segment_level_bboxes): + remapped_segment_level_bboxes = [] + for word_bbox in word_bboxes: + best_iou = 0.0 + best_segment_bbox = word_bbox # fallback to word bbox if no good match + + for segment_bbox in segment_level_bboxes: + iou = _compute_iou(word_bbox, segment_bbox) + if iou > best_iou: + best_iou = iou + best_segment_bbox = segment_bbox + + remapped_segment_level_bboxes.append(best_segment_bbox) + segment_level_bboxes = remapped_segment_level_bboxes + + assert len(segment_level_bboxes) == len(word_bboxes) == len(sample["words"]), ( + f"Length mismatch after remapping for sample {sample['sample_id']}. " + f"Words: {len(sample['words'])}, Word BBoxes: {len(word_bboxes)}, " + f"Segment Level BBoxes: {len(segment_level_bboxes)}" + ) + + if self.resize_images: + image = image.resize_with_aspect_ratio(1024) + + return DocumentInstance( + sample_id=sample["sample_id"], + image=image, + content=DocumentContent( + words=sample["words"], # simple list of words + word_bboxes=BoundingBoxList(value=word_bboxes, normalized=True), + word_segment_level_bboxes=BoundingBoxList( + value=segment_level_bboxes, normalized=True + ), + ), + annotations=self._prepare_annotations(sample, image.content), + ) + + +""" +hey man I checked your file and it was just a small mistake on read. +1. I also fixed some other mistakes on write +2. added metadata file copying for labels +3. added normalization to word bboxes +4. I noticed you use xywh format is that correct? if so it'd be better to just change it to x1y1x2y2 right here + +Hey man i just wrote this. +i havent tested it for anything but it will give you the idea of what you need to do. +You will also have to add the synthesized dataset name for each in DATASET_CONFIG_MAP i guess for it to finally be loaded after being saved. + After that it could be loaded like any other dataset and preprocessed as well. + We need to do preprocessing in later step only because different training will result in different preprocssing +""" diff --git a/docgenie/data/_core/_utilities.py b/docgenie/data/_core/_utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..7f55894f3649516aabc446126c2b1e4848769261 --- /dev/null +++ b/docgenie/data/_core/_utilities.py @@ -0,0 +1,137 @@ +import enum +from typing import Any + +from docgenie.data._core._data_types import BaseModelInput +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +class TaskType(str, enum.Enum): + generate_embeddings = "generate_embeddings" + sequence_classification = "sequence_classification" + token_classification = "token_classification" + extractive_qa = "extractive_qa" + layout_analysis = "layout_analysis" + table_extraction = "table_extraction" + table_detection = "table_detection" + + +def auto_dataloader(dataset: Any, **kwargs: Any) -> Any: + """ + Automatically configures a DataLoader for distributed training. + + This function adjusts DataLoader settings based on the distributed training configuration, + including rank, world size, and device type. It supports XLA devices and provides warnings + for incompatible configurations. + + Args: + iterator (Iterator): The dataset split iterator to load data from. + **kwargs (Any): Additional arguments for configuring the DataLoader. + + Returns: + DataLoader: A configured DataLoader instance. + + Raises: + ValueError: If incompatible configurations are detected. + """ + from ignite.distributed import DistributedProxySampler + from ignite.distributed import utils as idist + from ignite.distributed.comp_models import xla as idist_xla + from torch.utils.data import DataLoader, IterableDataset + from torch.utils.data.distributed import DistributedSampler + from torch.utils.data.sampler import Sampler + + rank = idist.get_rank() + world_size = idist.get_world_size() + + if world_size > 1: + if "batch_size" in kwargs and kwargs["batch_size"] >= world_size: + kwargs["batch_size"] //= world_size + + nproc = idist.get_nproc_per_node() + if "num_workers" in kwargs and kwargs["num_workers"] >= nproc: + kwargs["num_workers"] = (kwargs["num_workers"] + nproc - 1) // nproc + + if "batch_sampler" not in kwargs: + if isinstance(dataset, IterableDataset): + logger.info( + "Found iterable dataset, dataloader will be created without any distributed sampling. " + "Please, make sure that the dataset itself produces different data on different ranks." + ) + else: + sampler: DistributedProxySampler | DistributedSampler | Sampler | None + sampler = kwargs.get("sampler", None) + if isinstance(sampler, DistributedSampler): + if sampler.rank != rank: + logger.warning( + f"Found distributed sampler with rank={sampler.rank}, but process rank is {rank}" + ) + if sampler.num_replicas != world_size: + logger.warning( + f"Found distributed sampler with num_replicas={sampler.num_replicas}, " + f"but world size is {world_size}" + ) + elif sampler is None: + shuffle = kwargs.pop("shuffle", True) + sampler = DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle + ) + else: + sampler = DistributedProxySampler( + sampler, num_replicas=world_size, rank=rank + ) + kwargs["sampler"] = sampler + else: + logger.warning( + "Found batch_sampler in provided kwargs. Please, make sure that it is compatible " + "with distributed configuration" + ) + + if ( + idist.has_xla_support + and idist.backend() == idist_xla.XLA_TPU + and kwargs.get("pin_memory", False) + ): + logger.warning( + "Found incompatible options: xla support and pin_memory args equal True. " + "Argument `pin_memory=False` will be used to construct data loader." + ) + kwargs["pin_memory"] = False + else: + kwargs["pin_memory"] = kwargs.get("pin_memory", "cuda" in idist.device().type) + + dataloader = DataLoader(dataset, **kwargs) + if ( + idist.has_xla_support + and idist.backend() == idist_xla.XLA_TPU + and world_size > 1 + ): + logger.info("DataLoader is wrapped by `MpDeviceLoader` on XLA") + + from torch_xla.distributed.parallel_loader import MpDeviceLoader # type: ignore + + mp_device_loader_cls = MpDeviceLoader + mp_dataloader = mp_device_loader_cls(dataloader, idist.device()) + mp_dataloader.sampler = dataloader.sampler # type: ignore[attr-defined] + return mp_dataloader + + return dataloader + + +def default_collate(list_of_inputs: list[BaseModelInput] | list[list[BaseModelInput]]): + if isinstance(list_of_inputs[0], list): + list_of_inputs = [item for sublist in list_of_inputs for item in sublist] + if isinstance(list_of_inputs, list) and len(list_of_inputs) > 0: + return list_of_inputs[0].batch(list_of_inputs) + else: + raise ValueError("Batch is empty or not a list.") + + +def mmdet_pseudo_collate(batch: list["MMDetInput"]): + from atria_datasets.core.transforms.mmdet import MMDetInput + from mmengine.dataset.utils import pseudo_collate + + return MMDetInput.model_construct( + **pseudo_collate([sample.model_dump() for sample in batch]) + ) diff --git a/docgenie/data/_core/_visualization_utilities.py b/docgenie/data/_core/_visualization_utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..25571260b9e7ec5559906f4786ab9dcd8d527fc3 --- /dev/null +++ b/docgenie/data/_core/_visualization_utilities.py @@ -0,0 +1,706 @@ +from __future__ import annotations + +import os +import random +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import textdistance as td +from PIL import Image, ImageDraw, ImageFont +import textwrap +from docgenie.data._core._data_types import ( + AnnotatedObjectList, + BoundingBoxList, + DatasetLabels, + DocumentInstance, + ExtractiveQAPair, +) +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +GT_BG_FILL=(0, 0, 0, 120) + +def merge_bio_bboxes( + words: List[str], bboxes: List[List[int]], labels: List[str] +) -> Tuple[List[str], List[List[int]], List[str]]: + """ + Merge BIO-style labeled bounding boxes into combined entity boxes and labels. + + Args: + words: List of words in sequence. + bboxes: List of bounding boxes [x1, y1, x2, y2] corresponding to each word. + labels: BIO labels, e.g., ["B-ANSWER", "I-ANSWER", "O", "B-QUESTION", "I-QUESTION"]. + + Returns: + merged_words: List of concatenated entity strings. + merged_bboxes: List of merged bounding boxes for each entity. + merged_labels: List of entity types (e.g., ["ANSWER", "QUESTION"]). + """ + merged_words = [] + merged_bboxes = [] + merged_labels = [] + + current_words = [] + current_boxes = [] + current_label_type = None + + for word, bbox, label in zip(words, bboxes, labels): + if label.startswith("B-"): + # Finalize previous entity if any + if current_words: + x1 = min(b[0] for b in current_boxes) + y1 = min(b[1] for b in current_boxes) + x2 = max(b[2] for b in current_boxes) + y2 = max(b[3] for b in current_boxes) + merged_words.append(" ".join(current_words)) + merged_bboxes.append([x1, y1, x2, y2]) + merged_labels.append(current_label_type) + + # Start new entity + current_label_type = label.split("-", 1)[1] + current_words = [word] + current_boxes = [bbox] + + elif label.startswith("I-") and current_label_type == label.split("-", 1)[1]: + # Continue same entity + current_words.append(word) + current_boxes.append(bbox) + + else: + # Finalize previous if we hit O or mismatch + if current_words: + x1 = min(b[0] for b in current_boxes) + y1 = min(b[1] for b in current_boxes) + x2 = max(b[2] for b in current_boxes) + y2 = max(b[3] for b in current_boxes) + merged_words.append(" ".join(current_words)) + merged_bboxes.append([x1, y1, x2, y2]) + merged_labels.append(current_label_type) + current_words, current_boxes, current_label_type = [], [], None + + # If "O", skip (non-entity) + continue + + # Finalize last entity + if current_words: + x1 = min(b[0] for b in current_boxes) + y1 = min(b[1] for b in current_boxes) + x2 = max(b[2] for b in current_boxes) + y2 = max(b[3] for b in current_boxes) + merged_words.append(" ".join(current_words)) + merged_bboxes.append([x1, y1, x2, y2]) + merged_labels.append(current_label_type) + + return merged_words, merged_bboxes, merged_labels + + +def _save_visualization( + sample: DocumentInstance, + dataset_name: str, + output_dir: str, + split: str, + dataset_labels: DatasetLabels, + visualize_gt_only: bool = True, +): + """Save visualizations of document instance with bounding boxes and annotations.""" + + # Create output directory + sample_id = sample.sample_id.split("/")[-1] + output_path = Path(output_dir) / dataset_name / split + os.makedirs(output_path, exist_ok=True) + + # Extract annotations + annotations = _extract_annotations(sample=sample) + + # Extract content + words, word_bboxes, word_segment_level_bboxes = _extract_content_data(sample=sample) + + # Create filename suffix + label_suffix = "" + if "label" in annotations and annotations["label"] is not None: + label_suffix = ( + f"_label={annotations['label'].name}" if annotations["label"] else "" + ) + + # # Save visualizations + image = sample.image.content + if not visualize_gt_only: + if words is not None and word_bboxes is not None: + _save_word_bbox_visualization( + image=image, + word_bboxes=word_bboxes, + words=words, + word_labels=annotations["word_labels"], + output_path=output_path, + sample_id=sample_id, + label_suffix=label_suffix, + ) + + if words is not None and word_segment_level_bboxes is not None: + _save_segment_bbox_visualization( + image=image, + segment_bboxes=word_segment_level_bboxes, + words=words, + word_labels=annotations["word_labels"], + output_path=output_path, + sample_id=sample_id, + label_suffix=label_suffix, + ) + else: + if words is not None and word_bboxes is not None and annotations["word_labels"]: + _save_word_labels_visualization( + image=image, + word_bboxes=word_bboxes, + words=words, + word_labels=annotations["word_labels"], + output_path=output_path, + sample_id=sample_id, + label_suffix=label_suffix, + ) + + if annotations["qa_pairs"]: + _save_qa_visualization( + image=image, + word_bboxes=word_bboxes, + words=words, + qa_pairs=annotations["qa_pairs"], + output_path=output_path, + sample_id=sample_id, + label_suffix=label_suffix, + ) + + if annotations["annotated_objects"]: + _save_layout_visualization( + image=image, + annotated_objects=annotations["annotated_objects"], + image_size=sample.image.size, + output_path=output_path, + sample_id=sample_id, + layout_labels=dataset_labels.layout, + label_suffix=label_suffix, + ) + + +def _extract_annotations(sample: DocumentInstance) -> Dict[str, Any]: + """Extract annotations from sample.""" + annotations = { + "label": None, + "word_labels": None, + "qa_pairs": None, + "annotated_objects": None, + } + + for annotation in sample.annotations: + if annotation._type == "classification": + annotations["label"] = annotation.label + elif annotation._type == "entity_labeling": + annotations["word_labels"] = annotation.word_labels + elif annotation._type == "extractive_qa": + annotations["qa_pairs"] = annotation.qa_pairs + elif annotation._type == "layout": + annotations["annotated_objects"] = annotation.annotated_objects + + return annotations + + +def _extract_content_data( + sample: DocumentInstance, +) -> tuple[list[str], BoundingBoxList, Optional[BoundingBoxList]]: + """Extract content data from sample.""" + if sample.content is None: + return None, None, None + + words, word_bboxes, word_segment_level_bboxes = ( + sample.content.words, + sample.content.word_bboxes, + sample.content.word_segment_level_bboxes, + ) + + # Unnormalize bounding boxes + word_bboxes: BoundingBoxList = ( + _unnormalize_bboxes(word_bboxes, sample.image.size) + if word_bboxes.normalized + else word_bboxes + ) + + if word_segment_level_bboxes: + word_segment_level_bboxes = ( + _unnormalize_bboxes(word_segment_level_bboxes, sample.image.size) + if word_segment_level_bboxes.normalized + else word_segment_level_bboxes + ) + return ( + words, + word_bboxes, + word_segment_level_bboxes, + ) + + +def _unnormalize_bboxes(bbox_data, img_size): + """Unnormalize bounding boxes from 0-1 to pixel coordinates.""" + if not bbox_data or not bbox_data.value: + return None + + img_width, img_height = img_size + unnormalized_bboxes = [] + + for bbox in bbox_data.value: + unnormalized_bboxes.append( + [ + int(bbox[0] * img_width), + int(bbox[1] * img_height), + int(bbox[2] * img_width), + int(bbox[3] * img_height), + ] + ) + + return bbox_data.model_copy( + update={"value": unnormalized_bboxes, "normalized": False} + ) + + +def _draw_bboxes_on_image( + image: Image, + bboxes_data, + word_labels=None, +) -> Image: + """Draw bounding boxes with warm colors and readable transparent labels.""" + if not bboxes_data or not getattr(bboxes_data, "value", None): + return image.copy() + + img_copy = image.copy().convert("RGB") + draw = ImageDraw.Draw(img_copy, "RGBA") + + # Warm color palette (soft oranges, reds, and golds) + warm_colors = [ + (255, 99, 71), # tomato + (255, 140, 0), # dark orange + (255, 165, 0), # orange + (255, 69, 0), # red-orange + (255, 215, 0), # gold + (255, 182, 80), # light orange + ] + + # Calculate font size based on image dimensions as a ratio + img_width, img_height = image.size + base_size = img_height + font_size = max(12, int(base_size * 0.015)) # 2% of smaller dimension, minimum 12px + + try: + font = ImageFont.truetype("DejaVuSans.ttf", font_size) + except IOError: + font = ImageFont.load_default() + + unique_labels = [] + if word_labels: + unique_labels = list( + set(word_labels if isinstance(word_labels, list) else word_labels.name) + ) + + label_to_color = {} + for idx, label in enumerate(unique_labels): + label_to_color[label] = warm_colors[idx % len(warm_colors)] + + for i, bbox in enumerate(bboxes_data.value): + if len(bbox) < 4: + continue + + # Assign color based on label, fallback to random if no label + if word_labels and i < len(word_labels): + current_label = ( + word_labels[i] if isinstance(word_labels, list) else word_labels.name[i] + ) + color = label_to_color.get(current_label, random.choice(warm_colors)) + else: + color = random.choice(warm_colors) + + # Draw bounding box + try: + draw.rectangle(bbox[:4], outline=color + (255,), width=2) + except Exception as e: + print(f"Error drawing bounding box {bbox}: {e}") + continue + + # Prepare label text + text = "" + # if words and i < len(words): + # text = words[i] + if word_labels and i < len(word_labels): + text += f"{word_labels[i]}" + + if not text: + continue + + # Compute text size (modern Pillow uses textbbox) + try: + text_bbox = draw.textbbox((0, 0), text, font=font) + text_w, text_h = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] + except AttributeError: + # Fallback for older Pillow versions + text_w, text_h = font.getsize(text) + + # Place text slightly above bbox + text_x = bbox[0] + text_y = max(0, bbox[1] - text_h - 4) + + # Draw transparent black background behind text + draw.rectangle( + [text_x, text_y, text_x + text_w + 6, text_y + text_h + 4], + fill=GT_BG_FILL, + ) + + # Draw white text on top + draw.text((text_x + 3, text_y + 2), text, fill=(255, 255, 255, 255), font=font) + + return img_copy + + +def _draw_qa_answers_on_image(image, word_bboxes, qa_pairs): + if not word_bboxes or not word_bboxes.value: + return image.copy() + + img_copy = image.copy().convert("RGB") + draw = ImageDraw.Draw(img_copy, "RGBA") + + warm_colors = [ + (255, 99, 71), + (255, 140, 0), + (255, 165, 0), + (255, 69, 0), + (255, 215, 0), + (255, 182, 80), + ] + + img_width, img_height = image.size + font_size = max(12, int(img_height * 0.018)) + try: + font = ImageFont.truetype("DejaVuSans.ttf", font_size) + except IOError: + font = ImageFont.load_default() + + max_text_width = int(img_width * 0.6) + + for qa_idx, qa_pair in enumerate(qa_pairs): + color = warm_colors[qa_idx % len(warm_colors)] + + question_text = getattr(qa_pair, "question_text", f"Q{qa_idx + 1}") + answer_starts = getattr(qa_pair, "answer_start", []) + answer_ends = getattr(qa_pair, "answer_end", []) + + for start, end in zip(answer_starts, answer_ends): + if start == -1 or end == -1 or start >= len(word_bboxes.value): + continue + + # Merge bounding boxes for full answer span + boxes = word_bboxes.value[start : min(end + 1, len(word_bboxes.value))] + if not boxes: + continue + + x1 = min(b[0] for b in boxes) + y1 = min(b[1] for b in boxes) + x2 = max(b[2] for b in boxes) + y2 = max(b[3] for b in boxes) + + draw.rectangle([x1, y1, x2, y2], outline=color + (255,), width=2) + + # Create wrapped text using textwrap + label_text = f"Q{qa_idx + 1}: {question_text}" + # Approximate chars per line based on width and font metrics + char_width = font.getlength("A") or font_size * 0.6 + max_chars = max_text_width // int(char_width) + wrapped = textwrap.fill(label_text, width=max_chars) + + # Compute text block size + text_bbox = draw.multiline_textbbox((0, 0), wrapped, font=font, spacing=4) + tw = text_bbox[2] - text_bbox[0] + th = text_bbox[3] - text_bbox[1] + + # Define a box above the answer span (or clamp to top of image) + text_x = x1 + text_y = max(0, y1 - th - 8) + + # Background rectangle + draw.rectangle( + [text_x, text_y, text_x + tw + 8, text_y + th + 6], + fill=GT_BG_FILL, + ) + + # Draw wrapped text directly + draw.multiline_text( + (text_x + 4, text_y + 3), + wrapped, + font=font, + fill=(255, 255, 255, 255), + spacing=4, + ) + + return img_copy + + +def _save_word_labels_visualization( + image: Image, + word_bboxes, + words: List[str], + word_labels: Optional[List[str]], + output_path: Path, + sample_id: str, + label_suffix: str, +): + """Save word-level bounding box visualization.""" + has_bio_tagging = False + if word_labels and any( + label.startswith("B-") or label.startswith("I-") for label in word_labels.name + ): + has_bio_tagging = True + if has_bio_tagging: + words, word_bboxes, word_labels = merge_bio_bboxes( + words, word_bboxes.value, word_labels.name + ) + word_bboxes = BoundingBoxList(value=word_bboxes, normalized=False) + + image_with_bboxes = _draw_bboxes_on_image( + image, + word_bboxes, + word_labels if isinstance(word_labels, list) else word_labels.name, + ) + bbox_path = output_path / f"{sample_id}{label_suffix}_word_bboxes.png" + image_with_bboxes.save(bbox_path) + logger.info(f"Saved word bbox visualization: {bbox_path}") + + +def _save_word_bbox_visualization( + image: Image, + word_bboxes, + words: List[str], + word_labels: Optional[List[str]], + output_path: Path, + sample_id: str, + label_suffix: str, +): + """Save word-level bounding box visualization.""" + image_with_bboxes = _draw_bboxes_on_image( + image, word_bboxes, words, word_labels, "red" + ) + bbox_path = output_path / f"{sample_id}{label_suffix}_word_bboxes.png" + image_with_bboxes.save(bbox_path) + logger.info(f"Saved word bbox visualization: {bbox_path}") + + +def _save_segment_bbox_visualization( + image: Image, + segment_bboxes, + words: List[str], + word_labels: Optional[List[str]], + output_path: Path, + sample_id: str, + label_suffix: str, +): + """Save segment-level bounding box visualization.""" + try: + image_with_bboxes = _draw_bboxes_on_image( + image, segment_bboxes, words, word_labels, "blue" + ) + except: + logger.error( + f"Error drawing segment bounding boxes for sample {sample_id}. Skipping visualization." + ) + return + bbox_path = output_path / f"{sample_id}{label_suffix}_segment_bboxes.png" + image_with_bboxes.save(bbox_path) + logger.info(f"Saved segment bbox visualization: {bbox_path}") + + +def _save_qa_visualization( + image: Image, + word_bboxes, + words: List[str], + qa_pairs: List[ExtractiveQAPair], + output_path: Path, + sample_id: str, + label_suffix: str, +): + """Save QA answer visualization and text file.""" + # Save QA image + image_with_qa = _draw_qa_answers_on_image(image, word_bboxes, qa_pairs) + qa_image_path = output_path / f"{sample_id}{label_suffix}_qa_answers.png" + image_with_qa.save(qa_image_path) + logger.info(f"Saved QA answers visualization: {qa_image_path}") + + # qa_txt_path = output_path / f"{sample_id}{label_suffix}_qa.txt" + # with open(qa_txt_path, "w", encoding="utf-8") as f: + # f.write(f"Document Index: {sample_id}\n\n") + # f.write("Document OCR:\n") + # f.write(",".join(words) + "\n\n") + + # for i, qa_pair in enumerate(qa_pairs): + # f.write(f"Q{i + 1}: {qa_pair.question_text}\n") + # f.write(f"A{i + 1}: {qa_pair.answer_text}\n") + + # answer_starts, answer_ends = qa_pair.answer_start, qa_pair.answer_end + # for idx, (start, end) in enumerate(zip(answer_starts, answer_ends)): + # f.write(f"Answer Span [{idx}]: ({start}, {end})\n") + # f.write(f"Extracted Answer: {' '.join(words[start : end + 1])}\n") + # f.write("\n") + + # logger.info(f"Saved QA info: {qa_txt_path}") + + +def _draw_layout_bboxes_on_image( + image: Image, + annotated_objects, + image_size: tuple[int, int], + layout_labels: List[str], +) -> Image: + """ + Draw layout bounding boxes with warm colors (one color per label) and filled area. + """ + + # Unnormalize if needed + bboxes = ( + _unnormalize_bboxes(annotated_objects.bbox, image_size) + if annotated_objects.bbox.normalized + else annotated_objects.bbox + ) + + img_copy = image.copy().convert("RGB") + draw = ImageDraw.Draw(img_copy, "RGBA") + + # Warm color palette + warm_colors = [ + (255, 99, 71), # tomato + (255, 140, 0), # dark orange + (255, 165, 0), # orange + (255, 69, 0), # red-orange + (255, 215, 0), # gold + (255, 182, 80), # light orange + ] + + # Calculate font size based on image dimensions + img_width, img_height = image.size + base_size = img_height + font_size = max(12, int(base_size * 0.015)) + try: + font = ImageFont.truetype("DejaVuSans.ttf", font_size) + except IOError: + font = ImageFont.load_default() + + # Map each layout label to a warm color + unique_labels = list(set(layout_labels)) + label_to_color = { + label: warm_colors[idx % len(warm_colors)] + for idx, label in enumerate(unique_labels) + } + + for bbox, label_idx in zip(bboxes.value, annotated_objects.label.value): + if len(bbox) < 4: + continue + + x1, y1, x2, y2 = bbox + label_text = layout_labels[label_idx] + color = label_to_color.get(label_text, random.choice(warm_colors)) + + # Draw bounding box (outline only) + draw.rectangle([x1, y1, x2, y2], outline=color + (255,), width=3) + + # Draw label text with transparent black background + text_x, text_y = x1, max(0, y1 - font_size - 4) + try: + text_bbox = draw.textbbox((text_x, text_y), label_text, font=font) + text_w, text_h = text_bbox[2] - text_bbox[0], text_bbox[3] - text_bbox[1] + except AttributeError: + text_w, text_h = font.getsize(label_text) + + # Background rectangle + draw.rectangle( + [text_x, text_y, text_x + text_w + 6, text_y + text_h + 4], + fill=GT_BG_FILL, + ) + + # Draw text + draw.text( + (text_x + 3, text_y + 2), label_text, fill=(255, 255, 255, 255), font=font + ) + + return img_copy + + +def _save_layout_visualization( + image: Image, + annotated_objects: AnnotatedObjectList, + image_size: tuple[int, int], + output_path: Path, + sample_id: str, + layout_labels: list[str], + label_suffix: str, +): + """Save layout annotation visualization.""" + # Placeholder function for layout visualization + layout_image_path = output_path / f"{sample_id}{label_suffix}_layout.png" + img_copy = _draw_layout_bboxes_on_image( + image, + annotated_objects, + image_size=image_size, + layout_labels=layout_labels, + ) + img_copy.save(str(layout_image_path) + ".png") + logger.info(f"Saved layout visualization (placeholder): {layout_image_path}.png") + + +def _anls_metric_str( + predictions: list[list[str]], gold_labels: list[list[str]], tau=0.5, rank=0 +): + res = [] + for i, (preds, golds) in enumerate(zip(predictions, gold_labels)): + max_s = 0 + for pred in preds: + for gold in golds: + dis = td.levenshtein.distance(pred.lower(), gold.lower()) + max_len = max(len(pred), len(gold)) + if max_len == 0: + s = 0 + else: + nl = dis / max_len + s = 1 - nl if nl < tau else 0 + max_s = max(s, max_s) + res.append(max_s) + return res, sum(res) / len(res) + + +def _compute_qa_stats(split_reader, split_name): + """Compute QA statistics for a given dataset split.""" + import tqdm + + total_questions = 0 + total_answers_found = 0 + all_extracted_answers = [] + all_gold_answers = [] + + for sample in tqdm.tqdm(split_reader, f"Computing QA stats for {split_name}..."): + # Extract annotations + words = sample.content.words if sample.content else [] + annotations = _extract_annotations(sample=sample) + for qa_pair in annotations["qa_pairs"]: + total_questions += 1 + extracted_answers = [] + for ans_start, ans_end in zip(qa_pair.answer_start, qa_pair.answer_end): + if ans_start != -1 and ans_end != -1: + extracted_answers.append( + " ".join(words[ans_start : ans_end + 1]) + if ans_start != -1 and ans_end != -1 + else "" + ) + if len(extracted_answers) > 0: + total_answers_found += 1 + all_extracted_answers.append(extracted_answers) + all_gold_answers.append(qa_pair.answer_text) + + logger.info(f"{split_name} - total_questions: {total_questions}") + logger.info(f"{split_name} - total_answers_found: {total_answers_found}") + + if total_questions > 0: + logger.info("Computing ANLS metric...") + logger.info("First 10 extracted answers:\n%s", all_extracted_answers[:50]) + logger.info("First 10 gold answers:\n%s", all_gold_answers[:50]) + _, anls = _anls_metric_str(all_extracted_answers, all_gold_answers) + logger.info(f"{split_name} - anls: {anls}") diff --git a/docgenie/data/_transforms/__init__.py b/docgenie/data/_transforms/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..2ccc382d21fc34473bf968ba936eba3b5a08718d --- /dev/null +++ b/docgenie/data/_transforms/__init__.py @@ -0,0 +1,11 @@ +from ._tokenizers._document_processors import ( + QuestionAnsweringDocumentProcessor, + SequenceClassificationDocumentProcessor, + TokenClassificationDocumentProcessor, +) + +__all__ = [ + "SequenceClassificationDocumentProcessor", + "TokenClassificationDocumentProcessor", + "QuestionAnsweringDocumentProcessor", +] diff --git a/docgenie/data/_transforms/_generics/_base.py b/docgenie/data/_transforms/_generics/_base.py new file mode 100755 index 0000000000000000000000000000000000000000..d5ec29fa2a1977cc44dacd0664c52b869256b0a7 --- /dev/null +++ b/docgenie/data/_transforms/_generics/_base.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from typing import Generic, TypeVar + +import torch +from atria_core.utilities.repr import RepresentationMixin +from PIL.Image import Image as PILImage +from pydantic import BaseModel + +from docgenie.logging import get_logger + +logger = get_logger(__name__) + +T = TypeVar("T") + + +class ToRGB(object): + def __call__(self, image: PILImage | torch.Tensor) -> PILImage | torch.Tensor: + if isinstance(image, torch.Tensor): + if image.shape[0] == 3: + return image + return image.repeat(3, 1, 1) + else: + return image.convert("RGB") + + +class BaseTransform(RepresentationMixin, BaseModel, Generic[T]): + def get_output_data_model(self) -> type[T]: + raise NotImplementedError + + def __call__(self, *args, **kwargs) -> T | list[T]: + raise NotImplementedError diff --git a/docgenie/data/_transforms/_generics/_hf_processor.py b/docgenie/data/_transforms/_generics/_hf_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..3dae2d7a7f78af97b362aa364933b520332160db --- /dev/null +++ b/docgenie/data/_transforms/_generics/_hf_processor.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import inspect +from typing import Any + +from pydantic import Field +from transformers import ( + AutoProcessor, + BatchEncoding, + BertTokenizerFast, + RobertaTokenizerFast, +) + +from docgenie.data._transforms._generics._base import BaseTransform + +# add custom models +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +class HuggingfaceProcessor(BaseTransform[BatchEncoding]): + _TOKENIZERS_REQUIRING_SPLIT_TEXT = (BertTokenizerFast, RobertaTokenizerFast) + + tokenizer_name: str = "microsoft/layoutlmv3-base" + init_kwargs: dict = Field(default_factory=dict) + call_kwargs: dict = Field(default_factory=dict) + cache_dir: str = "./cache" + overflow_sampling: str = "return_all" + + @property + def tokenizer(self): + return ( + self._hf_processor.tokenizer + if hasattr(self._hf_processor, "tokenizer") + else self._hf_processor + ) + + @property + def all_special_ids(self) -> set[int]: + return set(self.tokenizer.all_special_ids) + + def model_post_init(self, context) -> None: + assert self.overflow_sampling in [ + "return_all", + "return_random_n", + "no_overflow", + "return_first_n", + ], f"Overflow sampling strategy {self.overflow_sampling} is not supported." + + self._hf_processor = self._initialize_transform() + + def _get_default_call_kwargs(self): + return { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 512, + "stride": 0, + "pad_to_multiple_of": 8, + "is_split_into_words": True, + "return_overflowing_tokens": self.overflow_sampling + != "no_overflow", # set some arguments that we need to stay fixed for our case + "return_token_type_ids": None, + "return_attention_mask": True, + "return_special_tokens_mask": False, + "return_offsets_mapping": False, + "return_length": False, + "return_tensors": "pt", + "verbose": True, + } + + def _initialize_transform(self): + processor = AutoProcessor.from_pretrained( + self.tokenizer_name, + cache_dir=self.cache_dir, + local_files_only=False, + apply_ocr=False, + add_prefix_space=True, + do_lower_case=True, + do_normalize=False, + do_resize=False, + do_rescale=False, + **self.init_kwargs, + ) + + self.call_kwargs = {**self._get_default_call_kwargs(), **self.call_kwargs} + self._possible_args = inspect.signature(processor.__call__).parameters + for key in list(self.call_kwargs.keys()): + if key not in self._possible_args: + logger.warning( + f"Invalid keyword argument '{key}' found in call_kwargs for {self.__class__.__name__}. Skipping it." + ) + self.call_kwargs.pop(key) + return processor + + def get_config(self): + return { + "tokenizer_name": self.tokenizer_name, + "init_kwargs": self.init_kwargs, + "call_kwargs": self.call_kwargs, + } + + def get_output_data_model(self) -> type[BatchEncoding]: + return BatchEncoding + + def _convert_text_to_list(self, text: Any) -> list[str]: + if isinstance(text, str): + return text.split() + elif isinstance(text, list): + return text + else: + raise ValueError("Input text must be a string or a list of strings.") + + def __call__(self, **inputs) -> BatchEncoding: + if isinstance(self.tokenizer, self._TOKENIZERS_REQUIRING_SPLIT_TEXT): + text = inputs.get("text", None) + text_pair = inputs.get("text_pair", None) + + if text is not None and text_pair is not None: + inputs["text"] = self._convert_text_to_list(text) + inputs["text_pair"] = self._convert_text_to_list(text_pair) + + assert isinstance(inputs["text"], list), ( + "Input 'text' must be a list of strings." + ) + assert isinstance(inputs["text_pair"], list), ( + "Input 'text_pair' must be a list of strings." + ) + filtered_inputs = {k: v for k, v in inputs.items() if k in self._possible_args} + return self._hf_processor(**filtered_inputs, **self.call_kwargs) diff --git a/docgenie/data/_transforms/_generics/_image_processor.py b/docgenie/data/_transforms/_generics/_image_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..ae6a24ef8ae438f32801e7d97eefd508b0bde2f2 --- /dev/null +++ b/docgenie/data/_transforms/_generics/_image_processor.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from PIL.Image import Image as PILImage + +from docgenie.data._transforms._generics._base import BaseTransform, ToRGB +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +class ImageProcessor(BaseTransform[PILImage]): + do_normalize: bool = True # Normalize the image to ImageNet mean and std + do_resize: bool = True # Resize the image to 224x224 + use_imagenet_mean_std: bool = False + resize_height: int = 224 + resize_width: int = 224 + image_mean: list[float] | None = None + image_std: list[float] | None = None + + def model_post_init(self, context) -> None: + from transformers.utils.constants import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + IMAGENET_STANDARD_MEAN, + IMAGENET_STANDARD_STD, + ) + + self.image_mean = self.image_mean or IMAGENET_STANDARD_MEAN + self.image_std = self.image_std or IMAGENET_STANDARD_STD + if self.use_imagenet_mean_std: + self.image_mean = IMAGENET_DEFAULT_MEAN + self.image_std = IMAGENET_DEFAULT_STD + + # prepare image transform + self._transform = self._prepare_image_transform() + + def _prepare_image_transform(self): + from torchvision.transforms import Compose, Normalize, Resize, ToTensor + + transform = [ToRGB(), ToTensor()] + if self.do_resize: + transform += [ + Resize( + (self.resize_height, self.resize_width), + interpolation=2, # type: ignore[attr-defined] + antialias=True, # type: ignore[attr-defined] + ), + ] + if self.do_normalize: + transform += [ + Normalize(mean=self.image_mean, std=self.image_std), + ] + transform = Compose(transform) + return transform + + def get_output_data_model(self) -> type[PILImage]: + return PILImage + + def __call__(self, image: PILImage) -> PILImage: + return self._transform(image) diff --git a/docgenie/data/_transforms/_tokenizers/_conditional_generation.py b/docgenie/data/_transforms/_tokenizers/_conditional_generation.py new file mode 100755 index 0000000000000000000000000000000000000000..124cba1e8e9c9ecaadcf22d5b475c0ac8d224b06 --- /dev/null +++ b/docgenie/data/_transforms/_tokenizers/_conditional_generation.py @@ -0,0 +1,547 @@ +from __future__ import annotations + +import json + +from docgenie.data._transforms._tokenizers._document_processors import BaseTransform +from docgenie.data._transforms._tokenizers._udop_processor import CustomUdopProcessor +from docgenie.logging import get_logger + +from ..._core._data_types import ( + AnnotatedObjectList, + ConditionalGenerationModelInput, + DatasetLabels, + DocumentInstance, + ExtractiveQAPair, + Label, + LabelList, +) +from ..._core._utilities import TaskType +from ._utilities import _extract_annotations + +logger = get_logger(__name__) + + +class ConditionalGenerationTokenizer(BaseTransform): + task_type: TaskType + tokenizer_name: str = "microsoft/udop-large" + tokenizer_cache_dir: str = "./cache" + is_training: bool = True + generate_entity_vocabulary: bool = True + dataset_labels: DatasetLabels + + def get_output_data_model(self) -> type: + return ConditionalGenerationModelInput + + def model_post_init(self, context) -> None: + from transformers import AutoProcessor + + self._default_init_kwargs = { + "cache_dir": self.tokenizer_cache_dir, + "local_files_only": False, + "apply_ocr": False, + } + self._default_call_kwargs = { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 1024, + "stride": 0, + "pad_to_multiple_of": 8, + "return_tensors": "pt", + } + if self.task_type == TaskType.token_classification: + self._default_call_kwargs["return_overflowing_tokens"] = True + self._default_call_kwargs["return_offsets_mapping"] = True + self._default_call_kwargs["stride"] = 128 + self._default_call_kwargs["max_length"] = 512 + self._processor = CustomUdopProcessor.from_pretrained( + self.tokenizer_name, + **self._default_init_kwargs, + clean_up_tokenization_spaces=False, + ) + else: + self._processor = AutoProcessor.from_pretrained( + self.tokenizer_name, **self._default_init_kwargs + ) + self._tokenizer = ( + self._processor.tokenizer + if hasattr(self._processor, "tokenizer") + else self._processor + ) + + # if self.task_type == TaskType.token_classification and self.generate_entity_vocabulary: + # possible_labels = ( + # self.dataset_labels.ser + # if self.dataset_labels.ser is not None + # else [] + # ) + # possible_labels = [f"<{lbl}>" for lbl in possible_labels] + # num_added_tokens = self._tokenizer.add_special_tokens({"additional_special_tokens": possible_labels}) + # logger.info(f"Added {num_added_tokens} special tokens for entity labels: {possible_labels}") + + def _get_common_kwargs(self, document_instance: DocumentInstance) -> tuple: + # get pil image from the document instance + image = document_instance.image.load().content.convert("RGB") + + # get words from the document instance + words = ( + document_instance.content.words + if document_instance.content is not None + else [] + ) + + # get bounding boxes from the document instance + boxes = ( + document_instance.content.word_bboxes.value + if document_instance.content is not None + else [] + ) + + return image, words, boxes + + def _prepare_instances_for_sequence_classification( + self, document_instance: DocumentInstance, label: Label + ) -> ConditionalGenerationModelInput: + import torch + + possible_labels = ( + self.dataset_labels.classification + if self.dataset_labels.classification is not None + else [] + ) + image, words, boxes = self._get_common_kwargs(document_instance) + prompt = f"Document Classification. Classify the document into one of these categories: {', '.join(possible_labels)}. Document: " + target_text = label.name + + if not words: + # Supply a dummy token and box so UDOP doesn't crash + words = ["None"] + boxes = [[0, 0, 0, 0]] + + tokenized_instance = {} + if self.tokenizer_name == "microsoft/udop-large": + tokenized_instance = self._processor( + image, prompt, text_pair=words, boxes=boxes, **self._default_call_kwargs + ) + elif self.tokenizer_name in ["google-t5/t5-large", "google-t5/t5-base"]: + tokenized_instance = self._processor( + prompt, text_pair=" ".join(words), **self._default_call_kwargs + ) + + for key, value in tokenized_instance.items(): + tokenized_instance[key] = value.squeeze(0) + + # # for debugging decode the input ids + # decoded_input = self._processor.decode(tokenized_instance['input_ids'], skip_special_tokens=True) + # print('Decoded input:', decoded_input) + + # Tokenize target text to get target_token_ids + target_token_ids = self._tokenizer.encode( # this takes text but returns a batch, truly a garbage design + target_text, + add_special_tokens=True, + return_tensors="pt", + max_length=16, + truncation=True, + padding="max_length", + )[0] + + # decoded_target_text = self._processor.decode(target_token_ids, skip_special_tokens=True) + # print('Decoded target_text:', decoded_target_text) + + # Set padding token IDs to -100 to ignore in loss computation + target_token_ids[target_token_ids == 0] = -100 + + return ConditionalGenerationModelInput( + **tokenized_instance, + index=torch.tensor(document_instance.index), + sample_id=document_instance.sample_id, + words=words, + target_text=target_text, + target_token_ids=target_token_ids, + _tokenizer_name=self.tokenizer_name, + _tokenizer_init_kwargs=self._default_init_kwargs, + ) + + def _generate_target_text_for_token_classification( + self, + words: list[str], + word_labels: list[str], + target_text_type: str = "key_value_pairs", + ) -> str: + # entities = {} + + target_text = "" + for word_idx, (word, word_label) in enumerate( + zip(words, word_labels, strict=True) + ): + target_text += f"{word} {word_label} " + target_text = target_text.strip() + return target_text + + # if word_label == "O" or word_label.startswith("I-"): + # continue + + # if word_label.startswith("B-"): + # entity_words = [word] + # for next_word, next_label in zip( + # words[word_idx + 1 :], word_labels[word_idx + 1 :] + # ): + # if next_label == f"I-{word_label[2:]}": + # entity_words.append(next_word) + # else: + # break + + # if word_label[2:] not in entities: + # entities[word_label[2:]] = [] + # entities[word_label[2:]].append(" ".join(entity_words)) + + if len(entities) == 0: + return None + + if target_text_type == "csv": + lines = [] + separator = "|" + for key, values in entities.items(): + for value in values: + line = f"{key}={value}{separator}" + lines.append(line) + lines[-1] = lines[-1].rstrip(f"{separator}") # remove sep from last line + return "".join(lines) + elif target_text_type == "json": + return json.dumps(entities) + else: + raise NotImplementedError( + f"Target text type {target_text_type} not supported." + ) + + def _prepare_instances_for_token_classification( + self, document_instance: DocumentInstance, word_labels: LabelList + ) -> list[ConditionalGenerationModelInput]: + import torch + + image, words, boxes = self._get_common_kwargs(document_instance) + prompt = "Information Extraction. Extract all the entities present in this document: Document: " + + if not words: + words = ["None"] + boxes = [[0, 0, 0, 0]] + word_labels.name = ["O"] + + features, encoded_batch = None, None + if self.tokenizer_name == "microsoft/udop-large": + features, encoded_batch = self._processor( + image, prompt, text_pair=words, boxes=boxes, **self._default_call_kwargs + ) + elif self.tokenizer_name in ["google-t5/t5-large", "google-t5/t5-base"]: + raise NotImplementedError( + "Token classification not implemented for T5 models yet." + ) + + sequence_ids = [] + word_ids = [] + for i in range(len(encoded_batch["input_ids"])): + sequence_ids_per_overflow = encoded_batch.sequence_ids(i) + word_ids_per_overflow = encoded_batch.word_ids(i) + + # filter sequence_ids + sequence_ids_per_overflow = [ + -100 if x is None else x for x in sequence_ids_per_overflow + ] + word_ids_per_overflow = [ + -100 if x is None else x for x in word_ids_per_overflow + ] + if max(sequence_ids_per_overflow) > 0: + word_ids_per_overflow = [ + -100 if sequence_id == 0 else word_id + for word_id, sequence_id in zip( + word_ids_per_overflow, sequence_ids_per_overflow + ) + ] + sequence_ids.append(sequence_ids_per_overflow) + word_ids.append(word_ids_per_overflow) + + sequence_ids = torch.tensor(sequence_ids) + word_ids = torch.tensor(word_ids) + + # to compare the targets we need to know where the start of next overlfow sequence is after stride + last_max_word_id = -1 + instances = [] + for overflow_idx in range(len(encoded_batch["input_ids"])): + # find min max word ids + input_ids_per_per_overflow = encoded_batch["input_ids"][overflow_idx] + word_ids_per_overflow = word_ids[overflow_idx] + min_word_id = min( + [wid for wid in word_ids_per_overflow.tolist() if wid != -100] + ) + max_word_id = max( + [wid for wid in word_ids_per_overflow.tolist() if wid != -100] + ) + + # words in this overflow + words_in_this_overflow = words[min_word_id : max_word_id + 1] + word_labels_in_this_overflow = word_labels.name[ + min_word_id : max_word_id + 1 + ] + + target_text = self._generate_target_text_for_token_classification( + words=words_in_this_overflow, + word_labels=word_labels_in_this_overflow, + target_text_type="csv", + ) + + if target_text is None: + continue + + target_token_ids = self._tokenizer.encode( + target_text, + add_special_tokens=True, + return_tensors="pt", + max_length=1024, + truncation=True, + padding="max_length", + )[0] + + # word labels after stride + word_to_extract_in_this_overflow = words[ + last_max_word_id + 1 : max_word_id + 1 + ] + word_labels_to_extract_in_this_overflow = word_labels.name[ + last_max_word_id + 1 : max_word_id + 1 + ] + # decoded_target_text = tokenizer.decode(target_token_ids, skip_special_tokens=True) + # decoded_input_text = tokenizer.decode(input_ids_per_overflow, skip_special_tokens=True) + last_max_word_id = max_word_id + + # index: Optional["torch.Tensor"] = None + # sample_id: Optional[str] = None + # input_ids: Optional["torch.Tensor"] = None + # bbox: Optional["torch.Tensor"] = None + # attention_mask: Optional["torch.Tensor"] = None + # pixel_values: Optional["torch.Tensor"] = None + # question_text: Optional[str] = None + # target_text: Optional[str] = None + # target_token_ids: Optional["torch.Tensor"] = None + # words: Optional[list[str]] = None + # word_labels: Optional[list[str]] = None + # label: Optional["torch.Tensor"] = None + # _tokenizer_name: Optional[str] = None + # _tokenizer_init_kwargs: Optional[dict] = None + + # Set padding token IDs to -100 to ignore in loss computation + target_token_ids[target_token_ids == 0] = -100 + + instance = ConditionalGenerationModelInput( + index=torch.tensor(document_instance.index), + sample_id=document_instance.sample_id, + input_ids=input_ids_per_per_overflow, + attention_mask=features["attention_mask"][overflow_idx], + pixel_values=features["pixel_values"][overflow_idx], + bbox=features["bbox"][overflow_idx], + words=word_to_extract_in_this_overflow, + word_labels=word_labels_to_extract_in_this_overflow, + target_text=target_text, + target_token_ids=target_token_ids, + _tokenizer_name=self.tokenizer_name, + _tokenizer_init_kwargs=self._default_init_kwargs, + ) + instances.append(instance) + + if self.is_training: + random_index = int(torch.randint(0, len(instances), (1,)).item()) + return instances[random_index] + + return instances + + def _prepare_instances_for_question_answering( + self, document_instance: DocumentInstance, qa_pairs: list[ExtractiveQAPair] + ) -> list[ConditionalGenerationModelInput]: + import torch + + image, words, boxes = self._get_common_kwargs(document_instance) + + instances = [] + for qa_pair in qa_pairs: + # since we can have multiple answers per question, we need to handle that here and just take one which is not + # -1 # we don't need to remove no answer indices in conditional generation setting as we always have the answer anyway + # word_ans_start, word_ans_end = -1, -1 + # for ans_start, ans_end in zip(qa_pair.answer_start, qa_pair.answer_end): + # if ans_start != -1 and ans_end != -1: + # word_ans_start = ans_start + # word_ans_end = ans_end + # break + + # if word_ans_start == -1 or word_ans_end == -1: + # if self.is_training: + # logger.warning(f"Skipping QA pair with no answer during training: {qa_pair}") + # continue + + prompt = f"Question answering. {qa_pair.question_text}" + target_text = qa_pair.answer_text[0] + + tokenized_instance = {} + if self.tokenizer_name == "microsoft/udop-large": + tokenized_instance = self._processor( + image, + prompt, + text_pair=words, + boxes=boxes, + **self._default_call_kwargs, + ) + elif self.tokenizer_name in ["google-t5/t5-large", "google-t5/t5-base"]: + tokenized_instance = self._processor( + prompt, text_pair=" ".join(words), **self._default_call_kwargs + ) + + for key, value in tokenized_instance.items(): + tokenized_instance[key] = value.squeeze(0) + + # # # for debugging decode the input ids + # decoded_input = self._processor.decode(tokenized_instance['input_ids'], skip_special_tokens=True) + # print('Decoded input:', decoded_input) + + # Tokenize target text to get target_token_ids + target_token_ids = self._tokenizer.encode( # this takes text but returns a batch, truly a garbage design + target_text, + add_special_tokens=True, + return_tensors="pt", + max_length=128, + truncation=True, + padding="max_length", + )[0] + + # decoded_target_text = self._processor.decode(target_token_ids, skip_special_tokens=True) + # print('Decoded target_text:', decoded_target_text) + + # Set padding token IDs to -100 to ignore in loss computation + target_token_ids[target_token_ids == 0] = -100 + + instance = ConditionalGenerationModelInput( + **tokenized_instance, + index=torch.tensor(document_instance.index), + sample_id=document_instance.sample_id, + words=words, + target_text=target_text, + question_text=qa_pair.question_text, + target_token_ids=target_token_ids, + _tokenizer_name=self.tokenizer_name, + _tokenizer_init_kwargs=self._default_init_kwargs, + ) + + instances.append(instance) + + if self.is_training: + random_index = int(torch.randint(0, len(instances), (1,)).item()) + return instances[random_index] + + return instances + + def _prepare_instances_for_layout_analysis( + self, + document_instance: DocumentInstance, + annotated_objects: AnnotatedObjectList, + ) -> str: + possible_labels = ( + self.dataset_labels.layout if self.dataset_labels.layout is not None else [] + ) + image, words, boxes = self._get_common_kwargs(document_instance) + prompt = f"Layout Analysis. Extract the layout entities present in the document into one of these categories: {', '.join(possible_labels)}. Document: " + bbox_labels_concatenated = [] + for label, bbox in zip( + annotated_objects.label.name, + annotated_objects.bbox, + ): + bbox = [int(x * 1000) for x in bbox] + bbox_str = "".join([f"<{coord}>" for coord in bbox]) + bbox_labels_concatenated.append(f"{bbox_str}<{label}>") + + target_text = ",".join(bbox_labels_concatenated) + + tokenized_instance = {} + if self.tokenizer_name == "microsoft/udop-large": + tokenized_instance = self._processor( + image, prompt, **self._default_call_kwargs + ) + elif self.tokenizer_name in ["google-t5/t5-large", "google-t5/t5-base"]: + tokenized_instance = self._processor( + prompt, text_pair=" ".join(words), **self._default_call_kwargs + ) + + for key, value in tokenized_instance.items(): + tokenized_instance[key] = value.squeeze(0) + + # for debugging decode the input ids + # decoded_input = self._processor.decode(tokenized_instance['input_ids'], skip_special_tokens=True) + # print('Decoded input:', decoded_input) + + # Tokenize target text to get target_token_ids + tokenizer = ( + self._processor.tokenizer + if hasattr(self._processor, "tokenizer") + else self._processor + ) + target_token_ids = tokenizer.encode( # this takes text but returns a batch, truly a garbage design + target_text, + add_special_tokens=True, + return_tensors="pt", + max_length=256, + truncation=True, + padding="max_length", + )[0] + + # debugging + # print(tokenizer.special_tokens_map) + # print(tokenizer.additional_special_tokens) + # print(tokenizer.additional_special_tokens_ids) + # print('target_token_ids', target_token_ids) + # for idx, token in enumerate(target_token_ids): + # decoded_token = tokenizer.decode([token.item()]) + # print(f'Token ID: {token.item()} -> Decoded Token: "{decoded_token}"') + # if idx > 10: + # break + # decoded_input = self._processor.decode(target_token_ids, skip_special_tokens=True) + # print('Decoded target_text:', target_text) + + # Set padding token IDs to -100 to ignore in loss computation + target_token_ids[target_token_ids == 0] = -100 + + return ConditionalGenerationModelInput( + **tokenized_instance, + index=torch.tensor(document_instance.index), + sample_id=document_instance.sample_id, + words=words, + target_text=target_text, + target_token_ids=target_token_ids, + image_size=document_instance.image.load().content.size, + _tokenizer_name=self.tokenizer_name, + _tokenizer_init_kwargs=self._default_init_kwargs, + ) + + def __call__( + self, document_instance: DocumentInstance + ) -> ConditionalGenerationModelInput | list[ConditionalGenerationModelInput]: + # prepare prompt based on task type + annotations = _extract_annotations(document_instance) + if self.task_type == TaskType.sequence_classification: + return self._prepare_instances_for_sequence_classification( + document_instance, annotations.label + ) + elif self.task_type == TaskType.token_classification: + return self._prepare_instances_for_token_classification( + document_instance, annotations.word_labels + ) + elif self.task_type == TaskType.extractive_qa: + return self._prepare_instances_for_question_answering( + document_instance, annotations.qa_pairs + ) + elif self.task_type == TaskType.layout_analysis: + return self._prepare_instances_for_layout_analysis( + document_instance, annotations.annotated_objects + ) + else: + raise NotImplementedError(f"Task type {self.task_type} not supported.") + + def __repr__(self) -> str: + return f"ConditionalGenerationTokenizer(task_type={self.task_type}, is_training={self.is_training})" + + def __str__(self) -> str: + return f"ConditionalGenerationTokenizer(task_type={self.task_type}, is_training={self.is_training})" diff --git a/docgenie/data/_transforms/_tokenizers/_document_processors.py b/docgenie/data/_transforms/_tokenizers/_document_processors.py new file mode 100755 index 0000000000000000000000000000000000000000..116c0808d28386e3bcb963b7dd0523b6da5ee55b --- /dev/null +++ b/docgenie/data/_transforms/_tokenizers/_document_processors.py @@ -0,0 +1,296 @@ +from __future__ import annotations + +from dataclasses import replace + +import numpy as np +import torch +from pydantic import ConfigDict, Field + +from docgenie.data._core._data_types import ( + DocumentInstance, + DocumentInstanceModelInput, +) +from docgenie.data._transforms._generics._base import BaseTransform +from docgenie.data._transforms._generics._hf_processor import HuggingfaceProcessor +from docgenie.data._transforms._generics._image_processor import ImageProcessor +from docgenie.logging import get_logger + +from ._utilities import ( + _document_instance_to_hf_processor_inputs, + _extract_annotations, + _generate_qa_token_ids, + _post_process_tokenizer_outputs, +) + +logger = get_logger(__name__) + + +class BaseDocumentProcessor(BaseTransform[DocumentInstanceModelInput]): + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, extra="forbid" + ) + + # tokenizer args + tokenizer_name: str = "microsoft/layoutlmv3-base" + init_kwargs: dict = Field(default_factory=dict) + call_kwargs: dict = Field(default_factory=dict) + overflow_sampling: str = "return_all" + max_overflow_samples: int = 10 + use_segment_level_bboxes: bool = False + cache_dir: str = "./cache" + + # image processor args + do_normalize: bool = True # Normalize the image to ImageNet mean and std + do_resize: bool = True # Resize the image to 224x224 + use_imagenet_mean_std: bool = False + resize_height: int = 224 + resize_width: int = 224 + image_mean: list[float] | None = None + image_std: list[float] | None = None + + # segment-level-rank info args + add_segment_level_info: bool = False + max_segment_num: int = 150 + + def model_post_init(self, context) -> None: + self._hf_processor = HuggingfaceProcessor( + tokenizer_name=self.tokenizer_name, + init_kwargs=self.init_kwargs, + call_kwargs=self.call_kwargs, + overflow_sampling=self.overflow_sampling, + cache_dir=self.cache_dir, + ) + self._image_transform = ImageProcessor( + do_normalize=self.do_normalize, + do_resize=self.do_resize, + use_imagenet_mean_std=self.use_imagenet_mean_std, + resize_height=self.resize_height, + resize_width=self.resize_width, + image_mean=self.image_mean, + image_std=self.image_std, + ) + + def get_output_data_model(self): + return DocumentInstanceModelInput + + def __call__( + self, document_instance: DocumentInstance + ) -> DocumentInstanceModelInput | list[DocumentInstanceModelInput]: + hf_processor_inputs = _document_instance_to_hf_processor_inputs( + document_instance, + use_segment_level_bboxes=self.use_segment_level_bboxes, + image_transform=self._image_transform, + ) + tokenization_data = self._hf_processor(**hf_processor_inputs) + processed_outputs = _post_process_tokenizer_outputs( + tokenization_data=tokenization_data, + input_word_boxes=hf_processor_inputs.get("boxes", None), + input_word_labels=hf_processor_inputs.get("word_labels", None), + input_image=hf_processor_inputs.get("images", None), + add_segment_level_info=self.add_segment_level_info, + all_special_ids=self._hf_processor.tokenizer.all_special_ids, + max_segment_num=self.max_segment_num, + ) + return DocumentInstanceModelInput( + index=torch.tensor(document_instance.index) + if document_instance.index is not None + else None, + sample_id=document_instance.sample_id, + words=hf_processor_inputs.pop("text", None), + tokenizer_config=self._hf_processor.get_config(), + **processed_outputs, + ) + + +class SequenceClassificationDocumentProcessor(BaseDocumentProcessor): + def __call__( + self, document_instance: DocumentInstance + ) -> DocumentInstanceModelInput | list[DocumentInstanceModelInput]: + instance = super().__call__(document_instance) + annotations = _extract_annotations(document_instance) + assert annotations.label is not None, "No label found in the document instance." + if isinstance(instance, list): + return [ + replace( + inst, + label=torch.tensor(annotations.label.value), + ) + for inst in instance + ] + return replace( + instance, + label=torch.tensor(annotations.label.value), + ) + + +class TokenClassificationDocumentProcessor(BaseDocumentProcessor): + pass + + +class QuestionAnsweringDocumentProcessor(BaseDocumentProcessor): + ignore_samples_with_no_answer: bool = False + is_training: bool = False + + def model_post_init(self, context) -> None: + # update call kwargs + self.call_kwargs["truncation"] = "only_second" + + super().model_post_init(context) + + def _is_no_answer_sample( + self, token_answer_start, token_answer_end, tokenization_data + ): + total_answers = len(token_answer_start) + for key, value in tokenization_data.items(): + if value is None: + continue + if key == "image": + continue + assert len(value) == total_answers, ( + f"Length mismatch in tokenization data for key {key}. " + f"Expected length: {total_answers}, Actual length: {len(value)}" + ) + + valid_indices = [] + for idx, (s, e) in enumerate(zip(token_answer_start, token_answer_end)): + if s != -1 and e != -1: + valid_indices.append(idx) + + if len(valid_indices) == 0: + return True # skip this sample entirely + + if len(valid_indices) < total_answers: + tokenization_data = { + k: v[valid_indices] if v is not None and k not in ["image"] else v + for k, v in tokenization_data.items() + } + token_answer_start = token_answer_start[valid_indices] + token_answer_end = token_answer_end[valid_indices] + + assert (np.array(token_answer_end) != -1).all(), ( + f"Some end answer indices are -1 in document {token_answer_end}" + ) + assert (np.array(token_answer_start) != -1).all(), ( + f"Some start answer indices are -1 in document {token_answer_start}" + ) + total_answers = len(token_answer_start) + for key, value in tokenization_data.items(): + if value is None: + continue + if key == "image": + continue + assert len(value) == total_answers, ( + f"Length mismatch in tokenization data for key {key}. " + f"Expected length: {total_answers}, Actual length: {len(value)}" + ) + return False + + def __call__( + self, document_instance: DocumentInstance + ) -> DocumentInstanceModelInput | list[DocumentInstanceModelInput]: + qa_pairs = _extract_annotations(document_instance).qa_pairs + assert qa_pairs is not None, "No QA pairs found in the document instance." + assert len(qa_pairs) > 0, "No QA pairs found in the document instance." + + transformed_instances = [] + for qa_pair_index in range(len(qa_pairs)): + # prepare model input + hf_processor_inputs = _document_instance_to_hf_processor_inputs( + document_instance, + use_segment_level_bboxes=self.use_segment_level_bboxes, + image_transform=self._image_transform, + context=qa_pairs[qa_pair_index].question_text, + ) + + text_pair = hf_processor_inputs.get("text_pair", None) + boxes = hf_processor_inputs.get("boxes", None) + assert len(text_pair) == len(boxes), ( + f"Length mismatch between text_pair and boxes for sample {document_instance.sample_id}. " + f"Length of text_pair: {len(text_pair)}, Length of boxes: {len(boxes)}" + ) + + tokenization_data = self._hf_processor(**hf_processor_inputs) + processed_outputs = _post_process_tokenizer_outputs( + tokenization_data=tokenization_data, + input_word_boxes=hf_processor_inputs.get("boxes", None), + input_word_labels=hf_processor_inputs.get("word_labels", None), + input_image=hf_processor_inputs.get("images", None), + add_segment_level_info=self.add_segment_level_info, + all_special_ids=self._hf_processor.tokenizer.all_special_ids, + max_segment_num=self.max_segment_num, + ) + + token_answer_start, token_answer_end = _generate_qa_token_ids( + qa_pair=qa_pairs[qa_pair_index], + word_ids=processed_outputs["word_ids"], + sequence_ids=processed_outputs["sequence_ids"], + sequence_length=processed_outputs["token_ids"].shape[-1], + ) + + # if all token_answer_start and token_answer_end are 0, it means we could not find the answer in the context + # therefore using this sample as a training sample will not help the model learn anything + if self.is_training and self.ignore_samples_with_no_answer: + total_answers = len(token_answer_start) + for key, value in processed_outputs.items(): + if value is None: + continue + if key == "image": + continue + assert len(value) == total_answers, ( + f"Length mismatch in tokenization data for key {key}. " + f"Expected length: {total_answers}, Actual length: {len(value)}" + ) + + valid_indices = [] + for idx, (s, e) in enumerate(zip(token_answer_start, token_answer_end)): + if s != -1 and e != -1: + valid_indices.append(idx) + + if len(valid_indices) == 0: + continue # skip this sample entirely + + if len(valid_indices) < total_answers: + processed_outputs = { + k: v[valid_indices] + if v is not None and k not in ["image"] + else v + for k, v in processed_outputs.items() + } + token_answer_start = token_answer_start[valid_indices] + token_answer_end = token_answer_end[valid_indices] + + assert (np.array(token_answer_end) != -1).all(), ( + f"Some end answer indices are -1 in document {token_answer_end}" + ) + assert (np.array(token_answer_start) != -1).all(), ( + f"Some start answer indices are -1 in document {token_answer_start}" + ) + total_answers = len(token_answer_start) + for key, value in processed_outputs.items(): + if value is None: + continue + if key == "image": + continue + assert len(value) == total_answers, ( + f"Length mismatch in tokenization data for key {key}. " + f"Expected length: {total_answers}, Actual length: {len(value)}" + ) + + # make sure afterwards we always have one length for all processed outputs + sample_id = document_instance.sample_id + "_subsample_" + str(qa_pair_index) + transformed_instance = DocumentInstanceModelInput( + index=torch.tensor(document_instance.index) + if document_instance.index is not None + else None, + sample_id=sample_id, + words=hf_processor_inputs.pop("text_pair", None), + question_id=qa_pair_index, + qa_question=qa_pairs[qa_pair_index].question_text, + qa_answers=qa_pairs[qa_pair_index].answer_text, + token_answer_start=token_answer_start, + token_answer_end=token_answer_end, + tokenizer_config=self._hf_processor.get_config(), + **processed_outputs, + ) + transformed_instances.append(transformed_instance) + return transformed_instances diff --git a/docgenie/data/_transforms/_tokenizers/_udop_processor.py b/docgenie/data/_transforms/_tokenizers/_udop_processor.py new file mode 100755 index 0000000000000000000000000000000000000000..937cc93aab9d601ab0b5b2f4c94405051607824a --- /dev/null +++ b/docgenie/data/_transforms/_tokenizers/_udop_processor.py @@ -0,0 +1,89 @@ +# verify input +# patch udop processor to return word and sequence ids +from transformers.models.udop.processing_udop import UdopProcessor, UdopProcessorKwargs +from typing import List, Optional, Union + +from transformers import logging + +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + + + +class CustomUdopProcessor(UdopProcessor): + def __call__( + self, + images: Optional[ImageInput] = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + *args, + audio=None, + videos=None, + **kwargs: Unpack[UdopProcessorKwargs], + ) -> BatchFeature: + output_kwargs = self._merge_kwargs( + UdopProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + **self.prepare_and_validate_optional_call_args(*args), + ) + + boxes = output_kwargs["text_kwargs"].pop("boxes", None) + word_labels = output_kwargs["text_kwargs"].pop("word_labels", None) + text_pair = output_kwargs["text_kwargs"].pop("text_pair", None) + return_overflowing_tokens = output_kwargs["text_kwargs"].get("return_overflowing_tokens", False) + return_offsets_mapping = output_kwargs["text_kwargs"].get("return_offsets_mapping", False) + text_target = output_kwargs["text_kwargs"].get("text_target", None) + + if self.image_processor.apply_ocr and (boxes is not None): + raise ValueError( + "You cannot provide bounding boxes if you initialized the image processor with apply_ocr set to True." + ) + + if self.image_processor.apply_ocr and (word_labels is not None): + raise ValueError( + "You cannot provide word labels if you initialized the image processor with apply_ocr set to True." + ) + + if return_overflowing_tokens and not return_offsets_mapping: + raise ValueError("You cannot return overflowing tokens without returning the offsets mapping.") + + if text_target is not None: + # use the processor to prepare the targets of UDOP + return self.tokenizer( + **output_kwargs["text_kwargs"], + ) + + else: + # use the processor to prepare the inputs of UDOP + # first, apply the image processor + features = self.image_processor(images=images, **output_kwargs["images_kwargs"]) + features_words = features.pop("words", None) + features_boxes = features.pop("boxes", None) + + output_kwargs["text_kwargs"].pop("text_target", None) + output_kwargs["text_kwargs"].pop("text_pair_target", None) + output_kwargs["text_kwargs"]["text_pair"] = text_pair + output_kwargs["text_kwargs"]["boxes"] = boxes if boxes is not None else features_boxes + output_kwargs["text_kwargs"]["word_labels"] = word_labels + + # second, apply the tokenizer + if text is not None and self.image_processor.apply_ocr and text_pair is None: + if isinstance(text, str): + text = [text] # add batch dimension (as the image processor always adds a batch dimension) + output_kwargs["text_kwargs"]["text_pair"] = features_words + + encoded_inputs = self.tokenizer( + text=text if text is not None else features_words, + **output_kwargs["text_kwargs"], + ) + + # add pixel values + if return_overflowing_tokens is True: + features["pixel_values"] = self.get_overflowing_images( + features["pixel_values"], encoded_inputs["overflow_to_sample_mapping"] + ) + features.update(encoded_inputs) + + return features, encoded_inputs \ No newline at end of file diff --git a/docgenie/data/_transforms/_tokenizers/_utilities.py b/docgenie/data/_transforms/_tokenizers/_utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..ec2d3dc2e17c08b723586c79e4c7cf1b260f34f9 --- /dev/null +++ b/docgenie/data/_transforms/_tokenizers/_utilities.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Mapping + +import torch +from transformers import BatchEncoding + +from docgenie.logging import get_logger + +from ..._core._data_types import ( + AnnotatedObjectList, + ClassificationAnnotation, + DocumentInstance, + EntityLabelingAnnotation, + ExtractiveQAAnnotation, + ExtractiveQAPair, + Label, + LabelList, + LayoutAnalysisAnnotation, +) + +logger = get_logger(__name__) + + +@dataclass +class Annotations: + label: Label | None = None + word_labels: LabelList | None = None + qa_pairs: list[ExtractiveQAPair] | None = None + annotated_objects: AnnotatedObjectList | None = None + + +def _document_instance_to_hf_processor_inputs( + document_instance: DocumentInstance, + use_segment_level_bboxes: bool = False, + image_transform: Callable | None = None, + context: str | None = None, +) -> dict[str, Any]: + if document_instance.content is None: + return {} + + inputs = {} + + if context is None: + if document_instance.content.words is not None: + inputs["text"] = document_instance.content.words + else: + qa_pairs = _extract_annotations(document_instance).qa_pairs + assert qa_pairs is not None and len(qa_pairs) > 0, ( + "No QA pairs found in the document instance for extractive QA task." + ) + inputs["text"] = context + inputs["text_pair"] = document_instance.content.words + + if document_instance.content.word_bboxes is not None: + inputs["boxes"] = ( + document_instance.content.word_segment_level_bboxes.value + if use_segment_level_bboxes + and document_instance.content.word_segment_level_bboxes is not None + else document_instance.content.word_bboxes.value + ) + + if document_instance.image is not None: + inputs["images"] = ( + image_transform(document_instance.image.content) + if image_transform is not None + else document_instance.image.content + ) + + # extract annotations + annotations = _extract_annotations(document_instance) + if annotations.label is not None: + inputs["label"] = annotations.label.value + if annotations.word_labels is not None: + inputs["word_labels"] = annotations.word_labels.value + + return inputs + + +def _extract_sequence_and_word_ids( + tokenization_data: BatchEncoding, +) -> tuple[torch.Tensor, torch.Tensor]: + sequence_ids = [] + word_ids = [] + input_ids = tokenization_data["input_ids"] + num_overflow_samples = len(input_ids) # type: ignore + for i in range(num_overflow_samples): + sequence_ids_per_overflow = tokenization_data.sequence_ids(i) + word_ids_per_overflow = tokenization_data.word_ids(i) + + # filter sequence_ids + sequence_ids_per_overflow = [ + -100 if x is None else x for x in sequence_ids_per_overflow + ] + word_ids_per_overflow = [ + -100 if x is None else x for x in word_ids_per_overflow + ] + if max(sequence_ids_per_overflow) > 0: + word_ids_per_overflow = [ + -100 if sequence_id == 0 else word_id + for word_id, sequence_id in zip( + word_ids_per_overflow, sequence_ids_per_overflow + ) + ] + sequence_ids.append(sequence_ids_per_overflow) + word_ids.append(word_ids_per_overflow) + + sequence_ids = torch.tensor(sequence_ids) + word_ids = torch.tensor(word_ids) + return sequence_ids, word_ids + + +def _extract_token_bboxes_from_word_bboxes( + word_bboxes: list[list[float]], word_ids: torch.Tensor +) -> torch.Tensor: + token_bboxes = [] + for word_ids_per_sample in word_ids: + token_bboxes_per_sample = [ + [0, 0, 0, 0] if word_id == -100 else word_bboxes[word_id] + for word_id in word_ids_per_sample.tolist() + ] + token_bboxes.append(token_bboxes_per_sample) + return torch.tensor(token_bboxes) + + +def _extract_token_labels_from_word_labels( + word_labels: list[int], word_ids: Any +) -> torch.Tensor: + token_labels = [] + for word_ids_per_sample in word_ids: + token_labels_per_sample = [] + last_word_id = None + for word_id in word_ids_per_sample.tolist(): + if word_id == -100 or word_id == last_word_id: + token_labels_per_sample.append(-100) # padding label + else: + token_labels_per_sample.append(word_labels[word_id]) + last_word_id = word_id + token_labels.append(token_labels_per_sample) + return torch.tensor(token_labels) + + +def _extract_segment_level_data( + token_ids: torch.Tensor, + token_bboxes: torch.Tensor, + all_special_ids: set[int], + max_segment_num: int = 150, +) -> Mapping[str, Any]: + segment_index = _generate_segment_level_bbox_ranks( + token_ids=token_ids, + segment_level_bboxes=token_bboxes, + all_special_ids=all_special_ids, + ) + segment_inner_token_rank = _generate_segment_level_inner_ranks( + line_rank_id=segment_index + ) + first_token_idxes, first_token_idxes_mask = _generate_first_token_idxes( + line_rank_id=segment_index, max_segment_num=max_segment_num + ) + return { + "segment_index": segment_index, + "segment_inner_token_rank": segment_inner_token_rank, + "first_token_idxes": first_token_idxes, + "first_token_idxes_mask": first_token_idxes_mask, + } + + +def _post_process_tokenizer_outputs( + tokenization_data: BatchEncoding, + input_word_boxes: list[list[float]] | None, + input_word_labels: list[int] | None, + input_image: Any | None, + add_segment_level_info: bool = False, + all_special_ids: set[int] = set(), + max_segment_num: int = 150, +) -> Mapping[str, Any]: + sequence_ids, word_ids = _extract_sequence_and_word_ids(tokenization_data) + token_bboxes = tokenization_data.get("bbox", None) + if token_bboxes is None and input_word_boxes is not None: + token_bboxes = _extract_token_bboxes_from_word_bboxes( + input_word_boxes, word_ids + ) + token_labels = tokenization_data.get("labels", None) + if token_labels is None and input_word_labels is not None: + token_labels = _extract_token_labels_from_word_labels( + input_word_labels, word_ids + ) + image = tokenization_data.get("pixel_values", None) + if image is not None: + image = image[0] + if image is None and input_image is not None: + image = input_image + + outputs = { + "token_ids": tokenization_data.get("input_ids"), + "attention_mask": tokenization_data.get("attention_mask"), + "token_bboxes": token_bboxes, + "token_type_ids": tokenization_data.get("token_type_ids", None), + "token_labels": token_labels, + "sequence_ids": sequence_ids, + "word_ids": word_ids, + "image": image, + } + + if add_segment_level_info: + segment_level_data = _extract_segment_level_data( + token_ids=outputs["token_ids"], + token_bboxes=outputs["token_bboxes"], + all_special_ids=all_special_ids, + max_segment_num=max_segment_num, + ) + outputs.update(segment_level_data) + + # assert that we have all the keys + assert outputs["token_ids"] is not None, ( + "token_ids is None in the tokenizer outputs." + ) + assert outputs["attention_mask"] is not None, ( + "attention_mask is None in the tokenizer outputs." + ) + assert outputs["token_bboxes"] is not None, ( + "token_bboxes is None in the tokenizer outputs." + ) + assert outputs["sequence_ids"] is not None, ( + "sequence_ids is None in the tokenizer outputs." + ) + assert outputs["word_ids"] is not None, "word_ids is None in the tokenizer outputs." + assert outputs["image"] is not None, "image is None in the tokenizer outputs." + if input_word_labels is not None: + assert outputs["token_labels"] is not None, ( + "token_labels is None in the tokenizer outputs." + ) + + return outputs + + +def _get_subword_start_end(word_start, word_end, word_ids, sequence_ids): + start_of_context = -1 + for i in range(len(sequence_ids)): + if sequence_ids[i] == 1: + start_of_context = i + break + num_question_tokens = start_of_context + assert start_of_context != -1, "Could not find the start of the context" + subword_start = -1 + subword_end = -1 + for i in range(start_of_context, len(word_ids)): + if word_start == word_ids[i] and subword_start == -1: + subword_start = i + if word_end == word_ids[i]: + subword_end = i + return subword_start, subword_end, num_question_tokens + + +def _generate_qa_token_ids( + qa_pair: ExtractiveQAPair, + word_ids: torch.Tensor, + sequence_ids: torch.Tensor, + sequence_length: int = 512, +) -> tuple[torch.Tensor, torch.Tensor]: + import torch + + # since we can have multiple answers per question, we need to handle that here and just take one which is not + # -1 + word_ans_start, word_ans_end = -1, -1 + for ans_start, ans_end in zip(qa_pair.answer_start, qa_pair.answer_end): + if ans_start != -1 and ans_end != -1: + word_ans_start = ans_start + word_ans_end = ans_end + break + + # now we have one answer, with start and end indices in the word level + # we need to convert them to token level + + token_answer_starts, token_answer_ends = [], [] + for word_ids_per_overflow, sequence_ids_per_overflow in zip( + word_ids, sequence_ids, strict=True + ): + token_answer_start, token_answer_end = None, None + if word_ans_start == -1: + token_answer_start = -1 + token_answer_end = -1 + else: + (token_answer_start, token_answer_end, _) = _get_subword_start_end( + word_ans_start, + word_ans_end, + word_ids_per_overflow, + sequence_ids_per_overflow, + ) + if token_answer_start == -1: + token_answer_start = -1 + token_answer_end = -1 + if token_answer_end == -1: + token_answer_end = sequence_length - 1 + assert token_answer_end >= token_answer_start, ( + "End token index is less than start token index. " + "Something is wrong in the conversion from answer word indices to answer token indices." + ) + token_answer_starts.append(token_answer_start) + token_answer_ends.append(token_answer_end) + token_answer_start = torch.tensor( + token_answer_starts, dtype=torch.long, device=word_ids.device + ) + token_answer_end = torch.tensor( + token_answer_ends, dtype=torch.long, device=word_ids.device + ) + return token_answer_start, token_answer_end + + +def _extract_annotations(sample: DocumentInstance) -> Annotations: + """Extract annotations from sample.""" + annotations = Annotations() + + if sample.annotations is not None: + for ann in sample.annotations: + if isinstance(ann, ClassificationAnnotation): + annotations.label = ann.label + elif isinstance(ann, ExtractiveQAAnnotation): + annotations.qa_pairs = ann.qa_pairs + elif isinstance(ann, EntityLabelingAnnotation): + annotations.word_labels = ann.word_labels + elif isinstance(ann, LayoutAnalysisAnnotation): + annotations.annotated_objects = ann.annotated_objects + + return annotations + + +def _generate_segment_level_bbox_ranks( + token_ids: torch.Tensor, + segment_level_bboxes: torch.Tensor, + all_special_ids: set[int], +): + import torch + + line_rank_ids = [] + assert len(token_ids) == len(segment_level_bboxes), ( + f"Token ids and segment level bboxes must have the same batch size, Got {len(token_ids)} and {len(segment_level_bboxes)}" + ) + for token_ids_per_sample, bboxes_per_sample in zip( + token_ids, segment_level_bboxes + ): # this is a shape of [batch_size, seq_len, 4] in xyxy format and normalized 0-1000 + assert len(token_ids_per_sample) == len(bboxes_per_sample), ( + "Token ids and segment level bboxes must have the same sequence length" + ) + line_rank_ids_per_sample = [] + line_rank = 0 + last_b = None + for token_id, b in zip(token_ids_per_sample, bboxes_per_sample): + if last_b is not None and not torch.equal(b, last_b): + line_rank += 1 + if token_id in all_special_ids: + line_rank_ids_per_sample.append(0) + else: + line_rank_ids_per_sample.append(line_rank) + last_b = b + line_rank_ids.append(line_rank_ids_per_sample) + + return torch.tensor(line_rank_ids, device=segment_level_bboxes.device) + + +def _generate_segment_level_inner_ranks(line_rank_id: torch.Tensor): + # line_inner_rank_id is the inner rank as follows 1 means start 2 for all middle tokens 3 for end token ... for each token in the line/segment. + # if there is no middle token, start token will be 1 and end token will be 3. + inner_ranks = [] + for line_ranks_per_sample in line_rank_id: + inner_ranks_per_sample = torch.zeros_like( + line_ranks_per_sample, device=line_ranks_per_sample.device + ) + + line_segment_spans = [] + start_idx = 0 + last_lr = None + for curr_idx, lr in enumerate(line_ranks_per_sample): + if last_lr is not None and lr != last_lr: + line_segment_spans.append((start_idx, curr_idx - 1)) + start_idx = curr_idx + last_lr = lr + line_segment_spans.append( + (start_idx, start_idx) + ) # add the last segment for sep token + + for span in line_segment_spans: + span_start, span_end = span + span_length = span_end - span_start + if span_length == 0: + inner_ranks_per_sample[span_start] = 1 # only one token in the line + elif span_length == 1: + inner_ranks_per_sample[span_start] = 1 # start + inner_ranks_per_sample[span_end] = 3 # end + else: + inner_ranks_per_sample[span_start] = 1 # start + inner_ranks_per_sample[span_start + 1 : span_end] = 2 + inner_ranks_per_sample[span_end] = 3 # end + inner_ranks.append(inner_ranks_per_sample) + return torch.stack(inner_ranks) + + +def _generate_first_token_idxes(line_rank_id: torch.Tensor, max_segment_num: int = 150): + first_token_idxes = [] + first_token_idxes_mask = [] + for line_ranks_per_sample in line_rank_id: + first_token_idxes_per_sample = [] + first_token_idxes_mask_per_sample = [] + last_lr = None + for curr_idx, lr in enumerate(line_ranks_per_sample): + if last_lr is not None and lr != last_lr and lr != 0: + first_token_idxes_per_sample.append(curr_idx) + last_lr = lr + + # make mask + if len(first_token_idxes_per_sample) > max_segment_num: + first_token_idxes_per_sample = first_token_idxes_per_sample[ + :max_segment_num + ] + + first_token_idxes_mask_per_sample = [1] * len(first_token_idxes_per_sample) + [ + 0 + ] * (max_segment_num - len(first_token_idxes_per_sample)) + first_token_idxes_per_sample = first_token_idxes_per_sample + [0] * ( + max_segment_num - len(first_token_idxes_per_sample) + ) + first_token_idxes_mask.append(first_token_idxes_mask_per_sample) + first_token_idxes.append(first_token_idxes_per_sample) + + first_token_idxes = torch.tensor(first_token_idxes, device=line_rank_id.device) + first_token_idxes_mask = torch.tensor( + first_token_idxes_mask, device=line_rank_id.device, dtype=torch.float32 + ) + return first_token_idxes, first_token_idxes_mask diff --git a/docgenie/data/_transforms/mmdet.py b/docgenie/data/_transforms/mmdet.py new file mode 100755 index 0000000000000000000000000000000000000000..5600dafe4931771330f44c66cbe9083656931754 --- /dev/null +++ b/docgenie/data/_transforms/mmdet.py @@ -0,0 +1,345 @@ +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np +from mmcv.transforms.base import BaseTransform +from mmcv.transforms.utils import cache_randomness +from mmdet.registry import TRANSFORMS +from PIL.Image import Image as PILImage +from pydantic import Field + +from docgenie.data._core._data_types import ( + DocumentInstance, + LayoutAnalysisAnnotation, + MMDetInput, +) +from docgenie.data._transforms._generics._base import ( + BaseTransform as DocGenieBaseTransform, +) +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +@TRANSFORMS.register_module() +class RandomChoiceResize(BaseTransform): + def __init__(self, scales: Sequence[int | tuple], **resize_kwargs) -> None: + super().__init__() + + import mmengine + from mmdet.datasets.transforms import Resize + + if isinstance(scales, list): + self.scales = scales + else: + self.scales = [scales] + assert mmengine.is_seq_of(self.scales, (tuple, int)) + self.resize = Resize(scale=0, backend="pillow", **resize_kwargs) + + @cache_randomness + def _random_select(self) -> tuple[int, int]: + """Randomly select an scale from given candidates. + + Returns: + (tuple, int): Returns a tuple ``(scale, scale_dix)``, + where ``scale`` is the selected image scale and + ``scale_idx`` is the selected index in the given candidates. + """ + + scale_idx = np.random.randint(len(self.scales)) + scale = self.scales[scale_idx] + return scale, scale_idx + + def transform(self, results: dict) -> dict: + """Apply resize transforms on results from a list of scales. + + Args: + results (dict): Result dict contains the data to transform. + + Returns: + dict: Resized results, 'img', 'gt_bboxes', 'gt_seg_map', + 'gt_keypoints', 'scale', 'scale_factor', 'img_shape', + and 'keep_ratio' keys are updated in result dict. + """ + + target_scale, scale_idx = self._random_select() + self.resize.scale = target_scale + results = self.resize(results) + results["scale_idx"] = scale_idx + return results + + def __repr__(self) -> str: + repr_str = self.__class__.__name__ + repr_str += f"(scales={self.scales}" + repr_str += f", resize={self.resize})" + return repr_str + + +class DocumentInstanceMMDetTransform(DocGenieBaseTransform[MMDetInput]): + train_scale: list[tuple[int, int]] | tuple[int, int] = Field( + default=[ + (480, 1333), + (512, 1333), + (800, 1333), + ], + description="Scale for training images.", + ) + test_scale: list[tuple[int, int]] | tuple[int, int] = Field( + default=(1333, 800), description="Scale for testing images." + ) + is_training: bool = Field( + default=False, description="Whether the transform is used for training." + ) + use_test_time_augmentation: bool = Field( + default=False, description="Whether to use test time augmentation." + ) + use_flip: bool = Field( + default=False, description="Whether to use flip augmentation during testing." + ) + use_fixed_size: bool = Field( + default=False, description="Whether to use fixed size resizing." + ) + fixed_size: int = Field( + default=800, description="Fixed size to resize the shorter side to." + ) + + def get_output_data_model(self) -> type[MMDetInput]: + return MMDetInput + + def model_post_init(self, context) -> None: + import torchvision.transforms as T + + # self._transform = T.Compose([]) + # return + from mmdet.datasets.transforms import ( + LoadAnnotations, + PackDetInputs, + RandomFlip, + Resize, + ) + + if self.is_training: + # from mmcv.transforms import RandomChoiceResize, TestTimeAug + from mmcv.transforms import TestTimeAug + + train_scale = self.train_scale + if isinstance(self.train_scale, tuple): + train_scale = [self.train_scale] + + self._transform = T.Compose( + [ + LoadAnnotations(with_bbox=True, with_mask=False, box_type=None), + Resize(scale=self.fixed_size, keep_ratio=False) + if self.use_fixed_size + else RandomChoiceResize( + scales=train_scale, keep_ratio=self.use_fixed_size is False + ), + *([RandomFlip(prob=0.5)] if self.use_flip else []), + PackDetInputs( + meta_keys=( + "id", + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "flip", + "flip_direction", + ) + ), + ] + ) + else: + from mmcv.transforms import TestTimeAug + + if self.use_test_time_augmentation: + if isinstance(self.test_scale, tuple): + test_scale = [self.test_scale] + self._transform = T.Compose( + [ + LoadAnnotations(with_bbox=True, with_mask=False, box_type=None), + TestTimeAug( + transforms=[ + [ + RandomChoiceResize( + scales=test_scale, keep_ratio=True + ) + ], + [RandomFlip(prob=0.0), RandomFlip(prob=1.0)], + [ + PackDetInputs( + meta_keys=( + "__key__", + "__index__", + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "flip", + "flip_direction", + ) + ) + ], + ] + ), + ] + ) + else: + import torchvision.transforms as T + from mmcv.transforms import TestTimeAug + + if isinstance(self.test_scale, list): + test_scale = self.test_scale[0] + else: + test_scale = self.test_scale + + self._transform = T.Compose( + [ + LoadAnnotations(with_bbox=True, with_mask=False, box_type=None), + Resize( + scale=self.fixed_size, + keep_ratio=self.use_fixed_size is False, + ), + PackDetInputs( + meta_keys=( + "__key__", + "__index__", + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + "flip", + "flip_direction", + ) + ), + ] + ) + + def _extract_annotated_objects(self, document_instance: DocumentInstance): + assert document_instance.annotations is not None, ( + f"Document instance must have annotations for {self.__class__} ." + ) + layout_annotations = None + for annotation in document_instance.annotations: + if isinstance(annotation, LayoutAnalysisAnnotation): + layout_annotations = annotation.annotated_objects + break + assert layout_annotations is not None, ( + f"Document instance must have layout annotations for {self.__class__}." + ) + return layout_annotations + + def _get_image(self, document_instance: DocumentInstance) -> PILImage: + assert document_instance.image is not None, ( + "DocumentInstance image must be loaded before applying transforms." + ) + assert isinstance(document_instance.image.content, PILImage), ( + "DocumentInstance image content must be a PIL Image." + ) + return document_instance.image.content + + def _is_valid_bbox( + self, bbox: list[float], image_width: int, image_height: int + ) -> bool: + x1, y1, x2, y2 = bbox + if 0 <= x1 < x2 <= image_width and 0 <= y1 < y2 <= image_height: + return True + if (x2 - x1) > 1 and (y2 - y1) > 1: + return True + return False + + def _unnormalize_bbox( + self, bbox: list[float], image_width: int, image_height: int + ) -> list[float]: + x1, y1, x2, y2 = bbox + return [ + x1 * image_width, + y1 * image_height, + x2 * image_width, + y2 * image_height, + ] + + def _clip_bbox( + self, bbox: list[float], image_width: int, image_height: int + ) -> list[float]: + x1, y1, x2, y2 = bbox + x1 = min(max(x1, 0), image_width - 1) + x2 = min(max(x2, 0), image_width - 1) + y1 = min(max(y1, 0), image_height - 1) + y2 = min(max(y2, 0), image_height - 1) + return [x1, y1, x2, y2] + + def _prepare_instances( + self, document_instance: DocumentInstance, image_width: int, image_height: int + ) -> list[dict]: + annotated_objects = self._extract_annotated_objects(document_instance) + + instances = [] + is_bbox_normalized = annotated_objects.bbox.normalized + for bbox, label, iscrowd in zip( + annotated_objects.bbox.value, + annotated_objects.label.value, + annotated_objects.iscrowd, + strict=True, + ): + if is_bbox_normalized: + bbox = self._unnormalize_bbox(bbox, image_width, image_height) + + # first clip the bbox to be within image bounds, then check validity + bbox = self._clip_bbox(bbox, image_width, image_height) + + if not self._is_valid_bbox(bbox, image_width, image_height): + logger.warning( + f"Invalid bbox {bbox} for image of size ({image_width}, {image_height}) in document instance {document_instance.sample_id}. Skipping this bbox." + ) + continue + + instance = { + "bbox": [float(coord) for coord in bbox], + "bbox_label": label, + "ignore_flag": 1 if iscrowd else 0, + } + instances.append(instance) + + return instances + + def _apply_transforms(self, document_instance: DocumentInstance) -> MMDetInput: + image = self._get_image(document_instance) + output = self._transform( + { + "id": document_instance.sample_id, + "img_id": document_instance.index, + "instances": self._prepare_instances( + document_instance, + image_width=image.width, + image_height=image.height, + ), + "img": np.array(image), + "img_shape": ( + image.height, + image.width, + ), + "ori_shape": ( + image.height, + image.width, + ), + } + ) + output = MMDetInput(**output) + + return output + + def __call__(self, document_instance: DocumentInstance): + return self._apply_transforms(document_instance) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(\n" + f" is_training={self.is_training},\n" + f" use_test_time_augmentation={self.use_test_time_augmentation},\n" + f" transform={self._transform},\n" + f")" + ) diff --git a/docgenie/data/_transforms/utilities.py b/docgenie/data/_transforms/utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..4f4d8ab740687ac000dc82c904b93eb24e6c2a40 --- /dev/null +++ b/docgenie/data/_transforms/utilities.py @@ -0,0 +1,8 @@ +import hashlib +import json + +def generate_transform_hash(kwargs_dict): + """Generate a unique hash for transform kwargs.""" + # Sort the dictionary to ensure consistent hashing + sorted_kwargs = json.dumps(kwargs_dict, sort_keys=True, default=str) + return hashlib.md5(sorted_kwargs.encode()).hexdigest()[:8] diff --git a/docgenie/data/_transforms/vlms/tranforms.py b/docgenie/data/_transforms/vlms/tranforms.py new file mode 100755 index 0000000000000000000000000000000000000000..dc2f05cded6a9beb2bcdfed56d86a37a7e8ca0f3 --- /dev/null +++ b/docgenie/data/_transforms/vlms/tranforms.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod + +from atria_core.utilities.repr import RepresentationMixin +from pydantic import BaseModel + +from docgenie.data._core._data_types import ( + ConditionalGenerationModelInput, + DatasetLabels, + DocumentInstance, +) +from docgenie.data._transforms._tokenizers._utilities import _extract_annotations +from docgenie.logging import get_logger + +logger = get_logger(__name__) + + +class BaseVLMTokenizer(RepresentationMixin, BaseModel, ABC): + """Base class for VLM tokenizers""" + + tokenizer_name: str = "deepseek-community/deepseek-vl-1.3b-base" + tokenizer_cache_dir: str = "./cache" + is_training: bool = True + dataset_labels: DatasetLabels + + def model_post_init(self, context) -> None: + self._default_init_kwargs = { + "cache_dir": self.tokenizer_cache_dir, + "local_files_only": False, + "apply_ocr": False, + } + self._default_call_kwargs = { + "add_special_tokens": True, + "padding": "max_length", + "truncation": True, + "max_length": 1024, + "stride": 0, + "pad_to_multiple_of": 8, + "return_tensors": "pt", + } + + self._setup_processor() + self._tokenizer = ( + self._processor.tokenizer + if hasattr(self._processor, "tokenizer") + else self._processor + ) + + def _setup_processor(self): + """Setup processor - can be overridden by child classes""" + from transformers import AutoProcessor + + self._processor = AutoProcessor.from_pretrained( + self.tokenizer_name, **self._default_init_kwargs + ) + + def _get_common_kwargs(self, document_instance: DocumentInstance) -> tuple: + """Extract common data from document instance""" + image = document_instance.image.load().content.convert("RGB") + words = ( + document_instance.content.words + if document_instance.content is not None + else [] + ) + boxes = ( + document_instance.content.word_bboxes.value + if document_instance.content is not None + else [] + ) + + if not words: + words = ["None"] + boxes = [[0, 0, 0, 0]] + + return image, words, boxes + + def _tokenize_target(self, target_text: str, max_length: int = 128): + """Common target tokenization logic""" + + target_token_ids = self._tokenizer.encode( + target_text, + add_special_tokens=True, + return_tensors="pt", + max_length=max_length, + truncation=True, + padding="max_length", + )[0] + + # Set padding token IDs to -100 to ignore in loss computation + target_token_ids[target_token_ids == 0] = -100 + return target_token_ids + + @abstractmethod + def _prepare_instances( + self, document_instance: DocumentInstance, annotations + ) -> ConditionalGenerationModelInput: + """Prepare instances for the specific task type""" + pass + + def __call__(self, document_instance: DocumentInstance): + annotations = _extract_annotations(document_instance) + return self._prepare_instances(document_instance, annotations) + + +class SequenceClassificationVLMTokenizer(BaseVLMTokenizer): + """Tokenizer for sequence classification tasks""" + + def _prepare_instances( + self, document_instance: DocumentInstance, annotations + ) -> ConditionalGenerationModelInput: + import torch + + possible_labels = self.dataset_labels.classification or [] + image, words, boxes = self._get_common_kwargs(document_instance) + + prompt = f"Document Classification. Classify the document into one of these categories: {', '.join(possible_labels)}. Document: " + target_text = annotations.label.name + + # Tokenize input + if self.tokenizer_name == "microsoft/udop-large": + tokenized_instance = self._processor( + image, prompt, text_pair=words, boxes=boxes, **self._default_call_kwargs + ) + elif self.tokenizer_name in ["google-t5/t5-large", "google-t5/t5-base"]: + tokenized_instance = self._processor( + prompt, text_pair=" ".join(words), **self._default_call_kwargs + ) + + for key, value in tokenized_instance.items(): + tokenized_instance[key] = value.squeeze(0) + + target_token_ids = self._tokenize_target(target_text, max_length=16) + + return ConditionalGenerationModelInput( + **tokenized_instance, + index=torch.tensor(document_instance.index), + sample_id=document_instance.sample_id, + words=words, + target_text=target_text, + target_token_ids=target_token_ids, + _tokenizer_name=self.tokenizer_name, + _tokenizer_init_kwargs=self._default_init_kwargs, + ) diff --git a/docgenie/data/_transforms/vlms/utilities.py b/docgenie/data/_transforms/vlms/utilities.py new file mode 100755 index 0000000000000000000000000000000000000000..d755ee4f4d2d680cf0e8d7ecbb69eb53985849a1 --- /dev/null +++ b/docgenie/data/_transforms/vlms/utilities.py @@ -0,0 +1,63 @@ + + +import json +from docgenie.data._core._utilities import TaskType + + +def _prepare_system_messages(task_type: TaskType, labels: list[str]) -> str: + if task_type == TaskType.sequence_classification: + return f"You are a document classification model. Classify the document into one of the given categories: {json.dumps(labels)}." + + elif task_type == TaskType.token_classification: + return f"You are an information extraction model. Extract all the entities present in this document. Choose from the given entity categories: {json.dumps(labels)}." + + elif task_type == TaskType.extractive_qa: + return "You are a question answering model. Answer the question based on the content of the document." + + elif task_type == TaskType.layout_analysis: + return f""" + You are a layout analysis model. Extract the layout entities present in the document. + Choose from the given layout categories: {json.dumps(labels)}. + Provide the output in the format