grshot commited on
Commit
a58a2c6
·
1 Parent(s): c5a9cfd

Include manual testing

Browse files
Files changed (3) hide show
  1. .env.sample +7 -0
  2. agent.py +77 -3
  3. requirements.txt +12 -14
.env.sample ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Required API Keys
2
+ GROQ_API_KEY=your-groq-api-key-here
3
+ TAVILY_API_KEY=your-tavily-api-key-here
4
+
5
+ # Optional Environment Variables
6
+ PYTHONPATH=.
7
+ DEBUG=false
agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
- from typing import Dict
3
 
 
4
  from langchain_community.document_loaders import WikipediaLoader
5
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
6
  from langchain_core.tools import tool
@@ -128,9 +129,22 @@ def build_agent_graph(provider: str = "groq"):
128
  llm_with_tools = llm.bind_tools(tools)
129
 
130
  # Create nodes
131
- def assistant(state: MessagesState):
132
  """Assistant node"""
133
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # Build graph
136
  builder = StateGraph(MessagesState)
@@ -143,3 +157,63 @@ def build_agent_graph(provider: str = "groq"):
143
  builder.add_edge("assistant", END)
144
 
145
  return builder.compile()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ from typing import Dict, List, Sequence, TypedDict, cast
3
 
4
+ from dotenv import load_dotenv
5
  from langchain_community.document_loaders import WikipediaLoader
6
  from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7
  from langchain_core.tools import tool
 
129
  llm_with_tools = llm.bind_tools(tools)
130
 
131
  # Create nodes
132
+ def assistant(state: MessagesState) -> Dict[str, List[AIMessage]]:
133
  """Assistant node"""
134
+ try:
135
+ # Get last message
136
+ messages = state.get("messages", [])
137
+ if not messages:
138
+ return {"messages": [AIMessage(content="Error: No messages found")]}
139
+
140
+ # Run LLM and ensure AIMessage response
141
+ response = llm_with_tools.invoke(messages)
142
+ if isinstance(response, AIMessage):
143
+ return {"messages": [response]}
144
+ return {"messages": [AIMessage(content=str(response))]}
145
+
146
+ except Exception as e:
147
+ return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
148
 
149
  # Build graph
150
  builder = StateGraph(MessagesState)
 
157
  builder.add_edge("assistant", END)
158
 
159
  return builder.compile()
160
+
161
+
162
+ # Manual test function
163
+ def test_agent():
164
+ """Run a manual test of the agent"""
165
+ # Load environment variables from .env file
166
+ load_dotenv()
167
+ print("\n" + "=" * 50)
168
+ print("Starting Agent Test")
169
+ print("=" * 50)
170
+
171
+ # Check environment variables
172
+ if not os.getenv("GROQ_API_KEY"):
173
+ print("\nError: GROQ_API_KEY not set")
174
+ return
175
+ if not os.getenv("TAVILY_API_KEY"):
176
+ print("\nWarning: TAVILY_API_KEY not set - web search will be unavailable")
177
+
178
+ print("\nInitializing agent...")
179
+ try:
180
+ graph = build_agent_graph()
181
+ print("Agent initialized successfully")
182
+ except Exception as e:
183
+ print(f"Failed to initialize agent: {str(e)}")
184
+ return
185
+
186
+ # Test a single question
187
+ question = "What is the first name of the only Malko Competition recipient from the 20th Century (after 1977) whose nationality on record is a country that no longer exists?"
188
+ print("\nTesting question:", question)
189
+ print("-" * 50)
190
+
191
+ try:
192
+ # Create messages state
193
+ messages = [system_prompt, HumanMessage(content=question)]
194
+
195
+ # Run agent
196
+ print("\nWaiting for response...")
197
+ result = graph.invoke({"messages": messages})
198
+
199
+ # Get answer
200
+ if result and "messages" in result and result["messages"]:
201
+ answer = result["messages"][-1].content
202
+ print("\nResponse received:")
203
+ print("-" * 20)
204
+ print(answer)
205
+ print("-" * 20)
206
+ else:
207
+ print("\nError: No response from agent")
208
+
209
+ except Exception as e:
210
+ print(f"\nError processing question: {str(e)}")
211
+
212
+ print("\n" + "=" * 50)
213
+ print("Test Complete")
214
+ print("=" * 50 + "\n")
215
+
216
+
217
+ # Run test if script is run directly
218
+ if __name__ == "__main__":
219
+ test_agent()
requirements.txt CHANGED
@@ -1,15 +1,13 @@
1
- gradio>=4.0.0
2
- requests>=2.31.0
3
- langchain>=0.1.0
4
- langchain-core>=0.1.0
5
- langchain-community>=0.0.10
6
- langchain_huggingface>=0.0.10
7
- langchain-groq>=0.0.5
8
- langchain-experimental>=0.0.40
9
- langchain-tavily>=0.0.5
10
- langgraph>=0.0.15
11
- tavily-python>=0.3.0
12
- wikipedia>=1.4.0
13
  youtube-transcript-api==0.6.3
14
- pytube>=15.0.0
15
- pydantic>=2.0.0
 
1
+ langchain
2
+ langchain-core
3
+ langchain-community
4
+ langchain-groq
5
+ langchain-tavily
6
+ langgraph
7
+ python-dotenv
8
+ wikipedia
9
+ requests
10
+ pydantic
11
+ lxml
 
12
  youtube-transcript-api==0.6.3
13
+ pytube>=15.0.0