SPARKNET / src /agents /memory_agent.py
MHamdan's picture
Initial commit: SPARKNET framework
a9dc537
"""
MemoryAgent for SPARKNET
Provides vector memory system using ChromaDB and LangChain
Supports episodic, semantic, and stakeholder memory
"""
from typing import Optional, Dict, Any, List, Literal
from datetime import datetime
from loguru import logger
import json
from langchain_chroma import Chroma
from langchain_core.documents import Document
from .base_agent import BaseAgent, Task, Message
from ..llm.langchain_ollama_client import LangChainOllamaClient
from ..workflow.langgraph_state import ScenarioType, TaskStatus
MemoryType = Literal["episodic", "semantic", "stakeholders", "all"]
class MemoryAgent(BaseAgent):
"""
Vector memory system using ChromaDB and LangChain.
Stores and retrieves context for agent decision-making.
Three collections:
- episodic_memory: Past workflow executions, outcomes, lessons learned
- semantic_memory: Domain knowledge (patents, legal frameworks, market data)
- stakeholder_profiles: Researcher and industry partner profiles
"""
def __init__(
self,
llm_client: LangChainOllamaClient,
persist_directory: str = "data/vector_store",
memory_agent: Optional['MemoryAgent'] = None,
):
"""
Initialize MemoryAgent with ChromaDB collections.
Args:
llm_client: LangChain Ollama client for embeddings
persist_directory: Directory to persist ChromaDB data
memory_agent: Not used (for interface compatibility)
"""
self.llm_client = llm_client
self.persist_directory = persist_directory
# Get embeddings from LangChain client
self.embeddings = llm_client.get_embeddings()
# Initialize ChromaDB collections
self._initialize_collections()
# Store for backward compatibility
self.name = "MemoryAgent"
self.description = "Vector memory and context retrieval"
logger.info(f"Initialized MemoryAgent with ChromaDB at {persist_directory}")
def _initialize_collections(self):
"""Initialize three ChromaDB collections."""
try:
# Episodic memory: Past workflow executions
self.episodic_memory = Chroma(
collection_name="episodic_memory",
embedding_function=self.embeddings,
persist_directory=f"{self.persist_directory}/episodic"
)
logger.debug("Initialized episodic_memory collection")
# Semantic memory: Domain knowledge
self.semantic_memory = Chroma(
collection_name="semantic_memory",
embedding_function=self.embeddings,
persist_directory=f"{self.persist_directory}/semantic"
)
logger.debug("Initialized semantic_memory collection")
# Stakeholder profiles
self.stakeholder_profiles = Chroma(
collection_name="stakeholder_profiles",
embedding_function=self.embeddings,
persist_directory=f"{self.persist_directory}/stakeholders"
)
logger.debug("Initialized stakeholder_profiles collection")
except Exception as e:
logger.error(f"Failed to initialize ChromaDB collections: {e}")
raise
async def process_task(self, task: Task) -> Task:
"""
Process memory-related task.
Args:
task: Task with memory operation
Returns:
Updated task with results
"""
logger.info(f"MemoryAgent processing task: {task.id}")
task.status = "in_progress"
try:
operation = task.metadata.get('operation') if task.metadata else None
if operation == 'store_episode':
# Store episode
episode_data = task.metadata.get('episode_data', {})
await self.store_episode(**episode_data)
task.result = {"stored": True}
elif operation == 'retrieve_context':
# Retrieve context
query = task.metadata.get('query', '')
context_type = task.metadata.get('context_type', 'all')
top_k = task.metadata.get('top_k', 3)
results = await self.retrieve_relevant_context(
query=query,
context_type=context_type,
top_k=top_k
)
task.result = {"contexts": results}
elif operation == 'store_knowledge':
# Store knowledge
documents = task.metadata.get('documents', [])
metadatas = task.metadata.get('metadatas', [])
category = task.metadata.get('category', 'general')
await self.store_knowledge(documents, metadatas, category)
task.result = {"stored": len(documents)}
else:
raise ValueError(f"Unknown memory operation: {operation}")
task.status = "completed"
logger.info(f"Memory operation completed: {operation}")
except Exception as e:
logger.error(f"Memory operation failed: {e}")
task.status = "failed"
task.error = str(e)
return task
async def store_episode(
self,
task_id: str,
task_description: str,
scenario: ScenarioType,
workflow_steps: List[Dict],
outcome: Dict,
quality_score: float,
execution_time: Optional[float] = None,
iterations_used: Optional[int] = None,
) -> None:
"""
Store a completed workflow execution for learning.
Args:
task_id: Unique task identifier
task_description: Natural language task description
scenario: VISTA scenario type
workflow_steps: List of subtasks executed
outcome: Final output and results
quality_score: Quality score from validation (0.0-1.0)
execution_time: Total execution time in seconds
iterations_used: Number of refinement iterations
"""
try:
# Create document content
content = f"""
Task: {task_description}
Scenario: {scenario.value if hasattr(scenario, 'value') else scenario}
Quality Score: {quality_score:.2f}
Steps: {len(workflow_steps)}
Outcome: {json.dumps(outcome, indent=2)[:500]}
"""
# Create metadata
metadata = {
"task_id": task_id,
"scenario": scenario.value if hasattr(scenario, 'value') else str(scenario),
"quality_score": float(quality_score),
"timestamp": datetime.now().isoformat(),
"num_steps": len(workflow_steps),
"execution_time": execution_time or 0.0,
"iterations": iterations_used or 0,
"success": quality_score >= 0.85,
}
# Create document
document = Document(
page_content=content,
metadata=metadata
)
# Add to episodic memory
self.episodic_memory.add_documents([document])
logger.info(f"Stored episode: {task_id} (score: {quality_score:.2f})")
except Exception as e:
logger.error(f"Failed to store episode: {e}")
raise
async def retrieve_relevant_context(
self,
query: str,
context_type: MemoryType = "episodic",
top_k: int = 3,
scenario_filter: Optional[ScenarioType] = None,
min_quality_score: Optional[float] = None,
) -> List[Document]:
"""
Semantic search across specified memory type.
Args:
query: Natural language query
context_type: Memory type to search
top_k: Number of results to return
scenario_filter: Filter by VISTA scenario
min_quality_score: Minimum quality score for episodes
Returns:
List of Document objects with content and metadata
"""
try:
results = []
# Build filter if needed
# Note: ChromaDB requires compound filters with $and operator
where_filter = None
if scenario_filter and min_quality_score is not None:
where_filter = {
"$and": [
{"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)},
{"quality_score": {"$gte": min_quality_score}}
]
}
elif scenario_filter:
where_filter = {"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)}
elif min_quality_score is not None:
where_filter = {"quality_score": {"$gte": min_quality_score}}
# Search appropriate collection(s)
if context_type == "episodic" or context_type == "all":
episodic_results = self.episodic_memory.similarity_search(
query=query,
k=top_k,
filter=where_filter if where_filter else None
)
results.extend(episodic_results)
logger.debug(f"Found {len(episodic_results)} episodic memories")
if context_type == "semantic" or context_type == "all":
semantic_results = self.semantic_memory.similarity_search(
query=query,
k=top_k
)
results.extend(semantic_results)
logger.debug(f"Found {len(semantic_results)} semantic memories")
if context_type == "stakeholders" or context_type == "all":
stakeholder_results = self.stakeholder_profiles.similarity_search(
query=query,
k=top_k
)
results.extend(stakeholder_results)
logger.debug(f"Found {len(stakeholder_results)} stakeholder profiles")
# Deduplicate and limit
unique_results = list({doc.page_content: doc for doc in results}.values())
return unique_results[:top_k]
except Exception as e:
logger.error(f"Failed to retrieve context: {e}")
return []
async def store_knowledge(
self,
documents: List[str],
metadatas: List[Dict],
category: str,
) -> None:
"""
Store domain knowledge in semantic memory.
Args:
documents: List of knowledge documents (text)
metadatas: List of metadata dicts
category: Knowledge category
Categories:
- "patent_templates": Common patent structures
- "legal_frameworks": GDPR, Law 25 regulations
- "market_data": Industry sectors, trends
- "best_practices": Successful valorization strategies
"""
try:
# Create documents with metadata
docs = []
for i, (text, metadata) in enumerate(zip(documents, metadatas)):
# Add category to metadata
metadata['category'] = category
metadata['timestamp'] = datetime.now().isoformat()
metadata['doc_id'] = f"{category}_{i}"
doc = Document(
page_content=text,
metadata=metadata
)
docs.append(doc)
# Add to semantic memory
self.semantic_memory.add_documents(docs)
logger.info(f"Stored {len(docs)} knowledge documents in category: {category}")
except Exception as e:
logger.error(f"Failed to store knowledge: {e}")
raise
async def store_stakeholder_profile(
self,
name: str,
profile: Dict,
categories: List[str],
) -> None:
"""
Store researcher or industry partner profile.
Args:
name: Stakeholder name
profile: Profile data
categories: List of categories (expertise areas)
Profile includes:
- expertise: List of expertise areas
- interests: Research interests
- collaborations: Past collaborations
- technologies: Technology domains
- location: Geographic location
- contact: Contact information
"""
try:
# Create profile text
content = f"""
Name: {name}
Expertise: {', '.join(profile.get('expertise', []))}
Interests: {', '.join(profile.get('interests', []))}
Technologies: {', '.join(profile.get('technologies', []))}
Location: {profile.get('location', 'Unknown')}
Past Collaborations: {profile.get('collaborations', 'None listed')}
"""
# Create metadata (ChromaDB only accepts str, int, float, bool, None)
metadata = {
"name": name,
"categories": ", ".join(categories), # Convert list to string
"timestamp": datetime.now().isoformat(),
"location": profile.get('location', 'Unknown'),
"num_expertise": len(profile.get('expertise', [])),
}
# Add full profile to metadata as JSON string (for retrieval)
metadata['profile'] = json.dumps(profile)
# Create document
document = Document(
page_content=content,
metadata=metadata
)
# Add to stakeholder collection
self.stakeholder_profiles.add_documents([document])
logger.info(f"Stored stakeholder profile: {name}")
except Exception as e:
logger.error(f"Failed to store stakeholder profile: {e}")
raise
async def learn_from_feedback(
self,
task_id: str,
feedback: str,
updated_score: Optional[float] = None,
) -> None:
"""
Update episodic memory with user feedback.
Mark successful strategies for reuse.
Args:
task_id: Task identifier
feedback: User feedback text
updated_score: Updated quality score after feedback
"""
try:
# Search for existing episode
results = self.episodic_memory.similarity_search(
query=task_id,
k=1,
filter={"task_id": task_id}
)
if results:
logger.info(f"Found episode {task_id} for feedback update")
# Store feedback as new episode variant
original = results[0]
content = f"{original.page_content}\n\nUser Feedback: {feedback}"
metadata = original.metadata.copy()
if updated_score is not None:
metadata['quality_score'] = updated_score
metadata['has_feedback'] = True
metadata['feedback_timestamp'] = datetime.now().isoformat()
# Add updated version
doc = Document(page_content=content, metadata=metadata)
self.episodic_memory.add_documents([doc])
logger.info(f"Updated episode {task_id} with feedback")
else:
logger.warning(f"Episode {task_id} not found for feedback")
except Exception as e:
logger.error(f"Failed to learn from feedback: {e}")
async def get_similar_episodes(
self,
task_description: str,
scenario: Optional[ScenarioType] = None,
min_quality_score: float = 0.8,
top_k: int = 3,
) -> List[Dict]:
"""
Find similar past episodes for learning.
Args:
task_description: Current task description
scenario: Optional scenario filter
min_quality_score: Minimum quality threshold
top_k: Number of results
Returns:
List of episode dictionaries with metadata
"""
results = await self.retrieve_relevant_context(
query=task_description,
context_type="episodic",
top_k=top_k,
scenario_filter=scenario,
min_quality_score=min_quality_score
)
episodes = []
for doc in results:
episodes.append({
"content": doc.page_content,
"metadata": doc.metadata
})
return episodes
async def get_domain_knowledge(
self,
query: str,
category: Optional[str] = None,
top_k: int = 3,
) -> List[Document]:
"""
Retrieve domain knowledge from semantic memory.
Args:
query: Knowledge query
category: Optional category filter
top_k: Number of results
Returns:
List of knowledge documents
"""
where_filter = {"category": category} if category else None
results = self.semantic_memory.similarity_search(
query=query,
k=top_k,
filter=where_filter
)
return results
async def find_matching_stakeholders(
self,
requirements: str,
categories: Optional[List[str]] = None,
location: Optional[str] = None,
top_k: int = 5,
) -> List[Dict]:
"""
Find stakeholders matching requirements.
Args:
requirements: Description of needed expertise/capabilities
categories: Optional category filters
location: Optional location filter
top_k: Number of matches
Returns:
List of matching stakeholder profiles
"""
# Build filter
where_filter = {}
if location:
where_filter["location"] = location
results = self.stakeholder_profiles.similarity_search(
query=requirements,
k=top_k,
filter=where_filter if where_filter else None
)
stakeholders = []
for doc in results:
profile_data = json.loads(doc.metadata.get('profile', '{}'))
stakeholders.append({
"name": doc.metadata.get('name'),
"profile": profile_data,
"match_text": doc.page_content,
"metadata": doc.metadata
})
return stakeholders
def get_collection_stats(self) -> Dict[str, int]:
"""
Get statistics about memory collections.
Returns:
Dictionary with collection counts
"""
try:
stats = {
"episodic_count": self.episodic_memory._collection.count(),
"semantic_count": self.semantic_memory._collection.count(),
"stakeholders_count": self.stakeholder_profiles._collection.count(),
}
return stats
except Exception as e:
logger.error(f"Failed to get collection stats: {e}")
return {"episodic_count": 0, "semantic_count": 0, "stakeholders_count": 0}
# Convenience function
def create_memory_agent(
llm_client: LangChainOllamaClient,
persist_directory: str = "data/vector_store",
) -> MemoryAgent:
"""
Create a MemoryAgent instance.
Args:
llm_client: LangChain Ollama client
persist_directory: Directory for ChromaDB persistence
Returns:
MemoryAgent instance
"""
return MemoryAgent(
llm_client=llm_client,
persist_directory=persist_directory
)