multi-agent-lab / src /core /memory.py
agharsallah
feat: Implement audience-only secret badge for Twenty Sprouts game
f6566bb
Raw
History Blame Contribute Delete
15.3 kB
"""Agent memory β€” episodic recall, salience scoring, and reflection.
Memory architecture (three layers):
1. EpisodicMemory β€” filtered view over the ledger; an agent sees its own events
plus the public record everyone witnesses: world beats, verdicts, visitor
pokes, reflections, and peers' spoken lines (``agent.spoke`` / ``oracle.spoke``).
Private ``agent.thought`` stays out β€” minds aren't read by peers. Always-on.
2. SalienceMemory β€” ranks visible events by a composite score:
salience(e) = w_relΒ·relevance(e,query) + w_recΒ·recency(e,turn) + w_impΒ·importance(e.kind)
and returns the top-K rather than the most-recent K. This layer is
optional (manifest.memory.use_salience=True) and adds ~0 latency.
3. ReflectionMemory β€” wraps either layer and emits an agent.reflected
event every threshold events, compacting episodic memories into
a high-level belief. Reflection events are themselves visible to
the agent, so beliefs accumulate over time without blowing the window.
None of these layers maintain a separate *source of truth* β€” they are functions
over the shared append-only ledger. Memory is always consistent with the ledger
because it *is* the ledger.
The one optional accelerator is a semantic relevance index
(:class:`~src.core.memory_index.MemoryIndex`, attached via ``SalienceMemory.index``
and gated by ``MEMORY_INDEX`` β€” see ADR-0018). It is a *derived, rebuildable*
view: populated FROM ledger events, keyed by ``event.id`` (idempotent re-index),
and used only to score the relevance term over events the visibility filter
already admits. Wipe it and it rebuilds from the ledger; with it unattached
(the offline default) relevance is keyword overlap, exactly as below.
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from src import observability as obs
from src.core.events import Event
if TYPE_CHECKING: # pragma: no cover - typing only
from src.core.memory_index import MemoryIndex
logger = logging.getLogger(__name__)
def _displayable(event: Event) -> str:
"""A safe one-line rendering of an event for a prompt.
Prefers ``text``/``summary``; falls back to the shared ``goal`` carried by
``run.started``. Never ``str(payload)`` β€” that dumped whole payload dicts
(e.g. the run seed) into every agent's context, which is both noise and, for a
hidden-word game, a leak vector."""
payload = event.payload
return payload.get("text") or payload.get("summary") or payload.get("goal") or ""
# ── importance weights by event kind ─────────────────────────────────────────
_KIND_IMPORTANCE: dict[str, float] = {
"run.started": 0.3,
"world.observed": 0.7,
"agent.spoke": 0.5,
"oracle.spoke": 0.5, # a custom public-speech kind (oracle-grove); ranks like agent.spoke
"agent.thought": 0.4,
"agent.reflected": 0.85, # reflections are high-value compact memories
"judge.verdict": 0.9,
"user.injected": 0.95, # visitor events are always salient
"hypothesis.proposed": 0.75,
"clue.found": 0.8,
"verdict.final": 1.0,
}
# What an agent can RECALL of others: globally-visible kinds (plus its own events).
# The split is public vs. private SPEECH. A spoken line (``agent.spoke`` / the custom
# ``oracle.spoke``) is table talk β€” every mind hears it, so it must be recallable across
# the whole run, not just this round's blackboard tail. Without this a judge that fires
# late (it has no own events yet) recalls *none* of the discussion it must rule on, and a
# worker forgets every peer line older than the 6-line blackboard window. A private
# ``agent.thought`` is deliberately NOT here: it rides only its own event payload (the
# mind-reader UI), so peers never read another mind's thinking. Secrets ride non-``text``
# payload keys and ``_displayable`` shows ``text`` only, so sharing speech leaks nothing.
_GLOBALLY_VISIBLE: frozenset[str] = frozenset(
{
"world.observed",
"judge.verdict",
"user.injected",
"run.started",
"agent.reflected",
"agent.spoke",
"oracle.spoke",
}
)
# What COUNTS toward an agent's reflection cadence (ReflectionTracker). Deliberately the
# narrower set *without* peer speech: reflection compacts "what I have been through" β€”
# my own arc plus the world beats β€” so its rhythm shouldn't lurch just because the table
# got chatty this round. Keeping it separate from _GLOBALLY_VISIBLE leaves reflection
# timing exactly as tuned while recall gains the shared discussion.
_REFLECTION_VISIBLE: frozenset[str] = frozenset(
{"world.observed", "judge.verdict", "user.injected", "run.started", "agent.reflected"}
)
# ── layer 1: episodic memory ─────────────────────────────────────────────────
@dataclass
class EpisodicMemory:
"""Per-agent filtered view over the ledger β€” the always-on memory layer.
An agent sees its own events plus globally-visible kinds (the public record β€”
world beats, verdicts, visitor pokes, reflections, and peers' *spoken* lines;
never peers' private thoughts). The window is capped at max_recent to stay
within small-model context budgets.
"""
agent_name: str
max_recent: int = 8
def visible(self, events: tuple[Event, ...], run_id: str | None = None) -> list[Event]:
if run_id is not None:
events = tuple(e for e in events if e.run_id == run_id)
result = [e for e in events if e.actor == self.agent_name or e.kind in _GLOBALLY_VISIBLE]
return result[-self.max_recent :]
def format_for_prompt(self, events: tuple[Event, ...]) -> str:
with obs.span("memory.recall", **{"mal.agent": self.agent_name, "memory.mode": "episodic"}):
recalled = self.visible(events)
lines = [f"[turn {e.turn:03d}][{e.kind}] {text}" for e in recalled if (text := _displayable(e))]
memory = "\n".join(lines) if lines else "(no prior memory)"
obs.add_span_attrs(**{"memory.visible_count": len(recalled)})
obs.observe("memory.visible_count", len(recalled), agent=self.agent_name)
# DEBUG: the EXACT memory string this agent will receive (what it "sees").
obs.log(
"memory.recall",
level="debug",
agent=self.agent_name,
mode="episodic",
visible_count=len(recalled),
memory=memory,
)
return memory
# ── layer 2: salience-scored memory ──────────────────────────────────────────
@dataclass
class SalienceMemory:
"""Ranks visible events by salience instead of pure recency.
salience(e) = w_relΒ·relevance + w_recΒ·recency + w_impΒ·importance
relevance: semantic similarity between the event and the current scene when
a :class:`~src.core.memory_index.MemoryIndex` is attached
(``index`` set), else keyword (Jaccard) overlap between the event
text and the scene. The index is a *derived* lens over the same
ledger events β€” it changes only how the relevance term is scored,
never which events are eligible (see ``visible``) nor the recency
or importance terms.
recency: exponential decay β€” exp(βˆ’Ξ»Β·Ξ”turn). Ξ»=0.1 gives half-life β‰ˆ7 turns.
importance: event-kind weight from _KIND_IMPORTANCE table.
Attach an index via ``index=...`` to use semantic relevance; with ``index``
left ``None`` (the default) the scoring is exactly the offline keyword path.
"""
agent_name: str
top_k: int = 8
w_relevance: float = 0.3
w_recency: float = 0.4
w_importance: float = 0.3
decay_lambda: float = 0.1
index: "MemoryIndex | None" = None
def _keyword_relevance(self, event: Event, query: str) -> float:
event_words = set(str(event.payload.get("text", "")).lower().split())
query_words = set(query.lower().split())
if not query_words or not event_words:
return 0.0
return len(query_words & event_words) / len(query_words | event_words)
def score(
self,
event: Event,
current_turn: int,
query: str,
relevance: float | None = None,
) -> float:
"""Composite salience. *relevance* may be supplied (e.g. a semantic rank);
when ``None`` it is computed from keyword overlap as before."""
recency = math.exp(-self.decay_lambda * max(0, current_turn - event.turn))
importance = _KIND_IMPORTANCE.get(event.kind, 0.5)
if relevance is None:
relevance = self._keyword_relevance(event, query)
return self.w_relevance * relevance + self.w_recency * recency + self.w_importance * importance
def _candidates(self, events: tuple[Event, ...]) -> list[Event]:
"""Ledger-derived visibility filter β€” unchanged whether or not an index
is attached: an agent only ever recalls its own events plus globally
visible kinds."""
return [e for e in events if e.actor == self.agent_name or e.kind in _GLOBALLY_VISIBLE]
def _relevance_map(self, candidates: list[Event], query: str) -> dict[str, float] | None:
"""When an index is attached, derive a semantic relevance score per
candidate event (id β†’ score in [0,1] by descending rank); else ``None``
so :meth:`score` uses keyword overlap.
The index is populated from the candidate events first, then queried β€”
derive, then read β€” so it never reports events the ledger has not
produced, and re-indexing is idempotent (keyed by ``event.id``).
"""
if self.index is None or not query or not candidates:
return None
# Scope the semantic search to the candidates' run: the index spans every
# run in the shared store, and unscoped hits from other runs (other users'
# shows) would crowd the recall budget out of this run's events. Derived
# from the candidates so callers that already pass a single-run slice (the
# conductor does) get scoping for free.
run_ids = {e.run_id for e in candidates}
run_id = next(iter(run_ids)) if len(run_ids) == 1 else None
# The index is a derived, rebuildable lens (ADR-0018) β€” never load-bearing.
# If it hiccups (a flaky hosted backend, a transient mem0 error), degrade to
# keyword relevance rather than let one agent's recall crash its whole turn.
try:
self.index.index(tuple(candidates))
hits = self.index.search(query, k=len(candidates), run_id=run_id)
except Exception as exc: # noqa: BLE001 β€” relevance is best-effort, never fatal
logger.warning("memory index unavailable, using keyword relevance: %s", exc)
obs.log("memory.index.fallback", level="warning", agent=self.agent_name, error=str(exc))
return None
eligible = {e.id for e in candidates}
ranked = [h.id for h in hits if h.id in eligible]
if not ranked:
return {}
n = len(ranked)
return {eid: (n - i) / n for i, eid in enumerate(ranked)}
def visible(
self, events: tuple[Event, ...], current_turn: int, query: str, run_id: str | None = None
) -> list[Event]:
if run_id is not None:
events = tuple(e for e in events if e.run_id == run_id)
candidates = self._candidates(events)
relevance = self._relevance_map(candidates, query)
scored = sorted(
candidates,
key=lambda e: self.score(
e,
current_turn,
query,
relevance=None if relevance is None else relevance.get(e.id, 0.0),
),
reverse=True,
)
# Return in chronological order so prompts read naturally
top = scored[: self.top_k]
return sorted(top, key=lambda e: e.turn)
def format_for_prompt(self, events: tuple[Event, ...], current_turn: int, query: str) -> str:
with obs.span(
"memory.recall",
**{"mal.agent": self.agent_name, "memory.mode": "salience", "memory.top_k": self.top_k},
):
candidates = self._candidates(events)
relevance = self._relevance_map(candidates, query)
def _score(e: Event) -> float:
rel = None if relevance is None else relevance.get(e.id, 0.0)
return self.score(e, current_turn, query, relevance=rel)
top = sorted(candidates, key=_score, reverse=True)[: self.top_k]
recalled = sorted(top, key=lambda e: e.turn)
lines = [
f"[turn {e.turn:03d}][{e.kind}][sal={_score(e):.2f}] {text}"
for e in recalled
if (text := _displayable(e))
]
memory = "\n".join(lines) if lines else "(no salient memories)"
scores = {e.id: round(_score(e), 3) for e in recalled}
obs.add_span_attrs(
**{
"memory.visible_count": len(recalled),
"memory.query": query,
"memory.semantic": relevance is not None,
}
)
obs.observe("memory.visible_count", len(recalled), agent=self.agent_name)
# DEBUG: the EXACT salience-ranked memory this agent will receive, with scores.
obs.log(
"memory.recall",
level="debug",
agent=self.agent_name,
mode="salience",
query=query,
visible_count=len(recalled),
semantic=relevance is not None,
scores=scores,
memory=memory,
)
return memory
# ── layer 3: reflection trigger ───────────────────────────────────────────────
@dataclass
class ReflectionTracker:
"""Tracks whether this agent is due to emit a reflection event.
Reflection events compact recent episodic memories into a high-level
belief ("the baker resents me") that is cheaper to carry than raw history
and richer. The belief becomes an agent.reflected event in the ledger β€”
which EpisodicMemory picks up in future turns because it is globally visible.
"""
agent_name: str
threshold: int # emit reflection every N visible events
_seen_count: int = field(default=0, init=False, repr=False)
def observe(self, events: tuple[Event, ...]) -> bool:
"""Return True when a reflection should be emitted this turn."""
visible_count = sum(1 for e in events if e.actor == self.agent_name or e.kind in _REFLECTION_VISIBLE)
due = visible_count > 0 and visible_count != self._seen_count and visible_count % self.threshold == 0
self._seen_count = visible_count
return due