import os import json import dotenv from dotenv import load_dotenv from langgraph.graph import START, StateGraph from langgraph.prebuilt import ToolNode,tools_condition from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader,ArxivLoader from langchain_community.vectorstores import SupabaseVectorStore from langchain.tools.retriever import create_retriever_tool from langchain_core.tools import tool from supabase.client import Client, create_client from langchain.chat_models import init_chat_model import random from typing import Annotated,TypedDict from langchain_core.messages import AnyMessage, HumanMessage, AIMessage,SystemMessage from langgraph.graph.message import add_messages load_dotenv() with open('metadata.jsonl', 'r') as jsonl_file: json_list = list(jsonl_file) json_QA = [] for json_str in json_list: json_data = json.loads(json_str) json_QA.append(json_data) random.seed(42) random_samples = random.sample(json_QA, 1) supabase_url = os.environ.get("SUPABASE_URL") supabase_key = os.environ.get("SUPABASE_SERVICE_KEY") supabase: Client = create_client(supabase_url, supabase_key) system_prompt = """ You are a helpful assistant tasked with answering questions using a set of tools. If the tool is not available, you can try to find the information online. You can also use your own knowledge to answer the question. You need to provide a step-by-step explanation of how you arrived at the answer. ========================== Here is a few examples showing you how to answer the question step by step. """ for i,sample in enumerate(random_samples): system_prompt += f"\nQuestion {i+1}: {sample['Question']}\nSteps:\n{sample['Annotator Metadata']['Steps']}\nTools:\n{sample['Annotator Metadata']['Tools']}\nFinal Answer: {sample['Final answer']}\n" system_prompt += "\n==========================\n" system_prompt += "Now, please answer the following question step by step.And if you can, please answer in Vietnamese.\n" # save the system_prompt to a file with open('system_prompt.txt', 'w') as f: f.write(system_prompt) with open('system_prompt.txt', 'r') as f: system_prompt=f.read() print(system_prompt) embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") tavily_key = os.getenv("TAVILY_API_KEY") # Tạo hoặc truy cập bảng vector vector_store = SupabaseVectorStore( client=supabase, embedding=embeddings, table_name="documents", query_name="match_documents_langchain", ) retriever = vector_store.as_retriever() create_retriever_tool = create_retriever_tool( retriever = vector_store.as_retriever(), name= "Question_Retriever", description= "Find similar questions in the vector database for the given question." ) @tool def multiply(a:int,b:int)->int: """Multiply two numbers Args: a: first int b: second int """ return a*b @tool def subtract(a:int,b:int)->int: """Subtract two numbers: Args: a: first int b: second int """ return a-b @tool def add(a:int,b:int)->int: """Add two numbers Args: a: first int b: second int """ return a+b @tool def divide(a:int,b:int)->int: """Divide two numbers. Args: a: first int b: second int """ return a/b @tool def modulus(a:int,b:int)->int: """Get the modulus of two numbers. Args: a: first int b: second int """ return a%b @tool def wiki_search(query:str) -> str: """Search Wikipedia for a query and return maximum 2 results. Args: query: The search query.""" search_docs = WikipediaLoader( query= query, load_max_docs=2 ).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ] ) return {'wiki_results' : formatted_search_docs} @tool def web_search(query: str) -> str: """Search Tavily for a query and return maximum 3 results. Args: query: The search query.""" search_docs = TavilySearchResults(max_results=3,tavily_api_key=tavily_key).invoke(query=query) formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return {"web_results": formatted_search_docs} @tool def arvix_search(query: str) -> str: """Search Arxiv for a query and return maximum 3 result. Args: query: The search query.""" search_docs = ArxivLoader(query=query, load_max_docs=3).load() formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ]) return {"arvix_results": formatted_search_docs} tools = [ multiply, add, subtract, divide, modulus, wiki_search, web_search, arvix_search, create_retriever_tool ] def build_graph(): """Build the graph""" llm = init_chat_model("google_genai:gemini-2.0-flash",google_api_key=os.environ["GOOGLE_API_KEY"]) llm_with_tools = llm.bind_tools(tools) sys_msg = SystemMessage(content=system_prompt) class MessagesState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # Node def assistant(state: MessagesState): """Assistant node""" return {"messages": [llm_with_tools.invoke(state["messages"])]} def retriever(state: MessagesState): """Retriever node""" similar_question = vector_store.similarity_search(state["messages"][0].content) example_msg = HumanMessage( content=f"Here I provide a question and answer using query for reference if it is similar to question below: \n\n{similar_question[0].page_content}\n\nNO MORE EXPLAIN, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].", ) return {"messages": [sys_msg] + state["messages"] + [example_msg]} # Build graph builder = StateGraph(MessagesState) builder.add_node("retriever", retriever) builder.add_node("assistant", assistant) builder.add_node("tools", ToolNode(tools)) builder.add_edge(START, "retriever") builder.add_edge("retriever", "assistant") builder.add_conditional_edges( "assistant", # If tool call -> tools_condition routes to tools # If not a tool call -> tools_condition routes to END tools_condition, ) builder.add_edge("tools", "assistant") # Compile graph return builder.compile() if __name__ == "__main__": question = "What is the capital of Vietnam?" # Build the graph graph = builder.compile() # Run the graph messages = [HumanMessage(content=question)] messages = graph.invoke({"messages": messages}) for m in messages["messages"]: m.pretty_print()