import cmath
import os
from typing import Dict, List, Sequence, TypedDict, cast
from dotenv import load_dotenv
from langchain.tools.retriever import create_retriever_tool
from langchain_community.document_loaders import ArxivLoader, WikipediaLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.tools import tool
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import (
ChatHuggingFace,
HuggingFaceEmbeddings,
HuggingFaceEndpoint,
)
from langchain_tavily import TavilySearch
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from pydantic import BaseModel
from supabase.client import Client, create_client
# Load environment variables from .env file
load_dotenv()
class WebSearchInput(BaseModel):
query: str
class WikipediaSearchInput(BaseModel):
query: str
class ArxivSearchInput(BaseModel):
query: str
@tool
def search_web(query: str) -> str:
"""Search the web using Tavily and return relevant results."""
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
search_docs = TavilySearch(max_results=3).invoke({"query": query})
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content}\n'
for doc in search_docs
]
)
return {"web_results": formatted_search_docs}
@tool
def search_wikipedia(query: str) -> str:
"""Search Wikipedia using LangChain's loader and return the first document summary."""
try:
loader = WikipediaLoader(query=query, lang="en", load_max_docs=2)
docs = loader.load()
if not docs:
return {"error": f"No Wikipedia articles found for query: {query}"}
formatted_docs = "\n\n---\n\n".join(
[f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs]
)
return {"wiki_results": formatted_docs}
except Exception as e:
return {"error": f"Error searching Wikipedia: {str(e)}"}
@tool
def arxiv_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'\n{doc.page_content[:1000]}\n'
for doc in search_docs
]
)
return {"arxiv_results": formatted_search_docs}
@tool
def power(a: float, b: float) -> float:
"""
Get the power of two numbers.
Args:
a (float): the first number
b (float): the second number
"""
return a**b
@tool
def square_root(a: float) -> float | complex:
"""
Get the square root of a number.
Args:
a (float): the number to get the square root of
"""
if a >= 0:
return a**0.5
return cmath.sqrt(a)
@tool
def multiply(a: int, b: int) -> int:
"""Multiply two numbers.
Args:
a: first int
b: second int
"""
return a * b
@tool
def add(a: int, b: int) -> int:
"""Add two numbers.
Args:
a: first int
b: second int
"""
return a + b
@tool
def subtract(a: int, b: int) -> int:
"""Subtract two numbers.
Args:
a: first int
b: second int
"""
return a - b
@tool
def divide(a: float, b: float) -> float:
"""
Divides two numbers.
Args:
a (float): the first float number
b (float): the second float number
"""
if b == 0:
raise ValueError("Cannot divided by zero.")
return a / b
@tool
def modulus(a: int, b: int) -> int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a % b
# System prompt
system_prompt = SystemMessage(
content="""You are a helpful assistant tasked with answering questions using a set of tools.
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma.
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """
)
supabase_url = os.environ.get("SUPABASE_URL")
supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY")
# build a retriever
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
) # dim=768
supabase: Client = create_client(supabase_url, supabase_service_key)
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
create_retriever_tool = create_retriever_tool(
retriever=vector_store.as_retriever(),
name="Question Search",
description="A tool to retrieve similar questions from a vector store.",
)
# Initialize tools
tools = [
search_wikipedia,
search_web,
arxiv_search,
power,
square_root,
multiply,
divide,
subtract,
add,
modulus,
]
def build_agent_graph(provider: str = "groq"):
"""Build the graph"""
# Initialize LLM class
try:
gemini_api_key = os.getenv("GEMINI_API_KEY")
if provider == "groq":
# Groq https://console.groq.com/docs/models
chat_model = ChatGroq(
model="qwen-qwq-32b", temperature=0
) # optional : qwen-qwq-32b gemma2-9b-it
elif provider == "gemini":
chat_model = ChatGoogleGenerativeAI(
model="gemini-2.5-pro",
temperature=1.0,
max_retries=2,
google_api_key=gemini_api_key,
)
elif provider == "huggingface":
llm = HuggingFaceEndpoint(
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
temperature=0,
)
chat_model = ChatHuggingFace(llm=llm, verbose=True)
else:
raise ValueError("Invalid provider.")
except Exception as e:
raise Exception(f"Failed to initialize LLM: {str(e)}")
llm_with_tools = chat_model.bind_tools(tools)
# Create nodes
def assistant(state: MessagesState):
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
query = state["messages"][-1].content
results = vector_store.similarity_search(query, k=1)
if not results:
print(f"[retriever] No similar documents found for query: {query}")
return {
"messages": [
AIMessage(content="I couldn't find any similar content in memory.")
]
}
similar_doc = results[0]
content = similar_doc.page_content
if "Final answer :" in content:
answer = content.split("Final answer :")[-1].strip()
else:
answer = content.strip()
return {"messages": [AIMessage(content=answer)]}
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
# builder.add_node("assistant", assistant)
# builder.add_node("tools", ToolNode(tools))
# builder.add_edge(START, "retriever")
# builder.add_edge("retriever", "assistant")
# builder.add_conditional_edges(
# "assistant",
# tools_condition,
# )
# builder.add_edge("tools", "assistant")
builder.set_entry_point("retriever")
builder.set_finish_point("retriever")
return builder.compile()
# Manual test function
def test_agent():
"""Run a manual test of the agent"""
print("\n" + "=" * 50)
print("Starting Agent Test")
print("=" * 50)
# Check environment variables
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
print("\nError: HUGGINGFACEHUB_API_TOKEN not set")
return
if not os.getenv("GEMINI_API_KEY"):
print("\nError: GEMINI_API_KEY not set")
return
if not os.getenv("TAVILY_API_KEY"):
print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
if not os.getenv("SUPABASE_URL"):
print("\nWarning: SUPABASE_URL not set - web search will be unavailable")
print("\nInitializing agent...")
try:
graph = build_agent_graph(provider="groq")
print("Agent initialized successfully")
except Exception as e:
print(f"Failed to initialize agent: {str(e)}")
return
# Test a single question
question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
print("\nTesting question:", question)
print("-" * 50)
try:
# Create messages state
messages = [HumanMessage(content=question)]
# Run agent
print("\nWaiting for response...")
result = graph.invoke({"messages": messages})
# Get answer
if result and "messages" in result and result["messages"]:
answer = result["messages"][-1].content
print("\nResponse received:")
print("-" * 20)
print(answer)
print("-" * 20)
else:
print("\nError: No response from agent")
except Exception as e:
print(f"\nError processing question: {str(e)}")
print("\n" + "=" * 50)
print("Test Complete")
print("=" * 50 + "\n")
# Run test if script is run directly
if __name__ == "__main__":
test_agent()