| | |
| | """ |
| | Quick test script for specific GAIA questions. |
| | Use this to verify fixes without running full evaluation. |
| | |
| | Usage: |
| | uv run python test/test_quick_fixes.py |
| | """ |
| |
|
| | import os |
| | import sys |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | from src.agent.graph import GAIAAgent |
| | from dotenv import load_dotenv |
| |
|
| | |
| | load_dotenv() |
| |
|
| | |
| | |
| | |
| |
|
| | TEST_QUESTIONS = [ |
| | { |
| | "task_id": "2d83110e-a098-4ebb-9987-066c06fa42d0", |
| | "name": "Reverse sentence (calculator threading fix)", |
| | "question": ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", |
| | "expected": "Right", |
| | }, |
| | { |
| | "task_id": "6f37996b-2ac7-44b0-8e68-6d28256631b4", |
| | "name": "Table commutativity (LLM issue - table in question)", |
| | "question": '''Given this table defining * on the set S = {a, b, c, d, e} |
| | |
| | |*|a|b|c|d|e| |
| | |---|---|---|---|---| |
| | |a|a|b|c|b|d| |
| | |b|b|c|a|e|c| |
| | |c|c|a|b|b|a| |
| | |d|b|e|b|e|d| |
| | |e|d|b|a|d|c| |
| | |
| | provide the subset of S involved in any possible counter-examples that prove * is not commutative. Provide your answer as a comma separated list of the elements in the set in alphabetical order.''', |
| | "expected": "b, e", |
| | }, |
| | ] |
| |
|
| | |
| |
|
| |
|
| | def test_question(agent: GAIAAgent, test_case: dict) -> dict: |
| | """Test a single question and return result.""" |
| | task_id = test_case["task_id"] |
| | question = test_case["question"] |
| | expected = test_case.get("expected", "N/A") |
| |
|
| | print(f"\n{'='*60}") |
| | print(f"Testing: {test_case['name']}") |
| | print(f"Task ID: {task_id}") |
| | print(f"Expected: {expected}") |
| | print(f"{'='*60}") |
| |
|
| | try: |
| | answer = agent(question, file_path=None) |
| |
|
| | |
| | is_correct = answer.strip().lower() == expected.lower().strip() |
| |
|
| | result = { |
| | "task_id": task_id, |
| | "name": test_case["name"], |
| | "question": question[:100] + "..." if len(question) > 100 else question, |
| | "expected": expected, |
| | "answer": answer, |
| | "correct": is_correct, |
| | "status": "success", |
| | } |
| |
|
| | |
| | if not answer: |
| | result["system_error"] = "yes" |
| | elif answer.lower().startswith("error:") or "no evidence collected" in answer.lower(): |
| | result["system_error"] = "yes" |
| | result["error_log"] = answer |
| | else: |
| | result["system_error"] = "no" |
| |
|
| | except Exception as e: |
| | result = { |
| | "task_id": task_id, |
| | "name": test_case["name"], |
| | "question": question[:100] + "..." if len(question) > 100 else question, |
| | "expected": expected, |
| | "answer": f"ERROR: {str(e)}", |
| | "correct": False, |
| | "status": "error", |
| | "system_error": "yes", |
| | "error_log": str(e), |
| | } |
| |
|
| | |
| | status_icon = "✅" if result["correct"] else "❌" if result["system_error"] == "no" else "⚠️" |
| | print(f"\n{status_icon} Result: {result['answer'][:100]}") |
| | if result["system_error"] == "yes": |
| | print(f" System Error: Yes") |
| | if result.get("error_log"): |
| | print(f" Error: {result['error_log'][:100]}") |
| |
|
| | return result |
| |
|
| |
|
| | def main(): |
| | """Run quick tests on specific questions.""" |
| | print("\n" + "="*60) |
| | print("GAIA Quick Test - Verify Fixes") |
| | print("="*60) |
| |
|
| | |
| | llm_provider = os.getenv("LLM_PROVIDER", "gemini") |
| | print(f"\nLLM Provider: {llm_provider}") |
| |
|
| | |
| | print("\nInitializing agent...") |
| | try: |
| | agent = GAIAAgent() |
| | print("✅ Agent initialized") |
| | except Exception as e: |
| | print(f"❌ Agent initialization failed: {e}") |
| | return |
| |
|
| | |
| | results = [] |
| | for test_case in TEST_QUESTIONS: |
| | result = test_question(agent, test_case) |
| | results.append(result) |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print("SUMMARY") |
| | print(f"{'='*60}") |
| |
|
| | success_count = sum(1 for r in results if r["correct"]) |
| | error_count = sum(1 for r in results if r["system_error"] == "yes") |
| | ai_fail_count = sum(1 for r in results if r["system_error"] == "no" and not r["correct"]) |
| |
|
| | print(f"\nTotal: {len(results)}") |
| | print(f"✅ Correct: {success_count}") |
| | print(f"⚠️ System Errors: {error_count}") |
| | print(f"❌ AI Wrong: {ai_fail_count}") |
| |
|
| | |
| | print(f"\nDetailed Results:") |
| | for r in results: |
| | status = "✅" if r["correct"] else "⚠️" if r["system_error"] == "yes" else "❌" |
| | print(f" {status} {r['name']}: {r['answer'][:50]}{'...' if len(r['answer']) > 50 else ''}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|