Memory Systems for Long-Running Retrieval Agents

Conversation buffers, working scratchpads, and episodic vector recall

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup
  2. Sliding Window Buffer
  3. Summary Buffer Memory
  4. Token-Aware Trimming
  5. Working Scratchpad (Research State)
  6. Episodic Vector Recall
!pip install -q langchain-openai tiktoken
import os
# os.environ["OPENAI_API_KEY"] = "your-key"

2. Sliding Window Buffer

Keep only the last k conversation turns to bound context length.

from dataclasses import dataclass, field


@dataclass
class SlidingWindowBuffer:
    """Keep the last `k` turns (user + assistant pairs) in memory."""
    k: int = 5
    messages: list = field(default_factory=list)

    def add(self, role: str, content: str):
        self.messages.append({"role": role, "content": content})
        # Each turn = 2 messages (user + assistant)
        max_msgs = self.k * 2
        if len(self.messages) > max_msgs:
            self.messages = self.messages[-max_msgs:]

    def get_messages(self) -> list:
        return list(self.messages)

    def __len__(self):
        return len(self.messages)


# Demo
buf = SlidingWindowBuffer(k=3)  # Keep last 3 turns
for i in range(6):
    buf.add("user", f"Question {i}")
    buf.add("assistant", f"Answer {i}")

print(f"Buffer has {len(buf)} messages (last 3 turns):")
for m in buf.get_messages():
    print(f"  {m['role']}: {m['content']}")

3. Summary Buffer Memory

Summarize older messages to compress history while retaining key facts.

from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


@dataclass
class SummaryBuffer:
    """Maintain a running summary + recent messages window."""
    recent_k: int = 4  # Keep last k messages verbatim
    messages: list = field(default_factory=list)
    summary: str = ""

    def add(self, role: str, content: str):
        self.messages.append({"role": role, "content": content})
        if len(self.messages) > self.recent_k:
            self._compress()

    def _compress(self):
        """Summarize older messages and keep only recent ones."""
        old = self.messages[:-self.recent_k]
        old_text = "\n".join(f"{m['role']}: {m['content']}" for m in old)

        response = llm.invoke([{
            "role": "system",
            "content": "Summarize this conversation history in 2-3 sentences. Preserve key facts, decisions, and entities.",
        }, {
            "role": "user",
            "content": f"Previous summary: {self.summary}\n\nNew messages:\n{old_text}",
        }])
        self.summary = response.content
        self.messages = self.messages[-self.recent_k:]

    def get_prompt_messages(self) -> list:
        result = []
        if self.summary:
            result.append({"role": "system", "content": f"Conversation summary: {self.summary}"})
        result.extend(self.messages)
        return result


# Demo
sbuf = SummaryBuffer(recent_k=4)
sbuf.add("user", "I'm researching RAG architectures for production use.")
sbuf.add("assistant", "RAG combines retrieval with generation. Key choices: vector DB, chunking strategy, reranker.")
sbuf.add("user", "What databases do you recommend?")
sbuf.add("assistant", "Pinecone for managed, Weaviate for self-hosted, pgvector for Postgres integration.")
sbuf.add("user", "Let's focus on pgvector.")
sbuf.add("assistant", "pgvector is great for teams already on PostgreSQL. Supports IVFFlat and HNSW indexes.")

print(f"Summary: {sbuf.summary}")
print(f"\nRecent messages ({len(sbuf.messages)}):")
for m in sbuf.get_prompt_messages():
    print(f"  {m['role']}: {m['content'][:80]}...")

4. Token-Aware Trimming

Trim from the oldest messages until within the token budget.

import tiktoken


def count_tokens(text: str, model: str = "gpt-4o-mini") -> int:
    enc = tiktoken.encoding_for_model(model)
    return len(enc.encode(text))


def trim_messages_by_tokens(messages: list, max_tokens: int = 2000) -> list:
    """Drop oldest messages until total fits within token budget."""
    total = sum(count_tokens(m["content"]) for m in messages)

    trimmed = list(messages)
    while total > max_tokens and len(trimmed) > 1:
        removed = trimmed.pop(0)
        total -= count_tokens(removed["content"])
        print(f"  Trimmed: {removed['role']}: {removed['content'][:50]}... ({count_tokens(removed['content'])} tokens)")

    print(f"Final: {len(trimmed)} messages, ~{total} tokens")
    return trimmed


# Demo
test_msgs = [
    {"role": "user", "content": "Tell me about attention mechanisms in detail. " * 20},
    {"role": "assistant", "content": "Attention allows models to focus on relevant parts. " * 20},
    {"role": "user", "content": "How about multi-head attention?"},
    {"role": "assistant", "content": "Multi-head attention runs several attention functions in parallel."},
]

print("Before trimming:")
for m in test_msgs:
    print(f"  {m['role']}: {count_tokens(m['content'])} tokens")

print("\nTrimming to 200 tokens:")
trimmed = trim_messages_by_tokens(test_msgs, max_tokens=200)

5. Working Scratchpad (Research State)

An agent-writable scratchpad for ongoing multi-step research tasks.

from dataclasses import dataclass, field
from typing import Optional
from datetime import datetime


@dataclass
class ResearchState:
    """Working memory for a multi-step research agent."""
    question: str
    hypothesis: str = ""
    evidence: list[str] = field(default_factory=list)
    gaps: list[str] = field(default_factory=list)
    confidence: float = 0.0
    iteration: int = 0

    def add_evidence(self, finding: str, source: str):
        entry = f"[{source}] {finding}"
        self.evidence.append(entry)
        # Simple confidence heuristic
        self.confidence = min(1.0, len(self.evidence) * 0.2)

    def add_gap(self, gap: str):
        self.gaps.append(gap)

    def to_prompt(self) -> str:
        return f"""## Research Scratchpad
**Question**: {self.question}
**Hypothesis**: {self.hypothesis or 'Not yet formed'}
**Evidence** ({len(self.evidence)} items):
{chr(10).join('- ' + e for e in self.evidence[-5:])}
**Gaps**: {', '.join(self.gaps) or 'None identified'}
**Confidence**: {self.confidence:.0%}
**Iteration**: {self.iteration}"""

    def should_continue(self) -> bool:
        return self.confidence < 0.8 and self.iteration < 5


# Demo
state = ResearchState(question="What's the best chunking strategy for legal documents?")
state.hypothesis = "Semantic chunking outperforms fixed-size for legal text"
state.add_evidence("Semantic chunking preserves clause boundaries", "legal-rag-paper")
state.add_evidence("Fixed 512-token chunks split mid-sentence 23% of the time", "benchmark-2024")
state.add_evidence("Recursive splitter with section headers works well", "langchain-docs")
state.add_gap("No comparison with document-structure-aware chunking")
state.iteration = 2

print(state.to_prompt())
print(f"\nShould continue? {state.should_continue()}")

6. Episodic Vector Recall

Store completed episodes (past research sessions) in a vector store for long-term recall.

from dataclasses import dataclass
from datetime import datetime
import hashlib
import json


@dataclass
class Episode:
    """A completed research episode that can be stored and recalled."""
    question: str
    answer: str
    evidence: list[str]
    confidence: float
    timestamp: str = ""

    def __post_init__(self):
        if not self.timestamp:
            self.timestamp = datetime.now().isoformat()

    def to_text(self) -> str:
        return f"Q: {self.question}\nA: {self.answer}\nEvidence: {'; '.join(self.evidence)}"

    def to_dict(self) -> dict:
        return {"question": self.question, "answer": self.answer,
                "evidence": self.evidence, "confidence": self.confidence,
                "timestamp": self.timestamp}


class EpisodicMemory:
    """Simple in-memory episodic store (replace with vector DB in production)."""

    def __init__(self):
        self.episodes: list[Episode] = []

    def store(self, episode: Episode):
        self.episodes.append(episode)
        print(f"📝 Stored episode: {episode.question[:60]}...")

    def recall(self, query: str, top_k: int = 3) -> list[Episode]:
        """Simple keyword matching (replace with embedding similarity)."""
        query_words = set(query.lower().split())
        scored = []
        for ep in self.episodes:
            ep_words = set(ep.to_text().lower().split())
            overlap = len(query_words & ep_words)
            scored.append((overlap, ep))
        scored.sort(key=lambda x: -x[0])
        return [ep for _, ep in scored[:top_k]]


# Demo
memory = EpisodicMemory()

memory.store(Episode(
    question="Best vector DB for production RAG?",
    answer="Pinecone for managed, Weaviate for self-hosted.",
    evidence=["Pinecone 99.9% SLA", "Weaviate supports hybrid search"],
    confidence=0.85,
))
memory.store(Episode(
    question="How to chunk legal documents?",
    answer="Use section-aware chunking with overlap.",
    evidence=["Section headers as split points", "200-token overlap preserves context"],
    confidence=0.75,
))
memory.store(Episode(
    question="RAG vs fine-tuning for domain adaptation?",
    answer="RAG for factual recall, fine-tuning for style/format.",
    evidence=["RAG better for knowledge-intensive tasks", "Fine-tuning for consistent output format"],
    confidence=0.9,
))

# Recall
results = memory.recall("Which vector database should I use for RAG?")
print(f"\n🔍 Recalled {len(results)} episodes:")
for ep in results:
    print(f"  Q: {ep.question}")
    print(f"  A: {ep.answer}")
    print(f"  Confidence: {ep.confidence:.0%}\n")