|
|
""" |
|
|
Test MemoryAgent with ChromaDB |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
from src.llm.langchain_ollama_client import get_langchain_client |
|
|
from src.agents.memory_agent import create_memory_agent |
|
|
from src.workflow.langgraph_state import ScenarioType |
|
|
|
|
|
async def test_memory_agent(): |
|
|
print("Testing MemoryAgent with ChromaDB...") |
|
|
print() |
|
|
|
|
|
|
|
|
client = get_langchain_client(default_complexity='standard', enable_monitoring=False) |
|
|
print("β LangChain client initialized") |
|
|
|
|
|
|
|
|
memory = create_memory_agent(llm_client=client) |
|
|
print("β MemoryAgent created") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 1: ChromaDB collections") |
|
|
stats = memory.get_collection_stats() |
|
|
print(f" β Episodic memory: {stats['episodic_count']} episodes") |
|
|
print(f" β Semantic memory: {stats['semantic_count']} documents") |
|
|
print(f" β Stakeholder profiles: {stats['stakeholders_count']} profiles") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 2: Store episode") |
|
|
await memory.store_episode( |
|
|
task_id="test_001", |
|
|
task_description="Analyze patent for commercialization", |
|
|
scenario=ScenarioType.PATENT_WAKEUP, |
|
|
workflow_steps=[ |
|
|
{"id": "step1", "description": "Extract patent claims"}, |
|
|
{"id": "step2", "description": "Identify market opportunities"} |
|
|
], |
|
|
outcome={"success": True, "matches": 3}, |
|
|
quality_score=0.92, |
|
|
execution_time=45.3, |
|
|
iterations_used=1 |
|
|
) |
|
|
print(" β Episode stored successfully") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 3: Store domain knowledge") |
|
|
await memory.store_knowledge( |
|
|
documents=[ |
|
|
"Patents typically include claims, description, drawings, and abstract.", |
|
|
"GDPR requires explicit consent for personal data processing." |
|
|
], |
|
|
metadatas=[ |
|
|
{"source": "patent_guide", "topic": "patent_structure"}, |
|
|
{"source": "gdpr_regulation", "topic": "data_protection"} |
|
|
], |
|
|
category="best_practices" |
|
|
) |
|
|
print(" β Knowledge stored successfully") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 4: Store stakeholder profile") |
|
|
await memory.store_stakeholder_profile( |
|
|
name="Dr. Jane Smith", |
|
|
profile={ |
|
|
"expertise": ["AI", "Drug Discovery", "Machine Learning"], |
|
|
"interests": ["Pharmaceutical AI", "Clinical Trials"], |
|
|
"technologies": ["Neural Networks", "NLP", "Computer Vision"], |
|
|
"location": "Montreal, QC", |
|
|
"collaborations": "Worked with XYZ Corp on AI diagnostics" |
|
|
}, |
|
|
categories=["AI", "Healthcare"] |
|
|
) |
|
|
print(" β Stakeholder profile stored") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 5: Retrieve similar episodes") |
|
|
episodes = await memory.get_similar_episodes( |
|
|
task_description="Patent analysis workflow", |
|
|
scenario=ScenarioType.PATENT_WAKEUP, |
|
|
min_quality_score=0.8, |
|
|
top_k=2 |
|
|
) |
|
|
print(f" β Found {len(episodes)} similar episodes") |
|
|
if episodes: |
|
|
print(f" β Latest episode score: {episodes[0]['metadata'].get('quality_score', 0):.2f}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 6: Retrieve domain knowledge") |
|
|
knowledge = await memory.get_domain_knowledge( |
|
|
query="patent structure and components", |
|
|
category="best_practices", |
|
|
top_k=2 |
|
|
) |
|
|
print(f" β Found {len(knowledge)} knowledge documents") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Test 7: Find matching stakeholders") |
|
|
stakeholders = await memory.find_matching_stakeholders( |
|
|
requirements="AI researcher with drug discovery experience", |
|
|
location="Montreal, QC", |
|
|
top_k=2 |
|
|
) |
|
|
print(f" β Found {len(stakeholders)} matching stakeholders") |
|
|
if stakeholders: |
|
|
print(f" β Top match: {stakeholders[0]['name']}") |
|
|
print() |
|
|
|
|
|
|
|
|
print("Final collection stats:") |
|
|
final_stats = memory.get_collection_stats() |
|
|
print(f" Episodes: {final_stats['episodic_count']}") |
|
|
print(f" Knowledge: {final_stats['semantic_count']}") |
|
|
print(f" Stakeholders: {final_stats['stakeholders_count']}") |
|
|
print() |
|
|
|
|
|
print("β All MemoryAgent tests passed!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(test_memory_agent()) |
|
|
|