|
|
import time
|
|
|
from typing import Dict, List, Any, Optional, Union
|
|
|
|
|
|
class MemoryManager:
|
|
|
"""Memory management for the autonomous AI agent
|
|
|
|
|
|
This module provides capabilities for:
|
|
|
1. Storing and retrieving conversation history
|
|
|
2. Managing context windows
|
|
|
3. Implementing forgetting mechanisms
|
|
|
4. Prioritizing important information
|
|
|
"""
|
|
|
|
|
|
def __init__(self, max_history_length: int = 20):
|
|
|
"""Initialize the memory manager
|
|
|
|
|
|
Args:
|
|
|
max_history_length: Maximum number of conversation turns to store
|
|
|
"""
|
|
|
self.conversation_history = []
|
|
|
self.max_history_length = max_history_length
|
|
|
self.important_facts = []
|
|
|
self.max_facts = 50
|
|
|
self.session_data = {}
|
|
|
|
|
|
def add_message(self, role: str, content: str) -> None:
|
|
|
"""Add a message to the conversation history
|
|
|
|
|
|
Args:
|
|
|
role: The role of the message sender (user or assistant)
|
|
|
content: The content of the message
|
|
|
"""
|
|
|
self.conversation_history.append({
|
|
|
"role": role,
|
|
|
"content": content,
|
|
|
"timestamp": time.time()
|
|
|
})
|
|
|
|
|
|
|
|
|
if len(self.conversation_history) > self.max_history_length * 2:
|
|
|
self.conversation_history = self.conversation_history[-self.max_history_length*2:]
|
|
|
|
|
|
def get_conversation_history(self, max_turns: Optional[int] = None) -> List[Dict[str, Any]]:
|
|
|
"""Get the conversation history
|
|
|
|
|
|
Args:
|
|
|
max_turns: Maximum number of turns to retrieve (None for all)
|
|
|
|
|
|
Returns:
|
|
|
List of conversation messages
|
|
|
"""
|
|
|
if max_turns is None:
|
|
|
return self.conversation_history
|
|
|
else:
|
|
|
|
|
|
max_messages = max_turns * 2
|
|
|
return self.conversation_history[-max_messages:]
|
|
|
|
|
|
def format_conversation_for_prompt(self, max_turns: Optional[int] = None) -> str:
|
|
|
"""Format the conversation history for inclusion in a prompt
|
|
|
|
|
|
Args:
|
|
|
max_turns: Maximum number of turns to include
|
|
|
|
|
|
Returns:
|
|
|
Formatted conversation string
|
|
|
"""
|
|
|
history = self.get_conversation_history(max_turns)
|
|
|
formatted = ""
|
|
|
|
|
|
for msg in history:
|
|
|
formatted += f"{msg['role']}: {msg['content']}\n"
|
|
|
|
|
|
return formatted
|
|
|
|
|
|
def add_important_fact(self, fact: str, source: str) -> None:
|
|
|
"""Add an important fact to memory
|
|
|
|
|
|
Args:
|
|
|
fact: The important fact to remember
|
|
|
source: The source of the fact (e.g., user, inference)
|
|
|
"""
|
|
|
self.important_facts.append({
|
|
|
"fact": fact,
|
|
|
"source": source,
|
|
|
"timestamp": time.time()
|
|
|
})
|
|
|
|
|
|
|
|
|
if len(self.important_facts) > self.max_facts:
|
|
|
self.important_facts = self.important_facts[-self.max_facts:]
|
|
|
|
|
|
def get_important_facts(self) -> List[Dict[str, Any]]:
|
|
|
"""Get the list of important facts
|
|
|
|
|
|
Returns:
|
|
|
List of important facts
|
|
|
"""
|
|
|
return self.important_facts
|
|
|
|
|
|
def format_facts_for_prompt(self) -> str:
|
|
|
"""Format important facts for inclusion in a prompt
|
|
|
|
|
|
Returns:
|
|
|
Formatted facts string
|
|
|
"""
|
|
|
if not self.important_facts:
|
|
|
return ""
|
|
|
|
|
|
formatted = "Important information I know about the user and context:\n"
|
|
|
|
|
|
|
|
|
sorted_facts = sorted(self.important_facts, key=lambda x: x.get('timestamp', 0), reverse=True)
|
|
|
|
|
|
|
|
|
user_facts = [fact for fact in sorted_facts if fact.get('source') == 'user']
|
|
|
inference_facts = [fact for fact in sorted_facts if fact.get('source') == 'inference']
|
|
|
|
|
|
|
|
|
for i, fact in enumerate(user_facts):
|
|
|
formatted += f"{i+1}. {fact['fact']} (from user)\n"
|
|
|
|
|
|
|
|
|
start_idx = len(user_facts) + 1
|
|
|
for i, fact in enumerate(inference_facts):
|
|
|
formatted += f"{start_idx + i}. {fact['fact']} (inferred)\n"
|
|
|
|
|
|
return formatted
|
|
|
|
|
|
def store_session_data(self, key: str, value: Any) -> None:
|
|
|
"""Store data for the current session
|
|
|
|
|
|
Args:
|
|
|
key: The key to store the data under
|
|
|
value: The data to store
|
|
|
"""
|
|
|
self.session_data[key] = {
|
|
|
"value": value,
|
|
|
"timestamp": time.time()
|
|
|
}
|
|
|
|
|
|
def get_session_data(self, key: str) -> Optional[Any]:
|
|
|
"""Retrieve data from the current session
|
|
|
|
|
|
Args:
|
|
|
key: The key to retrieve data for
|
|
|
|
|
|
Returns:
|
|
|
The stored data, or None if not found
|
|
|
"""
|
|
|
if key in self.session_data:
|
|
|
return self.session_data[key]["value"]
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def clear_conversation_history(self) -> None:
|
|
|
"""Clear the conversation history"""
|
|
|
self.conversation_history = []
|
|
|
|
|
|
def clear_all_memory(self) -> None:
|
|
|
"""Clear all memory (conversation history, facts, and session data)"""
|
|
|
self.conversation_history = []
|
|
|
self.important_facts = []
|
|
|
self.session_data = {}
|
|
|
|
|
|
def get_memory_stats(self) -> Dict[str, Any]:
|
|
|
"""Get statistics about the agent's memory usage
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing memory statistics
|
|
|
"""
|
|
|
return {
|
|
|
"conversation_turns": len(self.conversation_history) // 2,
|
|
|
"important_facts": len(self.important_facts),
|
|
|
"session_data_keys": list(self.session_data.keys()),
|
|
|
"memory_usage": {
|
|
|
"conversation": len(str(self.conversation_history)),
|
|
|
"facts": len(str(self.important_facts)),
|
|
|
"session": len(str(self.session_data))
|
|
|
}
|
|
|
} |