Deploying Retrieval Agents in Production

Durable execution, checkpointing, task queues, async workers, streaming, and cost monitoring

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup
  2. Checkpoint Pattern (Durable Execution)
  3. Event Sourcing for Agent Runs
  4. Async Worker Pool
  5. SSE Streaming for Agent Steps
  6. Cost Monitoring & Observability
!pip install -q langchain-openai
import os
# os.environ["OPENAI_API_KEY"] = "your-key"

2. Checkpoint Pattern (Durable Execution)

Save agent state after each step so runs can be resumed after crashes.

import json
import time
from dataclasses import dataclass, field, asdict
from typing import Optional
from pathlib import Path


@dataclass
class AgentCheckpoint:
    """Serializable agent state for crash recovery."""
    run_id: str
    step: int
    status: str  # "running", "completed", "failed", "paused"
    messages: list = field(default_factory=list)
    tool_results: dict = field(default_factory=dict)
    metadata: dict = field(default_factory=dict)
    timestamp: str = ""

    def __post_init__(self):
        if not self.timestamp:
            from datetime import datetime
            self.timestamp = datetime.now().isoformat()


class CheckpointStore:
    """Simple file-based checkpoint store (use Redis/DB in production)."""

    def __init__(self, base_dir: str = "/tmp/agent_checkpoints"):
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(parents=True, exist_ok=True)

    def save(self, checkpoint: AgentCheckpoint):
        path = self.base_dir / f"{checkpoint.run_id}.json"
        with open(path, "w") as f:
            json.dump(asdict(checkpoint), f, indent=2)
        print(f"  💾 Saved checkpoint: step={checkpoint.step}, status={checkpoint.status}")

    def load(self, run_id: str) -> Optional[AgentCheckpoint]:
        path = self.base_dir / f"{run_id}.json"
        if path.exists():
            with open(path) as f:
                data = json.load(f)
            return AgentCheckpoint(**data)
        return None

    def delete(self, run_id: str):
        path = self.base_dir / f"{run_id}.json"
        if path.exists():
            path.unlink()


def run_durable_agent(run_id: str, question: str, store: CheckpointStore) -> str:
    """Agent with checkpoint-based crash recovery."""
    # Try to resume from checkpoint
    checkpoint = store.load(run_id)
    if checkpoint and checkpoint.status == "running":
        start_step = checkpoint.step
        messages = checkpoint.messages
        print(f"♻️ Resuming from step {start_step}")
    else:
        start_step = 0
        messages = [{"role": "user", "content": question}]
        checkpoint = AgentCheckpoint(run_id=run_id, step=0, status="running", messages=messages)

    steps = ["plan", "search", "analyze", "synthesize", "respond"]

    for i in range(start_step, len(steps)):
        step_name = steps[i]
        print(f"\n📍 Step {i+1}/{len(steps)}: {step_name}")

        # Simulate step execution
        result = f"Result of {step_name} for: {question[:30]}..."
        messages.append({"role": "assistant", "content": f"[{step_name}] {result}"})

        # Save checkpoint after each step
        checkpoint.step = i + 1
        checkpoint.messages = messages
        checkpoint.tool_results[step_name] = result
        store.save(checkpoint)

        time.sleep(0.1)  # Simulate work

    checkpoint.status = "completed"
    store.save(checkpoint)
    return messages[-1]["content"]


# Demo
store = CheckpointStore()
result = run_durable_agent("run-001", "Best practices for production RAG?", store)
print(f"\n✅ Final: {result}")

# Show we can load the completed checkpoint
loaded = store.load("run-001")
print(f"\n📂 Loaded checkpoint: step={loaded.step}, status={loaded.status}")

3. Event Sourcing for Agent Runs

Log every action as an immutable event for full audit trail and replay.

from dataclasses import dataclass
from datetime import datetime
from typing import Any


@dataclass
class AgentEvent:
    """Immutable event in an agent's execution."""
    run_id: str
    event_type: str  # "step_start", "tool_call", "tool_result", "llm_call", "error", "complete"
    data: dict
    timestamp: str = ""

    def __post_init__(self):
        if not self.timestamp:
            self.timestamp = datetime.now().isoformat()


class EventStore:
    """Append-only event log."""

    def __init__(self):
        self.events: list[AgentEvent] = []

    def append(self, event: AgentEvent):
        self.events.append(event)

    def get_run(self, run_id: str) -> list[AgentEvent]:
        return [e for e in self.events if e.run_id == run_id]

    def replay(self, run_id: str):
        """Print full event timeline for a run."""
        events = self.get_run(run_id)
        print(f"\n📜 Event replay for {run_id} ({len(events)} events):")
        for e in events:
            print(f"  [{e.timestamp[-12:]}] {e.event_type:15s} | {str(e.data)[:60]}")


# Demo
event_store = EventStore()
run_id = "run-evt-001"

# Simulate an agent run with events
event_store.append(AgentEvent(run_id, "step_start", {"step": "plan", "input": "RAG best practices"}))
event_store.append(AgentEvent(run_id, "llm_call", {"model": "gpt-4o-mini", "tokens": 150}))
event_store.append(AgentEvent(run_id, "tool_call", {"tool": "search", "query": "RAG production"}))
event_store.append(AgentEvent(run_id, "tool_result", {"tool": "search", "docs": 5, "latency_ms": 230}))
event_store.append(AgentEvent(run_id, "llm_call", {"model": "gpt-4o-mini", "tokens": 450}))
event_store.append(AgentEvent(run_id, "complete", {"total_steps": 3, "total_tokens": 600}))

event_store.replay(run_id)

4. Async Worker Pool

Process multiple agent runs concurrently with controlled parallelism.

import asyncio
from dataclasses import dataclass


@dataclass
class AgentTask:
    task_id: str
    question: str
    priority: int = 0


class AgentWorkerPool:
    """Async worker pool for concurrent agent execution."""

    def __init__(self, max_workers: int = 3):
        self.max_workers = max_workers
        self.semaphore = asyncio.Semaphore(max_workers)
        self.results: dict[str, str] = {}

    async def process_task(self, task: AgentTask) -> str:
        async with self.semaphore:
            print(f"  🔧 Worker started: {task.task_id}")
            # Simulate agent work
            await asyncio.sleep(0.5 + len(task.question) * 0.01)
            result = f"Answer for '{task.question[:30]}...'"
            self.results[task.task_id] = result
            print(f"  ✅ Worker done: {task.task_id}")
            return result

    async def run_batch(self, tasks: list[AgentTask]) -> dict[str, str]:
        print(f"🚀 Processing {len(tasks)} tasks with {self.max_workers} workers")
        start = asyncio.get_event_loop().time()

        await asyncio.gather(*[self.process_task(t) for t in tasks])

        elapsed = asyncio.get_event_loop().time() - start
        print(f"\n⏱️ All {len(tasks)} tasks completed in {elapsed:.1f}s")
        return self.results


# Demo
pool = AgentWorkerPool(max_workers=3)

tasks = [
    AgentTask("t1", "What is RAG?"),
    AgentTask("t2", "Explain vector databases for production use"),
    AgentTask("t3", "How to evaluate retrieval agents?"),
    AgentTask("t4", "Best chunking strategies?"),
    AgentTask("t5", "Compare Pinecone vs Weaviate"),
]

results = await pool.run_batch(tasks)
for tid, result in results.items():
    print(f"  {tid}: {result}")

5. SSE Streaming for Agent Steps

Stream agent progress to clients in real-time (Server-Sent Events pattern).

import json
import asyncio
from typing import AsyncIterator


async def stream_agent_steps(question: str) -> AsyncIterator[str]:
    """Generate SSE-formatted events as agent executes."""
    steps = [
        {"type": "thinking", "content": "Analyzing the question..."},
        {"type": "tool_call", "content": "Searching knowledge base...", "tool": "vector_search"},
        {"type": "tool_result", "content": "Found 5 relevant documents"},
        {"type": "thinking", "content": "Synthesizing findings..."},
        {"type": "answer", "content": f"Based on my research, here's the answer to: {question}"},
    ]

    for step in steps:
        await asyncio.sleep(0.3)  # Simulate processing
        # SSE format: data: {json}\n\n
        event = f"data: {json.dumps(step)}\n\n"
        yield event

    yield "data: [DONE]\n\n"


# Demo: consume the stream
print("📡 Streaming agent steps:\n")
async for event in stream_agent_steps("How to deploy RAG in production?"):
    if event.strip() == "data: [DONE]":
        print("\n✅ Stream complete")
    else:
        data = json.loads(event.replace("data: ", "").strip())
        icon = {"thinking": "💭", "tool_call": "🔧", "tool_result": "📦", "answer": "💬"}.get(data["type"], "•")
        print(f"{icon} [{data['type']}] {data['content']}")

6. Cost Monitoring & Observability

Track tokens, latency, and cost per run for operational visibility.

from dataclasses import dataclass, field
from datetime import datetime
import time


@dataclass
class RunMetrics:
    """Metrics for a single agent run."""
    run_id: str
    start_time: float = 0.0
    end_time: float = 0.0
    input_tokens: int = 0
    output_tokens: int = 0
    llm_calls: int = 0
    tool_calls: int = 0
    errors: int = 0

    # Pricing (per 1K tokens)
    INPUT_PRICE = 0.00015
    OUTPUT_PRICE = 0.0006

    @property
    def duration_s(self) -> float:
        return self.end_time - self.start_time if self.end_time else 0

    @property
    def total_tokens(self) -> int:
        return self.input_tokens + self.output_tokens

    @property
    def cost_usd(self) -> float:
        return (self.input_tokens / 1000 * self.INPUT_PRICE +
                self.output_tokens / 1000 * self.OUTPUT_PRICE)

    def summary(self) -> str:
        return (f"Run {self.run_id}: {self.duration_s:.1f}s | "
                f"{self.total_tokens:,} tokens | ${self.cost_usd:.4f} | "
                f"{self.llm_calls} LLM + {self.tool_calls} tool calls | "
                f"{self.errors} errors")


class MetricsDashboard:
    """Aggregate metrics across runs."""

    def __init__(self):
        self.runs: list[RunMetrics] = []

    def add(self, run: RunMetrics):
        self.runs.append(run)

    def summary(self) -> str:
        if not self.runs:
            return "No runs recorded"

        total_cost = sum(r.cost_usd for r in self.runs)
        total_tokens = sum(r.total_tokens for r in self.runs)
        avg_duration = sum(r.duration_s for r in self.runs) / len(self.runs)
        error_rate = sum(1 for r in self.runs if r.errors > 0) / len(self.runs)

        return (f"\n📊 Dashboard ({len(self.runs)} runs):\n"
                f"  Total cost: ${total_cost:.4f}\n"
                f"  Total tokens: {total_tokens:,}\n"
                f"  Avg duration: {avg_duration:.1f}s\n"
                f"  Error rate: {error_rate:.0%}")


# Demo
dashboard = MetricsDashboard()

for i in range(5):
    m = RunMetrics(run_id=f"run-{i+1:03d}")
    m.start_time = time.time()
    m.input_tokens = 500 + i * 200
    m.output_tokens = 200 + i * 100
    m.llm_calls = 2 + i
    m.tool_calls = 1 + i
    m.errors = 1 if i == 3 else 0
    m.end_time = time.time() + 0.5 + i * 0.3
    dashboard.add(m)
    print(m.summary())

print(dashboard.summary())