Spaces:
Configuration error
Configuration error
| import cmath | |
| import os | |
| from typing import Dict, List, Sequence, TypedDict, cast | |
| from dotenv import load_dotenv | |
| from langchain.tools.retriever import create_retriever_tool | |
| from langchain_community.document_loaders import ArxivLoader, WikipediaLoader | |
| from langchain_community.vectorstores import SupabaseVectorStore | |
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
| from langchain_core.tools import tool | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_groq import ChatGroq | |
| from langchain_huggingface import ( | |
| ChatHuggingFace, | |
| HuggingFaceEmbeddings, | |
| HuggingFaceEndpoint, | |
| ) | |
| from langchain_tavily import TavilySearch | |
| from langgraph.graph import END, START, MessagesState, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from pydantic import BaseModel | |
| from supabase.client import Client, create_client | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| class WebSearchInput(BaseModel): | |
| query: str | |
| class WikipediaSearchInput(BaseModel): | |
| query: str | |
| class ArxivSearchInput(BaseModel): | |
| query: str | |
| def search_web(query: str) -> str: | |
| """Search the web using Tavily and return relevant results.""" | |
| """Search Tavily for a query and return maximum 3 results. | |
| Args: | |
| query: The search query.""" | |
| search_docs = TavilySearch(max_results=3).invoke({"query": query}) | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return {"web_results": formatted_search_docs} | |
| def search_wikipedia(query: str) -> str: | |
| """Search Wikipedia using LangChain's loader and return the first document summary.""" | |
| try: | |
| loader = WikipediaLoader(query=query, lang="en", load_max_docs=2) | |
| docs = loader.load() | |
| if not docs: | |
| return {"error": f"No Wikipedia articles found for query: {query}"} | |
| formatted_docs = "\n\n---\n\n".join( | |
| [f"Wikipedia Article: {query}\n\n{doc.page_content}" for doc in docs] | |
| ) | |
| return {"wiki_results": formatted_docs} | |
| except Exception as e: | |
| return {"error": f"Error searching Wikipedia: {str(e)}"} | |
| def arxiv_search(query: str) -> str: | |
| """Search Arxiv for a query and return maximum 3 result. | |
| Args: | |
| query: The search query.""" | |
| search_docs = ArxivLoader(query=query, load_max_docs=3).load() | |
| formatted_search_docs = "\n\n---\n\n".join( | |
| [ | |
| f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>' | |
| for doc in search_docs | |
| ] | |
| ) | |
| return {"arxiv_results": formatted_search_docs} | |
| def power(a: float, b: float) -> float: | |
| """ | |
| Get the power of two numbers. | |
| Args: | |
| a (float): the first number | |
| b (float): the second number | |
| """ | |
| return a**b | |
| def square_root(a: float) -> float | complex: | |
| """ | |
| Get the square root of a number. | |
| Args: | |
| a (float): the number to get the square root of | |
| """ | |
| if a >= 0: | |
| return a**0.5 | |
| return cmath.sqrt(a) | |
| def multiply(a: int, b: int) -> int: | |
| """Multiply two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| return a * b | |
| def add(a: int, b: int) -> int: | |
| """Add two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| return a + b | |
| def subtract(a: int, b: int) -> int: | |
| """Subtract two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| return a - b | |
| def divide(a: float, b: float) -> float: | |
| """ | |
| Divides two numbers. | |
| Args: | |
| a (float): the first float number | |
| b (float): the second float number | |
| """ | |
| if b == 0: | |
| raise ValueError("Cannot divided by zero.") | |
| return a / b | |
| def modulus(a: int, b: int) -> int: | |
| """Get the modulus of two numbers. | |
| Args: | |
| a: first int | |
| b: second int | |
| """ | |
| return a % b | |
| # System prompt | |
| system_prompt = SystemMessage( | |
| content="""You are a helpful assistant tasked with answering questions using a set of tools. | |
| Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: | |
| FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, Apply the rules above for each element (number or string), ensure there is exactly one space after each comma. | |
| Your answer should only start with "FINAL ANSWER: ", then follows with the answer. """ | |
| ) | |
| supabase_url = os.environ.get("SUPABASE_URL") | |
| supabase_service_key = os.environ.get("SUPABASE_SERVICE_KEY") | |
| # build a retriever | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-mpnet-base-v2" | |
| ) # dim=768 | |
| supabase: Client = create_client(supabase_url, supabase_service_key) | |
| vector_store = SupabaseVectorStore( | |
| client=supabase, | |
| embedding=embeddings, | |
| table_name="documents", | |
| query_name="match_documents_langchain", | |
| ) | |
| create_retriever_tool = create_retriever_tool( | |
| retriever=vector_store.as_retriever(), | |
| name="Question Search", | |
| description="A tool to retrieve similar questions from a vector store.", | |
| ) | |
| # Initialize tools | |
| tools = [ | |
| search_wikipedia, | |
| search_web, | |
| arxiv_search, | |
| power, | |
| square_root, | |
| multiply, | |
| divide, | |
| subtract, | |
| add, | |
| modulus, | |
| ] | |
| def build_agent_graph(provider: str = "groq"): | |
| """Build the graph""" | |
| # Initialize LLM class | |
| try: | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| if provider == "groq": | |
| # Groq https://console.groq.com/docs/models | |
| chat_model = ChatGroq( | |
| model="qwen-qwq-32b", temperature=0 | |
| ) # optional : qwen-qwq-32b gemma2-9b-it | |
| elif provider == "gemini": | |
| chat_model = ChatGoogleGenerativeAI( | |
| model="gemini-2.5-pro", | |
| temperature=1.0, | |
| max_retries=2, | |
| google_api_key=gemini_api_key, | |
| ) | |
| elif provider == "huggingface": | |
| llm = HuggingFaceEndpoint( | |
| url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf", | |
| temperature=0, | |
| ) | |
| chat_model = ChatHuggingFace(llm=llm, verbose=True) | |
| else: | |
| raise ValueError("Invalid provider.") | |
| except Exception as e: | |
| raise Exception(f"Failed to initialize LLM: {str(e)}") | |
| llm_with_tools = chat_model.bind_tools(tools) | |
| # Create nodes | |
| def assistant(state: MessagesState): | |
| """Assistant node""" | |
| return {"messages": [llm_with_tools.invoke(state["messages"])]} | |
| def retriever(state: MessagesState): | |
| query = state["messages"][-1].content | |
| results = vector_store.similarity_search(query, k=1) | |
| if not results: | |
| print(f"[retriever] No similar documents found for query: {query}") | |
| return { | |
| "messages": [ | |
| AIMessage(content="I couldn't find any similar content in memory.") | |
| ] | |
| } | |
| similar_doc = results[0] | |
| content = similar_doc.page_content | |
| if "Final answer :" in content: | |
| answer = content.split("Final answer :")[-1].strip() | |
| else: | |
| answer = content.strip() | |
| return {"messages": [AIMessage(content=answer)]} | |
| # Build graph | |
| builder = StateGraph(MessagesState) | |
| builder.add_node("retriever", retriever) | |
| # builder.add_node("assistant", assistant) | |
| # builder.add_node("tools", ToolNode(tools)) | |
| # builder.add_edge(START, "retriever") | |
| # builder.add_edge("retriever", "assistant") | |
| # builder.add_conditional_edges( | |
| # "assistant", | |
| # tools_condition, | |
| # ) | |
| # builder.add_edge("tools", "assistant") | |
| builder.set_entry_point("retriever") | |
| builder.set_finish_point("retriever") | |
| return builder.compile() | |
| # Manual test function | |
| def test_agent(): | |
| """Run a manual test of the agent""" | |
| print("\n" + "=" * 50) | |
| print("Starting Agent Test") | |
| print("=" * 50) | |
| # Check environment variables | |
| if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): | |
| print("\nError: HUGGINGFACEHUB_API_TOKEN not set") | |
| return | |
| if not os.getenv("GEMINI_API_KEY"): | |
| print("\nError: GEMINI_API_KEY not set") | |
| return | |
| if not os.getenv("TAVILY_API_KEY"): | |
| print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable") | |
| if not os.getenv("SUPABASE_URL"): | |
| print("\nWarning: SUPABASE_URL not set - web search will be unavailable") | |
| print("\nInitializing agent...") | |
| try: | |
| graph = build_agent_graph(provider="groq") | |
| print("Agent initialized successfully") | |
| except Exception as e: | |
| print(f"Failed to initialize agent: {str(e)}") | |
| return | |
| # Test a single question | |
| question = "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\"" | |
| print("\nTesting question:", question) | |
| print("-" * 50) | |
| try: | |
| # Create messages state | |
| messages = [HumanMessage(content=question)] | |
| # Run agent | |
| print("\nWaiting for response...") | |
| result = graph.invoke({"messages": messages}) | |
| # Get answer | |
| if result and "messages" in result and result["messages"]: | |
| answer = result["messages"][-1].content | |
| print("\nResponse received:") | |
| print("-" * 20) | |
| print(answer) | |
| print("-" * 20) | |
| else: | |
| print("\nError: No response from agent") | |
| except Exception as e: | |
| print(f"\nError processing question: {str(e)}") | |
| print("\n" + "=" * 50) | |
| print("Test Complete") | |
| print("=" * 50 + "\n") | |
| # Run test if script is run directly | |
| if __name__ == "__main__": | |
| test_agent() | |