import os import gradio as gr import requests import json import re import tempfile import base64 import io import time import threading from typing import TypedDict, Annotated, Sequence, List, Dict, Any, Generator from datetime import datetime import operator from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage from langchain_core.tools import tool from langgraph.graph import StateGraph, END from langgraph.prebuilt import ToolNode from langchain_core.utils.function_calling import convert_to_openai_function from bs4 import BeautifulSoup from youtube_transcript_api import YouTubeTranscriptApi # ============================================================================= # 配置常量 # ============================================================================= DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" AGICTO_BASE_URL = os.getenv("AGICTO_BASE_URL", "https://api.agicto.cn") AGICTO_API_KEY = os.getenv("AGICTO_API_KEY", "") QWEN_MODEL = "qwen3.5-35b-a3b" # ============================================================================= # 进度监控器(不变) # ============================================================================= class ProgressMonitor: def __init__(self): self.current = 0 self.total = 0 self.last_question = "" self.last_answer = "" self.logs = [] self._lock = threading.Lock() def start(self, total: int): with self._lock: self.total = total self.current = 0 self.logs = [] def step(self, question: str, answer: str): with self._lock: self.current += 1 self.last_question = question self.last_answer = answer self.logs.append(f"✅ 第 {self.current}/{self.total} 题完成:{answer[:50]}...") def get_html(self) -> str: with self._lock: pct = int(self.current / self.total * 100) if self.total > 0 else 0 html = f"""

📊 实时进度

{pct}% ({self.current}/{self.total})

最新题目: {self.last_question[:100]}{"..." if len(self.last_question)>100 else ""}

答案: {self.last_answer}

详细日志
{chr(10).join(self.logs)}
""" return html # ============================================================================= # Qwen LLM 封装(不变) # ============================================================================= class QwenLLM: def __init__(self, model=QWEN_MODEL): self.model = model self.api_key = AGICTO_API_KEY base = AGICTO_BASE_URL.rstrip('/') if base.endswith('/v1'): base = base[:-3] self.base_url = base if not self.api_key: print("⚠️ 未设置 AGICTO_API_KEY,请检查环境变量") def _call_api(self, messages: list, functions: list = None, max_tokens=2000): headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" } body = { "model": self.model, "messages": messages, "temperature": 0.0, "max_tokens": max_tokens } if functions: body["tools"] = [{"type": "function", "function": f} for f in functions] body["tool_choice"] = "auto" url = f"{self.base_url}/v1/chat/completions" try: resp = requests.post(url, headers=headers, json=body, timeout=60) resp.raise_for_status() return resp.json() except Exception as e: print(f"API 调用失败: {e}") return None def invoke(self, messages: list) -> AIMessage: formatted = self._format_messages(messages) result = self._call_api(formatted) if not result: return AIMessage(content="模型调用失败") choice = result["choices"][0] msg = choice["message"] if "tool_calls" in msg and msg["tool_calls"]: tool_call = msg["tool_calls"][0] return AIMessage( content=msg.get("content", ""), additional_kwargs={ "function_call": { "name": tool_call["function"]["name"], "arguments": tool_call["function"]["arguments"] } } ) return AIMessage(content=msg["content"]) def bind_functions(self, functions: list): class BoundLLM: def __init__(self, llm, funcs): self.llm = llm self.functions = funcs def invoke(self, messages: list) -> AIMessage: formatted = self.llm._format_messages(messages) result = self.llm._call_api(formatted, functions=self.functions) if not result: return AIMessage(content="模型调用失败") choice = result["choices"][0] msg = choice["message"] if "tool_calls" in msg and msg["tool_calls"]: tool_call = msg["tool_calls"][0] return AIMessage( content=msg.get("content", ""), additional_kwargs={ "function_call": { "name": tool_call["function"]["name"], "arguments": tool_call["function"]["arguments"] } } ) return AIMessage(content=msg["content"]) return BoundLLM(self, functions) def _format_messages(self, messages: list) -> list: formatted = [] for m in messages: if isinstance(m, SystemMessage): formatted.append({"role": "system", "content": m.content}) elif isinstance(m, HumanMessage): formatted.append({"role": "user", "content": m.content}) elif isinstance(m, AIMessage): entry = {"role": "assistant", "content": m.content} if hasattr(m, "additional_kwargs") and "function_call" in m.additional_kwargs: entry["tool_calls"] = [{ "id": "call_1", "type": "function", "function": m.additional_kwargs["function_call"] }] formatted.append(entry) elif isinstance(m, ToolMessage): formatted.append({ "role": "tool", "tool_call_id": m.tool_call_id if hasattr(m, "tool_call_id") else "call_1", "content": m.content }) return formatted # ============================================================================= # 工具定义(同之前,包含 search_wikipedia 等) # ============================================================================= api_url_tasks = DEFAULT_API_URL def _get_api_base(): base = AGICTO_BASE_URL.rstrip('/') if base.endswith('/v1'): base = base[:-3] return base @tool(description="搜索互联网信息,返回相关摘要。") def web_search(query: str) -> str: try: url = "https://api.duckduckgo.com/" params = {"q": query, "format": "json", "no_html": 1} resp = requests.get(url, params=params, timeout=10) data = resp.json() parts = [] if data.get("AbstractText"): parts.append(f"摘要: {data['AbstractText']}") for topic in data.get("RelatedTopics", [])[:3]: if isinstance(topic, dict) and "Text" in topic: parts.append(topic["Text"]) return "\n".join(parts) if parts else "未找到相关信息" except Exception as e: return f"搜索失败: {e}" @tool(description="抓取网页并提取纯文本内容。") def web_scraper(url: str) -> str: try: headers = {"User-Agent": "Mozilla/5.0"} resp = requests.get(url, headers=headers, timeout=15) soup = BeautifulSoup(resp.text, "html.parser") for el in soup(["script", "style", "nav", "footer"]): el.decompose() text = soup.get_text() lines = [line.strip() for line in text.splitlines() if line.strip()] return " ".join(lines)[:5000] except Exception as e: return f"抓取失败: {e}" @tool(description="执行数学表达式计算。") def calculator(expression: str) -> str: try: import math allowed = {k: v for k, v in math.__dict__.items() if not k.startswith("__")} result = eval(expression, {"__builtins__": {}}, allowed) return str(result) except Exception as e: return f"计算失败: {e}" @tool(description="分析图片内容(支持URL或base64编码)。") def analyze_image(image_data: str) -> str: try: headers = {"Authorization": f"Bearer {AGICTO_API_KEY}", "Content-Type": "application/json"} if not image_data.startswith("http"): image_data = f"data:image/jpeg;base64,{image_data}" body = { "model": QWEN_MODEL, "messages": [{"role": "user", "content": [ {"type": "text", "text": "请详细描述这张图片的内容,包括文字、数字等信息。"}, {"type": "image_url", "image_url": {"url": image_data}} ]}], "max_tokens": 800 } base = _get_api_base() url = f"{base}/v1/chat/completions" resp = requests.post(url, headers=headers, json=body, timeout=30) if resp.status_code == 200: return resp.json()["choices"][0]["message"]["content"] return f"图片分析失败: {resp.status_code}" except Exception as e: return f"图片分析失败: {e}" @tool(description="将音频文件(路径或URL)转录为文字。") def transcribe_audio(audio_path: str) -> str: try: headers = {"Authorization": f"Bearer {AGICTO_API_KEY}"} if audio_path.startswith("http"): resp = requests.get(audio_path, timeout=30) audio_data = io.BytesIO(resp.content) audio_data.name = "audio.mp3" else: audio_data = open(audio_path, "rb") files = {"file": audio_data, "model": (None, "whisper-1")} base = _get_api_base() url = f"{base}/v1/audio/transcriptions" resp = requests.post(url, headers=headers, files=files, timeout=60) if resp.status_code == 200: return resp.json()["text"] return f"转录失败: {resp.status_code}" except Exception as e: return f"转录失败: {e}" @tool(description="获取YouTube视频的字幕文本。") def get_youtube_transcript(video_url: str) -> str: try: if "watch?v=" in video_url: vid = video_url.split("v=")[1].split("&")[0] elif "youtu.be/" in video_url: vid = video_url.split("youtu.be/")[1].split("?")[0] else: return "无法提取视频 ID" transcript = YouTubeTranscriptApi.get_transcript(vid, languages=['en', 'zh']) return " ".join([t['text'] for t in transcript])[:4000] except Exception as e: return f"获取字幕失败: {e}" @tool(description="下载指定任务关联的文件,并返回文本内容或分析结果。") def download_file_for_task(task_id: str) -> str: try: url = f"{api_url_tasks}/files/{task_id}" resp = requests.get(url, timeout=20) if resp.status_code != 200: return f"文件不存在 (HTTP {resp.status_code})" content_type = resp.headers.get("content-type", "") if "image" in content_type: b64 = base64.b64encode(resp.content).decode() return analyze_image(b64) elif "audio" in content_type: with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: f.write(resp.content) temp_path = f.name result = transcribe_audio(temp_path) os.unlink(temp_path) return result else: return resp.text[:4000] except Exception as e: return f"文件下载失败: {e}" @tool(description="在维基百科中搜索关键词,返回页面摘要或详细信息。") def search_wikipedia(query: str) -> str: try: search_url = "https://en.wikipedia.org/w/api.php" params = { "action": "opensearch", "search": query, "limit": 1, "format": "json" } resp = requests.get(search_url, params=params, timeout=10) data = resp.json() titles = data[1] if not titles: return "维基百科未找到相关页面。" title = titles[0] extract_params = { "action": "query", "prop": "extracts", "exintro": True, "explaintext": True, "titles": title, "format": "json" } resp2 = requests.get(search_url, params=extract_params, timeout=10) data2 = resp2.json() pages = data2.get("query", {}).get("pages", {}) for page_info in pages.values(): extract = page_info.get("extract", "") if extract: return f"Wikipedia - {title}:\n{extract[:2000]}" return f"维基百科页面 '{title}' 未提供摘要。" except Exception as e: return f"维基百科搜索失败: {e}" # ============================================================================= # LangGraph 状态与节点(允许多次工具调用,最大3次) # ============================================================================= class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], operator.add] final_answer: str task_id: str tool_attempts: int # 已使用的工具调用次数 tools = [search_wikipedia, web_search, web_scraper, calculator, analyze_image, transcribe_audio, get_youtube_transcript, download_file_for_task] tool_node = ToolNode(tools) llm = QwenLLM() functions = [convert_to_openai_function(t) for t in tools] llm_with_tools = llm.bind_functions(functions) MAX_TOOL_CALLS = 3 # 最多允许的工具调用次数 def agent_node(state: AgentState) -> dict: messages = state["messages"] task_id = state.get("task_id", "") # 系统提示:引导使用工具,但最终必须给出答案(不要闲聊) sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions. You can use the following tools to find information: - search_wikipedia: search Wikipedia for facts. - web_search: general web search. - web_scraper: fetch content from a URL. - download_file_for_task: download a file associated with the current task (task_id: {task_id}). This can handle images, audio, and Python/text files. - analyze_image: describe an image given a URL or base64 data. - transcribe_audio: transcribe audio from a path or URL. - get_youtube_transcript: get captions from a YouTube video. - calculator: evaluate a mathematical expression. Instructions: 1. Use the most appropriate tool(s) to gather the information needed to answer the question. 2. If you need to follow up (e.g., search then scrape a specific page), you may use another tool. 3. Once you have enough information, output ONLY the final answer as a short string (a word, number, date, or phrase). Do NOT include explanations, greetings, or the phrase "FINAL ANSWER:". 4. If after using tools you still cannot find the answer, output exactly: "Unable to determine answer: insufficient information." 5. Do not make up an answer; only respond based on the information you retrieved. Current task ID: {task_id}.""" full = [SystemMessage(content=sys_prompt)] + list(messages) response = llm_with_tools.invoke(full) return {"messages": [response]} def should_continue(state: AgentState) -> str: messages = state["messages"] last = messages[-1] tool_attempts = state.get("tool_attempts", 0) # 如果已达到最大调用次数,强制进入 finish if tool_attempts >= MAX_TOOL_CALLS: return "finish" # 如果 LLM 请求了工具调用,则去执行工具 if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs: return "tools" # 尚未使用过任何工具?强制要求使用工具(确保至少一次) tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage)) if tool_msg_count == 0: return "force_tool" # 否则,LLM 已经给出了最终答案,进入 finish return "finish" def force_tool_node(state: AgentState) -> dict: new_msg = HumanMessage( content="You haven't used any tool yet. Please use an appropriate tool to find the answer." ) return {"messages": [new_msg]} def increment_tool_count(state: AgentState) -> dict: return {"tool_attempts": state.get("tool_attempts", 0) + 1} def finish_node(state: AgentState) -> dict: """从最后一条 AI 消息中提取最终答案,并清理格式""" last = state["messages"][-1] content = last.content # 如果已经包含标准错误信息,直接返回 if "Unable to determine answer" in content: return {"final_answer": content.split("\n")[0].strip()} # 去除可能的前缀 answer = content.split("FINAL ANSWER:")[-1].strip() # 尝试提取简洁答案:如果过长或包含问句,取第一句 if len(answer) > 50 or "?" in answer: sentences = re.split(r'(?<=[.!?])\s+', answer) for s in sentences: s = s.strip() if s and "?" not in s and not s.startswith(("Let me", "I ", "You ", "Please")): answer = s break else: answer = answer[:100].strip() # 若最终答案仍为空或无效,给出错误原因 if not answer or answer in ("模型调用失败",): if state.get("tool_attempts", 0) >= MAX_TOOL_CALLS: answer = "Unable to determine answer: maximum tool calls reached." else: answer = "Unable to determine answer: insufficient information." return {"final_answer": answer} def build_graph(): workflow = StateGraph(AgentState) workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) workflow.add_node("force_tool", force_tool_node) workflow.add_node("count_tools", increment_tool_count) workflow.add_node("finish", finish_node) workflow.set_entry_point("agent") workflow.add_conditional_edges( "agent", should_continue, { "tools": "tools", "force_tool": "force_tool", "finish": "finish" } ) # 工具调用后计数,然后返回 agent 继续思考 workflow.add_edge("tools", "count_tools") workflow.add_edge("count_tools", "agent") # force_tool 后返回 agent 重新决策 workflow.add_edge("force_tool", "agent") # finish 结束 workflow.add_edge("finish", END) return workflow.compile() # ============================================================================= # Agent 类 # ============================================================================= class LangGraphAgent: def __init__(self): self.graph = build_graph() print("LangGraphAgent 初始化完成,使用模型:", QWEN_MODEL) def __call__(self, question: str, task_id: str = "") -> str: state = { "messages": [HumanMessage(content=question)], "final_answer": "", "task_id": task_id, "tool_attempts": 0 } try: final_state = self.graph.invoke(state) return final_state["final_answer"] except Exception as e: print(f"Agent 运行失败: {e}") return f"Error: {e}" # ============================================================================= # 主运行函数(生成器,实时进度) # ============================================================================= import pandas as pd def run_and_submit_all(profile: gr.OAuthProfile | None) -> Generator: space_id = os.getenv("SPACE_ID") if not profile: yield "
请先登录
", "", pd.DataFrame() return username = profile.username agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" api_url = DEFAULT_API_URL try: agent = LangGraphAgent() monitor = ProgressMonitor() except Exception as e: yield f"
Agent 初始化失败: {e}
", f"Agent 初始化失败: {e}", pd.DataFrame() return try: resp = requests.get(f"{api_url}/questions", timeout=15) resp.raise_for_status() questions = resp.json() if not questions: yield "
没有题目
", "没有题目", pd.DataFrame() return except Exception as e: yield f"
获取题目失败: {e}
", f"获取题目失败: {e}", pd.DataFrame() return monitor.start(len(questions)) results_log = [] answers_payload = [] yield monitor.get_html(), "", pd.DataFrame() for idx, item in enumerate(questions): task_id = item.get("task_id") question = item.get("question", "") if not task_id or not question: continue try: answer = agent(question, task_id=task_id) except Exception as e: answer = f"ERROR: {e}" answers_payload.append({"task_id": task_id, "submitted_answer": answer}) results_log.append({"Task ID": task_id, "Question": question, "Submitted Answer": answer}) monitor.step(question, answer) yield monitor.get_html(), "", pd.DataFrame(results_log) if not answers_payload: yield monitor.get_html(), "没有答案可提交", pd.DataFrame(results_log) return submission = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} try: resp = requests.post(f"{api_url}/submit", json=submission, timeout=60) resp.raise_for_status() result = resp.json() final_status = ( f"✅ 提交成功!\n" f"用户:{username}\n" f"总分:{result.get('score', 'N/A')}% " f"({result.get('correct_count', 0)}/{result.get('total_attempted', 0)} 正确)\n" f"消息:{result.get('message', '')}" ) except Exception as e: final_status = f"提交失败: {e}" yield monitor.get_html(), final_status, pd.DataFrame(results_log) # ============================================================================= # Gradio 界面 # ============================================================================= with gr.Blocks(title="GAIA Agent") as demo: gr.Markdown(""" # 🤖 GAIA Level 1 Agent (LangGraph + Qwen) **模型:** Qwen3.5-35B-A3B | **API:** agicto.com 点击按钮获取题目,Agent 可调用多个工具(最多3次)以获取答案,最后提交评分。 **工具:** 维基百科、网页搜索/抓取、图片分析、音频转录、YouTube字幕、文件下载。 """) gr.LoginButton() run_btn = gr.Button("🚀 运行评测并提交", variant="primary") progress_html = gr.HTML(label="实时进度") status_output = gr.Textbox(label="提交结果 / 总分", lines=5, interactive=False) results_table = gr.DataFrame(label="题目与 Agent 答案", wrap=True) run_btn.click( fn=run_and_submit_all, outputs=[progress_html, status_output, results_table] ) if __name__ == "__main__": if not AGICTO_API_KEY: print("❌ 错误:AGICTO_API_KEY 未设置!请在 Space 的 Settings -> Repository Secrets 中添加。") if "v1" in AGICTO_BASE_URL: print("⚠️ 提示:AGICTO_BASE_URL 不应包含 /v1,已自动去除。请考虑设置为 https://api.agicto.cn") print("启动 Gradio App...") demo.queue().launch(server_name="0.0.0.0", server_port=7860)