Macmill commited on
Commit
673dfa2
·
verified ·
1 Parent(s): d81521e

Update final_agent.py

Browse files
Files changed (1) hide show
  1. final_agent.py +170 -59
final_agent.py CHANGED
@@ -13,6 +13,7 @@ import pytesseract # For image text extraction
13
  from urllib.parse import urlparse # For download tool
14
  from typing import Annotated, List, TypedDict, Optional
15
  from dotenv import load_dotenv
 
16
 
17
  # LangChain and LangGraph Imports
18
  from langgraph.graph import StateGraph, START, END
@@ -20,14 +21,14 @@ from langgraph.graph.message import add_messages
20
  from langgraph.prebuilt import ToolNode, tools_condition
21
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
22
  from langchain_core.tools import tool
23
- from langchain_groq import ChatGroq # Using Groq
 
24
  from langchain_community.tools.tavily_search import TavilySearchResults
25
 
26
  # ==============================================================================
27
  # Environment Setup & LLM
28
  # ==============================================================================
29
  load_dotenv()
30
- # Removed Gemini Key handling
31
  tavily_api_key = os.getenv("TAVILY_API_KEY")
32
  groq_api_key = os.getenv("GROQ_API_KEY")
33
 
@@ -36,23 +37,23 @@ groq_api_key = os.getenv("GROQ_API_KEY")
36
  # uncomment the following line and set the correct path to tesseract.exe
37
  # try:
38
  # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows
39
- # except NameError: pass
40
  # except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}")
41
 
42
 
43
  # --- Validate API Keys ---
44
  if not tavily_api_key:
45
- raise ValueError("TAVILY_API_KEY not found in Space secrets. Required for search.")
46
  if not groq_api_key:
47
- raise ValueError("GROQ_API_KEY not found in Space secrets. Required for Groq LLM.")
48
 
49
  # --- Initialize LLM (Using Groq) ---
50
  try:
51
  llm = ChatGroq(
52
- model="llama3-70b-8192", # Powerful model available on Groq, good for reasoning
53
- # model="gemma2-9b-it", # Alternative if Llama3 causes issues
54
  api_key=groq_api_key,
55
- temperature=0.1 # Low temperature for factual tasks
56
  )
57
  print(f"LLM Initialized: Groq - {llm.model_name}")
58
  except Exception as e:
@@ -65,10 +66,10 @@ except Exception as e:
65
  # ==============================================================================
66
  class AgentState(TypedDict):
67
  """Defines the structure of the information the agent tracks during its run."""
68
- input_question: str
69
- messages: Annotated[List[BaseMessage], add_messages]
70
- error: Optional[str]
71
- iterations: int
72
 
73
  # ==============================================================================
74
  # Tools Definitions
@@ -90,14 +91,22 @@ def web_browser(url: str) -> str:
90
  response = requests.get(url, headers=headers, timeout=20)
91
  response.raise_for_status()
92
  response.encoding = response.apparent_encoding or 'utf-8'
93
- h = html2text.HTML2Text(bodywidth=0); h.ignore_links = True; h.ignore_images = True
 
 
 
 
94
  clean_text = h.handle(response.text)
 
95
  max_length = 6000
96
- if len(clean_text) > max_length: return clean_text[:max_length] + "\n\n... [Content Truncated]"
 
97
  cleaned_and_stripped = clean_text.strip()
98
  return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
99
- except requests.exceptions.RequestException as e: return f"Error: Network request failed for URL: {url}. Reason: {e}"
100
- except Exception as e: return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}"
 
 
101
 
102
  # --- File Download Tool ---
103
  @tool
@@ -105,30 +114,43 @@ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
105
  """Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
106
  print(f"--- [Tool] Downloading file from: {url} ---")
107
  try:
 
108
  if not filename:
109
  try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
110
  except Exception: filename = None
111
  if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
 
112
  temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
 
113
  response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
114
  with open(filepath, 'wb') as f:
115
  for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
116
  print(f"--- [Tool] File downloaded to: {filepath} ---")
117
  return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it."
118
- except requests.exceptions.RequestException as e: return f"Error downloading file: Network issue for {url}. Reason: {e}"
119
- except Exception as e: return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}"
 
 
120
 
121
  # --- CSV Analysis Tool ---
122
  @tool
123
  def analyze_csv_file(file_path: str) -> str:
124
  """Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
125
  print(f"--- [Tool] Analyzing CSV: {file_path} ---")
 
126
  if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}"
127
  try:
128
- df = pd.read_csv(file_path); summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n- Columns: {', '.join(df.columns)}\n\nFirst 5 rows:\n{df.head().to_string()}\n"
 
 
 
 
 
129
  numeric_cols = df.select_dtypes(include=['number'])
130
- if not numeric_cols.empty: summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
131
- else: summary += "\nNo numeric columns for stats."
 
 
132
  return summary
133
  except ImportError: return "Error: 'pandas' required but not installed."
134
  except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}"
@@ -140,10 +162,17 @@ def analyze_excel_file(file_path: str) -> str:
140
  print(f"--- [Tool] Analyzing Excel: {file_path} ---")
141
  if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}"
142
  try:
143
- df = pd.read_excel(file_path, engine='openpyxl'); summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n- Columns: {', '.join(df.columns)}\n\nFirst 5 rows:\n{df.head().to_string()}\n"
 
 
 
 
 
144
  numeric_cols = df.select_dtypes(include=['number'])
145
- if not numeric_cols.empty: summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
146
- else: summary += "\nNo numeric columns for stats."
 
 
147
  return summary
148
  except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed."
149
  except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}"
@@ -155,8 +184,10 @@ def extract_text_from_image(file_path: str) -> str:
155
  print(f"--- [Tool] Extracting text from image: {file_path} ---")
156
  if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}"
157
  try:
 
158
  text = pytesseract.image_to_string(Image.open(file_path))
159
  text_stripped = text.strip()
 
160
  return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
161
  except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
162
  except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
@@ -164,14 +195,24 @@ def extract_text_from_image(file_path: str) -> str:
164
 
165
  # --- Basic Math Tools ---
166
  @tool
167
- def add(a: float, b: float) -> float: """Adds two numbers (a + b)."""
 
 
 
168
  @tool
169
- def subtract(a: float, b: float) -> float: """Subtracts the second number from the first (a - b)."""
 
 
 
170
  @tool
171
- def multiply(a: float, b: float) -> float: """Multiplies two numbers (a * b)."""
 
 
 
172
  @tool
173
  def divide(a: float, b: float) -> float | str:
174
- """Divides the first number by the second (a / b). Handles division by zero."""
 
175
  if b == 0: return "Error: Cannot divide by zero."
176
  return a / b
177
 
@@ -187,41 +228,87 @@ llm_with_tools = llm.bind_tools(tools)
187
  print(f"Agent initialized with {len(tools)} tools.")
188
 
189
  # ==============================================================================
190
- # Node Definitions
191
  # ==============================================================================
192
  print("Defining graph nodes...")
193
 
194
  # --- Agent Node ---
195
  def call_agent_node(state: AgentState) -> dict:
196
  """Invokes the LLM with current state to decide the next step."""
197
- print(f"\n--- [Node] Agent thinking... (Iteration {state['iterations']}) ---")
198
- MAX_ITERATIONS = 10 # Max steps allowed for the task
 
199
  current_iterations = state.get('iterations', 0)
200
  if current_iterations >= MAX_ITERATIONS:
201
- print(f"Warning: Reached max iterations ({MAX_ITERATIONS}). Stopping.")
202
  return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
203
  try:
 
204
  # Ensure LLM is bound with tools before invoking
205
  if 'llm_with_tools' not in globals():
206
  return {"error": "LLM tools not bound."}
 
207
  response = llm_with_tools.invoke(state['messages'])
208
- print("--- [Node] AI Response/Action ---"); response.pretty_print()
 
 
 
209
  return {"messages": [response], "iterations": current_iterations + 1}
210
  except Exception as e:
211
- error_message = f"LLM invocation failed: {str(e)}"; print(f"--- [Node] ERROR: {error_message} ---")
212
- traceback.print_exc()
213
- return {"messages": [AIMessage(content=f"Sorry, I encountered an error: {error_message}")], "error": error_message, "iterations": current_iterations + 1}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
- # --- Tool Node ---
216
- tool_node = ToolNode(tools)
217
 
218
  # ==============================================================================
219
- # Graph Construction (Non-conversational)
220
  # ==============================================================================
221
  print("Building agent graph...")
222
  builder = StateGraph(AgentState)
223
  builder.add_node("agent", call_agent_node)
224
- builder.add_node("tools", tool_node)
225
  builder.add_edge(START, "agent")
226
  builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", END: END})
227
  builder.add_edge("tools", "agent")
@@ -233,7 +320,8 @@ try:
233
  except Exception as e:
234
  print(f"ERROR: Failed to compile LangGraph graph: {e}")
235
  traceback.print_exc()
236
- graph = None # Set graph to None if compilation fails
 
237
 
238
  # ==============================================================================
239
  # Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>>
@@ -252,21 +340,21 @@ def answer_gaia_task(question: str, file_path: Optional[str] = None) -> str:
252
  file_context_info = f"An associated file is provided at path: '{file_path}'. Use this path if relevant." if file_path else ""
253
 
254
  # Define the initial prompt sent to the agent, incorporating strict formatting rules
255
- prompt_content = f"""You are a precise AI assistant answering a specific question based *only* on information obtained using your tools.
256
 
257
  {file_context_info}
258
 
259
  Follow these steps methodically:
260
  1. Analyze the question: {question}
261
- 2. Use tools (web_search, web_browser, download_file_from_url, analyze_csv_file, analyze_excel_file, extract_text_from_image, add, subtract, multiply, divide) ONLY if necessary to gather the specific information required. Assume local file paths mentioned in the question (like 'data.csv') are accessible.
262
  3. Synthesize the final answer from the gathered information.
263
 
264
  **CRITICAL OUTPUT FORMATTING RULES:**
265
- * Your final response MUST be ONLY the answer, without any other text, explanations, or introductions.
266
- * **Numbers:** Do not use commas (e.g., 1000). Do not include units ($ , %) unless explicitly asked for.
267
- * **Strings:** Do not use articles (a, an, the) unless part of a required proper noun. Do not use abbreviations (e.g., write "Saint Petersburg") unless the abbreviation is the answer. Write digits as numerals (5).
268
- * **Lists:** If a list is required, provide it as comma-separated values (e.g., apple,banana,cherry). Apply number/string rules to elements.
269
- * If you cannot find the answer using the tools, output only the exact phrase: Information not found
270
 
271
  Provide ONLY the final answer according to these rules.
272
  """
@@ -283,13 +371,15 @@ Provide ONLY the final answer according to these rules.
283
 
284
  try:
285
  # Invoke the compiled graph
286
- final_state = graph.invoke(initial_state, config={"recursion_limit": 15}) # Set recursion limit
287
 
288
  # Process the final state to extract the answer
289
  if final_state:
 
290
  if final_state.get("error"):
291
  print(f"--- Agent stopped due to ERROR: {final_state['error']} ---")
292
  final_answer = f"Error: {final_state['error']}"
 
293
  elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage):
294
  potential_answer = final_state['messages'][-1].content
295
  # Basic cleanup for potential quotes added by LLM
@@ -301,6 +391,7 @@ Provide ONLY the final answer according to these rules.
301
  final_answer = potential_answer
302
  else:
303
  print("--- Could not determine final answer (last message not AI or missing). Check logs. ---")
 
304
  print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}")
305
 
306
  except Exception as e:
@@ -319,31 +410,51 @@ Provide ONLY the final answer according to these rules.
319
  # This block allows you to test the agent by running final_agent.py directly.
320
  if __name__ == "__main__":
321
  print("\n--- Running Local Test ---")
 
322
  test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?"
 
 
323
  print("Creating dummy files for local test...")
324
  dummy_files_created = True
325
  try:
326
- with open("data.csv", "w") as f: f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2")
 
 
 
 
327
  try:
328
- img = Image.new('RGB', (300, 50), color = (255, 255, 255))
329
- from PIL import ImageDraw, ImageFont
330
  draw = ImageDraw.Draw(img)
 
331
  try: font = ImageFont.truetype("arial.ttf", 15)
332
  except IOError: font = ImageFont.load_default()
333
- draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font)
334
  img.save("image.png")
335
  print("Dummy data.csv and image.png created successfully.")
336
- except ImportError: print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image."); dummy_files_created = False
337
- except Exception as img_e: print(f"Error creating dummy image: {img_e}"); dummy_files_created = False
338
- except Exception as file_e: print(f"Error creating dummy files: {file_e}"); dummy_files_created = False
339
-
 
 
 
 
 
 
 
 
 
340
  if dummy_files_created:
 
341
  result = answer_gaia_task(question=test_question, file_path=None)
342
  print(f"\n--- Local Test Result ---")
343
  print(f"Returned Answer: {result}")
344
- print(f"Expected Answer (for dummy files): 21")
345
- else: print("Skipping test execution due to issues creating dummy files.")
 
346
 
 
347
  print("\nCleaning up dummy files...")
348
  for dummy_file in ["data.csv", "image.png"]:
349
  if os.path.exists(dummy_file):
 
13
  from urllib.parse import urlparse # For download tool
14
  from typing import Annotated, List, TypedDict, Optional
15
  from dotenv import load_dotenv
16
+ import time # For adding potential delays if needed later
17
 
18
  # LangChain and LangGraph Imports
19
  from langgraph.graph import StateGraph, START, END
 
21
  from langgraph.prebuilt import ToolNode, tools_condition
22
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage
23
  from langchain_core.tools import tool
24
+ # LLM Import - Using Groq
25
+ from langchain_groq import ChatGroq
26
  from langchain_community.tools.tavily_search import TavilySearchResults
27
 
28
  # ==============================================================================
29
  # Environment Setup & LLM
30
  # ==============================================================================
31
  load_dotenv()
 
32
  tavily_api_key = os.getenv("TAVILY_API_KEY")
33
  groq_api_key = os.getenv("GROQ_API_KEY")
34
 
 
37
  # uncomment the following line and set the correct path to tesseract.exe
38
  # try:
39
  # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Example path for Windows
40
+ # except NameError: pass # Handles case where pytesseract might not be imported yet if PIL fails first
41
  # except Exception as e: print(f"Warning: Could not set tesseract_cmd path: {e}")
42
 
43
 
44
  # --- Validate API Keys ---
45
  if not tavily_api_key:
46
+ raise ValueError("TAVILY_API_KEY not found in environment variables/Space secrets.")
47
  if not groq_api_key:
48
+ raise ValueError("GROQ_API_KEY not found in environment variables/Space secrets.")
49
 
50
  # --- Initialize LLM (Using Groq) ---
51
  try:
52
  llm = ChatGroq(
53
+ model="meta-llama/llama-4-maverick-17b-128e-instruct", # Powerful model available on Groq, good for reasoning
54
+ # model="gemma2-9b-it", # Alternative lighter model
55
  api_key=groq_api_key,
56
+ temperature=0.3 # Low temperature for factual tasks
57
  )
58
  print(f"LLM Initialized: Groq - {llm.model_name}")
59
  except Exception as e:
 
66
  # ==============================================================================
67
  class AgentState(TypedDict):
68
  """Defines the structure of the information the agent tracks during its run."""
69
+ input_question: str # The original question from the benchmark
70
+ messages: Annotated[List[BaseMessage], add_messages] # History of interactions (Human, AI, Tool)
71
+ error: Optional[str] # Stores any error message encountered
72
+ iterations: int # Counter for agent steps to prevent loops
73
 
74
  # ==============================================================================
75
  # Tools Definitions
 
91
  response = requests.get(url, headers=headers, timeout=20)
92
  response.raise_for_status()
93
  response.encoding = response.apparent_encoding or 'utf-8'
94
+ # Configure html2text
95
+ h = html2text.HTML2Text(bodywidth=0)
96
+ h.ignore_links = True
97
+ h.ignore_images = True
98
+ # Convert HTML to text
99
  clean_text = h.handle(response.text)
100
+ # Limit content length
101
  max_length = 6000
102
+ if len(clean_text) > max_length:
103
+ return clean_text[:max_length] + "\n\n... [Content Truncated]"
104
  cleaned_and_stripped = clean_text.strip()
105
  return cleaned_and_stripped if cleaned_and_stripped else f"Error: No meaningful content via html2text for {url}."
106
+ except requests.exceptions.RequestException as e:
107
+ return f"Error: Network request failed for URL: {url}. Reason: {e}"
108
+ except Exception as e:
109
+ return f"Error: Unexpected error processing URL with html2text: {url}. Reason: {str(e)}"
110
 
111
  # --- File Download Tool ---
112
  @tool
 
114
  """Downloads a file from a URL to a temporary directory. Input: file URL. Returns: path to downloaded file or error."""
115
  print(f"--- [Tool] Downloading file from: {url} ---")
116
  try:
117
+ # Generate filename if needed
118
  if not filename:
119
  try: path = urlparse(url).path; filename = os.path.basename(path) if path else None
120
  except Exception: filename = None
121
  if not filename: import uuid; filename = f"downloaded_{uuid.uuid4().hex[:8]}"
122
+ # Define save path
123
  temp_dir = tempfile.gettempdir(); filepath = os.path.join(temp_dir, filename)
124
+ # Download file
125
  response = requests.get(url, stream=True, timeout=30); response.raise_for_status()
126
  with open(filepath, 'wb') as f:
127
  for chunk in response.iter_content(chunk_size=8192): f.write(chunk)
128
  print(f"--- [Tool] File downloaded to: {filepath} ---")
129
  return f"File downloaded to {filepath}. Use appropriate tools (e.g., analyze_csv_file) to process it."
130
+ except requests.exceptions.RequestException as e:
131
+ return f"Error downloading file: Network issue for {url}. Reason: {e}"
132
+ except Exception as e:
133
+ return f"Error downloading file: Unexpected error for {url}. Reason: {str(e)}"
134
 
135
  # --- CSV Analysis Tool ---
136
  @tool
137
  def analyze_csv_file(file_path: str) -> str:
138
  """Analyzes a CSV file at the given path using pandas. Returns a summary of content or error."""
139
  print(f"--- [Tool] Analyzing CSV: {file_path} ---")
140
+ # GAIA might provide relative paths, ensure they work or adjust logic if needed
141
  if not os.path.exists(file_path): return f"Error: CSV file not found at path: {file_path}"
142
  try:
143
+ df = pd.read_csv(file_path)
144
+ # Generate summary string
145
+ summary = f"CSV Analysis Report for {os.path.basename(file_path)}:\n"
146
+ summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
147
+ summary += f"- Columns: {', '.join(df.columns)}\n"
148
+ summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
149
  numeric_cols = df.select_dtypes(include=['number'])
150
+ if not numeric_cols.empty:
151
+ summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
152
+ else:
153
+ summary += "\nNo numeric columns for stats."
154
  return summary
155
  except ImportError: return "Error: 'pandas' required but not installed."
156
  except Exception as e: return f"Error analyzing CSV {file_path}: {str(e)}"
 
162
  print(f"--- [Tool] Analyzing Excel: {file_path} ---")
163
  if not os.path.exists(file_path): return f"Error: Excel file not found at path: {file_path}"
164
  try:
165
+ df = pd.read_excel(file_path, engine='openpyxl')
166
+ # Generate summary string
167
+ summary = f"Excel Analysis Report for {os.path.basename(file_path)} (First Sheet):\n"
168
+ summary += f"- Shape: {df.shape[0]} rows, {df.shape[1]} columns\n"
169
+ summary += f"- Columns: {', '.join(df.columns)}\n"
170
+ summary += f"\nFirst 5 rows:\n{df.head().to_string()}\n"
171
  numeric_cols = df.select_dtypes(include=['number'])
172
+ if not numeric_cols.empty:
173
+ summary += f"\nBasic Stats (Numeric):\n{numeric_cols.describe().to_string()}"
174
+ else:
175
+ summary += "\nNo numeric columns for stats."
176
  return summary
177
  except ImportError: return "Error: 'pandas' and 'openpyxl' required but not installed."
178
  except Exception as e: return f"Error analyzing Excel {file_path}: {str(e)}"
 
184
  print(f"--- [Tool] Extracting text from image: {file_path} ---")
185
  if not os.path.exists(file_path): return f"Error: Image file not found at path: {file_path}"
186
  try:
187
+ # Need to explicitly handle potential empty string from pytesseract
188
  text = pytesseract.image_to_string(Image.open(file_path))
189
  text_stripped = text.strip()
190
+ # Return a clear message if no text found, otherwise return extracted text
191
  return f"Extracted text from image '{os.path.basename(file_path)}':\n{text_stripped}" if text_stripped else "No text found in image."
192
  except ImportError: return "Error: 'Pillow' or 'pytesseract' required but not installed."
193
  except pytesseract.TesseractNotFoundError: return "Error: Tesseract OCR not installed or not in PATH."
 
195
 
196
  # --- Basic Math Tools ---
197
  @tool
198
+ def add(a: float, b: float) -> float:
199
+ """Adds two numbers (a + b). Handles float inputs."""
200
+ print(f"--- [Tool] Calculating: {a} + {b} ---")
201
+ return a + b
202
  @tool
203
+ def subtract(a: float, b: float) -> float:
204
+ """Subtracts the second number from the first (a - b). Handles float inputs."""
205
+ print(f"--- [Tool] Calculating: {a} - {b} ---")
206
+ return a - b
207
  @tool
208
+ def multiply(a: float, b: float) -> float:
209
+ """Multiplies two numbers (a * b). Handles float inputs."""
210
+ print(f"--- [Tool] Calculating: {a} * {b} ---")
211
+ return a * b
212
  @tool
213
  def divide(a: float, b: float) -> float | str:
214
+ """Divides the first number by the second (a / b). Handles float inputs and division by zero."""
215
+ print(f"--- [Tool] Calculating: {a} / {b} ---")
216
  if b == 0: return "Error: Cannot divide by zero."
217
  return a / b
218
 
 
228
  print(f"Agent initialized with {len(tools)} tools.")
229
 
230
  # ==============================================================================
231
+ # Node Definitions (With Logging Added)
232
  # ==============================================================================
233
  print("Defining graph nodes...")
234
 
235
  # --- Agent Node ---
236
  def call_agent_node(state: AgentState) -> dict:
237
  """Invokes the LLM with current state to decide the next step."""
238
+ # --- Logging: Node Entry ---
239
+ print(f"\n>>> Entering Agent Node (Iteration {state['iterations']})")
240
+ MAX_ITERATIONS = 15 # Max steps allowed for the task - Increased slightly
241
  current_iterations = state.get('iterations', 0)
242
  if current_iterations >= MAX_ITERATIONS:
243
+ print(f"!!! Agent Node: Max iterations ({MAX_ITERATIONS}) reached. Setting error.")
244
  return {"error": f"Max iterations ({MAX_ITERATIONS}) reached."}
245
  try:
246
+ print(f"--- Agent Node: Invoking LLM ({llm.model_name})... ---") # Log before LLM call
247
  # Ensure LLM is bound with tools before invoking
248
  if 'llm_with_tools' not in globals():
249
  return {"error": "LLM tools not bound."}
250
+
251
  response = llm_with_tools.invoke(state['messages'])
252
+ print(f"--- Agent Node: LLM Invocation Complete. ---") # Log after LLM call
253
+ # response.pretty_print() # Optional: Print formatted LLM response
254
+ # --- Logging: Node Exit (Success) ---
255
+ print(f"<<< Exiting Agent Node (Success, Iteration {current_iterations + 1})")
256
  return {"messages": [response], "iterations": current_iterations + 1}
257
  except Exception as e:
258
+ error_message = f"LLM invocation failed: {str(e)}"
259
+ print(f"!!! Agent Node ERROR: {error_message} !!!")
260
+ traceback.print_exc() # Print full traceback for debugging LLM errors
261
+ # --- Logging: Node Exit (Error) ---
262
+ print(f"<<< Exiting Agent Node (LLM Error, Iteration {current_iterations})")
263
+ # Return an error message and set error state, still increment iteration to prevent infinite error loops
264
+ return {"messages": [AIMessage(content=f"Error during LLM call: {error_message}")], "error": error_message, "iterations": current_iterations + 1}
265
+
266
+ # --- Tool Node Wrapper (for Logging) ---
267
+ # We still use the prebuilt ToolNode, but wrap its call for logging
268
+ tool_executor = ToolNode(tools) # Keep the instance
269
+
270
+ def logged_tool_node(state: AgentState) -> dict:
271
+ """Logs tool execution start/end and calls the actual ToolNode."""
272
+ print(f">>> Entering Tool Node")
273
+ # Log requested tools
274
+ last_message = state['messages'][-1]
275
+ requested_tools_str = "None"
276
+ tool_calls = []
277
+ if hasattr(last_message, "tool_calls") and last_message.tool_calls:
278
+ tool_calls = last_message.tool_calls
279
+ tool_names = [tc.get('name', 'unknown') for tc in tool_calls]
280
+ requested_tools_str = ", ".join(tool_names)
281
+ print(f"--- Tool Node: Executing tools: {requested_tools_str} ---")
282
+ if tool_calls: print(f"--- Tool Node: Tool Args: {[tc.get('args') for tc in tool_calls]} ---")
283
+
284
+
285
+ try:
286
+ # Call the actual ToolNode instance
287
+ result = tool_executor.invoke(state)
288
+ # Log truncated results
289
+ print("--- Tool Node: Tool Execution Results ---")
290
+ if isinstance(result.get("messages"), list):
291
+ for msg in result["messages"]:
292
+ if isinstance(msg, ToolMessage):
293
+ print(f" - Tool: {msg.name}, Result (start): {str(msg.content)[:200]}...") # Slightly more context
294
+ print(f"<<< Exiting Tool Node (Success)")
295
+ return result # Return the dictionary containing ToolMessages
296
+ except Exception as e:
297
+ error_message = f"ToolNode invocation exception: {str(e)}"
298
+ print(f"!!! Tool Node ERROR: {error_message} !!!")
299
+ traceback.print_exc()
300
+ print(f"<<< Exiting Tool Node (Error)")
301
+ # Return an error message in the expected format
302
+ return {"messages": [ToolMessage(content=error_message, tool_call_id="tool_node_error")]}
303
 
 
 
304
 
305
  # ==============================================================================
306
+ # Graph Construction (Non-conversational, using logged tool node)
307
  # ==============================================================================
308
  print("Building agent graph...")
309
  builder = StateGraph(AgentState)
310
  builder.add_node("agent", call_agent_node)
311
+ builder.add_node("tools", logged_tool_node) # Use the logging wrapper node
312
  builder.add_edge(START, "agent")
313
  builder.add_conditional_edges("agent", tools_condition, {"tools": "tools", END: END})
314
  builder.add_edge("tools", "agent")
 
320
  except Exception as e:
321
  print(f"ERROR: Failed to compile LangGraph graph: {e}")
322
  traceback.print_exc()
323
+ graph = None # Ensure graph is None if compilation fails
324
+ raise # Reraise exception to make startup failure clear
325
 
326
  # ==============================================================================
327
  # Main Execution Function for GAIA Benchmark <<<< WRAPPER FUNCTION >>>>
 
340
  file_context_info = f"An associated file is provided at path: '{file_path}'. Use this path if relevant." if file_path else ""
341
 
342
  # Define the initial prompt sent to the agent, incorporating strict formatting rules
343
+ prompt_content = f"""Your task is to accurately answer the following question based *only* on information obtained using your tools (web search, web browser, file download, csv/excel analysis, image OCR, math).
344
 
345
  {file_context_info}
346
 
347
  Follow these steps methodically:
348
  1. Analyze the question: {question}
349
+ 2. Use tools ONLY if necessary to gather the specific information required. Assume local file paths mentioned (like 'data.csv') are accessible.
350
  3. Synthesize the final answer from the gathered information.
351
 
352
  **CRITICAL OUTPUT FORMATTING RULES:**
353
+ * Your final response MUST be ONLY the answer, without any other text/explanations.
354
+ * **Numbers:** No commas (1000). No units ($ , %) unless asked.
355
+ * **Strings:** No articles (a, an, the) unless proper noun. No abbreviations (Saint Petersburg) unless answer is abbreviation. Use numerals (5).
356
+ * **Lists:** Comma-separated (apple,banana,cherry). Apply number/string rules to elements.
357
+ * If answer not found, output only the exact phrase: Information not found
358
 
359
  Provide ONLY the final answer according to these rules.
360
  """
 
371
 
372
  try:
373
  # Invoke the compiled graph
374
+ final_state = graph.invoke(initial_state, config={"recursion_limit": 20}) # Increased recursion limit
375
 
376
  # Process the final state to extract the answer
377
  if final_state:
378
+ # Prioritize showing agent error if one occurred
379
  if final_state.get("error"):
380
  print(f"--- Agent stopped due to ERROR: {final_state['error']} ---")
381
  final_answer = f"Error: {final_state['error']}"
382
+ # Otherwise, try to get the last AI message content
383
  elif final_state.get('messages') and isinstance(final_state['messages'][-1], AIMessage):
384
  potential_answer = final_state['messages'][-1].content
385
  # Basic cleanup for potential quotes added by LLM
 
391
  final_answer = potential_answer
392
  else:
393
  print("--- Could not determine final answer (last message not AI or missing). Check logs. ---")
394
+ # Log final state details for debugging
395
  print(f"Final State: Error={final_state.get('error')}, Iterations={final_state.get('iterations')}")
396
 
397
  except Exception as e:
 
410
  # This block allows you to test the agent by running final_agent.py directly.
411
  if __name__ == "__main__":
412
  print("\n--- Running Local Test ---")
413
+ # --- Define Test Question ---
414
  test_question = "What is the result of multiplying the number of rows (excluding the header) in 'data.csv' by the number found after the phrase 'total items:' in 'image.png'?"
415
+
416
+ # --- Create Dummy Files for Local Test ---
417
  print("Creating dummy files for local test...")
418
  dummy_files_created = True
419
  try:
420
+ # Dummy CSV with 3 data rows + header
421
+ with open("data.csv", "w") as f:
422
+ f.write("Header1,Header2\nRow1Val1,Row1Val2\nRow2Val1,Row2Val2\nRow3Val1,Row3Val2")
423
+
424
+ # Dummy Image containing the required text
425
  try:
426
+ img = Image.new('RGB', (300, 50), color = (255, 255, 255)) # White background
427
+ from PIL import ImageDraw, ImageFont # Import drawing tools locally
428
  draw = ImageDraw.Draw(img)
429
+ # Use a basic font if specific ones aren't found
430
  try: font = ImageFont.truetype("arial.ttf", 15)
431
  except IOError: font = ImageFont.load_default()
432
+ draw.text((10,10), "Some random info... total items: 7 ... more text", fill=(0,0,0), font=font) # Black text
433
  img.save("image.png")
434
  print("Dummy data.csv and image.png created successfully.")
435
+ except ImportError:
436
+ print("Pillow/ImageDraw/ImageFont not installed. Cannot create dummy image file.")
437
+ dummy_files_created = False
438
+ except Exception as img_e:
439
+ print(f"Error creating dummy image: {img_e}")
440
+ dummy_files_created = False
441
+
442
+ except Exception as file_e:
443
+ print(f"Error creating dummy files: {file_e}")
444
+ dummy_files_created = False
445
+ # ---------------------------------------------
446
+
447
+ # --- Run the Test ---
448
  if dummy_files_created:
449
+ # Call the main function, simulating how the benchmark runner would call it.
450
  result = answer_gaia_task(question=test_question, file_path=None)
451
  print(f"\n--- Local Test Result ---")
452
  print(f"Returned Answer: {result}")
453
+ print(f"Expected Answer (for dummy files): 21") # 3 data rows * 7 = 21
454
+ else:
455
+ print("Skipping test execution due to issues creating dummy files.")
456
 
457
+ # --- Clean up Dummy Files ---
458
  print("\nCleaning up dummy files...")
459
  for dummy_file in ["data.csv", "image.png"]:
460
  if os.path.exists(dummy_file):