| import json |
| import os |
| from typing import List, Dict |
| from agent import GAIAAgent |
|
|
| def normalize_answer(answer: str) -> str: |
| """Normalize answer for comparison.""" |
| if not answer: |
| return "" |
| |
| |
| answer = answer.strip() |
| |
| |
| if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")): |
| answer = answer[1:-1] |
| |
| |
| return answer.lower().strip() |
|
|
| def extract_final_answer(response: str) -> str: |
| """Extract the final answer from the model response.""" |
| if "FINAL ANSWER:" in response: |
| answer = response.split("FINAL ANSWER:")[1].strip() |
| |
| answer = answer.split('\n')[0].strip() |
| return answer |
| |
| |
| lines = response.strip().split('\n') |
| return lines[-1].strip() |
|
|
| def load_gaia_dataset(dataset_path: str) -> List[Dict]: |
| """Load GAIA dataset from JSON/JSONL file.""" |
| tasks = [] |
| |
| if not os.path.exists(dataset_path): |
| print(f"Dataset file not found: {dataset_path}") |
| return tasks |
| |
| try: |
| with open(dataset_path, "r", encoding="utf-8") as f: |
| if dataset_path.endswith('.jsonl'): |
| |
| for line_num, line in enumerate(f, 1): |
| line = line.strip() |
| if line: |
| try: |
| task = json.loads(line) |
| tasks.append(task) |
| except json.JSONDecodeError as e: |
| print(f"Error parsing line {line_num}: {e}") |
| else: |
| |
| data = json.load(f) |
| if isinstance(data, list): |
| tasks = data |
| elif isinstance(data, dict) and 'tasks' in data: |
| tasks = data['tasks'] |
| else: |
| print("Unexpected JSON format") |
| |
| except Exception as e: |
| print(f"Error loading dataset: {e}") |
| |
| print(f"Loaded {len(tasks)} tasks from {dataset_path}") |
| return tasks |
|
|
| def create_sample_dataset() -> List[Dict]: |
| """Create a sample dataset for testing if no GAIA dataset is available.""" |
| sample_tasks = [ |
| { |
| "task_id": "sample_1", |
| "question": "What is 15 + 27?", |
| "answer": "42", |
| "level": 1, |
| "file_name": None |
| }, |
| { |
| "task_id": "sample_2", |
| "question": "What is the capital of France?", |
| "answer": "Paris", |
| "level": 1, |
| "file_name": None |
| }, |
| { |
| "task_id": "sample_3", |
| "question": "How many days are in a leap year?", |
| "answer": "366", |
| "level": 1, |
| "file_name": None |
| }, |
| { |
| "task_id": "sample_4", |
| "question": "What is 2 * 6 * 7?", |
| "answer": "84", |
| "level": 1, |
| "file_name": None |
| }, |
| { |
| "task_id": "sample_5", |
| "question": "What year did World War II end?", |
| "answer": "1945", |
| "level": 1, |
| "file_name": None |
| } |
| ] |
| |
| print("Using sample dataset for testing") |
| return sample_tasks |
|
|
| def evaluate_agent(dataset_path: str = None, max_tasks: int = None) -> float: |
| """Evaluate the GAIA agent on the dataset.""" |
| |
| if dataset_path and os.path.exists(dataset_path): |
| tasks = load_gaia_dataset(dataset_path) |
| else: |
| print("No dataset file found, using sample tasks for testing") |
| tasks = create_sample_dataset() |
| |
| if not tasks: |
| print("No tasks to evaluate") |
| return 0.0 |
| |
| |
| if max_tasks: |
| tasks = tasks[:max_tasks] |
| print(f"Evaluating on first {len(tasks)} tasks") |
| |
| |
| print("Initializing GAIA agent...") |
| agent = GAIAAgent() |
| |
| |
| print("Testing API connection...") |
| test_response = agent.test_grok() |
| if "error" in test_response.lower(): |
| print(f"API test failed: {test_response}") |
| return 0.0 |
| else: |
| print("API connection successful!") |
| |
| |
| correct = 0 |
| total = len(tasks) |
| submission_entries = [] |
| |
| print(f"\nStarting evaluation on {total} tasks...") |
| print("=" * 50) |
| |
| for i, task in enumerate(tasks, 1): |
| task_id = task.get("task_id", f"task_{i}") |
| question = task.get("question", "") |
| expected_answer = task.get("answer", "") |
| |
| print(f"\nTask {i}/{total}: {task_id}") |
| print(f"Question: {question[:100]}{'...' if len(question) > 100 else ''}") |
| |
| try: |
| |
| response = agent.process_task(task) |
| predicted_answer = extract_final_answer(response) |
| |
| print(f"Expected: {expected_answer}") |
| print(f"Predicted: {predicted_answer}") |
| |
| |
| is_correct = normalize_answer(predicted_answer) == normalize_answer(expected_answer) |
| |
| if is_correct: |
| correct += 1 |
| print("✅ CORRECT") |
| else: |
| print("❌ INCORRECT") |
| |
| |
| submission_entries.append({ |
| "task_id": task_id, |
| "model_answer": predicted_answer, |
| "reasoning_trace": response |
| }) |
| |
| except Exception as e: |
| print(f"Error processing task {task_id}: {e}") |
| submission_entries.append({ |
| "task_id": task_id, |
| "model_answer": "ERROR", |
| "reasoning_trace": f"Error: {str(e)}" |
| }) |
| |
| |
| current_score = (correct / i) * 100 |
| print(f"Current score: {correct}/{i} = {current_score:.1f}%") |
| print("-" * 30) |
| |
| |
| final_score = (correct / total) * 100 |
| |
| |
| try: |
| with open("submission.jsonl", "w", encoding="utf-8") as f: |
| for entry in submission_entries: |
| f.write(json.dumps(entry) + "\n") |
| print(f"\nSubmission saved to submission.jsonl") |
| except Exception as e: |
| print(f"Error saving submission: {e}") |
| |
| |
| print("=" * 50) |
| print("FINAL RESULTS") |
| print("=" * 50) |
| print(f"Total tasks: {total}") |
| print(f"Correct answers: {correct}") |
| print(f"Final score: {final_score:.2f}%") |
| |
| if final_score >= 30: |
| print("🎉 CONGRATULATIONS! Score ≥30% - Certificate achieved!") |
| else: |
| print(f"📈 Score below 30%. Need {30 - final_score:.2f}% more for certificate.") |
| |
| return final_score |
|
|
| def main(): |
| """Main evaluation function.""" |
| import argparse |
| |
| parser = argparse.ArgumentParser(description="Evaluate GAIA agent") |
| parser.add_argument("--dataset", type=str, default="gaia_test.json", |
| help="Path to GAIA dataset file") |
| parser.add_argument("--max-tasks", type=int, default=None, |
| help="Maximum number of tasks to evaluate") |
| |
| args = parser.parse_args() |
| |
| score = evaluate_agent(args.dataset, args.max_tasks) |
| |
| print(f"\nFinal evaluation score: {score:.2f}%") |
| |
| if score >= 30: |
| print("Certificate requirements met! 🎉") |
| else: |
| print("Keep working to reach 30% for the certificate! 💪") |
|
|
| if __name__ == "__main__": |
| main() |