| | """ |
| | Enhanced Memory System for GAIA-Ready AI Agent |
| | |
| | This module provides an advanced memory system for the AI agent, |
| | including short-term, long-term, and working memory components, |
| | as well as semantic retrieval capabilities. |
| | """ |
| |
|
| | import os |
| | import json |
| | from typing import List, Dict, Any, Optional, Union |
| | from datetime import datetime |
| | import re |
| | import numpy as np |
| | from collections import defaultdict |
| |
|
| | try: |
| | from sentence_transformers import SentenceTransformer |
| | except ImportError: |
| | import subprocess |
| | subprocess.check_call(["pip", "install", "sentence-transformers"]) |
| | from sentence_transformers import SentenceTransformer |
| |
|
| |
|
| | class EnhancedMemoryManager: |
| | """ |
| | Advanced memory manager for the agent that maintains short-term, long-term, |
| | and working memory with semantic retrieval capabilities. |
| | """ |
| | def __init__(self, use_semantic_search=True): |
| | self.short_term_memory = [] |
| | self.long_term_memory = [] |
| | self.working_memory = {} |
| | self.max_short_term_items = 15 |
| | self.max_long_term_items = 100 |
| | self.use_semantic_search = use_semantic_search |
| | |
| | |
| | if self.use_semantic_search: |
| | try: |
| | self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| | self.memory_embeddings = [] |
| | except Exception as e: |
| | print(f"Warning: Could not initialize semantic search: {str(e)}") |
| | self.use_semantic_search = False |
| | |
| | |
| | self.memory_file = "agent_memory.json" |
| | self.load_memories() |
| | |
| | def add_to_short_term(self, item: Dict[str, Any]) -> None: |
| | """Add an item to short-term memory, maintaining size limit""" |
| | |
| | if "content" not in item: |
| | raise ValueError("Memory item must have 'content' field") |
| | |
| | if "timestamp" not in item: |
| | item["timestamp"] = datetime.now().isoformat() |
| | |
| | if "type" not in item: |
| | item["type"] = "general" |
| | |
| | self.short_term_memory.append(item) |
| | |
| | |
| | if self.use_semantic_search: |
| | try: |
| | content = item.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, len(self.short_term_memory) - 1, "short_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | |
| | if len(self.short_term_memory) > self.max_short_term_items: |
| | removed_item = self.short_term_memory.pop(0) |
| | |
| | if self.use_semantic_search: |
| | self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| | if not (mem_type == "short_term" and idx == 0)] |
| | |
| | self.memory_embeddings = [(emb, idx-1 if mem_type == "short_term" else idx, mem_type) |
| | for emb, idx, mem_type in self.memory_embeddings] |
| | |
| | |
| | self.save_memories() |
| | |
| | def add_to_long_term(self, item: Dict[str, Any]) -> None: |
| | """Add an important item to long-term memory, maintaining size limit""" |
| | |
| | if "content" not in item: |
| | raise ValueError("Memory item must have 'content' field") |
| | |
| | if "timestamp" not in item: |
| | item["timestamp"] = datetime.now().isoformat() |
| | |
| | if "type" not in item: |
| | item["type"] = "general" |
| | |
| | |
| | if "importance" not in item: |
| | |
| | content_length = len(item.get("content", "")) |
| | type_importance = { |
| | "final_answer": 0.9, |
| | "key_fact": 0.8, |
| | "reasoning": 0.7, |
| | "general": 0.5 |
| | } |
| | item["importance"] = min(1.0, (content_length / 1000) * type_importance.get(item["type"], 0.5)) |
| | |
| | self.long_term_memory.append(item) |
| | |
| | |
| | if self.use_semantic_search: |
| | try: |
| | content = item.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, len(self.long_term_memory) - 1, "long_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | |
| | self.long_term_memory.sort(key=lambda x: x.get("importance", 0), reverse=True) |
| | |
| | |
| | if len(self.long_term_memory) > self.max_long_term_items: |
| | |
| | removed_item = self.long_term_memory.pop() |
| | |
| | if self.use_semantic_search: |
| | self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| | if not (mem_type == "long_term" and idx == len(self.long_term_memory))] |
| | |
| | |
| | long_term_embeddings = [] |
| | for i, item in enumerate(self.long_term_memory): |
| | content = item.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | long_term_embeddings.append((embedding, i, "long_term")) |
| | |
| | |
| | self.memory_embeddings = [(emb, idx, mem_type) for emb, idx, mem_type in self.memory_embeddings |
| | if mem_type == "short_term"] + long_term_embeddings |
| | |
| | |
| | self.save_memories() |
| | |
| | def store_in_working_memory(self, key: str, value: Any) -> None: |
| | """Store a value in working memory under the specified key""" |
| | self.working_memory[key] = value |
| | |
| | |
| | def get_from_working_memory(self, key: str) -> Optional[Any]: |
| | """Retrieve a value from working memory by key""" |
| | return self.working_memory.get(key) |
| | |
| | def clear_working_memory(self) -> None: |
| | """Clear the working memory""" |
| | self.working_memory = {} |
| | |
| | def get_relevant_memories(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
| | """ |
| | Retrieve memories relevant to the current query |
| | |
| | Args: |
| | query: The query to find relevant memories for |
| | max_results: Maximum number of results to return |
| | |
| | Returns: |
| | List of relevant memory items |
| | """ |
| | if self.use_semantic_search: |
| | try: |
| | |
| | query_embedding = self.embedding_model.encode(query) |
| | |
| | |
| | similarities = [] |
| | for embedding, idx, mem_type in self.memory_embeddings: |
| | similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding)) |
| | similarities.append((similarity, idx, mem_type)) |
| | |
| | |
| | similarities.sort(reverse=True) |
| | |
| | |
| | relevant_memories = [] |
| | for similarity, idx, mem_type in similarities[:max_results]: |
| | if mem_type == "short_term": |
| | memory = self.short_term_memory[idx] |
| | else: |
| | memory = self.long_term_memory[idx] |
| | |
| | |
| | memory_with_score = memory.copy() |
| | memory_with_score["relevance_score"] = float(similarity) |
| | relevant_memories.append(memory_with_score) |
| | |
| | return relevant_memories |
| | except Exception as e: |
| | print(f"Warning: Semantic search failed: {str(e)}. Falling back to keyword search.") |
| | return self._keyword_search(query, max_results) |
| | else: |
| | return self._keyword_search(query, max_results) |
| | |
| | def _keyword_search(self, query: str, max_results: int = 10) -> List[Dict[str, Any]]: |
| | """ |
| | Fallback keyword-based search for relevant memories |
| | |
| | Args: |
| | query: The query to find relevant memories for |
| | max_results: Maximum number of results to return |
| | |
| | Returns: |
| | List of relevant memory items |
| | """ |
| | relevant_memories = [] |
| | query_keywords = set(re.findall(r'\b\w+\b', query.lower())) |
| | |
| | |
| | def score_memory(memory): |
| | content = memory.get("content", "").lower() |
| | content_words = set(re.findall(r'\b\w+\b', content)) |
| | |
| | |
| | matches = len(query_keywords.intersection(content_words)) |
| | |
| | |
| | type_boost = { |
| | "final_answer": 2.0, |
| | "key_fact": 1.5, |
| | "reasoning": 1.2, |
| | "general": 1.0 |
| | } |
| | |
| | |
| | try: |
| | timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| | now = datetime.now() |
| | hours_ago = (now - timestamp).total_seconds() / 3600 |
| | recency_factor = max(0.5, 1.0 - (hours_ago / 24)) |
| | except: |
| | recency_factor = 0.5 |
| | |
| | |
| | score = matches * type_boost.get(memory.get("type", "general"), 1.0) * recency_factor |
| | |
| | return score |
| | |
| | |
| | scored_memories = [] |
| | |
| | |
| | for memory in self.long_term_memory: |
| | score = score_memory(memory) |
| | if score > 0: |
| | memory_with_score = memory.copy() |
| | memory_with_score["relevance_score"] = score |
| | scored_memories.append((score, memory_with_score)) |
| | |
| | |
| | for memory in self.short_term_memory: |
| | score = score_memory(memory) |
| | if score > 0: |
| | memory_with_score = memory.copy() |
| | memory_with_score["relevance_score"] = score |
| | scored_memories.append((score, memory_with_score)) |
| | |
| | |
| | scored_memories.sort(reverse=True, key=lambda x: x[0]) |
| | relevant_memories = [memory for _, memory in scored_memories[:max_results]] |
| | |
| | return relevant_memories |
| | |
| | def get_memory_summary(self) -> str: |
| | """Get a summary of the current memory state for the agent""" |
| | |
| | recent_short_term = self.short_term_memory[-5:] if self.short_term_memory else [] |
| | short_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
| | for m in recent_short_term]) |
| | |
| | |
| | important_long_term = sorted(self.long_term_memory, |
| | key=lambda x: x.get("importance", 0), |
| | reverse=True)[:5] if self.long_term_memory else [] |
| | long_term_summary = "\n".join([f"- [{m.get('type', 'general')}] {m.get('content', '')[:100]}..." |
| | for m in important_long_term]) |
| | |
| | |
| | working_memory_summary = "\n".join([f"- {k}: {str(v)[:50]}..." if isinstance(v, str) and len(str(v)) > 50 |
| | else f"- {k}: {v}" for k, v in self.working_memory.items()]) |
| | |
| | return f""" |
| | MEMORY SUMMARY: |
| | -------------- |
| | Recent Short-Term Memory: |
| | {short_term_summary if short_term_summary else "No recent short-term memories."} |
| | |
| | Important Long-Term Memory: |
| | {long_term_summary if long_term_summary else "No important long-term memories."} |
| | |
| | Working Memory: |
| | {working_memory_summary if working_memory_summary else "Working memory is empty."} |
| | """ |
| | |
| | def save_memories(self) -> None: |
| | """Save memories to disk for persistence""" |
| | try: |
| | |
| | memories = { |
| | "short_term": self.short_term_memory, |
| | "long_term": self.long_term_memory, |
| | "last_updated": datetime.now().isoformat() |
| | } |
| | |
| | with open(self.memory_file, 'w') as f: |
| | json.dump(memories, f, indent=2) |
| | except Exception as e: |
| | print(f"Warning: Could not save memories: {str(e)}") |
| | |
| | def load_memories(self) -> None: |
| | """Load memories from disk if available""" |
| | try: |
| | if os.path.exists(self.memory_file): |
| | with open(self.memory_file, 'r') as f: |
| | memories = json.load(f) |
| | |
| | self.short_term_memory = memories.get("short_term", []) |
| | self.long_term_memory = memories.get("long_term", []) |
| | |
| | |
| | if self.use_semantic_search: |
| | self.memory_embeddings = [] |
| | |
| | |
| | for i, memory in enumerate(self.short_term_memory): |
| | try: |
| | content = memory.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, i, "short_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | |
| | for i, memory in enumerate(self.long_term_memory): |
| | try: |
| | content = memory.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, i, "long_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | print(f"Loaded {len(self.short_term_memory)} short-term and {len(self.long_term_memory)} long-term memories.") |
| | except Exception as e: |
| | print(f"Warning: Could not load memories: {str(e)}") |
| | |
| | def forget_old_memories(self, days_threshold: int = 30) -> None: |
| | """ |
| | Remove memories older than the specified threshold |
| | |
| | Args: |
| | days_threshold: Age threshold in days |
| | """ |
| | try: |
| | now = datetime.now() |
| | threshold = days_threshold * 24 * 60 * 60 |
| | |
| | |
| | new_short_term = [] |
| | for i, memory in enumerate(self.short_term_memory): |
| | try: |
| | timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| | age = (now - timestamp).total_seconds() |
| | if age < threshold: |
| | new_short_term.append(memory) |
| | except: |
| | |
| | new_short_term.append(memory) |
| | |
| | |
| | new_long_term = [] |
| | for i, memory in enumerate(self.long_term_memory): |
| | try: |
| | timestamp = datetime.fromisoformat(memory.get("timestamp", "2000-01-01T00:00:00")) |
| | age = (now - timestamp).total_seconds() |
| | |
| | importance = memory.get("importance", 0.5) |
| | |
| | adjusted_threshold = threshold * (1 + importance) |
| | if age < adjusted_threshold: |
| | new_long_term.append(memory) |
| | except: |
| | |
| | new_long_term.append(memory) |
| | |
| | |
| | removed_short_term = len(self.short_term_memory) - len(new_short_term) |
| | removed_long_term = len(self.long_term_memory) - len(new_long_term) |
| | |
| | self.short_term_memory = new_short_term |
| | self.long_term_memory = new_long_term |
| | |
| | |
| | if self.use_semantic_search: |
| | self.memory_embeddings = [] |
| | |
| | |
| | for i, memory in enumerate(self.short_term_memory): |
| | try: |
| | content = memory.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, i, "short_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | |
| | for i, memory in enumerate(self.long_term_memory): |
| | try: |
| | content = memory.get("content", "") |
| | embedding = self.embedding_model.encode(content) |
| | self.memory_embeddings.append((embedding, i, "long_term")) |
| | except Exception as e: |
| | print(f"Warning: Could not create embedding for memory item: {str(e)}") |
| | |
| | |
| | self.save_memories() |
| | |
| | print(f"Forgot {removed_short_term} short-term and {removed_long_term} long-term memories older than {days_threshold} days.") |
| | except Exception as e: |
| | print(f"Warning: Could not forget old memories: {str(e)}") |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | memory_manager = EnhancedMemoryManager(use_semantic_search=True) |
| | |
| | |
| | memory_manager.add_to_short_term({ |
| | "type": "query", |
| | "content": "What is the capital of France?", |
| | "timestamp": datetime.now().isoformat() |
| | }) |
| | |
| | memory_manager.add_to_long_term({ |
| | "type": "key_fact", |
| | "content": "Paris is the capital of France with a population of about 2.2 million people.", |
| | "timestamp": datetime.now().isoformat() |
| | }) |
| | |
| | memory_manager.store_in_working_memory("current_task", "Finding information about France") |
| | |
| | |
| | relevant_memories = memory_manager.get_relevant_memories("What is the population of Paris?") |
| | print("\nRelevant memories for 'What is the population of Paris?':") |
| | for memory in relevant_memories: |
| | print(f"- Score: {memory.get('relevance_score', 0):.2f}, Content: {memory.get('content', '')}") |
| | |
| | |
| | print("\nMemory Summary:") |
| | print(memory_manager.get_memory_summary()) |
| |
|