Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| import requests | |
| import json | |
| import re | |
| import tempfile | |
| import base64 | |
| import io | |
| import time | |
| import threading | |
| from typing import TypedDict, Annotated, Sequence, List, Dict, Any, Generator | |
| from datetime import datetime | |
| import operator | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ToolMessage, SystemMessage | |
| from langchain_core.tools import tool | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.prebuilt import ToolNode | |
| from langchain_core.utils.function_calling import convert_to_openai_function | |
| from bs4 import BeautifulSoup | |
| from youtube_transcript_api import YouTubeTranscriptApi | |
| # ============================================================================= | |
| # 配置常量 | |
| # ============================================================================= | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| AGICTO_BASE_URL = os.getenv("AGICTO_BASE_URL", "https://api.agicto.cn") | |
| AGICTO_API_KEY = os.getenv("AGICTO_API_KEY", "") | |
| QWEN_MODEL = "qwen3.5-35b-a3b" | |
| # ============================================================================= | |
| # 进度监控器(不变) | |
| # ============================================================================= | |
| class ProgressMonitor: | |
| def __init__(self): | |
| self.current = 0 | |
| self.total = 0 | |
| self.last_question = "" | |
| self.last_answer = "" | |
| self.logs = [] | |
| self._lock = threading.Lock() | |
| def start(self, total: int): | |
| with self._lock: | |
| self.total = total | |
| self.current = 0 | |
| self.logs = [] | |
| def step(self, question: str, answer: str): | |
| with self._lock: | |
| self.current += 1 | |
| self.last_question = question | |
| self.last_answer = answer | |
| self.logs.append(f"✅ 第 {self.current}/{self.total} 题完成:{answer[:50]}...") | |
| def get_html(self) -> str: | |
| with self._lock: | |
| pct = int(self.current / self.total * 100) if self.total > 0 else 0 | |
| html = f""" | |
| <div style="border:1px solid #ddd; padding:10px; border-radius:8px; background:#fafafa;"> | |
| <h3>📊 实时进度</h3> | |
| <div style="background:#eee; height:20px; border-radius:10px; margin-bottom:10px;"> | |
| <div style="width:{pct}%; background:#4CAF50; height:100%; border-radius:10px; text-align:center; color:white; font-size:12px; line-height:20px;"> | |
| {pct}% ({self.current}/{self.total}) | |
| </div> | |
| </div> | |
| <p><b>最新题目:</b> {self.last_question[:100]}{"..." if len(self.last_question)>100 else ""}</p> | |
| <p><b>答案:</b> <span style="color:#2e7d32;">{self.last_answer}</span></p> | |
| <details> | |
| <summary>详细日志</summary> | |
| <pre style="background:#f5f5f5; padding:10px; border-radius:4px; max-height:200px; overflow:auto;">{chr(10).join(self.logs)}</pre> | |
| </details> | |
| </div> | |
| """ | |
| return html | |
| # ============================================================================= | |
| # Qwen LLM 封装(不变) | |
| # ============================================================================= | |
| class QwenLLM: | |
| def __init__(self, model=QWEN_MODEL): | |
| self.model = model | |
| self.api_key = AGICTO_API_KEY | |
| base = AGICTO_BASE_URL.rstrip('/') | |
| if base.endswith('/v1'): | |
| base = base[:-3] | |
| self.base_url = base | |
| if not self.api_key: | |
| print("⚠️ 未设置 AGICTO_API_KEY,请检查环境变量") | |
| def _call_api(self, messages: list, functions: list = None, max_tokens=2000): | |
| headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {self.api_key}" | |
| } | |
| body = { | |
| "model": self.model, | |
| "messages": messages, | |
| "temperature": 0.0, | |
| "max_tokens": max_tokens | |
| } | |
| if functions: | |
| body["tools"] = [{"type": "function", "function": f} for f in functions] | |
| body["tool_choice"] = "auto" | |
| url = f"{self.base_url}/v1/chat/completions" | |
| try: | |
| resp = requests.post(url, headers=headers, json=body, timeout=60) | |
| resp.raise_for_status() | |
| return resp.json() | |
| except Exception as e: | |
| print(f"API 调用失败: {e}") | |
| return None | |
| def invoke(self, messages: list) -> AIMessage: | |
| formatted = self._format_messages(messages) | |
| result = self._call_api(formatted) | |
| if not result: | |
| return AIMessage(content="模型调用失败") | |
| choice = result["choices"][0] | |
| msg = choice["message"] | |
| if "tool_calls" in msg and msg["tool_calls"]: | |
| tool_call = msg["tool_calls"][0] | |
| return AIMessage( | |
| content=msg.get("content", ""), | |
| additional_kwargs={ | |
| "function_call": { | |
| "name": tool_call["function"]["name"], | |
| "arguments": tool_call["function"]["arguments"] | |
| } | |
| } | |
| ) | |
| return AIMessage(content=msg["content"]) | |
| def bind_functions(self, functions: list): | |
| class BoundLLM: | |
| def __init__(self, llm, funcs): | |
| self.llm = llm | |
| self.functions = funcs | |
| def invoke(self, messages: list) -> AIMessage: | |
| formatted = self.llm._format_messages(messages) | |
| result = self.llm._call_api(formatted, functions=self.functions) | |
| if not result: | |
| return AIMessage(content="模型调用失败") | |
| choice = result["choices"][0] | |
| msg = choice["message"] | |
| if "tool_calls" in msg and msg["tool_calls"]: | |
| tool_call = msg["tool_calls"][0] | |
| return AIMessage( | |
| content=msg.get("content", ""), | |
| additional_kwargs={ | |
| "function_call": { | |
| "name": tool_call["function"]["name"], | |
| "arguments": tool_call["function"]["arguments"] | |
| } | |
| } | |
| ) | |
| return AIMessage(content=msg["content"]) | |
| return BoundLLM(self, functions) | |
| def _format_messages(self, messages: list) -> list: | |
| formatted = [] | |
| for m in messages: | |
| if isinstance(m, SystemMessage): | |
| formatted.append({"role": "system", "content": m.content}) | |
| elif isinstance(m, HumanMessage): | |
| formatted.append({"role": "user", "content": m.content}) | |
| elif isinstance(m, AIMessage): | |
| entry = {"role": "assistant", "content": m.content} | |
| if hasattr(m, "additional_kwargs") and "function_call" in m.additional_kwargs: | |
| entry["tool_calls"] = [{ | |
| "id": "call_1", | |
| "type": "function", | |
| "function": m.additional_kwargs["function_call"] | |
| }] | |
| formatted.append(entry) | |
| elif isinstance(m, ToolMessage): | |
| formatted.append({ | |
| "role": "tool", | |
| "tool_call_id": m.tool_call_id if hasattr(m, "tool_call_id") else "call_1", | |
| "content": m.content | |
| }) | |
| return formatted | |
| # ============================================================================= | |
| # 工具定义(同之前,包含 search_wikipedia 等) | |
| # ============================================================================= | |
| api_url_tasks = DEFAULT_API_URL | |
| def _get_api_base(): | |
| base = AGICTO_BASE_URL.rstrip('/') | |
| if base.endswith('/v1'): | |
| base = base[:-3] | |
| return base | |
| def web_search(query: str) -> str: | |
| try: | |
| url = "https://api.duckduckgo.com/" | |
| params = {"q": query, "format": "json", "no_html": 1} | |
| resp = requests.get(url, params=params, timeout=10) | |
| data = resp.json() | |
| parts = [] | |
| if data.get("AbstractText"): | |
| parts.append(f"摘要: {data['AbstractText']}") | |
| for topic in data.get("RelatedTopics", [])[:3]: | |
| if isinstance(topic, dict) and "Text" in topic: | |
| parts.append(topic["Text"]) | |
| return "\n".join(parts) if parts else "未找到相关信息" | |
| except Exception as e: | |
| return f"搜索失败: {e}" | |
| def web_scraper(url: str) -> str: | |
| try: | |
| headers = {"User-Agent": "Mozilla/5.0"} | |
| resp = requests.get(url, headers=headers, timeout=15) | |
| soup = BeautifulSoup(resp.text, "html.parser") | |
| for el in soup(["script", "style", "nav", "footer"]): | |
| el.decompose() | |
| text = soup.get_text() | |
| lines = [line.strip() for line in text.splitlines() if line.strip()] | |
| return " ".join(lines)[:5000] | |
| except Exception as e: | |
| return f"抓取失败: {e}" | |
| def calculator(expression: str) -> str: | |
| try: | |
| import math | |
| allowed = {k: v for k, v in math.__dict__.items() if not k.startswith("__")} | |
| result = eval(expression, {"__builtins__": {}}, allowed) | |
| return str(result) | |
| except Exception as e: | |
| return f"计算失败: {e}" | |
| def analyze_image(image_data: str) -> str: | |
| try: | |
| headers = {"Authorization": f"Bearer {AGICTO_API_KEY}", "Content-Type": "application/json"} | |
| if not image_data.startswith("http"): | |
| image_data = f"data:image/jpeg;base64,{image_data}" | |
| body = { | |
| "model": QWEN_MODEL, | |
| "messages": [{"role": "user", "content": [ | |
| {"type": "text", "text": "请详细描述这张图片的内容,包括文字、数字等信息。"}, | |
| {"type": "image_url", "image_url": {"url": image_data}} | |
| ]}], | |
| "max_tokens": 800 | |
| } | |
| base = _get_api_base() | |
| url = f"{base}/v1/chat/completions" | |
| resp = requests.post(url, headers=headers, json=body, timeout=30) | |
| if resp.status_code == 200: | |
| return resp.json()["choices"][0]["message"]["content"] | |
| return f"图片分析失败: {resp.status_code}" | |
| except Exception as e: | |
| return f"图片分析失败: {e}" | |
| def transcribe_audio(audio_path: str) -> str: | |
| try: | |
| headers = {"Authorization": f"Bearer {AGICTO_API_KEY}"} | |
| if audio_path.startswith("http"): | |
| resp = requests.get(audio_path, timeout=30) | |
| audio_data = io.BytesIO(resp.content) | |
| audio_data.name = "audio.mp3" | |
| else: | |
| audio_data = open(audio_path, "rb") | |
| files = {"file": audio_data, "model": (None, "whisper-1")} | |
| base = _get_api_base() | |
| url = f"{base}/v1/audio/transcriptions" | |
| resp = requests.post(url, headers=headers, files=files, timeout=60) | |
| if resp.status_code == 200: | |
| return resp.json()["text"] | |
| return f"转录失败: {resp.status_code}" | |
| except Exception as e: | |
| return f"转录失败: {e}" | |
| def get_youtube_transcript(video_url: str) -> str: | |
| try: | |
| if "watch?v=" in video_url: | |
| vid = video_url.split("v=")[1].split("&")[0] | |
| elif "youtu.be/" in video_url: | |
| vid = video_url.split("youtu.be/")[1].split("?")[0] | |
| else: | |
| return "无法提取视频 ID" | |
| transcript = YouTubeTranscriptApi.get_transcript(vid, languages=['en', 'zh']) | |
| return " ".join([t['text'] for t in transcript])[:4000] | |
| except Exception as e: | |
| return f"获取字幕失败: {e}" | |
| def download_file_for_task(task_id: str) -> str: | |
| try: | |
| url = f"{api_url_tasks}/files/{task_id}" | |
| resp = requests.get(url, timeout=20) | |
| if resp.status_code != 200: | |
| return f"文件不存在 (HTTP {resp.status_code})" | |
| content_type = resp.headers.get("content-type", "") | |
| if "image" in content_type: | |
| b64 = base64.b64encode(resp.content).decode() | |
| return analyze_image(b64) | |
| elif "audio" in content_type: | |
| with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f: | |
| f.write(resp.content) | |
| temp_path = f.name | |
| result = transcribe_audio(temp_path) | |
| os.unlink(temp_path) | |
| return result | |
| else: | |
| return resp.text[:4000] | |
| except Exception as e: | |
| return f"文件下载失败: {e}" | |
| def search_wikipedia(query: str) -> str: | |
| try: | |
| search_url = "https://en.wikipedia.org/w/api.php" | |
| params = { | |
| "action": "opensearch", | |
| "search": query, | |
| "limit": 1, | |
| "format": "json" | |
| } | |
| resp = requests.get(search_url, params=params, timeout=10) | |
| data = resp.json() | |
| titles = data[1] | |
| if not titles: | |
| return "维基百科未找到相关页面。" | |
| title = titles[0] | |
| extract_params = { | |
| "action": "query", | |
| "prop": "extracts", | |
| "exintro": True, | |
| "explaintext": True, | |
| "titles": title, | |
| "format": "json" | |
| } | |
| resp2 = requests.get(search_url, params=extract_params, timeout=10) | |
| data2 = resp2.json() | |
| pages = data2.get("query", {}).get("pages", {}) | |
| for page_info in pages.values(): | |
| extract = page_info.get("extract", "") | |
| if extract: | |
| return f"Wikipedia - {title}:\n{extract[:2000]}" | |
| return f"维基百科页面 '{title}' 未提供摘要。" | |
| except Exception as e: | |
| return f"维基百科搜索失败: {e}" | |
| # ============================================================================= | |
| # LangGraph 状态与节点(允许多次工具调用,最大3次) | |
| # ============================================================================= | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| final_answer: str | |
| task_id: str | |
| tool_attempts: int # 已使用的工具调用次数 | |
| tools = [search_wikipedia, web_search, web_scraper, calculator, | |
| analyze_image, transcribe_audio, get_youtube_transcript, download_file_for_task] | |
| tool_node = ToolNode(tools) | |
| llm = QwenLLM() | |
| functions = [convert_to_openai_function(t) for t in tools] | |
| llm_with_tools = llm.bind_functions(functions) | |
| MAX_TOOL_CALLS = 3 # 最多允许的工具调用次数 | |
| def agent_node(state: AgentState) -> dict: | |
| messages = state["messages"] | |
| task_id = state.get("task_id", "") | |
| # 系统提示:引导使用工具,但最终必须给出答案(不要闲聊) | |
| sys_prompt = f"""You are a helpful assistant answering GAIA Level 1 questions. | |
| You can use the following tools to find information: | |
| - search_wikipedia: search Wikipedia for facts. | |
| - web_search: general web search. | |
| - web_scraper: fetch content from a URL. | |
| - download_file_for_task: download a file associated with the current task (task_id: {task_id}). This can handle images, audio, and Python/text files. | |
| - analyze_image: describe an image given a URL or base64 data. | |
| - transcribe_audio: transcribe audio from a path or URL. | |
| - get_youtube_transcript: get captions from a YouTube video. | |
| - calculator: evaluate a mathematical expression. | |
| Instructions: | |
| 1. Use the most appropriate tool(s) to gather the information needed to answer the question. | |
| 2. If you need to follow up (e.g., search then scrape a specific page), you may use another tool. | |
| 3. Once you have enough information, output ONLY the final answer as a short string (a word, number, date, or phrase). Do NOT include explanations, greetings, or the phrase "FINAL ANSWER:". | |
| 4. If after using tools you still cannot find the answer, output exactly: "Unable to determine answer: insufficient information." | |
| 5. Do not make up an answer; only respond based on the information you retrieved. | |
| Current task ID: {task_id}.""" | |
| full = [SystemMessage(content=sys_prompt)] + list(messages) | |
| response = llm_with_tools.invoke(full) | |
| return {"messages": [response]} | |
| def should_continue(state: AgentState) -> str: | |
| messages = state["messages"] | |
| last = messages[-1] | |
| tool_attempts = state.get("tool_attempts", 0) | |
| # 如果已达到最大调用次数,强制进入 finish | |
| if tool_attempts >= MAX_TOOL_CALLS: | |
| return "finish" | |
| # 如果 LLM 请求了工具调用,则去执行工具 | |
| if hasattr(last, "additional_kwargs") and "function_call" in last.additional_kwargs: | |
| return "tools" | |
| # 尚未使用过任何工具?强制要求使用工具(确保至少一次) | |
| tool_msg_count = sum(1 for m in messages if isinstance(m, ToolMessage)) | |
| if tool_msg_count == 0: | |
| return "force_tool" | |
| # 否则,LLM 已经给出了最终答案,进入 finish | |
| return "finish" | |
| def force_tool_node(state: AgentState) -> dict: | |
| new_msg = HumanMessage( | |
| content="You haven't used any tool yet. Please use an appropriate tool to find the answer." | |
| ) | |
| return {"messages": [new_msg]} | |
| def increment_tool_count(state: AgentState) -> dict: | |
| return {"tool_attempts": state.get("tool_attempts", 0) + 1} | |
| def finish_node(state: AgentState) -> dict: | |
| """从最后一条 AI 消息中提取最终答案,并清理格式""" | |
| last = state["messages"][-1] | |
| content = last.content | |
| # 如果已经包含标准错误信息,直接返回 | |
| if "Unable to determine answer" in content: | |
| return {"final_answer": content.split("\n")[0].strip()} | |
| # 去除可能的前缀 | |
| answer = content.split("FINAL ANSWER:")[-1].strip() | |
| # 尝试提取简洁答案:如果过长或包含问句,取第一句 | |
| if len(answer) > 50 or "?" in answer: | |
| sentences = re.split(r'(?<=[.!?])\s+', answer) | |
| for s in sentences: | |
| s = s.strip() | |
| if s and "?" not in s and not s.startswith(("Let me", "I ", "You ", "Please")): | |
| answer = s | |
| break | |
| else: | |
| answer = answer[:100].strip() | |
| # 若最终答案仍为空或无效,给出错误原因 | |
| if not answer or answer in ("模型调用失败",): | |
| if state.get("tool_attempts", 0) >= MAX_TOOL_CALLS: | |
| answer = "Unable to determine answer: maximum tool calls reached." | |
| else: | |
| answer = "Unable to determine answer: insufficient information." | |
| return {"final_answer": answer} | |
| def build_graph(): | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("agent", agent_node) | |
| workflow.add_node("tools", tool_node) | |
| workflow.add_node("force_tool", force_tool_node) | |
| workflow.add_node("count_tools", increment_tool_count) | |
| workflow.add_node("finish", finish_node) | |
| workflow.set_entry_point("agent") | |
| workflow.add_conditional_edges( | |
| "agent", | |
| should_continue, | |
| { | |
| "tools": "tools", | |
| "force_tool": "force_tool", | |
| "finish": "finish" | |
| } | |
| ) | |
| # 工具调用后计数,然后返回 agent 继续思考 | |
| workflow.add_edge("tools", "count_tools") | |
| workflow.add_edge("count_tools", "agent") | |
| # force_tool 后返回 agent 重新决策 | |
| workflow.add_edge("force_tool", "agent") | |
| # finish 结束 | |
| workflow.add_edge("finish", END) | |
| return workflow.compile() | |
| # ============================================================================= | |
| # Agent 类 | |
| # ============================================================================= | |
| class LangGraphAgent: | |
| def __init__(self): | |
| self.graph = build_graph() | |
| print("LangGraphAgent 初始化完成,使用模型:", QWEN_MODEL) | |
| def __call__(self, question: str, task_id: str = "") -> str: | |
| state = { | |
| "messages": [HumanMessage(content=question)], | |
| "final_answer": "", | |
| "task_id": task_id, | |
| "tool_attempts": 0 | |
| } | |
| try: | |
| final_state = self.graph.invoke(state) | |
| return final_state["final_answer"] | |
| except Exception as e: | |
| print(f"Agent 运行失败: {e}") | |
| return f"Error: {e}" | |
| # ============================================================================= | |
| # 主运行函数(生成器,实时进度) | |
| # ============================================================================= | |
| import pandas as pd | |
| def run_and_submit_all(profile: gr.OAuthProfile | None) -> Generator: | |
| space_id = os.getenv("SPACE_ID") | |
| if not profile: | |
| yield "<div>请先登录</div>", "", pd.DataFrame() | |
| return | |
| username = profile.username | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| api_url = DEFAULT_API_URL | |
| try: | |
| agent = LangGraphAgent() | |
| monitor = ProgressMonitor() | |
| except Exception as e: | |
| yield f"<div>Agent 初始化失败: {e}</div>", f"Agent 初始化失败: {e}", pd.DataFrame() | |
| return | |
| try: | |
| resp = requests.get(f"{api_url}/questions", timeout=15) | |
| resp.raise_for_status() | |
| questions = resp.json() | |
| if not questions: | |
| yield "<div>没有题目</div>", "没有题目", pd.DataFrame() | |
| return | |
| except Exception as e: | |
| yield f"<div>获取题目失败: {e}</div>", f"获取题目失败: {e}", pd.DataFrame() | |
| return | |
| monitor.start(len(questions)) | |
| results_log = [] | |
| answers_payload = [] | |
| yield monitor.get_html(), "", pd.DataFrame() | |
| for idx, item in enumerate(questions): | |
| task_id = item.get("task_id") | |
| question = item.get("question", "") | |
| if not task_id or not question: | |
| continue | |
| try: | |
| answer = agent(question, task_id=task_id) | |
| except Exception as e: | |
| answer = f"ERROR: {e}" | |
| answers_payload.append({"task_id": task_id, "submitted_answer": answer}) | |
| results_log.append({"Task ID": task_id, "Question": question, "Submitted Answer": answer}) | |
| monitor.step(question, answer) | |
| yield monitor.get_html(), "", pd.DataFrame(results_log) | |
| if not answers_payload: | |
| yield monitor.get_html(), "没有答案可提交", pd.DataFrame(results_log) | |
| return | |
| submission = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| try: | |
| resp = requests.post(f"{api_url}/submit", json=submission, timeout=60) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| final_status = ( | |
| f"✅ 提交成功!\n" | |
| f"用户:{username}\n" | |
| f"总分:{result.get('score', 'N/A')}% " | |
| f"({result.get('correct_count', 0)}/{result.get('total_attempted', 0)} 正确)\n" | |
| f"消息:{result.get('message', '')}" | |
| ) | |
| except Exception as e: | |
| final_status = f"提交失败: {e}" | |
| yield monitor.get_html(), final_status, pd.DataFrame(results_log) | |
| # ============================================================================= | |
| # Gradio 界面 | |
| # ============================================================================= | |
| with gr.Blocks(title="GAIA Agent") as demo: | |
| gr.Markdown(""" | |
| # 🤖 GAIA Level 1 Agent (LangGraph + Qwen) | |
| **模型:** Qwen3.5-35B-A3B | **API:** agicto.com | |
| 点击按钮获取题目,Agent 可调用多个工具(最多3次)以获取答案,最后提交评分。 | |
| **工具:** 维基百科、网页搜索/抓取、图片分析、音频转录、YouTube字幕、文件下载。 | |
| """) | |
| gr.LoginButton() | |
| run_btn = gr.Button("🚀 运行评测并提交", variant="primary") | |
| progress_html = gr.HTML(label="实时进度") | |
| status_output = gr.Textbox(label="提交结果 / 总分", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="题目与 Agent 答案", wrap=True) | |
| run_btn.click( | |
| fn=run_and_submit_all, | |
| outputs=[progress_html, status_output, results_table] | |
| ) | |
| if __name__ == "__main__": | |
| if not AGICTO_API_KEY: | |
| print("❌ 错误:AGICTO_API_KEY 未设置!请在 Space 的 Settings -> Repository Secrets 中添加。") | |
| if "v1" in AGICTO_BASE_URL: | |
| print("⚠️ 提示:AGICTO_BASE_URL 不应包含 /v1,已自动去除。请考虑设置为 https://api.agicto.cn") | |
| print("启动 Gradio App...") | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |