BiGuan commited on
Commit
a223bb0
·
verified ·
1 Parent(s): 951a5f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -62
app.py CHANGED
@@ -12,14 +12,12 @@ from typing import TypedDict, Annotated, Sequence, List, Dict, Any, Generator
12
  from datetime import datetime
13
  import operator
14
 
15
- # LangChain / LangGraph
16
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
17
  from langchain_core.tools import tool
18
  from langgraph.graph import StateGraph, END
19
  from langgraph.prebuilt import ToolNode
20
  from langchain_core.utils.function_calling import convert_to_openai_function
21
 
22
- # 其他工具依赖
23
  from bs4 import BeautifulSoup
24
  from youtube_transcript_api import YouTubeTranscriptApi
25
 
@@ -32,10 +30,9 @@ AGICTO_API_KEY = os.getenv("AGICTO_API_KEY", "")
32
  QWEN_MODEL = "qwen3.5-35b-a3b"
33
 
34
  # =============================================================================
35
- # 进度监控器
36
  # =============================================================================
37
  class ProgressMonitor:
38
- # ... 保持不变 ...
39
  def __init__(self):
40
  self.current = 0
41
  self.total = 0
@@ -79,10 +76,9 @@ class ProgressMonitor:
79
  return html
80
 
81
  # =============================================================================
82
- # Qwen LLM 封装
83
  # =============================================================================
84
  class QwenLLM:
85
- # ... 保持不变 ...
86
  def __init__(self, model=QWEN_MODEL):
87
  self.model = model
88
  self.api_key = AGICTO_API_KEY
@@ -187,7 +183,7 @@ class QwenLLM:
187
  return formatted
188
 
189
  # =============================================================================
190
- # 工具定义
191
  # =============================================================================
192
  api_url_tasks = DEFAULT_API_URL
193
 
@@ -197,7 +193,6 @@ def _get_api_base():
197
  base = base[:-3]
198
  return base
199
 
200
- # --- 原有工具 ---
201
  @tool(description="搜索互联网信息,返回相关摘要。")
202
  def web_search(query: str) -> str:
203
  try:
@@ -315,20 +310,13 @@ def download_file_for_task(task_id: str) -> str:
315
  os.unlink(temp_path)
316
  return result
317
  else:
318
- # 对于文本文件(包括 .py, .txt 等),直接返回文本内容
319
  return resp.text[:4000]
320
  except Exception as e:
321
  return f"文件下载失败: {e}"
322
 
323
- # --- 新增:维基百科搜索工具 ---
324
  @tool(description="在维基百科中搜索关键词,返回页面摘要或详细信息。")
325
  def search_wikipedia(query: str) -> str:
326
- """
327
- 使用维基百科 API 搜索关键词。
328
- 首先尝试 opensearch 获取页面标题,然后用 extract 获取摘要。
329
- """
330
  try:
331
- # 第一步:搜索相关页面标题
332
  search_url = "https://en.wikipedia.org/w/api.php"
333
  params = {
334
  "action": "opensearch",
@@ -338,11 +326,10 @@ def search_wikipedia(query: str) -> str:
338
  }
339
  resp = requests.get(search_url, params=params, timeout=10)
340
  data = resp.json()
341
- titles = data[1] # 标题列表
342
  if not titles:
343
  return "维基百科未找到相关页面。"
344
  title = titles[0]
345
- # 第二步:获取页面摘要
346
  extract_params = {
347
  "action": "query",
348
  "prop": "extracts",
@@ -354,52 +341,55 @@ def search_wikipedia(query: str) -> str:
354
  resp2 = requests.get(search_url, params=extract_params, timeout=10)
355
  data2 = resp2.json()
356
  pages = data2.get("query", {}).get("pages", {})
357
- for page_id, page_info in pages.items():
358
  extract = page_info.get("extract", "")
359
  if extract:
360
- # 返回前2000字符,避免过长
361
  return f"Wikipedia - {title}:\n{extract[:2000]}"
362
  return f"维基百科页面 '{title}' 未提供摘要。"
363
  except Exception as e:
364
  return f"维基百科搜索失败: {e}"
365
 
366
  # =============================================================================
367
- # LangGraph 状态与节点
368
  # =============================================================================
369
  class AgentState(TypedDict):
370
  messages: Annotated[Sequence[BaseMessage], operator.add]
371
  final_answer: str
372
  task_id: str
373
- tool_attempts: int
374
-
375
- # 所有工具(包含新增的 search_wikipedia)
376
- tools = [
377
- search_wikipedia, # 优先搜索维基百科
378
- web_search, # 备用网络搜索
379
- web_scraper,
380
- calculator,
381
- analyze_image,
382
- transcribe_audio,
383
- get_youtube_transcript,
384
- download_file_for_task
385
- ]
386
 
 
 
387
  tool_node = ToolNode(tools)
388
  llm = QwenLLM()
389
  functions = [convert_to_openai_function(t) for t in tools]
390
  llm_with_tools = llm.bind_functions(functions)
391
 
 
 
392
  def agent_node(state: AgentState) -> dict:
393
  messages = state["messages"]
394
  task_id = state.get("task_id", "")
395
- # 更新系统提示,强调维基百科、文件处理和 YouTube 工具的使用
396
  sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions.
397
- IMPORTANT GUIDELINES:
398
- - For fact-based questions, first try to find the answer using the `search_wikipedia` tool. Only if Wikipedia fails, use `web_search` or other tools.
399
- - If the question provides a file (image, audio, or code), use `download_file_for_task` with the given task_id to retrieve it. The tool will automatically analyze images, transcribe audio, or return text for Python/text files.
400
- - For YouTube links, use `get_youtube_transcript` to obtain the captions.
401
- - When you have the final answer, output ONLY the answer string (a word, number, short phrase, or letter). Do NOT include any extra text, explanations, or "FINAL ANSWER:".
402
- Current task ID: {task_id}. If the question requires a file, use download_file_for_task with task_id="{task_id}"."""
 
 
 
 
 
 
 
 
 
 
 
 
403
  full = [SystemMessage(content=sys_prompt)] + list(messages)
404
  response = llm_with_tools.invoke(full)
405
  return {"messages": [response]}
@@ -408,28 +398,26 @@ def should_continue(state: AgentState) -> str:
408
  messages = state["messages"]
409
  last = messages[-1]
410
  tool_attempts = state.get("tool_attempts", 0)
411
- MAX_TOOL_CALLS = 3 # 限制最多3次工具调用,避免循环
412
 
 
413
  if tool_attempts >= MAX_TOOL_CALLS:
414
  return "finish"
415
 
 
416
  if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs:
417
  return "tools"
418
 
 
419
  tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage))
420
  if tool_msg_count == 0:
421
  return "force_tool"
422
 
423
- # 如果 LLM 已经给出了一个简洁答案,结束
424
- content = last.content
425
- if "?" not in content and len(content.strip()) < 100:
426
- return "finish"
427
-
428
  return "finish"
429
 
430
  def force_tool_node(state: AgentState) -> dict:
431
  new_msg = HumanMessage(
432
- content="You haven't used any tool yet. Please use an appropriate tool (e.g., search_wikipedia, download_file_for_task) to find the answer."
433
  )
434
  return {"messages": [new_msg]}
435
 
@@ -437,21 +425,31 @@ def increment_tool_count(state: AgentState) -> dict:
437
  return {"tool_attempts": state.get("tool_attempts", 0) + 1}
438
 
439
  def finish_node(state: AgentState) -> dict:
 
440
  last = state["messages"][-1]
441
  content = last.content
442
- answer = content.strip().split("\n")[-1].strip()
443
- if "FINAL ANSWER:" in answer:
444
- answer = answer.split("FINAL ANSWER:")[-1].strip()
445
-
446
- if not answer:
447
- for m in reversed(state["messages"]):
448
- if isinstance(m, AIMessage) and m.content.strip():
449
- answer = m.content.strip().split("\n")[-1].strip()
 
 
 
 
 
 
450
  break
 
 
451
 
452
- if not answer:
453
- if state.get("tool_attempts", 0) >= 3:
454
- answer = "Unable to determine answer: max tool calls reached."
 
455
  else:
456
  answer = "Unable to determine answer: insufficient information."
457
 
@@ -461,19 +459,28 @@ def build_graph():
461
  workflow = StateGraph(AgentState)
462
  workflow.add_node("agent", agent_node)
463
  workflow.add_node("tools", tool_node)
464
- workflow.add_node("finish", finish_node)
465
  workflow.add_node("force_tool", force_tool_node)
466
  workflow.add_node("count_tools", increment_tool_count)
 
467
 
468
  workflow.set_entry_point("agent")
 
469
  workflow.add_conditional_edges(
470
  "agent",
471
  should_continue,
472
- {"tools": "tools", "force_tool": "force_tool", "finish": "finish"}
 
 
 
 
473
  )
 
 
474
  workflow.add_edge("tools", "count_tools")
475
  workflow.add_edge("count_tools", "agent")
 
476
  workflow.add_edge("force_tool", "agent")
 
477
  workflow.add_edge("finish", END)
478
 
479
  return workflow.compile()
@@ -581,8 +588,8 @@ with gr.Blocks(title="GAIA Agent") as demo:
581
  gr.Markdown("""
582
  # 🤖 GAIA Level 1 Agent (LangGraph + Qwen)
583
  **模型:** Qwen3.5-35B-A3B | **API:** agicto.com
584
- 点击按钮获取题目,Agent 自动调用工具并回答,最后提交评分。
585
- **新增维基百科搜索、文件处理(图片/音频/代码)、YouTube 字幕提取**
586
  """)
587
  gr.LoginButton()
588
  run_btn = gr.Button("🚀 运行评测并提交", variant="primary")
 
12
  from datetime import datetime
13
  import operator
14
 
 
15
  from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage
16
  from langchain_core.tools import tool
17
  from langgraph.graph import StateGraph, END
18
  from langgraph.prebuilt import ToolNode
19
  from langchain_core.utils.function_calling import convert_to_openai_function
20
 
 
21
  from bs4 import BeautifulSoup
22
  from youtube_transcript_api import YouTubeTranscriptApi
23
 
 
30
  QWEN_MODEL = "qwen3.5-35b-a3b"
31
 
32
  # =============================================================================
33
+ # 进度监控器(不变)
34
  # =============================================================================
35
  class ProgressMonitor:
 
36
  def __init__(self):
37
  self.current = 0
38
  self.total = 0
 
76
  return html
77
 
78
  # =============================================================================
79
+ # Qwen LLM 封装(不变)
80
  # =============================================================================
81
  class QwenLLM:
 
82
  def __init__(self, model=QWEN_MODEL):
83
  self.model = model
84
  self.api_key = AGICTO_API_KEY
 
183
  return formatted
184
 
185
  # =============================================================================
186
+ # 工具定义(同之前,包含 search_wikipedia 等)
187
  # =============================================================================
188
  api_url_tasks = DEFAULT_API_URL
189
 
 
193
  base = base[:-3]
194
  return base
195
 
 
196
  @tool(description="搜索互联网信息,返回相关摘要。")
197
  def web_search(query: str) -> str:
198
  try:
 
310
  os.unlink(temp_path)
311
  return result
312
  else:
 
313
  return resp.text[:4000]
314
  except Exception as e:
315
  return f"文件下载失败: {e}"
316
 
 
317
  @tool(description="在维基百科中搜索关键词,返回页面摘要或详细信息。")
318
  def search_wikipedia(query: str) -> str:
 
 
 
 
319
  try:
 
320
  search_url = "https://en.wikipedia.org/w/api.php"
321
  params = {
322
  "action": "opensearch",
 
326
  }
327
  resp = requests.get(search_url, params=params, timeout=10)
328
  data = resp.json()
329
+ titles = data[1]
330
  if not titles:
331
  return "维基百科未找到相关页面。"
332
  title = titles[0]
 
333
  extract_params = {
334
  "action": "query",
335
  "prop": "extracts",
 
341
  resp2 = requests.get(search_url, params=extract_params, timeout=10)
342
  data2 = resp2.json()
343
  pages = data2.get("query", {}).get("pages", {})
344
+ for page_info in pages.values():
345
  extract = page_info.get("extract", "")
346
  if extract:
 
347
  return f"Wikipedia - {title}:\n{extract[:2000]}"
348
  return f"维基百科页面 '{title}' 未提供摘要。"
349
  except Exception as e:
350
  return f"维基百科搜索失败: {e}"
351
 
352
  # =============================================================================
353
+ # LangGraph 状态与节点(允许多次工具调用,最大3次)
354
  # =============================================================================
355
  class AgentState(TypedDict):
356
  messages: Annotated[Sequence[BaseMessage], operator.add]
357
  final_answer: str
358
  task_id: str
359
+ tool_attempts: int # 已使用的工具调用次数
 
 
 
 
 
 
 
 
 
 
 
 
360
 
361
+ tools = [search_wikipedia, web_search, web_scraper, calculator,
362
+ analyze_image, transcribe_audio, get_youtube_transcript, download_file_for_task]
363
  tool_node = ToolNode(tools)
364
  llm = QwenLLM()
365
  functions = [convert_to_openai_function(t) for t in tools]
366
  llm_with_tools = llm.bind_functions(functions)
367
 
368
+ MAX_TOOL_CALLS = 3 # 最多允许的工具调用次数
369
+
370
  def agent_node(state: AgentState) -> dict:
371
  messages = state["messages"]
372
  task_id = state.get("task_id", "")
373
+ # 系统提示:引导使用工具,但最终必须给出答案(不要闲聊)
374
  sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions.
375
+ You can use the following tools to find information:
376
+ - search_wikipedia: search Wikipedia for facts.
377
+ - web_search: general web search.
378
+ - web_scraper: fetch content from a URL.
379
+ - download_file_for_task: download a file associated with the current task (task_id: {task_id}). This can handle images, audio, and Python/text files.
380
+ - analyze_image: describe an image given a URL or base64 data.
381
+ - transcribe_audio: transcribe audio from a path or URL.
382
+ - get_youtube_transcript: get captions from a YouTube video.
383
+ - calculator: evaluate a mathematical expression.
384
+
385
+ Instructions:
386
+ 1. Use the most appropriate tool(s) to gather the information needed to answer the question.
387
+ 2. If you need to follow up (e.g., search then scrape a specific page), you may use another tool.
388
+ 3. Once you have enough information, output ONLY the final answer as a short string (a word, number, date, or phrase). Do NOT include explanations, greetings, or the phrase "FINAL ANSWER:".
389
+ 4. If after using tools you still cannot find the answer, output exactly: "Unable to determine answer: insufficient information."
390
+ 5. Do not make up an answer; only respond based on the information you retrieved.
391
+
392
+ Current task ID: {task_id}."""
393
  full = [SystemMessage(content=sys_prompt)] + list(messages)
394
  response = llm_with_tools.invoke(full)
395
  return {"messages": [response]}
 
398
  messages = state["messages"]
399
  last = messages[-1]
400
  tool_attempts = state.get("tool_attempts", 0)
 
401
 
402
+ # 如果已达到最大调用次数,强制进入 finish
403
  if tool_attempts >= MAX_TOOL_CALLS:
404
  return "finish"
405
 
406
+ # 如果 LLM 请求了工具调用,则去执行工具
407
  if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs:
408
  return "tools"
409
 
410
+ # 尚未使用过任何工具?强制要求使用工具(确保至少一次)
411
  tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage))
412
  if tool_msg_count == 0:
413
  return "force_tool"
414
 
415
+ # 否则,LLM 已经给出了最终答案,进入 finish
 
 
 
 
416
  return "finish"
417
 
418
  def force_tool_node(state: AgentState) -> dict:
419
  new_msg = HumanMessage(
420
+ content="You haven't used any tool yet. Please use an appropriate tool to find the answer."
421
  )
422
  return {"messages": [new_msg]}
423
 
 
425
  return {"tool_attempts": state.get("tool_attempts", 0) + 1}
426
 
427
  def finish_node(state: AgentState) -> dict:
428
+ """从最后一条 AI 消息中提取最终答案,并清理格式"""
429
  last = state["messages"][-1]
430
  content = last.content
431
+ # 如果已经包含标准错误信息,直接返回
432
+ if "Unable to determine answer" in content:
433
+ return {"final_answer": content.split("\n")[0].strip()}
434
+
435
+ # 去除可能的前缀
436
+ answer = content.split("FINAL ANSWER:")[-1].strip()
437
+
438
+ # 尝试提取简洁答案:如果过长或包含问句,取第一句
439
+ if len(answer) > 50 or "?" in answer:
440
+ sentences = re.split(r'(?<=[.!?])\s+', answer)
441
+ for s in sentences:
442
+ s = s.strip()
443
+ if s and "?" not in s and not s.startswith(("Let me", "I ", "You ", "Please")):
444
+ answer = s
445
  break
446
+ else:
447
+ answer = answer[:100].strip()
448
 
449
+ # 若最终答案仍为空或无效,给出错误原因
450
+ if not answer or answer in ("模型调用失败",):
451
+ if state.get("tool_attempts", 0) >= MAX_TOOL_CALLS:
452
+ answer = "Unable to determine answer: maximum tool calls reached."
453
  else:
454
  answer = "Unable to determine answer: insufficient information."
455
 
 
459
  workflow = StateGraph(AgentState)
460
  workflow.add_node("agent", agent_node)
461
  workflow.add_node("tools", tool_node)
 
462
  workflow.add_node("force_tool", force_tool_node)
463
  workflow.add_node("count_tools", increment_tool_count)
464
+ workflow.add_node("finish", finish_node)
465
 
466
  workflow.set_entry_point("agent")
467
+
468
  workflow.add_conditional_edges(
469
  "agent",
470
  should_continue,
471
+ {
472
+ "tools": "tools",
473
+ "force_tool": "force_tool",
474
+ "finish": "finish"
475
+ }
476
  )
477
+
478
+ # 工具调用后计数,然后返回 agent 继续思考
479
  workflow.add_edge("tools", "count_tools")
480
  workflow.add_edge("count_tools", "agent")
481
+ # force_tool 后返回 agent 重新决策
482
  workflow.add_edge("force_tool", "agent")
483
+ # finish 结束
484
  workflow.add_edge("finish", END)
485
 
486
  return workflow.compile()
 
588
  gr.Markdown("""
589
  # 🤖 GAIA Level 1 Agent (LangGraph + Qwen)
590
  **模型:** Qwen3.5-35B-A3B | **API:** agicto.com
591
+ 点击按钮获取题目,Agent 调用多个工具(最多3次)以获取,最后提交评分。
592
+ **工具:** 维基百科、网页搜索/抓取、图片分析、音频转录、YouTube字幕、文件下载
593
  """)
594
  gr.LoginButton()
595
  run_btn = gr.Button("🚀 运行评测并提交", variant="primary")