BocchiMing commited on
Commit
50c0210
·
verified ·
1 Parent(s): ee3d024

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +231 -228
agent.py CHANGED
@@ -1,228 +1,231 @@
1
- import os
2
- import json
3
- from dotenv import load_dotenv
4
-
5
- import dotenv
6
- from langgraph.graph import START, StateGraph
7
- from langgraph.prebuilt import tools_condition
8
- from langgraph.prebuilt import ToolNode
9
- from langchain_google_genai import ChatGoogleGenerativeAI
10
- from langchain_huggingface import HuggingFaceEmbeddings
11
- from langchain_community.tools.tavily_search import TavilySearchResults
12
- from langchain_community.document_loaders import WikipediaLoader
13
- from langchain_community.document_loaders import ArxivLoader
14
- from langchain_community.vectorstores import SupabaseVectorStore
15
- from langchain.tools.retriever import create_retriever_tool
16
- from langchain_core.messages import HumanMessage, SystemMessage
17
- from langchain_core.tools import tool
18
- from supabase.client import Client, create_client
19
- from langchain.chat_models import init_chat_model
20
- import random
21
- from typing import Annotated,TypedDict
22
- from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
23
- from langgraph.graph.message import add_messages
24
-
25
-
26
-
27
- load_dotenv()
28
-
29
- with open('metadata.jsonl', 'r') as jsonl_file:
30
- json_list = list(jsonl_file)
31
-
32
- json_QA = []
33
- for json_str in json_list:
34
- json_data = json.loads(json_str)
35
- json_QA.append(json_data)
36
-
37
- random.seed(42)
38
- random_samples = random.sample(json_QA, 1)
39
-
40
- supabase_url = os.environ.get("SUPABASE_URL")
41
- supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
42
- supabase: Client = create_client(supabase_url, supabase_key)
43
-
44
- system_prompt = """
45
- You are a helpful assistant tasked with answering questions using a set of tools.
46
- If the tool is not available, you can try to find the information online. You can also use your own knowledge to answer the question.
47
- You need to provide a step-by-step explanation of how you arrived at the answer.
48
- ==========================
49
- Here is a few examples showing you how to answer the question step by step.
50
- """
51
- for i,sample in enumerate(random_samples):
52
- system_prompt += f"\nQuestion {i+1}: {sample['Question']}\nSteps:\n{sample['Annotator Metadata']['Steps']}\nTools:\n{sample['Annotator Metadata']['Tools']}\nFinal Answer: {sample['Final answer']}\n"
53
- system_prompt += "\n==========================\n"
54
- system_prompt += "Now, please answer the following question step by step.And if you can, please answer in Vietnamese.\n"
55
- # save the system_prompt to a file
56
- with open('system_prompt.txt', 'w') as f:
57
- f.write(system_prompt)
58
-
59
- with open('system_prompt.txt', 'r') as f:
60
- system_prompt=f.read()
61
- print(system_prompt)
62
-
63
-
64
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
65
- tavily_key = os.getenv("TAVILY_API_KEY")
66
-
67
- # Tạo hoặc truy cập bảng vector
68
- vector_store = SupabaseVectorStore(
69
- client=supabase,
70
- embedding=embeddings,
71
- table_name="documents", # bảng sẽ tự tạo nếu chưa có
72
- query_name="match_documents_langchain", # tên function tự sinh
73
- )
74
- retriever = vector_store.as_retriever()
75
-
76
- create_retriever_tool = create_retriever_tool(
77
- retriever = vector_store.as_retriever(),
78
- name= "Question_Retriever",
79
- description= "Find similar questions in the vector database for the given question."
80
- )
81
-
82
- @tool
83
- def multiply(a:int,b:int)->int:
84
- """Multiply two numbers
85
- Args:
86
- a: first int
87
- b: second int
88
- """
89
- return a*b
90
-
91
- @tool
92
- def subtract(a:int,b:int)->int:
93
- """Subtract two numbers:
94
- Args:
95
- a: first int
96
- b: second int
97
- """
98
- return a-b
99
-
100
- @tool
101
- def add(a:int,b:int)->int:
102
- """Add two numbers
103
- Args:
104
- a: first int
105
- b: second int
106
- """
107
- return a+b
108
-
109
- @tool
110
- def divide(a:int,b:int)->int:
111
- """Divide two numbers.
112
- Args:
113
- a: first int
114
- b: second int
115
- """
116
- return a/b
117
-
118
- @tool
119
- def modulus(a:int,b:int)->int:
120
- """Get the modulus of two numbers.
121
- Args:
122
- a: first int
123
- b: second int
124
- """
125
- return a%b
126
-
127
- @tool
128
- def wiki_search(query:str) -> str:
129
- """Search Wikipedia for a query and return maximum 2 results.
130
-
131
- Args:
132
- query: The search query."""
133
- search_docs = WikipediaLoader(
134
- query= query,
135
- load_max_docs=2
136
- ).load()
137
-
138
- formatted_search_docs = "\n\n---\n\n".join(
139
- [
140
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
141
- for doc in search_docs
142
- ]
143
- )
144
- return {'wiki_results' : formatted_search_docs}
145
-
146
- @tool
147
- def web_search(query: str) -> str:
148
- """Search Tavily for a query and return maximum 3 results.
149
-
150
- Args:
151
- query: The search query."""
152
- search_docs = TavilySearchResults(max_results=3,tavily_api_key=tavily_key).invoke(query=query)
153
- formatted_search_docs = "\n\n---\n\n".join(
154
- [
155
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
156
- for doc in search_docs
157
- ])
158
- return {"web_results": formatted_search_docs}
159
- @tool
160
- def arvix_search(query: str) -> str:
161
- """Search Arxiv for a query and return maximum 3 result.
162
-
163
- Args:
164
- query: The search query."""
165
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
166
- formatted_search_docs = "\n\n---\n\n".join(
167
- [
168
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
169
- for doc in search_docs
170
- ])
171
- return {"arvix_results": formatted_search_docs}
172
-
173
- tools = [
174
- multiply,
175
- add,
176
- subtract,
177
- divide,
178
- modulus,
179
- wiki_search,
180
- web_search,
181
- arvix_search,
182
- create_retriever_tool
183
- ]
184
-
185
- llm = init_chat_model("google_genai:gemini-2.0-flash",google_api_key=os.environ["GOOGLE_API_KEY"])
186
- llm_with_tools = llm.bind_tools(tools)
187
-
188
- sys_msg = SystemMessage(content=system_prompt)
189
-
190
- class MessagesState(TypedDict):
191
- messages: Annotated[list[AnyMessage], add_messages]
192
- # Node
193
- def assistant(state: MessagesState):
194
- """Assistant node"""
195
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
196
- def retriever(state: MessagesState):
197
- """Retriever node"""
198
- similar_question = vector_store.similarity_search(state["messages"][0].content)
199
- example_msg = HumanMessage(
200
- content=f"Here I provide a question and answer using query for reference if it is similar to question below: \n\n{similar_question[0].page_content}",
201
- )
202
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
203
-
204
- # Build graph
205
- builder = StateGraph(MessagesState)
206
- builder.add_node("retriever", retriever)
207
- builder.add_node("assistant", assistant)
208
- builder.add_node("tools", ToolNode(tools))
209
- builder.add_edge(START, "retriever")
210
- builder.add_edge("retriever", "assistant")
211
- builder.add_conditional_edges(
212
- "assistant",
213
- # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
214
- # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
215
- tools_condition,
216
- )
217
- builder.add_edge("tools", "assistant")
218
-
219
- # Compile graph
220
- graph = builder.compile()
221
-
222
- question = "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?"
223
- messages = [HumanMessage(content=question)]
224
- messages = graph.invoke({"messages": messages})
225
-
226
- for m in messages['messages']:
227
- m.pretty_print()
228
-
 
 
 
 
1
+ import os
2
+ import json
3
+ from dotenv import load_dotenv
4
+
5
+ import dotenv
6
+ from langgraph.graph import START, StateGraph
7
+ from langgraph.prebuilt import tools_condition
8
+ from langgraph.prebuilt import ToolNode
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ from langchain_community.tools.tavily_search import TavilySearchResults
12
+ from langchain_community.document_loaders import WikipediaLoader
13
+ from langchain_community.document_loaders import ArxivLoader
14
+ from langchain_community.vectorstores import SupabaseVectorStore
15
+ from langchain.tools.retriever import create_retriever_tool
16
+ from langchain_core.messages import HumanMessage, SystemMessage
17
+ from langchain_core.tools import tool
18
+ from supabase.client import Client, create_client
19
+ from langchain.chat_models import init_chat_model
20
+ import random
21
+ from typing import Annotated,TypedDict
22
+ from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
23
+ from langgraph.graph.message import add_messages
24
+
25
+
26
+
27
+ load_dotenv()
28
+
29
+ with open('metadata.jsonl', 'r') as jsonl_file:
30
+ json_list = list(jsonl_file)
31
+
32
+ json_QA = []
33
+ for json_str in json_list:
34
+ json_data = json.loads(json_str)
35
+ json_QA.append(json_data)
36
+
37
+ random.seed(42)
38
+ random_samples = random.sample(json_QA, 1)
39
+
40
+ supabase_url = os.environ.get("SUPABASE_URL")
41
+ supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
42
+ supabase: Client = create_client(supabase_url, supabase_key)
43
+
44
+ system_prompt = """
45
+ You are a helpful assistant tasked with answering questions using a set of tools.
46
+ If the tool is not available, you can try to find the information online. You can also use your own knowledge to answer the question.
47
+ You need to provide a step-by-step explanation of how you arrived at the answer.
48
+ ==========================
49
+ Here is a few examples showing you how to answer the question step by step.
50
+ """
51
+ for i,sample in enumerate(random_samples):
52
+ system_prompt += f"\nQuestion {i+1}: {sample['Question']}\nSteps:\n{sample['Annotator Metadata']['Steps']}\nTools:\n{sample['Annotator Metadata']['Tools']}\nFinal Answer: {sample['Final answer']}\n"
53
+ system_prompt += "\n==========================\n"
54
+ system_prompt += "Now, please answer the following question step by step.And if you can, please answer in Vietnamese.\n"
55
+ # save the system_prompt to a file
56
+ with open('system_prompt.txt', 'w') as f:
57
+ f.write(system_prompt)
58
+
59
+ with open('system_prompt.txt', 'r') as f:
60
+ system_prompt=f.read()
61
+ print(system_prompt)
62
+
63
+
64
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
65
+ tavily_key = os.getenv("TAVILY_API_KEY")
66
+
67
+ # Tạo hoặc truy cập bảng vector
68
+ vector_store = SupabaseVectorStore(
69
+ client=supabase,
70
+ embedding=embeddings,
71
+ table_name="documents", # bảng sẽ tự tạo nếu chưa có
72
+ query_name="match_documents_langchain", # tên function tự sinh
73
+ )
74
+ retriever = vector_store.as_retriever()
75
+
76
+ create_retriever_tool = create_retriever_tool(
77
+ retriever = vector_store.as_retriever(),
78
+ name= "Question_Retriever",
79
+ description= "Find similar questions in the vector database for the given question."
80
+ )
81
+
82
+ @tool
83
+ def multiply(a:int,b:int)->int:
84
+ """Multiply two numbers
85
+ Args:
86
+ a: first int
87
+ b: second int
88
+ """
89
+ return a*b
90
+
91
+ @tool
92
+ def subtract(a:int,b:int)->int:
93
+ """Subtract two numbers:
94
+ Args:
95
+ a: first int
96
+ b: second int
97
+ """
98
+ return a-b
99
+
100
+ @tool
101
+ def add(a:int,b:int)->int:
102
+ """Add two numbers
103
+ Args:
104
+ a: first int
105
+ b: second int
106
+ """
107
+ return a+b
108
+
109
+ @tool
110
+ def divide(a:int,b:int)->int:
111
+ """Divide two numbers.
112
+ Args:
113
+ a: first int
114
+ b: second int
115
+ """
116
+ return a/b
117
+
118
+ @tool
119
+ def modulus(a:int,b:int)->int:
120
+ """Get the modulus of two numbers.
121
+ Args:
122
+ a: first int
123
+ b: second int
124
+ """
125
+ return a%b
126
+
127
+ @tool
128
+ def wiki_search(query:str) -> str:
129
+ """Search Wikipedia for a query and return maximum 2 results.
130
+
131
+ Args:
132
+ query: The search query."""
133
+ search_docs = WikipediaLoader(
134
+ query= query,
135
+ load_max_docs=2
136
+ ).load()
137
+
138
+ formatted_search_docs = "\n\n---\n\n".join(
139
+ [
140
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
141
+ for doc in search_docs
142
+ ]
143
+ )
144
+ return {'wiki_results' : formatted_search_docs}
145
+
146
+ @tool
147
+ def web_search(query: str) -> str:
148
+ """Search Tavily for a query and return maximum 3 results.
149
+
150
+ Args:
151
+ query: The search query."""
152
+ search_docs = TavilySearchResults(max_results=3,tavily_api_key=tavily_key).invoke(query=query)
153
+ formatted_search_docs = "\n\n---\n\n".join(
154
+ [
155
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
156
+ for doc in search_docs
157
+ ])
158
+ return {"web_results": formatted_search_docs}
159
+ @tool
160
+ def arvix_search(query: str) -> str:
161
+ """Search Arxiv for a query and return maximum 3 result.
162
+
163
+ Args:
164
+ query: The search query."""
165
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
166
+ formatted_search_docs = "\n\n---\n\n".join(
167
+ [
168
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
169
+ for doc in search_docs
170
+ ])
171
+ return {"arvix_results": formatted_search_docs}
172
+
173
+ tools = [
174
+ multiply,
175
+ add,
176
+ subtract,
177
+ divide,
178
+ modulus,
179
+ wiki_search,
180
+ web_search,
181
+ arvix_search,
182
+ create_retriever_tool
183
+ ]
184
+
185
+ llm = init_chat_model("google_genai:gemini-2.0-flash",google_api_key=os.environ["GOOGLE_API_KEY"])
186
+ llm_with_tools = llm.bind_tools(tools)
187
+
188
+ sys_msg = SystemMessage(content=system_prompt)
189
+
190
+ class MessagesState(TypedDict):
191
+ messages: Annotated[list[AnyMessage], add_messages]
192
+ # Node
193
+ def assistant(state: MessagesState):
194
+ """Assistant node"""
195
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
196
+ def retriever(state: MessagesState):
197
+ """Retriever node"""
198
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
199
+ example_msg = HumanMessage(
200
+ content=f"Here I provide a question and answer using query for reference if it is similar to question below: \n\n{similar_question[0].page_content}",
201
+ )
202
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
203
+
204
+ # Build graph
205
+ builder = StateGraph(MessagesState)
206
+ builder.add_node("retriever", retriever)
207
+ builder.add_node("assistant", assistant)
208
+ builder.add_node("tools", ToolNode(tools))
209
+ builder.add_edge(START, "retriever")
210
+ builder.add_edge("retriever", "assistant")
211
+ builder.add_conditional_edges(
212
+ "assistant",
213
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
214
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
215
+ tools_condition,
216
+ )
217
+ builder.add_edge("tools", "assistant")
218
+
219
+ # Compile graph
220
+ graph = builder.compile()
221
+
222
+ if __name__ == "__main__":
223
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
224
+ # Build the graph
225
+ graph = builder.compile()
226
+ # Run the graph
227
+ messages = [HumanMessage(content=question)]
228
+ messages = graph.invoke({"messages": messages})
229
+ for m in messages["messages"]:
230
+ m.pretty_print()
231
+