NeerajCodz commited on
Commit
bb3ee41
·
1 Parent(s): ca1fd98

feat: implement hierarchical memory system

Browse files
backend/app/memory/__init__.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Memory module for ScrapeRL agent memory management.
2
+
3
+ This module provides a multi-layered memory system for RL agents:
4
+
5
+ - **ShortTermMemory**: Episode-scoped dictionary storage that auto-clears
6
+ - **WorkingMemory**: LRU-based reasoning/scratch space with limited capacity
7
+ - **LongTermMemory**: Persistent vector storage with ChromaDB for semantic search
8
+ - **SharedMemory**: Thread-safe pub/sub and state sharing for multi-agent coordination
9
+ - **MemoryManager**: Unified interface to all memory layers
10
+
11
+ Example:
12
+ >>> from app.config import get_settings
13
+ >>> from app.memory import MemoryManager, MemoryType
14
+ >>>
15
+ >>> settings = get_settings()
16
+ >>> memory = MemoryManager(settings)
17
+ >>> await memory.initialize()
18
+ >>>
19
+ >>> # Store in short-term memory
20
+ >>> await memory.store("key", "value", MemoryType.SHORT_TERM)
21
+ >>>
22
+ >>> # Semantic search in long-term memory
23
+ >>> results = await memory.search("query", MemoryType.LONG_TERM)
24
+ >>>
25
+ >>> # Cleanup
26
+ >>> await memory.shutdown()
27
+ """
28
+
29
+ from app.memory.long_term import Document, LongTermMemory, SearchResult
30
+ from app.memory.manager import MemoryManager, MemoryStats, MemoryType
31
+ from app.memory.shared import Channel, Message, SharedMemory, Subscription
32
+ from app.memory.short_term import MemoryEntry, ShortTermMemory
33
+ from app.memory.working import WorkingMemory, WorkingMemoryItem
34
+
35
+ __all__ = [
36
+ # Manager
37
+ "MemoryManager",
38
+ "MemoryStats",
39
+ "MemoryType",
40
+ # Short-term
41
+ "ShortTermMemory",
42
+ "MemoryEntry",
43
+ # Working
44
+ "WorkingMemory",
45
+ "WorkingMemoryItem",
46
+ # Long-term
47
+ "LongTermMemory",
48
+ "Document",
49
+ "SearchResult",
50
+ # Shared
51
+ "SharedMemory",
52
+ "Channel",
53
+ "Message",
54
+ "Subscription",
55
+ ]
backend/app/memory/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.79 kB). View file
 
backend/app/memory/__pycache__/long_term.cpython-314.pyc ADDED
Binary file (24.4 kB). View file
 
backend/app/memory/__pycache__/manager.cpython-314.pyc ADDED
Binary file (20.2 kB). View file
 
backend/app/memory/__pycache__/shared.cpython-314.pyc ADDED
Binary file (25.3 kB). View file
 
backend/app/memory/__pycache__/short_term.cpython-314.pyc ADDED
Binary file (13.3 kB). View file
 
backend/app/memory/__pycache__/working.cpython-314.pyc ADDED
Binary file (15.4 kB). View file
 
backend/app/memory/long_term.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Long-term memory with persistent vector storage using ChromaDB."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import hashlib
7
+ import logging
8
+ from datetime import datetime
9
+ from typing import Any
10
+ from uuid import uuid4
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Document(BaseModel):
18
+ """A document stored in long-term memory."""
19
+
20
+ id: str = Field(default_factory=lambda: str(uuid4()))
21
+ content: str
22
+ embedding: list[float] | None = None
23
+ metadata: dict[str, Any] = Field(default_factory=dict)
24
+ created_at: datetime = Field(default_factory=datetime.utcnow)
25
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
26
+
27
+ model_config = {"arbitrary_types_allowed": True}
28
+
29
+
30
+ class SearchResult(BaseModel):
31
+ """A search result from long-term memory."""
32
+
33
+ document: Document
34
+ score: float
35
+ distance: float | None = None
36
+
37
+ model_config = {"arbitrary_types_allowed": True}
38
+
39
+
40
+ class LongTermMemory:
41
+ """
42
+ Long-term persistent memory using ChromaDB for vector storage.
43
+
44
+ This memory layer provides semantic search capabilities using embeddings.
45
+ It persists across episodes and sessions, storing knowledge that should
46
+ be retained long-term.
47
+
48
+ Attributes:
49
+ collection_name: Name of the ChromaDB collection.
50
+ persist_directory: Directory for persistent storage.
51
+ top_k: Default number of results to return from search.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ collection_name: str = "scraperl_memory",
57
+ persist_directory: str = "./data/chroma",
58
+ top_k: int = 10,
59
+ embedding_function: Any | None = None,
60
+ ) -> None:
61
+ """
62
+ Initialize long-term memory.
63
+
64
+ Args:
65
+ collection_name: Name of the ChromaDB collection.
66
+ persist_directory: Directory for persistent storage.
67
+ top_k: Default number of results to return from search.
68
+ embedding_function: Optional custom embedding function.
69
+ """
70
+ self.collection_name = collection_name
71
+ self.persist_directory = persist_directory
72
+ self.top_k = top_k
73
+ self._embedding_function = embedding_function
74
+ self._client: Any = None
75
+ self._collection: Any = None
76
+ self._initialized = False
77
+ self._lock = asyncio.Lock()
78
+
79
+ async def initialize(self) -> None:
80
+ """
81
+ Initialize ChromaDB client and collection.
82
+
83
+ This should be called before using other methods.
84
+ """
85
+ if self._initialized:
86
+ return
87
+
88
+ async with self._lock:
89
+ if self._initialized:
90
+ return
91
+
92
+ try:
93
+ import chromadb
94
+ from chromadb.config import Settings
95
+
96
+ # Create persistent client
97
+ self._client = chromadb.Client(
98
+ Settings(
99
+ chroma_db_impl="duckdb+parquet",
100
+ persist_directory=self.persist_directory,
101
+ anonymized_telemetry=False,
102
+ )
103
+ )
104
+
105
+ # Get or create collection
106
+ self._collection = self._client.get_or_create_collection(
107
+ name=self.collection_name,
108
+ embedding_function=self._embedding_function,
109
+ metadata={"hnsw:space": "cosine"},
110
+ )
111
+
112
+ self._initialized = True
113
+ logger.info(
114
+ f"Initialized long-term memory: collection={self.collection_name}"
115
+ )
116
+
117
+ except ImportError:
118
+ logger.warning(
119
+ "ChromaDB not available. Long-term memory will use in-memory fallback."
120
+ )
121
+ self._use_fallback()
122
+ except Exception as e:
123
+ logger.warning(
124
+ f"Failed to initialize ChromaDB: {e}. Using in-memory fallback."
125
+ )
126
+ self._use_fallback()
127
+
128
+ def _use_fallback(self) -> None:
129
+ """Use in-memory fallback when ChromaDB is unavailable."""
130
+ self._client = None
131
+ self._collection = None
132
+ self._fallback_store: dict[str, Document] = {}
133
+ self._initialized = True
134
+
135
+ @property
136
+ def is_initialized(self) -> bool:
137
+ """Check if memory is initialized."""
138
+ return self._initialized
139
+
140
+ @property
141
+ def _using_fallback(self) -> bool:
142
+ """Check if using in-memory fallback."""
143
+ return self._collection is None
144
+
145
+ def _generate_id(self, content: str) -> str:
146
+ """Generate a deterministic ID from content."""
147
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
148
+
149
+ async def store(
150
+ self,
151
+ content: str,
152
+ document_id: str | None = None,
153
+ metadata: dict[str, Any] | None = None,
154
+ embedding: list[float] | None = None,
155
+ ) -> Document:
156
+ """
157
+ Store a document in long-term memory.
158
+
159
+ Args:
160
+ content: Text content to store.
161
+ document_id: Optional custom ID. Generated from content if not provided.
162
+ metadata: Optional metadata dictionary.
163
+ embedding: Optional pre-computed embedding vector.
164
+
165
+ Returns:
166
+ The stored document.
167
+ """
168
+ if not self._initialized:
169
+ await self.initialize()
170
+
171
+ async with self._lock:
172
+ doc_id = document_id or self._generate_id(content)
173
+ now = datetime.utcnow()
174
+
175
+ document = Document(
176
+ id=doc_id,
177
+ content=content,
178
+ embedding=embedding,
179
+ metadata=metadata or {},
180
+ created_at=now,
181
+ updated_at=now,
182
+ )
183
+
184
+ if self._using_fallback:
185
+ self._fallback_store[doc_id] = document
186
+ else:
187
+ # Store in ChromaDB
188
+ try:
189
+ self._collection.upsert(
190
+ ids=[doc_id],
191
+ documents=[content],
192
+ metadatas=[
193
+ {
194
+ **document.metadata,
195
+ "created_at": now.isoformat(),
196
+ "updated_at": now.isoformat(),
197
+ }
198
+ ],
199
+ embeddings=[embedding] if embedding else None,
200
+ )
201
+ except Exception as e:
202
+ logger.error(f"Failed to store document: {e}")
203
+ raise
204
+
205
+ return document
206
+
207
+ async def search(
208
+ self,
209
+ query: str,
210
+ top_k: int | None = None,
211
+ where: dict[str, Any] | None = None,
212
+ query_embedding: list[float] | None = None,
213
+ ) -> list[SearchResult]:
214
+ """
215
+ Search for similar documents using semantic search.
216
+
217
+ Args:
218
+ query: Search query text.
219
+ top_k: Number of results to return. Uses default if not specified.
220
+ where: Optional metadata filter.
221
+ query_embedding: Optional pre-computed query embedding.
222
+
223
+ Returns:
224
+ List of search results with scores.
225
+ """
226
+ if not self._initialized:
227
+ await self.initialize()
228
+
229
+ k = top_k or self.top_k
230
+
231
+ async with self._lock:
232
+ if self._using_fallback:
233
+ # Simple substring matching for fallback
234
+ results = []
235
+ query_lower = query.lower()
236
+ for doc in self._fallback_store.values():
237
+ if query_lower in doc.content.lower():
238
+ results.append(
239
+ SearchResult(document=doc, score=1.0, distance=0.0)
240
+ )
241
+ return results[:k]
242
+
243
+ try:
244
+ # Query ChromaDB
245
+ query_params: dict[str, Any] = {
246
+ "n_results": k,
247
+ }
248
+
249
+ if query_embedding:
250
+ query_params["query_embeddings"] = [query_embedding]
251
+ else:
252
+ query_params["query_texts"] = [query]
253
+
254
+ if where:
255
+ query_params["where"] = where
256
+
257
+ results = self._collection.query(**query_params)
258
+
259
+ # Parse results
260
+ search_results = []
261
+ if results and results.get("ids"):
262
+ for i, doc_id in enumerate(results["ids"][0]):
263
+ content = (
264
+ results["documents"][0][i]
265
+ if results.get("documents")
266
+ else ""
267
+ )
268
+ metadata = (
269
+ results["metadatas"][0][i]
270
+ if results.get("metadatas")
271
+ else {}
272
+ )
273
+ distance = (
274
+ results["distances"][0][i]
275
+ if results.get("distances")
276
+ else None
277
+ )
278
+
279
+ doc = Document(
280
+ id=doc_id,
281
+ content=content,
282
+ metadata=metadata,
283
+ )
284
+
285
+ # Convert distance to score (cosine similarity)
286
+ score = 1 - distance if distance is not None else 1.0
287
+
288
+ search_results.append(
289
+ SearchResult(
290
+ document=doc,
291
+ score=score,
292
+ distance=distance,
293
+ )
294
+ )
295
+
296
+ return search_results
297
+
298
+ except Exception as e:
299
+ logger.error(f"Search failed: {e}")
300
+ return []
301
+
302
+ async def get(self, document_id: str) -> Document | None:
303
+ """
304
+ Retrieve a document by ID.
305
+
306
+ Args:
307
+ document_id: The document ID to retrieve.
308
+
309
+ Returns:
310
+ The document or None if not found.
311
+ """
312
+ if not self._initialized:
313
+ await self.initialize()
314
+
315
+ async with self._lock:
316
+ if self._using_fallback:
317
+ return self._fallback_store.get(document_id)
318
+
319
+ try:
320
+ result = self._collection.get(ids=[document_id])
321
+ if result and result["ids"]:
322
+ return Document(
323
+ id=result["ids"][0],
324
+ content=result["documents"][0] if result.get("documents") else "",
325
+ metadata=result["metadatas"][0] if result.get("metadatas") else {},
326
+ )
327
+ return None
328
+ except Exception as e:
329
+ logger.error(f"Failed to get document: {e}")
330
+ return None
331
+
332
+ async def delete(self, document_id: str) -> bool:
333
+ """
334
+ Delete a document from long-term memory.
335
+
336
+ Args:
337
+ document_id: The document ID to delete.
338
+
339
+ Returns:
340
+ True if document was deleted, False otherwise.
341
+ """
342
+ if not self._initialized:
343
+ await self.initialize()
344
+
345
+ async with self._lock:
346
+ if self._using_fallback:
347
+ if document_id in self._fallback_store:
348
+ del self._fallback_store[document_id]
349
+ return True
350
+ return False
351
+
352
+ try:
353
+ self._collection.delete(ids=[document_id])
354
+ return True
355
+ except Exception as e:
356
+ logger.error(f"Failed to delete document: {e}")
357
+ return False
358
+
359
+ async def delete_where(self, where: dict[str, Any]) -> int:
360
+ """
361
+ Delete documents matching a metadata filter.
362
+
363
+ Args:
364
+ where: Metadata filter for documents to delete.
365
+
366
+ Returns:
367
+ Number of documents deleted.
368
+ """
369
+ if not self._initialized:
370
+ await self.initialize()
371
+
372
+ async with self._lock:
373
+ if self._using_fallback:
374
+ to_delete = []
375
+ for doc_id, doc in self._fallback_store.items():
376
+ if all(doc.metadata.get(k) == v for k, v in where.items()):
377
+ to_delete.append(doc_id)
378
+ for doc_id in to_delete:
379
+ del self._fallback_store[doc_id]
380
+ return len(to_delete)
381
+
382
+ try:
383
+ # Get matching IDs first
384
+ result = self._collection.get(where=where)
385
+ if result and result["ids"]:
386
+ self._collection.delete(ids=result["ids"])
387
+ return len(result["ids"])
388
+ return 0
389
+ except Exception as e:
390
+ logger.error(f"Failed to delete documents: {e}")
391
+ return 0
392
+
393
+ async def count(self) -> int:
394
+ """
395
+ Get the total number of documents stored.
396
+
397
+ Returns:
398
+ Document count.
399
+ """
400
+ if not self._initialized:
401
+ await self.initialize()
402
+
403
+ async with self._lock:
404
+ if self._using_fallback:
405
+ return len(self._fallback_store)
406
+
407
+ try:
408
+ return self._collection.count()
409
+ except Exception as e:
410
+ logger.error(f"Failed to count documents: {e}")
411
+ return 0
412
+
413
+ async def clear(self) -> int:
414
+ """
415
+ Clear all documents from memory.
416
+
417
+ Returns:
418
+ Number of documents that were cleared.
419
+ """
420
+ if not self._initialized:
421
+ await self.initialize()
422
+
423
+ async with self._lock:
424
+ if self._using_fallback:
425
+ count = len(self._fallback_store)
426
+ self._fallback_store.clear()
427
+ return count
428
+
429
+ try:
430
+ count = self._collection.count()
431
+ # Delete and recreate collection
432
+ self._client.delete_collection(self.collection_name)
433
+ self._collection = self._client.create_collection(
434
+ name=self.collection_name,
435
+ embedding_function=self._embedding_function,
436
+ metadata={"hnsw:space": "cosine"},
437
+ )
438
+ return count
439
+ except Exception as e:
440
+ logger.error(f"Failed to clear memory: {e}")
441
+ return 0
442
+
443
+ async def persist(self) -> None:
444
+ """Persist changes to disk."""
445
+ if self._client and hasattr(self._client, "persist"):
446
+ try:
447
+ self._client.persist()
448
+ except Exception as e:
449
+ logger.error(f"Failed to persist memory: {e}")
450
+
451
+ async def shutdown(self) -> None:
452
+ """Shutdown long-term memory and persist data."""
453
+ if self._initialized and not self._using_fallback:
454
+ await self.persist()
455
+ self._initialized = False
456
+ logger.info("Long-term memory shutdown complete")
457
+
458
+ async def get_stats(self) -> dict[str, Any]:
459
+ """
460
+ Get statistics about long-term memory.
461
+
462
+ Returns:
463
+ Dictionary with memory statistics.
464
+ """
465
+ count = await self.count()
466
+ return {
467
+ "initialized": self._initialized,
468
+ "using_fallback": self._using_fallback,
469
+ "collection_name": self.collection_name,
470
+ "persist_directory": self.persist_directory,
471
+ "document_count": count,
472
+ "top_k": self.top_k,
473
+ }
backend/app/memory/manager.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unified memory manager providing access to all memory layers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+ from pydantic import BaseModel, Field
10
+
11
+ from app.config import Settings
12
+ from app.memory.long_term import Document, LongTermMemory, SearchResult
13
+ from app.memory.shared import Message, SharedMemory
14
+ from app.memory.short_term import MemoryEntry, ShortTermMemory
15
+ from app.memory.working import WorkingMemory, WorkingMemoryItem
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class MemoryType(str, Enum):
21
+ """Types of memory layers."""
22
+
23
+ SHORT_TERM = "short_term"
24
+ WORKING = "working"
25
+ LONG_TERM = "long_term"
26
+ SHARED = "shared"
27
+
28
+
29
+ class MemoryStats(BaseModel):
30
+ """Statistics for all memory layers."""
31
+
32
+ short_term: dict[str, Any] = Field(default_factory=dict)
33
+ working: dict[str, Any] = Field(default_factory=dict)
34
+ long_term: dict[str, Any] = Field(default_factory=dict)
35
+ shared: dict[str, Any] = Field(default_factory=dict)
36
+
37
+
38
+ class MemoryManager:
39
+ """
40
+ Unified interface to all memory layers.
41
+
42
+ The MemoryManager provides a single entry point for interacting with
43
+ different types of memory (short-term, working, long-term, shared).
44
+ It handles initialization, coordination, and lifecycle management.
45
+
46
+ Attributes:
47
+ short_term: Episode-scoped dictionary memory.
48
+ working: LRU-based reasoning scratch space.
49
+ long_term: Persistent vector storage.
50
+ shared: Multi-agent shared state.
51
+ """
52
+
53
+ def __init__(self, settings: Settings) -> None:
54
+ """
55
+ Initialize memory manager with settings.
56
+
57
+ Args:
58
+ settings: Application settings.
59
+ """
60
+ self._settings = settings
61
+ self._initialized = False
62
+
63
+ # Initialize memory layers
64
+ self.short_term = ShortTermMemory(
65
+ max_size=settings.short_term_memory_size,
66
+ )
67
+
68
+ self.working = WorkingMemory(
69
+ capacity=settings.working_memory_size,
70
+ )
71
+
72
+ self.long_term = LongTermMemory(
73
+ collection_name=settings.chroma_collection_name,
74
+ persist_directory=settings.chroma_persist_directory,
75
+ top_k=settings.long_term_memory_top_k,
76
+ )
77
+
78
+ self.shared = SharedMemory()
79
+
80
+ async def initialize(self) -> None:
81
+ """
82
+ Initialize all memory layers.
83
+
84
+ This should be called during application startup.
85
+ """
86
+ if self._initialized:
87
+ return
88
+
89
+ try:
90
+ # Initialize long-term memory (ChromaDB)
91
+ await self.long_term.initialize()
92
+ self._initialized = True
93
+ logger.info("Memory manager initialized successfully")
94
+ except Exception as e:
95
+ logger.error(f"Failed to initialize memory manager: {e}")
96
+ raise
97
+
98
+ async def shutdown(self) -> None:
99
+ """
100
+ Shutdown all memory layers gracefully.
101
+
102
+ This should be called during application shutdown.
103
+ """
104
+ try:
105
+ # Persist long-term memory
106
+ await self.long_term.shutdown()
107
+
108
+ # Clear working memory
109
+ await self.working.clear()
110
+
111
+ self._initialized = False
112
+ logger.info("Memory manager shutdown complete")
113
+ except Exception as e:
114
+ logger.error(f"Error during memory manager shutdown: {e}")
115
+ raise
116
+
117
+ @property
118
+ def is_initialized(self) -> bool:
119
+ """Check if memory manager is initialized."""
120
+ return self._initialized
121
+
122
+ # =========================================================================
123
+ # Unified Store Interface
124
+ # =========================================================================
125
+
126
+ async def store(
127
+ self,
128
+ key: str,
129
+ value: Any,
130
+ memory_type: MemoryType = MemoryType.SHORT_TERM,
131
+ **kwargs: Any,
132
+ ) -> Any:
133
+ """
134
+ Store a value in the specified memory layer.
135
+
136
+ Args:
137
+ key: Key or identifier for the stored value.
138
+ value: Value to store.
139
+ memory_type: Which memory layer to use.
140
+ **kwargs: Additional arguments passed to the specific layer.
141
+
142
+ Returns:
143
+ The created entry/document (varies by memory type).
144
+
145
+ Raises:
146
+ ValueError: If memory_type is invalid.
147
+ """
148
+ match memory_type:
149
+ case MemoryType.SHORT_TERM:
150
+ tags = kwargs.get("tags")
151
+ return await self.short_term.set(key, value, tags=tags)
152
+
153
+ case MemoryType.WORKING:
154
+ priority = kwargs.get("priority", 0.0)
155
+ metadata = kwargs.get("metadata")
156
+ return await self.working.push(
157
+ content=value,
158
+ item_id=key,
159
+ priority=priority,
160
+ metadata=metadata,
161
+ )
162
+
163
+ case MemoryType.LONG_TERM:
164
+ if not isinstance(value, str):
165
+ value = str(value)
166
+ metadata = kwargs.get("metadata")
167
+ embedding = kwargs.get("embedding")
168
+ return await self.long_term.store(
169
+ content=value,
170
+ document_id=key,
171
+ metadata=metadata,
172
+ embedding=embedding,
173
+ )
174
+
175
+ case MemoryType.SHARED:
176
+ await self.shared.set_state(key, value)
177
+ return value
178
+
179
+ case _:
180
+ raise ValueError(f"Invalid memory type: {memory_type}")
181
+
182
+ # =========================================================================
183
+ # Unified Retrieve Interface
184
+ # =========================================================================
185
+
186
+ async def retrieve(
187
+ self,
188
+ key: str,
189
+ memory_type: MemoryType = MemoryType.SHORT_TERM,
190
+ default: Any = None,
191
+ ) -> Any:
192
+ """
193
+ Retrieve a value from the specified memory layer.
194
+
195
+ Args:
196
+ key: Key or identifier to look up.
197
+ memory_type: Which memory layer to query.
198
+ default: Default value if not found.
199
+
200
+ Returns:
201
+ The stored value or default.
202
+
203
+ Raises:
204
+ ValueError: If memory_type is invalid.
205
+ """
206
+ match memory_type:
207
+ case MemoryType.SHORT_TERM:
208
+ return await self.short_term.get(key, default=default)
209
+
210
+ case MemoryType.WORKING:
211
+ item = await self.working.peek_by_id(key)
212
+ return item.content if item else default
213
+
214
+ case MemoryType.LONG_TERM:
215
+ doc = await self.long_term.get(key)
216
+ return doc.content if doc else default
217
+
218
+ case MemoryType.SHARED:
219
+ return await self.shared.get_state(key, default=default)
220
+
221
+ case _:
222
+ raise ValueError(f"Invalid memory type: {memory_type}")
223
+
224
+ # =========================================================================
225
+ # Unified Search Interface
226
+ # =========================================================================
227
+
228
+ async def search(
229
+ self,
230
+ query: str,
231
+ memory_type: MemoryType = MemoryType.LONG_TERM,
232
+ top_k: int = 10,
233
+ **kwargs: Any,
234
+ ) -> list[Any]:
235
+ """
236
+ Search for values in the specified memory layer.
237
+
238
+ Args:
239
+ query: Search query.
240
+ memory_type: Which memory layer to search.
241
+ top_k: Maximum number of results.
242
+ **kwargs: Additional arguments for specific layers.
243
+
244
+ Returns:
245
+ List of matching entries/documents.
246
+
247
+ Raises:
248
+ ValueError: If memory_type is invalid or doesn't support search.
249
+ """
250
+ match memory_type:
251
+ case MemoryType.SHORT_TERM:
252
+ # Search by tag or return all keys containing query
253
+ tag = kwargs.get("tag")
254
+ if tag:
255
+ return list((await self.short_term.get_by_tag(tag)).items())[:top_k]
256
+ keys = await self.short_term.list_keys()
257
+ matching = [k for k in keys if query.lower() in k.lower()]
258
+ results = []
259
+ for key in matching[:top_k]:
260
+ value = await self.short_term.get(key)
261
+ results.append((key, value))
262
+ return results
263
+
264
+ case MemoryType.WORKING:
265
+ # Search working memory items
266
+ def matches(item: WorkingMemoryItem) -> bool:
267
+ content_str = str(item.content).lower()
268
+ return query.lower() in content_str
269
+
270
+ items = await self.working.search(matches)
271
+ return items[:top_k]
272
+
273
+ case MemoryType.LONG_TERM:
274
+ where = kwargs.get("where")
275
+ query_embedding = kwargs.get("query_embedding")
276
+ return await self.long_term.search(
277
+ query=query,
278
+ top_k=top_k,
279
+ where=where,
280
+ query_embedding=query_embedding,
281
+ )
282
+
283
+ case MemoryType.SHARED:
284
+ # Search state keys
285
+ all_state = await self.shared.get_all_state()
286
+ matching = [
287
+ (k, v)
288
+ for k, v in all_state.items()
289
+ if query.lower() in k.lower()
290
+ or query.lower() in str(v).lower()
291
+ ]
292
+ return matching[:top_k]
293
+
294
+ case _:
295
+ raise ValueError(f"Invalid memory type: {memory_type}")
296
+
297
+ # =========================================================================
298
+ # Unified Clear Interface
299
+ # =========================================================================
300
+
301
+ async def clear(
302
+ self,
303
+ memory_type: MemoryType | None = None,
304
+ ) -> dict[str, int]:
305
+ """
306
+ Clear memory layers.
307
+
308
+ Args:
309
+ memory_type: Specific layer to clear, or None for all.
310
+
311
+ Returns:
312
+ Dictionary with counts of cleared items per layer.
313
+ """
314
+ results: dict[str, int] = {}
315
+
316
+ if memory_type is None or memory_type == MemoryType.SHORT_TERM:
317
+ results["short_term"] = await self.short_term.clear()
318
+
319
+ if memory_type is None or memory_type == MemoryType.WORKING:
320
+ results["working"] = await self.working.clear()
321
+
322
+ if memory_type is None or memory_type == MemoryType.LONG_TERM:
323
+ results["long_term"] = await self.long_term.clear()
324
+
325
+ if memory_type is None or memory_type == MemoryType.SHARED:
326
+ shared_results = await self.shared.clear()
327
+ results["shared_channels"] = shared_results["channels"]
328
+ results["shared_state"] = shared_results["state_keys"]
329
+
330
+ return results
331
+
332
+ # =========================================================================
333
+ # Episode Management
334
+ # =========================================================================
335
+
336
+ async def start_episode(self, episode_id: str) -> None:
337
+ """
338
+ Start a new episode, clearing episode-scoped memory.
339
+
340
+ Args:
341
+ episode_id: Unique identifier for the episode.
342
+ """
343
+ await self.short_term.set_episode(episode_id)
344
+ await self.working.clear()
345
+ logger.debug(f"Started episode: {episode_id}")
346
+
347
+ async def end_episode(self) -> dict[str, int]:
348
+ """
349
+ End the current episode, clearing temporary memory.
350
+
351
+ Returns:
352
+ Counts of cleared items.
353
+ """
354
+ results = {
355
+ "short_term": await self.short_term.clear(),
356
+ "working": await self.working.clear(),
357
+ }
358
+ logger.debug(f"Ended episode: {results}")
359
+ return results
360
+
361
+ # =========================================================================
362
+ # Statistics
363
+ # =========================================================================
364
+
365
+ async def get_stats(self) -> MemoryStats:
366
+ """
367
+ Get statistics from all memory layers.
368
+
369
+ Returns:
370
+ MemoryStats with info from each layer.
371
+ """
372
+ return MemoryStats(
373
+ short_term=await self.short_term.get_stats(),
374
+ working=await self.working.get_stats(),
375
+ long_term=await self.long_term.get_stats(),
376
+ shared=await self.shared.get_stats(),
377
+ )
378
+
379
+ # =========================================================================
380
+ # Convenience Methods
381
+ # =========================================================================
382
+
383
+ async def remember(
384
+ self,
385
+ content: str,
386
+ metadata: dict[str, Any] | None = None,
387
+ ) -> Document:
388
+ """
389
+ Store content in long-term memory for later retrieval.
390
+
391
+ This is a convenience method for storing knowledge.
392
+
393
+ Args:
394
+ content: Text content to remember.
395
+ metadata: Optional metadata.
396
+
397
+ Returns:
398
+ The stored document.
399
+ """
400
+ return await self.long_term.store(content=content, metadata=metadata)
401
+
402
+ async def recall(
403
+ self,
404
+ query: str,
405
+ top_k: int = 5,
406
+ ) -> list[SearchResult]:
407
+ """
408
+ Recall relevant memories based on a query.
409
+
410
+ This is a convenience method for semantic search.
411
+
412
+ Args:
413
+ query: Search query.
414
+ top_k: Number of results to return.
415
+
416
+ Returns:
417
+ List of relevant search results.
418
+ """
419
+ return await self.long_term.search(query=query, top_k=top_k)
420
+
421
+ async def think(
422
+ self,
423
+ thought: str,
424
+ priority: float = 0.0,
425
+ ) -> WorkingMemoryItem:
426
+ """
427
+ Add a thought to working memory.
428
+
429
+ This is a convenience method for reasoning steps.
430
+
431
+ Args:
432
+ thought: The thought content.
433
+ priority: Priority score.
434
+
435
+ Returns:
436
+ The working memory item.
437
+ """
438
+ return await self.working.push(content=thought, priority=priority)
439
+
440
+ async def broadcast(
441
+ self,
442
+ channel: str,
443
+ message: Any,
444
+ sender: str | None = None,
445
+ ) -> Message:
446
+ """
447
+ Broadcast a message to a shared channel.
448
+
449
+ This is a convenience method for multi-agent communication.
450
+
451
+ Args:
452
+ channel: Channel name.
453
+ message: Message payload.
454
+ sender: Optional sender identifier.
455
+
456
+ Returns:
457
+ The published message.
458
+ """
459
+ return await self.shared.publish(
460
+ channel=channel,
461
+ payload=message,
462
+ sender=sender,
463
+ )
backend/app/memory/shared.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared memory for multi-agent communication and state sharing."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import logging
7
+ from collections import defaultdict
8
+ from datetime import datetime
9
+ from typing import Any, Callable, Awaitable
10
+ from uuid import uuid4
11
+
12
+ from pydantic import BaseModel, Field
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Type alias for async callback functions
17
+ MessageCallback = Callable[[Any], Awaitable[None]]
18
+
19
+
20
+ class Message(BaseModel):
21
+ """A message published to a channel."""
22
+
23
+ id: str = Field(default_factory=lambda: str(uuid4()))
24
+ channel: str
25
+ payload: Any
26
+ sender: str | None = None
27
+ timestamp: datetime = Field(default_factory=datetime.utcnow)
28
+ metadata: dict[str, Any] = Field(default_factory=dict)
29
+
30
+ model_config = {"arbitrary_types_allowed": True}
31
+
32
+
33
+ class Subscription(BaseModel):
34
+ """A subscription to a channel."""
35
+
36
+ id: str = Field(default_factory=lambda: str(uuid4()))
37
+ channel: str
38
+ subscriber_id: str
39
+ created_at: datetime = Field(default_factory=datetime.utcnow)
40
+
41
+ model_config = {"arbitrary_types_allowed": True}
42
+
43
+
44
+ class Channel:
45
+ """A named channel for pub/sub communication."""
46
+
47
+ def __init__(self, name: str, max_history: int = 100) -> None:
48
+ """
49
+ Initialize a channel.
50
+
51
+ Args:
52
+ name: Channel name.
53
+ max_history: Maximum number of messages to retain in history.
54
+ """
55
+ self.name = name
56
+ self.max_history = max_history
57
+ self._subscribers: dict[str, MessageCallback] = {}
58
+ self._history: list[Message] = []
59
+ self._lock = asyncio.Lock()
60
+
61
+ @property
62
+ def subscriber_count(self) -> int:
63
+ """Get the number of subscribers."""
64
+ return len(self._subscribers)
65
+
66
+ async def publish(self, message: Message) -> int:
67
+ """
68
+ Publish a message to all subscribers.
69
+
70
+ Args:
71
+ message: Message to publish.
72
+
73
+ Returns:
74
+ Number of subscribers that received the message.
75
+ """
76
+ async with self._lock:
77
+ # Add to history
78
+ self._history.append(message)
79
+ if len(self._history) > self.max_history:
80
+ self._history = self._history[-self.max_history:]
81
+
82
+ # Notify subscribers
83
+ notified = 0
84
+ for sub_id, callback in list(self._subscribers.items()):
85
+ try:
86
+ await callback(message)
87
+ notified += 1
88
+ except Exception as e:
89
+ logger.error(f"Error notifying subscriber {sub_id}: {e}")
90
+
91
+ return notified
92
+
93
+ async def subscribe(
94
+ self,
95
+ subscriber_id: str,
96
+ callback: MessageCallback,
97
+ ) -> Subscription:
98
+ """
99
+ Subscribe to the channel.
100
+
101
+ Args:
102
+ subscriber_id: Unique identifier for the subscriber.
103
+ callback: Async callback function to receive messages.
104
+
105
+ Returns:
106
+ Subscription object.
107
+ """
108
+ async with self._lock:
109
+ self._subscribers[subscriber_id] = callback
110
+ return Subscription(
111
+ channel=self.name,
112
+ subscriber_id=subscriber_id,
113
+ )
114
+
115
+ async def unsubscribe(self, subscriber_id: str) -> bool:
116
+ """
117
+ Unsubscribe from the channel.
118
+
119
+ Args:
120
+ subscriber_id: Subscriber to remove.
121
+
122
+ Returns:
123
+ True if subscriber was found and removed.
124
+ """
125
+ async with self._lock:
126
+ if subscriber_id in self._subscribers:
127
+ del self._subscribers[subscriber_id]
128
+ return True
129
+ return False
130
+
131
+ async def get_history(
132
+ self,
133
+ limit: int | None = None,
134
+ since: datetime | None = None,
135
+ ) -> list[Message]:
136
+ """
137
+ Get channel message history.
138
+
139
+ Args:
140
+ limit: Maximum number of messages to return.
141
+ since: Only return messages after this timestamp.
142
+
143
+ Returns:
144
+ List of historical messages.
145
+ """
146
+ async with self._lock:
147
+ messages = self._history
148
+
149
+ if since:
150
+ messages = [m for m in messages if m.timestamp > since]
151
+
152
+ if limit:
153
+ messages = messages[-limit:]
154
+
155
+ return messages
156
+
157
+ async def clear_history(self) -> int:
158
+ """
159
+ Clear the channel's message history.
160
+
161
+ Returns:
162
+ Number of messages cleared.
163
+ """
164
+ async with self._lock:
165
+ count = len(self._history)
166
+ self._history.clear()
167
+ return count
168
+
169
+
170
+ class SharedMemory:
171
+ """
172
+ Thread-safe shared memory for multi-agent coordination.
173
+
174
+ This memory layer provides pub/sub messaging and shared state storage
175
+ for coordination between multiple agents. All operations are thread-safe.
176
+
177
+ Attributes:
178
+ _channels: Dictionary of channels by name.
179
+ _state: Shared key-value state store.
180
+ """
181
+
182
+ def __init__(self, max_channel_history: int = 100) -> None:
183
+ """
184
+ Initialize shared memory.
185
+
186
+ Args:
187
+ max_channel_history: Maximum history per channel.
188
+ """
189
+ self.max_channel_history = max_channel_history
190
+ self._channels: dict[str, Channel] = {}
191
+ self._state: dict[str, Any] = {}
192
+ self._state_lock = asyncio.Lock()
193
+ self._channel_lock = asyncio.Lock()
194
+ self._queues: dict[str, dict[str, asyncio.Queue]] = defaultdict(dict)
195
+
196
+ async def get_channel(self, name: str) -> Channel:
197
+ """
198
+ Get or create a channel by name.
199
+
200
+ Args:
201
+ name: Channel name.
202
+
203
+ Returns:
204
+ The channel object.
205
+ """
206
+ async with self._channel_lock:
207
+ if name not in self._channels:
208
+ self._channels[name] = Channel(
209
+ name=name,
210
+ max_history=self.max_channel_history,
211
+ )
212
+ return self._channels[name]
213
+
214
+ async def publish(
215
+ self,
216
+ channel: str,
217
+ payload: Any,
218
+ sender: str | None = None,
219
+ metadata: dict[str, Any] | None = None,
220
+ ) -> Message:
221
+ """
222
+ Publish a message to a channel.
223
+
224
+ Args:
225
+ channel: Channel name to publish to.
226
+ payload: Message payload.
227
+ sender: Optional sender identifier.
228
+ metadata: Optional message metadata.
229
+
230
+ Returns:
231
+ The published message.
232
+ """
233
+ ch = await self.get_channel(channel)
234
+
235
+ message = Message(
236
+ channel=channel,
237
+ payload=payload,
238
+ sender=sender,
239
+ metadata=metadata or {},
240
+ )
241
+
242
+ await ch.publish(message)
243
+
244
+ # Also put in subscriber queues
245
+ async with self._channel_lock:
246
+ if channel in self._queues:
247
+ for queue in self._queues[channel].values():
248
+ try:
249
+ queue.put_nowait(message)
250
+ except asyncio.QueueFull:
251
+ # Remove oldest and add new
252
+ try:
253
+ queue.get_nowait()
254
+ queue.put_nowait(message)
255
+ except asyncio.QueueEmpty:
256
+ pass
257
+
258
+ return message
259
+
260
+ async def subscribe(
261
+ self,
262
+ channel: str,
263
+ subscriber_id: str,
264
+ callback: MessageCallback,
265
+ ) -> Subscription:
266
+ """
267
+ Subscribe to a channel with a callback.
268
+
269
+ Args:
270
+ channel: Channel name to subscribe to.
271
+ subscriber_id: Unique subscriber identifier.
272
+ callback: Async callback for received messages.
273
+
274
+ Returns:
275
+ Subscription object.
276
+ """
277
+ ch = await self.get_channel(channel)
278
+ return await ch.subscribe(subscriber_id, callback)
279
+
280
+ async def subscribe_queue(
281
+ self,
282
+ channel: str,
283
+ subscriber_id: str,
284
+ max_size: int = 100,
285
+ ) -> asyncio.Queue[Message]:
286
+ """
287
+ Subscribe to a channel and receive messages via a queue.
288
+
289
+ This is an alternative to callback-based subscriptions.
290
+
291
+ Args:
292
+ channel: Channel name to subscribe to.
293
+ subscriber_id: Unique subscriber identifier.
294
+ max_size: Maximum queue size.
295
+
296
+ Returns:
297
+ Queue that will receive messages.
298
+ """
299
+ async with self._channel_lock:
300
+ if subscriber_id not in self._queues[channel]:
301
+ self._queues[channel][subscriber_id] = asyncio.Queue(maxsize=max_size)
302
+ return self._queues[channel][subscriber_id]
303
+
304
+ async def unsubscribe(self, channel: str, subscriber_id: str) -> bool:
305
+ """
306
+ Unsubscribe from a channel.
307
+
308
+ Args:
309
+ channel: Channel name.
310
+ subscriber_id: Subscriber to remove.
311
+
312
+ Returns:
313
+ True if subscriber was found and removed.
314
+ """
315
+ async with self._channel_lock:
316
+ # Remove from callback subscriptions
317
+ if channel in self._channels:
318
+ await self._channels[channel].unsubscribe(subscriber_id)
319
+
320
+ # Remove from queue subscriptions
321
+ if channel in self._queues and subscriber_id in self._queues[channel]:
322
+ del self._queues[channel][subscriber_id]
323
+ return True
324
+
325
+ return False
326
+
327
+ async def set_state(self, key: str, value: Any) -> None:
328
+ """
329
+ Set a shared state value.
330
+
331
+ Args:
332
+ key: State key.
333
+ value: Value to store.
334
+ """
335
+ async with self._state_lock:
336
+ self._state[key] = value
337
+
338
+ async def get_state(self, key: str, default: Any = None) -> Any:
339
+ """
340
+ Get a shared state value.
341
+
342
+ Args:
343
+ key: State key.
344
+ default: Default value if key not found.
345
+
346
+ Returns:
347
+ The stored value or default.
348
+ """
349
+ async with self._state_lock:
350
+ return self._state.get(key, default)
351
+
352
+ async def delete_state(self, key: str) -> bool:
353
+ """
354
+ Delete a shared state value.
355
+
356
+ Args:
357
+ key: State key to delete.
358
+
359
+ Returns:
360
+ True if key was found and deleted.
361
+ """
362
+ async with self._state_lock:
363
+ if key in self._state:
364
+ del self._state[key]
365
+ return True
366
+ return False
367
+
368
+ async def update_state(self, key: str, updater: Callable[[Any], Any]) -> Any:
369
+ """
370
+ Atomically update a state value.
371
+
372
+ Args:
373
+ key: State key.
374
+ updater: Function that takes current value and returns new value.
375
+
376
+ Returns:
377
+ The new value after update.
378
+ """
379
+ async with self._state_lock:
380
+ current = self._state.get(key)
381
+ new_value = updater(current)
382
+ self._state[key] = new_value
383
+ return new_value
384
+
385
+ async def get_all_state(self) -> dict[str, Any]:
386
+ """
387
+ Get all shared state values.
388
+
389
+ Returns:
390
+ Copy of the state dictionary.
391
+ """
392
+ async with self._state_lock:
393
+ return dict(self._state)
394
+
395
+ async def clear_state(self) -> int:
396
+ """
397
+ Clear all shared state.
398
+
399
+ Returns:
400
+ Number of keys cleared.
401
+ """
402
+ async with self._state_lock:
403
+ count = len(self._state)
404
+ self._state.clear()
405
+ return count
406
+
407
+ async def list_channels(self) -> list[str]:
408
+ """
409
+ List all active channels.
410
+
411
+ Returns:
412
+ List of channel names.
413
+ """
414
+ async with self._channel_lock:
415
+ return list(self._channels.keys())
416
+
417
+ async def delete_channel(self, name: str) -> bool:
418
+ """
419
+ Delete a channel and all its subscriptions.
420
+
421
+ Args:
422
+ name: Channel name to delete.
423
+
424
+ Returns:
425
+ True if channel was found and deleted.
426
+ """
427
+ async with self._channel_lock:
428
+ if name in self._channels:
429
+ del self._channels[name]
430
+ if name in self._queues:
431
+ del self._queues[name]
432
+ return True
433
+ return False
434
+
435
+ async def clear(self) -> dict[str, int]:
436
+ """
437
+ Clear all channels and state.
438
+
439
+ Returns:
440
+ Dictionary with counts of cleared items.
441
+ """
442
+ async with self._channel_lock:
443
+ channel_count = len(self._channels)
444
+ self._channels.clear()
445
+ self._queues.clear()
446
+
447
+ async with self._state_lock:
448
+ state_count = len(self._state)
449
+ self._state.clear()
450
+
451
+ return {
452
+ "channels": channel_count,
453
+ "state_keys": state_count,
454
+ }
455
+
456
+ async def get_stats(self) -> dict[str, Any]:
457
+ """
458
+ Get statistics about shared memory.
459
+
460
+ Returns:
461
+ Dictionary with memory statistics.
462
+ """
463
+ async with self._channel_lock:
464
+ channel_stats = {}
465
+ for name, channel in self._channels.items():
466
+ channel_stats[name] = {
467
+ "subscribers": channel.subscriber_count,
468
+ "history_size": len(channel._history),
469
+ }
470
+
471
+ async with self._state_lock:
472
+ state_keys = list(self._state.keys())
473
+
474
+ return {
475
+ "channel_count": len(channel_stats),
476
+ "channels": channel_stats,
477
+ "state_key_count": len(state_keys),
478
+ "state_keys": state_keys,
479
+ "max_channel_history": self.max_channel_history,
480
+ }
backend/app/memory/short_term.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Short-term memory for episode-scoped data storage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from collections import OrderedDict
7
+ from datetime import datetime
8
+ from typing import Any, Generic, TypeVar
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ class MemoryEntry(BaseModel, Generic[T]):
16
+ """A single memory entry with metadata."""
17
+
18
+ key: str
19
+ value: Any
20
+ created_at: datetime = Field(default_factory=datetime.utcnow)
21
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
22
+ access_count: int = 0
23
+ tags: list[str] = Field(default_factory=list)
24
+
25
+ model_config = {"arbitrary_types_allowed": True}
26
+
27
+
28
+ class ShortTermMemory:
29
+ """
30
+ Episode-scoped memory using dictionary-based storage.
31
+
32
+ This memory layer is designed for transient data that should persist
33
+ only within a single episode. It automatically clears when the episode
34
+ resets.
35
+
36
+ Attributes:
37
+ max_size: Maximum number of entries allowed.
38
+ _store: Internal storage dictionary.
39
+ _episode_id: Current episode identifier.
40
+ """
41
+
42
+ def __init__(self, max_size: int = 100) -> None:
43
+ """
44
+ Initialize short-term memory.
45
+
46
+ Args:
47
+ max_size: Maximum number of entries to store. Defaults to 100.
48
+ """
49
+ self.max_size = max_size
50
+ self._store: OrderedDict[str, MemoryEntry] = OrderedDict()
51
+ self._episode_id: str | None = None
52
+ self._lock = asyncio.Lock()
53
+
54
+ @property
55
+ def episode_id(self) -> str | None:
56
+ """Get the current episode ID."""
57
+ return self._episode_id
58
+
59
+ @property
60
+ def size(self) -> int:
61
+ """Get the current number of entries."""
62
+ return len(self._store)
63
+
64
+ async def set_episode(self, episode_id: str) -> None:
65
+ """
66
+ Set the current episode ID and clear existing memory.
67
+
68
+ Args:
69
+ episode_id: Unique identifier for the new episode.
70
+ """
71
+ async with self._lock:
72
+ if self._episode_id != episode_id:
73
+ self._store.clear()
74
+ self._episode_id = episode_id
75
+
76
+ async def set(
77
+ self,
78
+ key: str,
79
+ value: Any,
80
+ tags: list[str] | None = None,
81
+ ) -> MemoryEntry:
82
+ """
83
+ Store a value in short-term memory.
84
+
85
+ Args:
86
+ key: Unique key for the entry.
87
+ value: Value to store.
88
+ tags: Optional tags for categorization.
89
+
90
+ Returns:
91
+ The created or updated memory entry.
92
+
93
+ Raises:
94
+ ValueError: If max_size would be exceeded for a new key.
95
+ """
96
+ async with self._lock:
97
+ now = datetime.utcnow()
98
+
99
+ if key in self._store:
100
+ entry = self._store[key]
101
+ entry.value = value
102
+ entry.updated_at = now
103
+ if tags is not None:
104
+ entry.tags = tags
105
+ # Move to end (most recent)
106
+ self._store.move_to_end(key)
107
+ else:
108
+ # Check capacity
109
+ if len(self._store) >= self.max_size:
110
+ # Remove oldest entry
111
+ self._store.popitem(last=False)
112
+
113
+ entry = MemoryEntry(
114
+ key=key,
115
+ value=value,
116
+ created_at=now,
117
+ updated_at=now,
118
+ tags=tags or [],
119
+ )
120
+ self._store[key] = entry
121
+
122
+ return entry
123
+
124
+ async def get(self, key: str, default: Any = None) -> Any:
125
+ """
126
+ Retrieve a value from short-term memory.
127
+
128
+ Args:
129
+ key: Key to look up.
130
+ default: Default value if key not found.
131
+
132
+ Returns:
133
+ The stored value or default.
134
+ """
135
+ async with self._lock:
136
+ entry = self._store.get(key)
137
+ if entry is None:
138
+ return default
139
+ entry.access_count += 1
140
+ return entry.value
141
+
142
+ async def get_entry(self, key: str) -> MemoryEntry | None:
143
+ """
144
+ Retrieve a full memory entry with metadata.
145
+
146
+ Args:
147
+ key: Key to look up.
148
+
149
+ Returns:
150
+ The memory entry or None if not found.
151
+ """
152
+ async with self._lock:
153
+ entry = self._store.get(key)
154
+ if entry:
155
+ entry.access_count += 1
156
+ return entry
157
+
158
+ async def delete(self, key: str) -> bool:
159
+ """
160
+ Delete an entry from memory.
161
+
162
+ Args:
163
+ key: Key to delete.
164
+
165
+ Returns:
166
+ True if the key was found and deleted, False otherwise.
167
+ """
168
+ async with self._lock:
169
+ if key in self._store:
170
+ del self._store[key]
171
+ return True
172
+ return False
173
+
174
+ async def clear(self) -> int:
175
+ """
176
+ Clear all entries from memory.
177
+
178
+ Returns:
179
+ Number of entries that were cleared.
180
+ """
181
+ async with self._lock:
182
+ count = len(self._store)
183
+ self._store.clear()
184
+ return count
185
+
186
+ async def list_keys(self, tag: str | None = None) -> list[str]:
187
+ """
188
+ List all keys in memory, optionally filtered by tag.
189
+
190
+ Args:
191
+ tag: Optional tag to filter by.
192
+
193
+ Returns:
194
+ List of matching keys.
195
+ """
196
+ async with self._lock:
197
+ if tag is None:
198
+ return list(self._store.keys())
199
+ return [k for k, v in self._store.items() if tag in v.tags]
200
+
201
+ async def get_by_tag(self, tag: str) -> dict[str, Any]:
202
+ """
203
+ Retrieve all entries with a specific tag.
204
+
205
+ Args:
206
+ tag: Tag to filter by.
207
+
208
+ Returns:
209
+ Dictionary of key-value pairs matching the tag.
210
+ """
211
+ async with self._lock:
212
+ return {
213
+ k: v.value for k, v in self._store.items() if tag in v.tags
214
+ }
215
+
216
+ async def exists(self, key: str) -> bool:
217
+ """
218
+ Check if a key exists in memory.
219
+
220
+ Args:
221
+ key: Key to check.
222
+
223
+ Returns:
224
+ True if key exists, False otherwise.
225
+ """
226
+ async with self._lock:
227
+ return key in self._store
228
+
229
+ async def get_stats(self) -> dict[str, Any]:
230
+ """
231
+ Get statistics about the memory store.
232
+
233
+ Returns:
234
+ Dictionary with memory statistics.
235
+ """
236
+ async with self._lock:
237
+ return {
238
+ "size": len(self._store),
239
+ "max_size": self.max_size,
240
+ "episode_id": self._episode_id,
241
+ "keys": list(self._store.keys()),
242
+ "utilization": len(self._store) / self.max_size if self.max_size > 0 else 0,
243
+ }
backend/app/memory/working.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Working memory for reasoning and scratch space with LRU eviction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from collections import OrderedDict
7
+ from datetime import datetime
8
+ from typing import Any
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ class WorkingMemoryItem(BaseModel):
14
+ """A single item in working memory."""
15
+
16
+ id: str
17
+ content: Any
18
+ priority: float = 0.0
19
+ created_at: datetime = Field(default_factory=datetime.utcnow)
20
+ last_accessed: datetime = Field(default_factory=datetime.utcnow)
21
+ access_count: int = 0
22
+ metadata: dict[str, Any] = Field(default_factory=dict)
23
+
24
+ model_config = {"arbitrary_types_allowed": True}
25
+
26
+
27
+ class WorkingMemory:
28
+ """
29
+ Working memory for reasoning and scratch computations.
30
+
31
+ This memory layer provides a limited-capacity buffer with LRU (Least Recently Used)
32
+ eviction policy. It's designed for temporary reasoning steps, intermediate results,
33
+ and scratch space during agent deliberation.
34
+
35
+ Attributes:
36
+ capacity: Maximum number of items in working memory.
37
+ _items: Internal LRU-ordered storage.
38
+ """
39
+
40
+ def __init__(self, capacity: int = 20) -> None:
41
+ """
42
+ Initialize working memory.
43
+
44
+ Args:
45
+ capacity: Maximum number of items to store. Defaults to 20.
46
+ """
47
+ self.capacity = capacity
48
+ self._items: OrderedDict[str, WorkingMemoryItem] = OrderedDict()
49
+ self._counter = 0
50
+ self._lock = asyncio.Lock()
51
+
52
+ @property
53
+ def size(self) -> int:
54
+ """Get current number of items in memory."""
55
+ return len(self._items)
56
+
57
+ @property
58
+ def is_full(self) -> bool:
59
+ """Check if memory is at capacity."""
60
+ return len(self._items) >= self.capacity
61
+
62
+ async def push(
63
+ self,
64
+ content: Any,
65
+ item_id: str | None = None,
66
+ priority: float = 0.0,
67
+ metadata: dict[str, Any] | None = None,
68
+ ) -> WorkingMemoryItem:
69
+ """
70
+ Push a new item into working memory.
71
+
72
+ If capacity is reached, the least recently used item is evicted.
73
+
74
+ Args:
75
+ content: The content to store.
76
+ item_id: Optional custom ID. Auto-generated if not provided.
77
+ priority: Priority score for potential prioritized eviction.
78
+ metadata: Optional metadata dictionary.
79
+
80
+ Returns:
81
+ The created working memory item.
82
+ """
83
+ async with self._lock:
84
+ # Generate ID if not provided
85
+ if item_id is None:
86
+ self._counter += 1
87
+ item_id = f"wm_{self._counter}"
88
+
89
+ now = datetime.utcnow()
90
+
91
+ # Check if item already exists (update it)
92
+ if item_id in self._items:
93
+ item = self._items[item_id]
94
+ item.content = content
95
+ item.last_accessed = now
96
+ item.access_count += 1
97
+ if metadata:
98
+ item.metadata.update(metadata)
99
+ if priority != 0.0:
100
+ item.priority = priority
101
+ # Move to end (most recent)
102
+ self._items.move_to_end(item_id)
103
+ return item
104
+
105
+ # Evict LRU item if at capacity
106
+ if len(self._items) >= self.capacity:
107
+ self._evict_lru()
108
+
109
+ # Create new item
110
+ item = WorkingMemoryItem(
111
+ id=item_id,
112
+ content=content,
113
+ priority=priority,
114
+ created_at=now,
115
+ last_accessed=now,
116
+ metadata=metadata or {},
117
+ )
118
+ self._items[item_id] = item
119
+ return item
120
+
121
+ def _evict_lru(self) -> WorkingMemoryItem | None:
122
+ """
123
+ Evict the least recently used item.
124
+
125
+ Returns:
126
+ The evicted item, or None if memory was empty.
127
+ """
128
+ if not self._items:
129
+ return None
130
+ # Pop first item (least recently used)
131
+ _, item = self._items.popitem(last=False)
132
+ return item
133
+
134
+ async def pop(self) -> WorkingMemoryItem | None:
135
+ """
136
+ Remove and return the most recently used item.
137
+
138
+ Returns:
139
+ The most recent item, or None if memory is empty.
140
+ """
141
+ async with self._lock:
142
+ if not self._items:
143
+ return None
144
+ _, item = self._items.popitem(last=True)
145
+ return item
146
+
147
+ async def pop_by_id(self, item_id: str) -> WorkingMemoryItem | None:
148
+ """
149
+ Remove and return an item by its ID.
150
+
151
+ Args:
152
+ item_id: The ID of the item to remove.
153
+
154
+ Returns:
155
+ The removed item, or None if not found.
156
+ """
157
+ async with self._lock:
158
+ return self._items.pop(item_id, None)
159
+
160
+ async def peek(self) -> WorkingMemoryItem | None:
161
+ """
162
+ Return the most recently used item without removing it.
163
+
164
+ Returns:
165
+ The most recent item, or None if memory is empty.
166
+ """
167
+ async with self._lock:
168
+ if not self._items:
169
+ return None
170
+ # Get last item
171
+ item_id = next(reversed(self._items))
172
+ item = self._items[item_id]
173
+ item.last_accessed = datetime.utcnow()
174
+ item.access_count += 1
175
+ return item
176
+
177
+ async def peek_by_id(self, item_id: str) -> WorkingMemoryItem | None:
178
+ """
179
+ Return an item by ID without removing it.
180
+
181
+ Args:
182
+ item_id: The ID of the item to peek.
183
+
184
+ Returns:
185
+ The item, or None if not found.
186
+ """
187
+ async with self._lock:
188
+ item = self._items.get(item_id)
189
+ if item:
190
+ item.last_accessed = datetime.utcnow()
191
+ item.access_count += 1
192
+ # Move to end (mark as recently accessed)
193
+ self._items.move_to_end(item_id)
194
+ return item
195
+
196
+ async def get_all(self) -> list[WorkingMemoryItem]:
197
+ """
198
+ Get all items in memory, ordered by recency.
199
+
200
+ Returns:
201
+ List of items from least to most recent.
202
+ """
203
+ async with self._lock:
204
+ return list(self._items.values())
205
+
206
+ async def get_recent(self, n: int = 5) -> list[WorkingMemoryItem]:
207
+ """
208
+ Get the N most recently accessed items.
209
+
210
+ Args:
211
+ n: Number of items to return.
212
+
213
+ Returns:
214
+ List of most recent items.
215
+ """
216
+ async with self._lock:
217
+ items = list(self._items.values())
218
+ return items[-n:] if n < len(items) else items
219
+
220
+ async def clear(self) -> int:
221
+ """
222
+ Clear all items from working memory.
223
+
224
+ Returns:
225
+ Number of items that were cleared.
226
+ """
227
+ async with self._lock:
228
+ count = len(self._items)
229
+ self._items.clear()
230
+ self._counter = 0
231
+ return count
232
+
233
+ async def search(self, predicate: Any) -> list[WorkingMemoryItem]:
234
+ """
235
+ Search items using a predicate function.
236
+
237
+ Args:
238
+ predicate: Callable that takes a WorkingMemoryItem and returns bool.
239
+
240
+ Returns:
241
+ List of matching items.
242
+ """
243
+ async with self._lock:
244
+ return [item for item in self._items.values() if predicate(item)]
245
+
246
+ async def update_priority(self, item_id: str, priority: float) -> bool:
247
+ """
248
+ Update the priority of an item.
249
+
250
+ Args:
251
+ item_id: ID of the item to update.
252
+ priority: New priority value.
253
+
254
+ Returns:
255
+ True if item was found and updated, False otherwise.
256
+ """
257
+ async with self._lock:
258
+ if item_id in self._items:
259
+ self._items[item_id].priority = priority
260
+ return True
261
+ return False
262
+
263
+ async def get_stats(self) -> dict[str, Any]:
264
+ """
265
+ Get statistics about working memory.
266
+
267
+ Returns:
268
+ Dictionary with memory statistics.
269
+ """
270
+ async with self._lock:
271
+ return {
272
+ "size": len(self._items),
273
+ "capacity": self.capacity,
274
+ "is_full": len(self._items) >= self.capacity,
275
+ "utilization": len(self._items) / self.capacity if self.capacity > 0 else 0,
276
+ "item_ids": list(self._items.keys()),
277
+ }