import cmath import os from typing import Dict, List, Sequence, TypedDict, cast from dotenv import load_dotenv from langchain.tools.retriever import create_retriever_tool from langchain_community.document_loaders import ArxivLoader, WikipediaLoader from langchain_community.vectorstores import SupabaseVectorStore from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.tools import tool from langchain_google_genai import ChatGoogleGenerativeAI from langchain_groq import ChatGroq from langchain_huggingface import ( ChatHuggingFace, HuggingFaceEmbeddings, HuggingFaceEndpoint, ) from langchain_tavily import TavilySearch from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from pydantic import BaseModel from supabase.client import Client, create_client # Load environment variables from .env file load_dotenv() class WebSearchInput(BaseModel): query: str class WikipediaSearchInput(BaseModel): query: str class ArxivSearchInput(BaseModel): query: str @tool def search_web(query: str) -> str: """Search the web using Tavily and return relevant results.""" """Search Tavily for a query and return maximum 3 results. Args: query: The search query.""" search_docs = TavilySearch(max_results=3).invoke({"query": query}) formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) return {"web_results": formatted_search_docs} @tool def search_wikipedia(query: str) -> str: """Search Wikipedia using LangChain's loader and return the first document summary.""" try: loader = WikipediaLoader(query=query, lang="en", load_max_docs=2) docs = loader.load() if not docs: return {"error": f"No Wikipedia articles found for query: {query}"} formatted_docs = "\n\n---\n\n".join( [f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs] ) return {"wiki_results": formatted_docs} except Exception as e: return {"error": f"Error searching Wikipedia: {str(e)}"} @tool def arxiv_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 result. Args: query: The search query.""" search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ] ) return {"arxiv_results": formatted_search_docs} @tool def power(a: float, b: float) -> float: """ Get the power of two numbers. Args: a (float): the first number b (float): the second number """ return a**b @tool def square_root(a: float) -> float | complex: """ Get the square root of a number. Args: a (float): the number to get the square root of """ if a >= 0: return a**0.5 return cmath.sqrt(a) @tool def multiply(a: int, b: int) -> int: """Multiply two numbers. Args: a: first int b: second int """ return a * b @tool def add(a: int, b: int) -> int: """Add two numbers. Args: a: first int b: second int """ return a + b @tool def subtract(a: int, b: int) -> int: """Subtract two numbers. Args: a: first int b: second int """ return a - b @tool def divide(a: float, b: float) -> float: """ Divides two numbers. Args: a (float): the first float number b (float): the second float number """ if b == 0: raise ValueError("Cannot divided by zero.") return a / b @tool def modulus(a: int, b: int) -> int: """Get the modulus of two numbers. Args: a: first int b: second int """ return a % b # System prompt system_prompt = SystemMessage( content="""You are a helpful assistant tasked with answering questions using a set of tools. Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma. Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """ ) supabase_url = os.environ.get("SUPABASE_URL") supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY") # build a retriever embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-mpnet-base-v2" ) # dim=768 supabase: Client = create_client(supabase_url, supabase_service_key) vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, table_name="documents", query_name="match_documents_langchain", ) create_retriever_tool = create_retriever_tool( retriever=vector_store.as_retriever(), name="Question Search", description="A tool to retrieve similar questions from a vector store.", ) # Initialize tools tools = [ search_wikipedia, search_web, arxiv_search, power, square_root, multiply, divide, subtract, add, modulus, ] def build_agent_graph(provider: str = "groq"): """Build the graph""" # Initialize LLM class try: gemini_api_key = os.getenv("GEMINI_API_KEY") if provider == "groq": # Groq https://console.groq.com/docs/models chat_model = ChatGroq( model="qwen-qwq-32b", temperature=0 ) # optional : qwen-qwq-32b gemma2-9b-it elif provider == "gemini": chat_model = ChatGoogleGenerativeAI( model="gemini-2.5-pro", temperature=1.0, max_retries=2, google_api_key=gemini_api_key, ) elif provider == "huggingface": llm = HuggingFaceEndpoint( url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", temperature=0, ) chat_model = ChatHuggingFace(llm=llm, verbose=True) else: raise ValueError("Invalid provider.") except Exception as e: raise Exception(f"Failed to initialize LLM: {str(e)}") llm_with_tools = chat_model.bind_tools(tools) # Create nodes def assistant(state: MessagesState): """Assistant node""" return {"messages": [llm_with_tools.invoke(state["messages"])]} def retriever(state: MessagesState): query = state["messages"][-1].content results = vector_store.similarity_search(query, k=1) if not results: print(f"[retriever] No similar documents found for query: {query}") return { "messages": [ AIMessage(content="I couldn't find any similar content in memory.") ] } similar_doc = results[0] content = similar_doc.page_content if "Final answer :" in content: answer = content.split("Final answer :")[-1].strip() else: answer = content.strip() return {"messages": [AIMessage(content=answer)]} # Build graph builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) # builder.add_node("assistant", assistant) # builder.add_node("tools", ToolNode(tools)) # builder.add_edge(START, "retriever") # builder.add_edge("retriever", "assistant") # builder.add_conditional_edges( # "assistant", # tools_condition, # ) # builder.add_edge("tools", "assistant") builder.set_entry_point("retriever") builder.set_finish_point("retriever") return builder.compile() # Manual test function def test_agent(): """Run a manual test of the agent""" print("\n" + "=" * 50) print("Starting Agent Test") print("=" * 50) # Check environment variables if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): print("\nError: HUGGINGFACEHUB_API_TOKEN not set") return if not os.getenv("GEMINI_API_KEY"): print("\nError: GEMINI_API_KEY not set") return if not os.getenv("TAVILY_API_KEY"): print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable") if not os.getenv("SUPABASE_URL"): print("\nWarning: SUPABASE_URL not set - web search will be unavailable") print("\nInitializing agent...") try: graph = build_agent_graph(provider="groq") print("Agent initialized successfully") except Exception as e: print(f"Failed to initialize agent: {str(e)}") return # Test a single question question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\"" print("\nTesting question:", question) print("-" * 50) try: # Create messages state messages = [HumanMessage(content=question)] # Run agent print("\nWaiting for response...") result = graph.invoke({"messages": messages}) # Get answer if result and "messages" in result and result["messages"]: answer = result["messages"][-1].content print("\nResponse received:") print("-" * 20) print(answer) print("-" * 20) else: print("\nError: No response from agent") except Exception as e: print(f"\nError processing question: {str(e)}") print("\n" + "=" * 50) print("Test Complete") print("=" * 50 + "\n") # Run test if script is run directly if __name__ == "__main__": test_agent()