BocchiMing's picture
Update agent.py
4c024a3 verified
import os
import json
import dotenv
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import ToolNode,tools_condition
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader,ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain.tools.retriever import create_retriever_tool
from langchain_core.tools import tool
from supabase.client import Client, create_client
from langchain.chat_models import init_chat_model
import random
from typing import Annotated,TypedDict
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage,SystemMessage
from langgraph.graph.message import add_messages
load_dotenv()
with open('metadata.jsonl', 'r') as jsonl_file:
json_list = list(jsonl_file)
json_QA = []
for json_str in json_list:
json_data = json.loads(json_str)
json_QA.append(json_data)
random.seed(42)
random_samples = random.sample(json_QA, 1)
supabase_url = os.environ.get("SUPABASE_URL")
supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
supabase: Client = create_client(supabase_url, supabase_key)
system_prompt = """
You are a helpful assistant tasked with answering questions using a set of tools.
If the tool is not available, you can try to find the information online. You can also use your own knowledge to answer the question.
You need to provide a step-by-step explanation of how you arrived at the answer.
==========================
Here is a few examples showing you how to answer the question step by step.
"""
for i,sample in enumerate(random_samples):
system_prompt += f"\nQuestion {i+1}: {sample['Question']}\nSteps:\n{sample['Annotator Metadata']['Steps']}\nTools:\n{sample['Annotator Metadata']['Tools']}\nFinal Answer: {sample['Final answer']}\n"
system_prompt += "\n==========================\n"
system_prompt += "Now, please answer the following question step by step.And if you can, please answer in Vietnamese.\n"
# save the system_prompt to a file
with open('system_prompt.txt', 'w') as f:
f.write(system_prompt)
with open('system_prompt.txt', 'r') as f:
system_prompt=f.read()
print(system_prompt)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
tavily_key = os.getenv("TAVILY_API_KEY")
# Tạo hoặc truy cập bảng vector
vector_store = SupabaseVectorStore(
client=supabase,
embedding=embeddings,
table_name="documents",
query_name="match_documents_langchain",
)
retriever = vector_store.as_retriever()
create_retriever_tool = create_retriever_tool(
retriever = vector_store.as_retriever(),
name= "Question_Retriever",
description= "Find similar questions in the vector database for the given question."
)
@tool
def multiply(a:int,b:int)->int:
"""Multiply two numbers
Args:
a: first int
b: second int
"""
return a*b
@tool
def subtract(a:int,b:int)->int:
"""Subtract two numbers:
Args:
a: first int
b: second int
"""
return a-b
@tool
def add(a:int,b:int)->int:
"""Add two numbers
Args:
a: first int
b: second int
"""
return a+b
@tool
def divide(a:int,b:int)->int:
"""Divide two numbers.
Args:
a: first int
b: second int
"""
return a/b
@tool
def modulus(a:int,b:int)->int:
"""Get the modulus of two numbers.
Args:
a: first int
b: second int
"""
return a%b
@tool
def wiki_search(query:str) -> str:
"""Search Wikipedia for a query and return maximum 2 results.
Args:
query: The search query."""
search_docs = WikipediaLoader(
query= query,
load_max_docs=2
).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return {'wiki_results' : formatted_search_docs}
@tool
def web_search(query: str) -> str:
"""Search Tavily for a query and return maximum 3 results.
Args:
query: The search query."""
search_docs = TavilySearchResults(max_results=3,tavily_api_key=tavily_key).invoke(query=query)
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
])
return {"web_results": formatted_search_docs}
@tool
def arvix_search(query: str) -> str:
"""Search Arxiv for a query and return maximum 3 result.
Args:
query: The search query."""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
])
return {"arvix_results": formatted_search_docs}
tools = [
multiply,
add,
subtract,
divide,
modulus,
wiki_search,
web_search,
arvix_search,
create_retriever_tool
]
def build_graph():
"""Build the graph"""
llm = init_chat_model("google_genai:gemini-2.0-flash",google_api_key=os.environ["GOOGLE_API_KEY"])
llm_with_tools = llm.bind_tools(tools)
sys_msg = SystemMessage(content=system_prompt)
class MessagesState(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# Node
def assistant(state: MessagesState):
"""Assistant node"""
return {"messages": [llm_with_tools.invoke(state["messages"])]}
def retriever(state: MessagesState):
"""Retriever node"""
similar_question = vector_store.similarity_search(state["messages"][0].content)
example_msg = HumanMessage(
content=f"Here I provide a question and answer using query for reference if it is similar to question below: \n\n{similar_question[0].page_content}\n\nNO MORE EXPLAIN, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].",
)
return {"messages": [sys_msg] + state["messages"] + [example_msg]}
# Build graph
builder = StateGraph(MessagesState)
builder.add_node("retriever", retriever)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
builder.add_edge(START, "retriever")
builder.add_edge("retriever", "assistant")
builder.add_conditional_edges(
"assistant",
# If tool call -> tools_condition routes to tools
# If not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_edge("tools", "assistant")
# Compile graph
return builder.compile()
if __name__ == "__main__":
question = "What is the capital of Vietnam?"
# Build the graph
graph = builder.compile()
# Run the graph
messages = [HumanMessage(content=question)]
messages = graph.invoke({"messages": messages})
for m in messages["messages"]:
m.pretty_print()