|
|
""" |
|
|
MemoryAgent for SPARKNET |
|
|
Provides vector memory system using ChromaDB and LangChain |
|
|
Supports episodic, semantic, and stakeholder memory |
|
|
""" |
|
|
|
|
|
from typing import Optional, Dict, Any, List, Literal |
|
|
from datetime import datetime |
|
|
from loguru import logger |
|
|
import json |
|
|
|
|
|
from langchain_chroma import Chroma |
|
|
from langchain_core.documents import Document |
|
|
|
|
|
from .base_agent import BaseAgent, Task, Message |
|
|
from ..llm.langchain_ollama_client import LangChainOllamaClient |
|
|
from ..workflow.langgraph_state import ScenarioType, TaskStatus |
|
|
|
|
|
|
|
|
MemoryType = Literal["episodic", "semantic", "stakeholders", "all"] |
|
|
|
|
|
|
|
|
class MemoryAgent(BaseAgent): |
|
|
""" |
|
|
Vector memory system using ChromaDB and LangChain. |
|
|
Stores and retrieves context for agent decision-making. |
|
|
|
|
|
Three collections: |
|
|
- episodic_memory: Past workflow executions, outcomes, lessons learned |
|
|
- semantic_memory: Domain knowledge (patents, legal frameworks, market data) |
|
|
- stakeholder_profiles: Researcher and industry partner profiles |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
llm_client: LangChainOllamaClient, |
|
|
persist_directory: str = "data/vector_store", |
|
|
memory_agent: Optional['MemoryAgent'] = None, |
|
|
): |
|
|
""" |
|
|
Initialize MemoryAgent with ChromaDB collections. |
|
|
|
|
|
Args: |
|
|
llm_client: LangChain Ollama client for embeddings |
|
|
persist_directory: Directory to persist ChromaDB data |
|
|
memory_agent: Not used (for interface compatibility) |
|
|
""" |
|
|
self.llm_client = llm_client |
|
|
self.persist_directory = persist_directory |
|
|
|
|
|
|
|
|
self.embeddings = llm_client.get_embeddings() |
|
|
|
|
|
|
|
|
self._initialize_collections() |
|
|
|
|
|
|
|
|
self.name = "MemoryAgent" |
|
|
self.description = "Vector memory and context retrieval" |
|
|
|
|
|
logger.info(f"Initialized MemoryAgent with ChromaDB at {persist_directory}") |
|
|
|
|
|
def _initialize_collections(self): |
|
|
"""Initialize three ChromaDB collections.""" |
|
|
try: |
|
|
|
|
|
self.episodic_memory = Chroma( |
|
|
collection_name="episodic_memory", |
|
|
embedding_function=self.embeddings, |
|
|
persist_directory=f"{self.persist_directory}/episodic" |
|
|
) |
|
|
logger.debug("Initialized episodic_memory collection") |
|
|
|
|
|
|
|
|
self.semantic_memory = Chroma( |
|
|
collection_name="semantic_memory", |
|
|
embedding_function=self.embeddings, |
|
|
persist_directory=f"{self.persist_directory}/semantic" |
|
|
) |
|
|
logger.debug("Initialized semantic_memory collection") |
|
|
|
|
|
|
|
|
self.stakeholder_profiles = Chroma( |
|
|
collection_name="stakeholder_profiles", |
|
|
embedding_function=self.embeddings, |
|
|
persist_directory=f"{self.persist_directory}/stakeholders" |
|
|
) |
|
|
logger.debug("Initialized stakeholder_profiles collection") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize ChromaDB collections: {e}") |
|
|
raise |
|
|
|
|
|
async def process_task(self, task: Task) -> Task: |
|
|
""" |
|
|
Process memory-related task. |
|
|
|
|
|
Args: |
|
|
task: Task with memory operation |
|
|
|
|
|
Returns: |
|
|
Updated task with results |
|
|
""" |
|
|
logger.info(f"MemoryAgent processing task: {task.id}") |
|
|
task.status = "in_progress" |
|
|
|
|
|
try: |
|
|
operation = task.metadata.get('operation') if task.metadata else None |
|
|
|
|
|
if operation == 'store_episode': |
|
|
|
|
|
episode_data = task.metadata.get('episode_data', {}) |
|
|
await self.store_episode(**episode_data) |
|
|
task.result = {"stored": True} |
|
|
|
|
|
elif operation == 'retrieve_context': |
|
|
|
|
|
query = task.metadata.get('query', '') |
|
|
context_type = task.metadata.get('context_type', 'all') |
|
|
top_k = task.metadata.get('top_k', 3) |
|
|
|
|
|
results = await self.retrieve_relevant_context( |
|
|
query=query, |
|
|
context_type=context_type, |
|
|
top_k=top_k |
|
|
) |
|
|
task.result = {"contexts": results} |
|
|
|
|
|
elif operation == 'store_knowledge': |
|
|
|
|
|
documents = task.metadata.get('documents', []) |
|
|
metadatas = task.metadata.get('metadatas', []) |
|
|
category = task.metadata.get('category', 'general') |
|
|
|
|
|
await self.store_knowledge(documents, metadatas, category) |
|
|
task.result = {"stored": len(documents)} |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown memory operation: {operation}") |
|
|
|
|
|
task.status = "completed" |
|
|
logger.info(f"Memory operation completed: {operation}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Memory operation failed: {e}") |
|
|
task.status = "failed" |
|
|
task.error = str(e) |
|
|
|
|
|
return task |
|
|
|
|
|
async def store_episode( |
|
|
self, |
|
|
task_id: str, |
|
|
task_description: str, |
|
|
scenario: ScenarioType, |
|
|
workflow_steps: List[Dict], |
|
|
outcome: Dict, |
|
|
quality_score: float, |
|
|
execution_time: Optional[float] = None, |
|
|
iterations_used: Optional[int] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Store a completed workflow execution for learning. |
|
|
|
|
|
Args: |
|
|
task_id: Unique task identifier |
|
|
task_description: Natural language task description |
|
|
scenario: VISTA scenario type |
|
|
workflow_steps: List of subtasks executed |
|
|
outcome: Final output and results |
|
|
quality_score: Quality score from validation (0.0-1.0) |
|
|
execution_time: Total execution time in seconds |
|
|
iterations_used: Number of refinement iterations |
|
|
""" |
|
|
try: |
|
|
|
|
|
content = f""" |
|
|
Task: {task_description} |
|
|
Scenario: {scenario.value if hasattr(scenario, 'value') else scenario} |
|
|
Quality Score: {quality_score:.2f} |
|
|
Steps: {len(workflow_steps)} |
|
|
Outcome: {json.dumps(outcome, indent=2)[:500]} |
|
|
""" |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"task_id": task_id, |
|
|
"scenario": scenario.value if hasattr(scenario, 'value') else str(scenario), |
|
|
"quality_score": float(quality_score), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"num_steps": len(workflow_steps), |
|
|
"execution_time": execution_time or 0.0, |
|
|
"iterations": iterations_used or 0, |
|
|
"success": quality_score >= 0.85, |
|
|
} |
|
|
|
|
|
|
|
|
document = Document( |
|
|
page_content=content, |
|
|
metadata=metadata |
|
|
) |
|
|
|
|
|
|
|
|
self.episodic_memory.add_documents([document]) |
|
|
|
|
|
logger.info(f"Stored episode: {task_id} (score: {quality_score:.2f})") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to store episode: {e}") |
|
|
raise |
|
|
|
|
|
async def retrieve_relevant_context( |
|
|
self, |
|
|
query: str, |
|
|
context_type: MemoryType = "episodic", |
|
|
top_k: int = 3, |
|
|
scenario_filter: Optional[ScenarioType] = None, |
|
|
min_quality_score: Optional[float] = None, |
|
|
) -> List[Document]: |
|
|
""" |
|
|
Semantic search across specified memory type. |
|
|
|
|
|
Args: |
|
|
query: Natural language query |
|
|
context_type: Memory type to search |
|
|
top_k: Number of results to return |
|
|
scenario_filter: Filter by VISTA scenario |
|
|
min_quality_score: Minimum quality score for episodes |
|
|
|
|
|
Returns: |
|
|
List of Document objects with content and metadata |
|
|
""" |
|
|
try: |
|
|
results = [] |
|
|
|
|
|
|
|
|
|
|
|
where_filter = None |
|
|
if scenario_filter and min_quality_score is not None: |
|
|
where_filter = { |
|
|
"$and": [ |
|
|
{"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)}, |
|
|
{"quality_score": {"$gte": min_quality_score}} |
|
|
] |
|
|
} |
|
|
elif scenario_filter: |
|
|
where_filter = {"scenario": scenario_filter.value if hasattr(scenario_filter, 'value') else str(scenario_filter)} |
|
|
elif min_quality_score is not None: |
|
|
where_filter = {"quality_score": {"$gte": min_quality_score}} |
|
|
|
|
|
|
|
|
if context_type == "episodic" or context_type == "all": |
|
|
episodic_results = self.episodic_memory.similarity_search( |
|
|
query=query, |
|
|
k=top_k, |
|
|
filter=where_filter if where_filter else None |
|
|
) |
|
|
results.extend(episodic_results) |
|
|
logger.debug(f"Found {len(episodic_results)} episodic memories") |
|
|
|
|
|
if context_type == "semantic" or context_type == "all": |
|
|
semantic_results = self.semantic_memory.similarity_search( |
|
|
query=query, |
|
|
k=top_k |
|
|
) |
|
|
results.extend(semantic_results) |
|
|
logger.debug(f"Found {len(semantic_results)} semantic memories") |
|
|
|
|
|
if context_type == "stakeholders" or context_type == "all": |
|
|
stakeholder_results = self.stakeholder_profiles.similarity_search( |
|
|
query=query, |
|
|
k=top_k |
|
|
) |
|
|
results.extend(stakeholder_results) |
|
|
logger.debug(f"Found {len(stakeholder_results)} stakeholder profiles") |
|
|
|
|
|
|
|
|
unique_results = list({doc.page_content: doc for doc in results}.values()) |
|
|
return unique_results[:top_k] |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to retrieve context: {e}") |
|
|
return [] |
|
|
|
|
|
async def store_knowledge( |
|
|
self, |
|
|
documents: List[str], |
|
|
metadatas: List[Dict], |
|
|
category: str, |
|
|
) -> None: |
|
|
""" |
|
|
Store domain knowledge in semantic memory. |
|
|
|
|
|
Args: |
|
|
documents: List of knowledge documents (text) |
|
|
metadatas: List of metadata dicts |
|
|
category: Knowledge category |
|
|
|
|
|
Categories: |
|
|
- "patent_templates": Common patent structures |
|
|
- "legal_frameworks": GDPR, Law 25 regulations |
|
|
- "market_data": Industry sectors, trends |
|
|
- "best_practices": Successful valorization strategies |
|
|
""" |
|
|
try: |
|
|
|
|
|
docs = [] |
|
|
for i, (text, metadata) in enumerate(zip(documents, metadatas)): |
|
|
|
|
|
metadata['category'] = category |
|
|
metadata['timestamp'] = datetime.now().isoformat() |
|
|
metadata['doc_id'] = f"{category}_{i}" |
|
|
|
|
|
doc = Document( |
|
|
page_content=text, |
|
|
metadata=metadata |
|
|
) |
|
|
docs.append(doc) |
|
|
|
|
|
|
|
|
self.semantic_memory.add_documents(docs) |
|
|
|
|
|
logger.info(f"Stored {len(docs)} knowledge documents in category: {category}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to store knowledge: {e}") |
|
|
raise |
|
|
|
|
|
async def store_stakeholder_profile( |
|
|
self, |
|
|
name: str, |
|
|
profile: Dict, |
|
|
categories: List[str], |
|
|
) -> None: |
|
|
""" |
|
|
Store researcher or industry partner profile. |
|
|
|
|
|
Args: |
|
|
name: Stakeholder name |
|
|
profile: Profile data |
|
|
categories: List of categories (expertise areas) |
|
|
|
|
|
Profile includes: |
|
|
- expertise: List of expertise areas |
|
|
- interests: Research interests |
|
|
- collaborations: Past collaborations |
|
|
- technologies: Technology domains |
|
|
- location: Geographic location |
|
|
- contact: Contact information |
|
|
""" |
|
|
try: |
|
|
|
|
|
content = f""" |
|
|
Name: {name} |
|
|
Expertise: {', '.join(profile.get('expertise', []))} |
|
|
Interests: {', '.join(profile.get('interests', []))} |
|
|
Technologies: {', '.join(profile.get('technologies', []))} |
|
|
Location: {profile.get('location', 'Unknown')} |
|
|
Past Collaborations: {profile.get('collaborations', 'None listed')} |
|
|
""" |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"name": name, |
|
|
"categories": ", ".join(categories), |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"location": profile.get('location', 'Unknown'), |
|
|
"num_expertise": len(profile.get('expertise', [])), |
|
|
} |
|
|
|
|
|
|
|
|
metadata['profile'] = json.dumps(profile) |
|
|
|
|
|
|
|
|
document = Document( |
|
|
page_content=content, |
|
|
metadata=metadata |
|
|
) |
|
|
|
|
|
|
|
|
self.stakeholder_profiles.add_documents([document]) |
|
|
|
|
|
logger.info(f"Stored stakeholder profile: {name}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to store stakeholder profile: {e}") |
|
|
raise |
|
|
|
|
|
async def learn_from_feedback( |
|
|
self, |
|
|
task_id: str, |
|
|
feedback: str, |
|
|
updated_score: Optional[float] = None, |
|
|
) -> None: |
|
|
""" |
|
|
Update episodic memory with user feedback. |
|
|
Mark successful strategies for reuse. |
|
|
|
|
|
Args: |
|
|
task_id: Task identifier |
|
|
feedback: User feedback text |
|
|
updated_score: Updated quality score after feedback |
|
|
""" |
|
|
try: |
|
|
|
|
|
results = self.episodic_memory.similarity_search( |
|
|
query=task_id, |
|
|
k=1, |
|
|
filter={"task_id": task_id} |
|
|
) |
|
|
|
|
|
if results: |
|
|
logger.info(f"Found episode {task_id} for feedback update") |
|
|
|
|
|
|
|
|
original = results[0] |
|
|
content = f"{original.page_content}\n\nUser Feedback: {feedback}" |
|
|
|
|
|
metadata = original.metadata.copy() |
|
|
if updated_score is not None: |
|
|
metadata['quality_score'] = updated_score |
|
|
metadata['has_feedback'] = True |
|
|
metadata['feedback_timestamp'] = datetime.now().isoformat() |
|
|
|
|
|
|
|
|
doc = Document(page_content=content, metadata=metadata) |
|
|
self.episodic_memory.add_documents([doc]) |
|
|
|
|
|
logger.info(f"Updated episode {task_id} with feedback") |
|
|
else: |
|
|
logger.warning(f"Episode {task_id} not found for feedback") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to learn from feedback: {e}") |
|
|
|
|
|
async def get_similar_episodes( |
|
|
self, |
|
|
task_description: str, |
|
|
scenario: Optional[ScenarioType] = None, |
|
|
min_quality_score: float = 0.8, |
|
|
top_k: int = 3, |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Find similar past episodes for learning. |
|
|
|
|
|
Args: |
|
|
task_description: Current task description |
|
|
scenario: Optional scenario filter |
|
|
min_quality_score: Minimum quality threshold |
|
|
top_k: Number of results |
|
|
|
|
|
Returns: |
|
|
List of episode dictionaries with metadata |
|
|
""" |
|
|
results = await self.retrieve_relevant_context( |
|
|
query=task_description, |
|
|
context_type="episodic", |
|
|
top_k=top_k, |
|
|
scenario_filter=scenario, |
|
|
min_quality_score=min_quality_score |
|
|
) |
|
|
|
|
|
episodes = [] |
|
|
for doc in results: |
|
|
episodes.append({ |
|
|
"content": doc.page_content, |
|
|
"metadata": doc.metadata |
|
|
}) |
|
|
|
|
|
return episodes |
|
|
|
|
|
async def get_domain_knowledge( |
|
|
self, |
|
|
query: str, |
|
|
category: Optional[str] = None, |
|
|
top_k: int = 3, |
|
|
) -> List[Document]: |
|
|
""" |
|
|
Retrieve domain knowledge from semantic memory. |
|
|
|
|
|
Args: |
|
|
query: Knowledge query |
|
|
category: Optional category filter |
|
|
top_k: Number of results |
|
|
|
|
|
Returns: |
|
|
List of knowledge documents |
|
|
""" |
|
|
where_filter = {"category": category} if category else None |
|
|
|
|
|
results = self.semantic_memory.similarity_search( |
|
|
query=query, |
|
|
k=top_k, |
|
|
filter=where_filter |
|
|
) |
|
|
|
|
|
return results |
|
|
|
|
|
async def find_matching_stakeholders( |
|
|
self, |
|
|
requirements: str, |
|
|
categories: Optional[List[str]] = None, |
|
|
location: Optional[str] = None, |
|
|
top_k: int = 5, |
|
|
) -> List[Dict]: |
|
|
""" |
|
|
Find stakeholders matching requirements. |
|
|
|
|
|
Args: |
|
|
requirements: Description of needed expertise/capabilities |
|
|
categories: Optional category filters |
|
|
location: Optional location filter |
|
|
top_k: Number of matches |
|
|
|
|
|
Returns: |
|
|
List of matching stakeholder profiles |
|
|
""" |
|
|
|
|
|
where_filter = {} |
|
|
if location: |
|
|
where_filter["location"] = location |
|
|
|
|
|
results = self.stakeholder_profiles.similarity_search( |
|
|
query=requirements, |
|
|
k=top_k, |
|
|
filter=where_filter if where_filter else None |
|
|
) |
|
|
|
|
|
stakeholders = [] |
|
|
for doc in results: |
|
|
profile_data = json.loads(doc.metadata.get('profile', '{}')) |
|
|
stakeholders.append({ |
|
|
"name": doc.metadata.get('name'), |
|
|
"profile": profile_data, |
|
|
"match_text": doc.page_content, |
|
|
"metadata": doc.metadata |
|
|
}) |
|
|
|
|
|
return stakeholders |
|
|
|
|
|
def get_collection_stats(self) -> Dict[str, int]: |
|
|
""" |
|
|
Get statistics about memory collections. |
|
|
|
|
|
Returns: |
|
|
Dictionary with collection counts |
|
|
""" |
|
|
try: |
|
|
stats = { |
|
|
"episodic_count": self.episodic_memory._collection.count(), |
|
|
"semantic_count": self.semantic_memory._collection.count(), |
|
|
"stakeholders_count": self.stakeholder_profiles._collection.count(), |
|
|
} |
|
|
return stats |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get collection stats: {e}") |
|
|
return {"episodic_count": 0, "semantic_count": 0, "stakeholders_count": 0} |
|
|
|
|
|
|
|
|
|
|
|
def create_memory_agent( |
|
|
llm_client: LangChainOllamaClient, |
|
|
persist_directory: str = "data/vector_store", |
|
|
) -> MemoryAgent: |
|
|
""" |
|
|
Create a MemoryAgent instance. |
|
|
|
|
|
Args: |
|
|
llm_client: LangChain Ollama client |
|
|
persist_directory: Directory for ChromaDB persistence |
|
|
|
|
|
Returns: |
|
|
MemoryAgent instance |
|
|
""" |
|
|
return MemoryAgent( |
|
|
llm_client=llm_client, |
|
|
persist_directory=persist_directory |
|
|
) |
|
|
|