File size: 4,464 Bytes
a9dc537 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""
Test MemoryAgent with ChromaDB
"""
import asyncio
from src.llm.langchain_ollama_client import get_langchain_client
from src.agents.memory_agent import create_memory_agent
from src.workflow.langgraph_state import ScenarioType
async def test_memory_agent():
print("Testing MemoryAgent with ChromaDB...")
print()
# Initialize LangChain client
client = get_langchain_client(default_complexity='standard', enable_monitoring=False)
print("β LangChain client initialized")
# Create MemoryAgent
memory = create_memory_agent(llm_client=client)
print("β MemoryAgent created")
print()
# Test 1: Collection stats
print("Test 1: ChromaDB collections")
stats = memory.get_collection_stats()
print(f" β Episodic memory: {stats['episodic_count']} episodes")
print(f" β Semantic memory: {stats['semantic_count']} documents")
print(f" β Stakeholder profiles: {stats['stakeholders_count']} profiles")
print()
# Test 2: Store episode
print("Test 2: Store episode")
await memory.store_episode(
task_id="test_001",
task_description="Analyze patent for commercialization",
scenario=ScenarioType.PATENT_WAKEUP,
workflow_steps=[
{"id": "step1", "description": "Extract patent claims"},
{"id": "step2", "description": "Identify market opportunities"}
],
outcome={"success": True, "matches": 3},
quality_score=0.92,
execution_time=45.3,
iterations_used=1
)
print(" β Episode stored successfully")
print()
# Test 3: Store knowledge
print("Test 3: Store domain knowledge")
await memory.store_knowledge(
documents=[
"Patents typically include claims, description, drawings, and abstract.",
"GDPR requires explicit consent for personal data processing."
],
metadatas=[
{"source": "patent_guide", "topic": "patent_structure"},
{"source": "gdpr_regulation", "topic": "data_protection"}
],
category="best_practices"
)
print(" β Knowledge stored successfully")
print()
# Test 4: Store stakeholder profile
print("Test 4: Store stakeholder profile")
await memory.store_stakeholder_profile(
name="Dr. Jane Smith",
profile={
"expertise": ["AI", "Drug Discovery", "Machine Learning"],
"interests": ["Pharmaceutical AI", "Clinical Trials"],
"technologies": ["Neural Networks", "NLP", "Computer Vision"],
"location": "Montreal, QC",
"collaborations": "Worked with XYZ Corp on AI diagnostics"
},
categories=["AI", "Healthcare"]
)
print(" β Stakeholder profile stored")
print()
# Test 5: Retrieve similar episodes
print("Test 5: Retrieve similar episodes")
episodes = await memory.get_similar_episodes(
task_description="Patent analysis workflow",
scenario=ScenarioType.PATENT_WAKEUP,
min_quality_score=0.8,
top_k=2
)
print(f" β Found {len(episodes)} similar episodes")
if episodes:
print(f" β Latest episode score: {episodes[0]['metadata'].get('quality_score', 0):.2f}")
print()
# Test 6: Get domain knowledge
print("Test 6: Retrieve domain knowledge")
knowledge = await memory.get_domain_knowledge(
query="patent structure and components",
category="best_practices",
top_k=2
)
print(f" β Found {len(knowledge)} knowledge documents")
print()
# Test 7: Find matching stakeholders
print("Test 7: Find matching stakeholders")
stakeholders = await memory.find_matching_stakeholders(
requirements="AI researcher with drug discovery experience",
location="Montreal, QC",
top_k=2
)
print(f" β Found {len(stakeholders)} matching stakeholders")
if stakeholders:
print(f" β Top match: {stakeholders[0]['name']}")
print()
# Final stats
print("Final collection stats:")
final_stats = memory.get_collection_stats()
print(f" Episodes: {final_stats['episodic_count']}")
print(f" Knowledge: {final_stats['semantic_count']}")
print(f" Stakeholders: {final_stats['stakeholders_count']}")
print()
print("β All MemoryAgent tests passed!")
if __name__ == "__main__":
asyncio.run(test_memory_agent())
|