Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Test script for async document generation API with Google Drive upload. | |
| Tests the complete async workflow with all features enabled: | |
| - Handwriting insertion | |
| - Visual elements (stamps, logos, figures, barcodes, photos) | |
| - OCR processing | |
| - Ground truth verification | |
| - Analysis and debug visualization | |
| - Dataset export | |
| - Google Drive upload | |
| Usage: | |
| python test_async_api.py | |
| The script uses hardcoded tokens and polls continuously for status updates. | |
| """ | |
| import requests | |
| import time | |
| import sys | |
| # Configuration | |
| BASE_URL = "http://localhost:8000" | |
| POLL_INTERVAL = 10 # seconds between status checks | |
| # Test payload with all features enabled | |
| PAYLOAD = { | |
| "user_id": 123, | |
| "google_drive_token": "ya29.a0ATkoCc5wSA3DqNSI35d2EOCfLku0NWULKJYNMPhngjTwcnEKrcNcut1vawhiErgauHc85BrZdF5pug1xzp9Zu1oWATlzIMrMo5jqKDaXWThC0GuRifayOstjOetZnRLPRxVlmjx4k_xm7rto_pN6mT1CUrnte0Qkwf7FJVtF08JzJqaCG9Vvamag4OkkOhy-LB8MsUQaCgYKAXASARISFQHGX2MiAX_4jMvIlv2OkO7WurUUVA0206", | |
| "google_drive_refresh_token": "1//03aLYGLUIYIl0CgYIARAAGAMSNwF-L9IrCfdJ-QHJHisqG86UjBvaEalyhWZdDcwbfLENt4V1ckik_wIkmsgjRwC9-SFeHrj-Yk4", | |
| "seed_images": [ | |
| "https://ocr.space/Content/Images/receipt-ocr-original.webp" | |
| ], | |
| "prompt_params": { | |
| "language": "English", | |
| "doc_type": "business and administrative", | |
| "gt_type": "Multiple questions about each document, with their answers taken **verbatim** from the document.", | |
| "gt_format": "{\"<Text of question 1>\": \"<Answer to question 1>\", \"<Text of question 2>\": \"<Answer to question 2>\", ...}", | |
| "num_solutions": 1, | |
| "enable_handwriting": True, | |
| "handwriting_ratio": 0.3, | |
| "enable_visual_elements": True, | |
| "visual_element_types": [ | |
| "stamp", | |
| "logo", | |
| "figure", | |
| "barcode", | |
| "photo" | |
| ], | |
| "seed": None, # Use None for random behavior, or set to integer for reproducibility | |
| "enable_ocr": True, | |
| "ocr_language": "en", | |
| "enable_bbox_normalization": True, | |
| "enable_gt_verification": True, | |
| "enable_analysis": True, | |
| "enable_debug_visualization": True, | |
| "enable_dataset_export": True, | |
| "dataset_export_format": "msgpack", | |
| "output_detail": "dataset" | |
| } | |
| } | |
| def test_health(): | |
| """Test API health endpoint""" | |
| print("=" * 80) | |
| print("TESTING API HEALTH") | |
| print("=" * 80) | |
| try: | |
| response = requests.get(f"{BASE_URL}/health", timeout=5) | |
| response.raise_for_status() | |
| print(f"β API is healthy: {response.json()}\n") | |
| return True | |
| except Exception as e: | |
| print(f"β Health check failed: {e}\n") | |
| return False | |
| def submit_async_job(): | |
| """Submit async document generation job""" | |
| print("=" * 80) | |
| print("SUBMITTING ASYNC JOB") | |
| print("=" * 80) | |
| print("\nConfiguration:") | |
| print(f" User ID: {PAYLOAD['user_id']}") | |
| print(f" Seed Images: {len(PAYLOAD['seed_images'])}") | |
| print(f" Num Solutions: {PAYLOAD['prompt_params']['num_solutions']}") | |
| print(f" Handwriting: {PAYLOAD['prompt_params']['enable_handwriting']} (ratio: {PAYLOAD['prompt_params']['handwriting_ratio']})") | |
| print(f" Visual Elements: {PAYLOAD['prompt_params']['enable_visual_elements']} (types: {len(PAYLOAD['prompt_params']['visual_element_types'])})") | |
| print(f" OCR: {PAYLOAD['prompt_params']['enable_ocr']}") | |
| print(f" GT Verification: {PAYLOAD['prompt_params']['enable_gt_verification']}") | |
| print(f" Analysis: {PAYLOAD['prompt_params']['enable_analysis']}") | |
| print(f" Debug Viz: {PAYLOAD['prompt_params']['enable_debug_visualization']}") | |
| print(f" Dataset Export: {PAYLOAD['prompt_params']['enable_dataset_export']}") | |
| print(f" Google Drive Upload: Yes") | |
| print() | |
| try: | |
| print("β³ Submitting job to /generate/async...") | |
| response = requests.post( | |
| f"{BASE_URL}/generate/async", | |
| json=PAYLOAD, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| request_id = result["request_id"] | |
| print(f"\nβ Job submitted successfully!") | |
| print(f" Request ID: {request_id}") | |
| print(f" Status: {result['status']}") | |
| print(f" Estimated Time: {result.get('estimated_time_minutes', 'N/A')} minutes") | |
| print(f" Poll URL: {result.get('poll_url', 'N/A')}") | |
| return request_id | |
| except requests.exceptions.HTTPError as e: | |
| print(f"\nβ Job submission failed: {e}") | |
| if e.response: | |
| print(f" Response: {e.response.text}") | |
| return None | |
| except Exception as e: | |
| print(f"\nβ Unexpected error: {e}") | |
| return None | |
| def poll_job_status(request_id): | |
| """Poll job status continuously until completion or failure""" | |
| print("\n" + "=" * 80) | |
| print("CONTINUOUS STATUS POLLING") | |
| print("=" * 80) | |
| print(f"Request ID: {request_id}") | |
| print(f"Polling every {POLL_INTERVAL} seconds...") | |
| print("Press Ctrl+C to stop polling\n") | |
| poll_count = 0 | |
| last_status = None | |
| last_progress = None | |
| while True: | |
| poll_count += 1 | |
| timestamp = time.strftime("%H:%M:%S") | |
| try: | |
| response = requests.get( | |
| f"{BASE_URL}/jobs/{request_id}/status", | |
| timeout=10 | |
| ) | |
| response.raise_for_status() | |
| status_data = response.json() | |
| current_status = status_data["status"] | |
| current_progress = status_data.get("progress") | |
| # Only print if status or progress changed | |
| if current_status != last_status or current_progress != last_progress: | |
| print(f"[{timestamp}] Poll #{poll_count}: {current_status.upper()}", end="") | |
| if current_progress: | |
| print(f" - {current_progress}", end="") | |
| print() | |
| last_status = current_status | |
| last_progress = current_progress | |
| # Check terminal states | |
| if current_status == "completed": | |
| print("\n" + "=" * 80) | |
| print("β JOB COMPLETED!") | |
| print("=" * 80) | |
| results = status_data.get('results', {}) | |
| download_url = results.get('download_url') | |
| if download_url: | |
| print(f" β Google Drive URL: {download_url}") | |
| else: | |
| print(f" β Google Drive URL not available") | |
| if results.get('file_size_mb'): | |
| print(f" File Size: {results['file_size_mb']:.2f} MB") | |
| print(f" Document Count: {results.get('document_count', 'N/A')}") | |
| print(f" Created: {status_data.get('created_at')}") | |
| print(f" Completed: {status_data.get('updated_at')}") | |
| return status_data | |
| elif current_status == "failed": | |
| print("\n" + "=" * 80) | |
| print("β JOB FAILED!") | |
| print("=" * 80) | |
| print(f" Error: {status_data.get('error_message', 'Unknown error')}") | |
| print(f" Created: {status_data.get('created_at')}") | |
| print(f" Failed: {status_data.get('updated_at')}") | |
| return status_data | |
| # Wait before next poll | |
| time.sleep(POLL_INTERVAL) | |
| except KeyboardInterrupt: | |
| print("\n\nβ Polling interrupted by user") | |
| print(f"You can continue polling manually:") | |
| print(f" GET {BASE_URL}/jobs/{request_id}/status") | |
| return {"status": "interrupted"} | |
| except Exception as e: | |
| print(f"\nβ Error polling status: {e}") | |
| time.sleep(POLL_INTERVAL) | |
| def list_user_jobs(): | |
| """List all jobs for the test user""" | |
| print("\n" + "=" * 80) | |
| print("LISTING USER JOBS") | |
| print("=" * 80) | |
| user_id = PAYLOAD['user_id'] | |
| try: | |
| response = requests.get( | |
| f"{BASE_URL}/jobs/user/{user_id}", | |
| params={"limit": 10, "offset": 0}, | |
| timeout=10 | |
| ) | |
| response.raise_for_status() | |
| result = response.json() | |
| jobs = result.get("jobs", []) | |
| print(f"\nβ Found {len(jobs)} jobs for user {user_id}:\n") | |
| for i, job in enumerate(jobs, 1): | |
| print(f"{i}. Request {job['request_id'][:8]}...") | |
| print(f" Status: {job['status']}") | |
| print(f" Created: {job.get('created_at', 'N/A')}") | |
| if job.get('download_url'): | |
| print(f" Download: {job['download_url']}") | |
| print() | |
| return jobs | |
| except Exception as e: | |
| print(f"\nβ Error listing jobs: {e}") | |
| return [] | |
| def main(): | |
| print("\n" + "=" * 80) | |
| print(" " * 15 + "ASYNC PDF API TEST - FULL FEATURE SET") | |
| print("=" * 80) | |
| print(f"Base URL: {BASE_URL}") | |
| print(f"User ID: {PAYLOAD['user_id']}") | |
| print("=" * 80) | |
| print() | |
| # Step 1: Health check | |
| if not test_health(): | |
| print("\nβ API is not accessible. Make sure the server is running.") | |
| print(f" Expected URL: {BASE_URL}") | |
| sys.exit(1) | |
| # Step 2: Submit job | |
| request_id = submit_async_job() | |
| if not request_id: | |
| print("\nβ Failed to submit job. Test aborted.") | |
| sys.exit(1) | |
| # Step 3: Poll status continuously | |
| final_status = poll_job_status(request_id) | |
| # Step 4: List all user jobs | |
| list_user_jobs() | |
| # Final summary | |
| print("\n" + "=" * 80) | |
| print(" " * 30 + "SUMMARY") | |
| print("=" * 80) | |
| status = final_status.get("status") | |
| if status == "completed": | |
| print("β ALL TESTS PASSED!") | |
| print("\nFeatures tested:") | |
| print(" β Async job submission") | |
| print(" β Handwriting insertion") | |
| print(" β Visual elements (5 types)") | |
| print(" β OCR processing") | |
| print(" β Ground truth verification") | |
| print(" β Analysis & debug visualization") | |
| print(" β Dataset export") | |
| print(" β Google Drive upload") | |
| print(" β Continuous status polling") | |
| print(f"\nβ Your documents are available at:") | |
| print(f" {final_status.get('results', {}).get('download_url')}") | |
| sys.exit(0) | |
| elif status == "failed": | |
| print("β JOB FAILED") | |
| print(f"Error: {final_status.get('error_message')}") | |
| sys.exit(1) | |
| elif status == "interrupted": | |
| print("βΈ POLLING INTERRUPTED") | |
| print(f"Job is still running. Check status manually:") | |
| print(f" GET {BASE_URL}/jobs/{request_id}/status") | |
| sys.exit(0) | |
| else: | |
| print("β± JOB STILL IN PROGRESS") | |
| print(f"Check status manually: GET {BASE_URL}/jobs/{request_id}/status") | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |