Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import MessagesState | |
| from langgraph.graph import StateGraph, START, END | |
| from typing import TypedDict, Annotated, Literal | |
| from langchain_community.tools import BraveSearch # web search | |
| from langchain_experimental.tools.python.tool import PythonAstREPLTool # for logic/math problems | |
| from tools import (calculator_basic, datetime_tools, transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel) | |
| from prompt import system_prompt | |
| from langchain_core.runnables import RunnableConfig # for LangSmith tracking | |
| # LangSmith to observe the agent | |
| langsmith_api_key = os.getenv("LANGSMITH_API_KEY") | |
| langsmith_tracing = os.getenv("LANGSMITH_TRACING") | |
| llm = ChatOpenAI( | |
| base_url="https://api.groq.com/openai/v1", | |
| api_key=os.getenv("GROQ_API_KEY"), | |
| model="openai/gpt-oss-120b", # Model must support function calling in OpenRouter | |
| temperature=1 | |
| ) | |
| python_tool = PythonAstREPLTool() | |
| search_tool = BraveSearch.from_api_key( | |
| api_key=os.getenv("BRAVE_SEARCH_API"), | |
| search_kwargs={"count": 4}, # returns the 4 best results and their URL | |
| description="Web search using Brave" | |
| ) | |
| community_tools = [search_tool, python_tool] | |
| custom_tools = calculator_basic + datetime_tools + [transcribe_audio, transcribe_youtube, query_image, webpage_content, read_excel] | |
| tools = community_tools + custom_tools | |
| # Prepare tools by name | |
| tools_by_name = {tool.name: tool for tool in tools} | |
| class MessagesState(TypedDict): # creates the state (is like the agent's memory at any moment) | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| llm_with_tools: object | |
| # LLM node | |
| def llm_call(state: MessagesState): | |
| return { | |
| "messages": [ | |
| state["llm_with_tools"].invoke( | |
| [SystemMessage(content=system_prompt)] + state["messages"] | |
| ) | |
| ] | |
| } | |
| # Tool node | |
| def tool_node(state: MessagesState): | |
| """Executes the tools""" | |
| result = [] | |
| for tool_call in state["messages"][-1].tool_calls: # gives a list of the tools the LLM decided to call | |
| tool = tools_by_name[tool_call["name"]] # look up the actual tool function using a dictionary | |
| observation = tool.invoke(tool_call["args"]) # executes the tool | |
| result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) # the result from the tool is added to the memory | |
| return {"messages": result} # thanks to add_messages, LangGraph will automatically append the result to the agent's message history | |
| # Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call | |
| def should_continue(state: MessagesState) -> Literal["Action", END]: | |
| """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" | |
| last_message = state["messages"][-1] # looks at the last message (usually from the LLM) | |
| # If the LLM makes a tool call, then perform an action | |
| if last_message.tool_calls: | |
| return "Action" | |
| # Otherwise, we stop (reply to the user) | |
| return END | |
| # Build workflow | |
| builder = StateGraph(MessagesState) | |
| # Add nodes | |
| builder.add_node("llm_call", llm_call) | |
| builder.add_node("environment", tool_node) | |
| # Add edges to connect nodes | |
| builder.add_edge(START, "llm_call") | |
| builder.add_conditional_edges( | |
| "llm_call", | |
| should_continue, | |
| {"Action": "environment", # name returned by should_continue : Name of the next node | |
| END: END} | |
| ) | |
| # If tool calls -> "Action" -> environment (executes the tool) | |
| # If no tool calls -> END | |
| builder.add_edge("environment", "llm_call") # after running the tools go back to the LLM for another round of reasoning | |
| gaia_agent = builder.compile() # converts my builder into a runnable agent by using gaia_agent.invoke() | |
| # Wrapper class to initialize and call the LangGraph agent with a user question | |
| class LangGraphAgent: | |
| def __init__(self): | |
| self.api_keys = [ | |
| os.getenv("GROQ_API_KEY"), | |
| os.getenv("GROQ_API_KEY_1"), | |
| os.getenv("GROQ_API_KEY_2"), | |
| os.getenv("GROQ_API_KEY_3") | |
| ] | |
| self.key_index = 0 | |
| print("LangGraphAgent initialized.") | |
| def __call__(self, question: str) -> str: | |
| api_key = self.api_keys[self.key_index] | |
| self.key_index = (self.key_index + 1) % len(self.api_keys) | |
| llm = ChatOpenAI( | |
| base_url="https://api.groq.com/openai/v1", | |
| api_key=api_key, | |
| model="openai/gpt-oss-120b", | |
| temperature=1 | |
| ) | |
| llm_with_tools = llm.bind_tools(tools) | |
| input_state = {"messages": [HumanMessage(content=question)], "llm_with_tools": llm_with_tools} # prepare the initial user message | |
| print(f"Running LangGraphAgent with input: {question[:150]}...") | |
| # tracing configuration for LangSmith | |
| config = RunnableConfig( | |
| config={ | |
| "run_name": "GAIA Agent", | |
| "tags": ["gaia", "langgraph", "agent"], | |
| "metadata": {"user_input": question}, | |
| "recursion_limit": 30 | |
| } | |
| ) | |
| result = gaia_agent.invoke(input_state, config) # prevents infinite looping when the LLM keeps calling tools over and over | |
| final_response = result["messages"][-1].content | |
| try: | |
| return final_response.split("FINAL ANSWER:")[-1].strip() # parse out only what's after "FINAL ANSWER:" | |
| except Exception: | |
| print("Could not split on 'FINAL ANSWER:'") | |
| return final_response |