Agentic RAG: When Retrieval Needs Reasoning

Building RAG agents that plan queries, route to tools, self-reflect, and iteratively refine answers with LangGraph and LlamaIndex Workflows

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. Why Standard RAG Breaks Down
  3. Define the Agent State
  4. Build the Retriever
  5. Define Graph Nodes
  6. Define Conditional Edges
  7. Build and Run the Agentic RAG Graph
  8. Adding Tool Routing

1. Setup & Installation

!pip install -q langchain langchain-openai langchain-community langgraph langchain-text-splitters faiss-cpu tavily-python pydantic
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"  # Uncomment and set
# os.environ["TAVILY_API_KEY"] = "your-tavily-key-here"  # Uncomment for web search

2. Why Standard RAG Breaks Down

Standard RAG follows a fixed linear pipeline: embed → retrieve → generate. No feedback loop, no self-correction.

# Standard RAG failure modes
failure_modes = {
    "Irrelevant retrieval": "Query about 'Python decorators' retrieves docs about 'Python snakes'",
    "Incomplete retrieval": "Multi-part question, only first part answered",
    "Wrong data source": "Question needs SQL data but system only searches vector store",
    "Stale data": "'What's the latest version?' retrieves outdated chunk",
    "Hallucination from noise": "LLM generates confidently from marginally relevant chunks",
}

print("Standard RAG Failure Modes:")
print("=" * 60)
for mode, example in failure_modes.items():
    print(f"\n{mode}:")
    print(f"  Example: {example}")

3. Define the Agent State

The state flows through the agentic RAG graph, carrying the question, documents, generation, and control flags.

from typing import Literal
from typing_extensions import TypedDict


class AgentState(TypedDict):
    """State that flows through the agentic RAG graph."""
    question: str
    documents: list[str]
    generation: str
    web_search_needed: bool
    retry_count: int


print("AgentState fields:", list(AgentState.__annotations__.keys()))

4. Build the Retriever

Create a FAISS vector store with sample documents for retrieval.

from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document

# Sample documents
sample_docs = [
    Document(page_content="RLHF (Reinforcement Learning from Human Feedback) is a technique for aligning LLMs with human preferences. It involves training a reward model on human comparisons, then using PPO to optimize the language model against that reward."),
    Document(page_content="DPO (Direct Preference Optimization) simplifies RLHF by directly optimizing the language model on preference data without needing a separate reward model. DPO reformulates the RLHF objective as a classification loss."),
    Document(page_content="vLLM is a high-throughput inference engine for LLMs that uses PagedAttention to efficiently manage KV cache memory. It achieves 2-4x higher throughput compared to HuggingFace Transformers."),
    Document(page_content="Ollama is a tool for running LLMs locally. It supports models like Llama, Mistral, and Gemma. Ollama wraps llama.cpp and provides a simple API for local inference."),
    Document(page_content="The attention mechanism allows transformers to weigh the importance of different tokens. Self-attention computes queries, keys, and values for each position in the sequence."),
]

# Build vector store
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
chunks = splitter.split_documents(sample_docs)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})

print(f"Indexed {len(chunks)} chunks")

# Test retrieval
test_results = retriever.invoke("What are the differences between RLHF and DPO?")
print(f"Retrieved {len(test_results)} documents for test query")

5. Define Graph Nodes

Each node is a function that processes the agent state: retrieve, grade documents, rewrite query, web search, and generate.

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


# --- Node: Retrieve ---
def retrieve(state: AgentState) -> AgentState:
    """Retrieve documents from vector store."""
    question = state["question"]
    documents = retriever.invoke(question)
    return {
        **state,
        "documents": [doc.page_content for doc in documents],
    }


# --- Node: Grade Documents ---
class RelevanceGrade(BaseModel):
    """Binary relevance grade for a retrieved document."""
    is_relevant: bool = Field(
        description="Whether the document is relevant to the question"
    )

grader_llm = llm.with_structured_output(RelevanceGrade)

GRADE_PROMPT = ChatPromptTemplate.from_messages([
    ("system", "You are a grader assessing whether a retrieved document "
     "is relevant to a user question. Answer with is_relevant=true or false."),
    ("human", "Document:\n{document}\n\nQuestion: {question}"),
])

grader_chain = GRADE_PROMPT | grader_llm


def grade_documents(state: AgentState) -> AgentState:
    """Grade each retrieved document for relevance."""
    question = state["question"]
    documents = state["documents"]

    relevant_docs = []
    for doc in documents:
        grade = grader_chain.invoke(
            {"document": doc, "question": question}
        )
        if grade.is_relevant:
            relevant_docs.append(doc)
        print(f"  Doc relevant: {grade.is_relevant}{doc[:60]}...")

    return {
        **state,
        "documents": relevant_docs,
        "web_search_needed": len(relevant_docs) == 0,
    }


print("Nodes defined: retrieve, grade_documents")
# --- Node: Rewrite Query ---
REWRITE_PROMPT = ChatPromptTemplate.from_messages([
    ("system", "You are a query rewriter. Given a question that did not "
     "retrieve good results, rewrite it to be more specific and "
     "search-friendly. Return only the rewritten question."),
    ("human", "Original question: {question}"),
])

rewrite_chain = REWRITE_PROMPT | llm | StrOutputParser()


def rewrite_query(state: AgentState) -> AgentState:
    """Rewrite the query for better retrieval."""
    new_question = rewrite_chain.invoke(
        {"question": state["question"]}
    )
    print(f"  Rewritten: {new_question}")
    return {
        **state,
        "question": new_question,
        "retry_count": state.get("retry_count", 0) + 1,
    }


# --- Node: Web Search (placeholder) ---
def web_search(state: AgentState) -> AgentState:
    """Supplement retrieval with web search results."""
    # In production, use TavilySearchResults:
    # from langchain_community.tools.tavily_search import TavilySearchResults
    # web_search_tool = TavilySearchResults(max_results=3)
    # results = web_search_tool.invoke({"query": state["question"]})
    print("  Web search triggered (placeholder)")
    web_docs = [f"Web result for: {state['question']}"]
    return {
        **state,
        "documents": state["documents"] + web_docs,
    }


# --- Node: Generate ---
RAG_PROMPT = ChatPromptTemplate.from_messages([
    ("system", "You are an assistant answering questions based on "
     "provided context. Answer only from the context. If the context "
     "is insufficient, say so."),
    ("human", "Context:\n{context}\n\nQuestion: {question}"),
])

generate_chain = RAG_PROMPT | llm | StrOutputParser()


def generate(state: AgentState) -> AgentState:
    """Generate an answer from retrieved documents."""
    context = "\n\n".join(state["documents"])
    generation = generate_chain.invoke(
        {"context": context, "question": state["question"]}
    )
    return {**state, "generation": generation}


print("Nodes defined: rewrite_query, web_search, generate")

6. Define Conditional Edges

Conditional edges route the graph based on document relevance, hallucination checks, and answer quality.

class GradeHallucination(BaseModel):
    """Check if generation is grounded in documents."""
    is_grounded: bool = Field(
        description="Whether the answer is grounded in the provided documents"
    )


class GradeAnswer(BaseModel):
    """Check if generation answers the question."""
    answers_question: bool = Field(
        description="Whether the answer addresses the user's question"
    )


hallucination_grader = llm.with_structured_output(GradeHallucination)
answer_grader = llm.with_structured_output(GradeAnswer)


def should_search_web(state: AgentState) -> Literal["web_search", "generate"]:
    """Route based on document relevance."""
    if state["web_search_needed"]:
        print("  → Routing to web_search")
        return "web_search"
    print("  → Routing to generate")
    return "generate"


def check_generation(state: AgentState) -> Literal["end", "rewrite_query", "generate"]:
    """Check if generation is grounded and answers the question."""
    if state.get("retry_count", 0) >= 3:
        print("  Max retries reached → end")
        return "end"

    context = "\n\n".join(state["documents"])
    hallucination = hallucination_grader.invoke(
        {"messages": [
            {"role": "system", "content": "Check if the answer is grounded "
             "in the provided documents."},
            {"role": "human", "content": f"Documents:\n{context}\n\n"
             f"Answer: {state['generation']}"},
        ]}
    )
    if not hallucination.is_grounded:
        print("  Not grounded → re-generate")
        return "generate"

    answer_check = answer_grader.invoke(
        {"messages": [
            {"role": "system", "content": "Check if the answer addresses "
             "the user's question."},
            {"role": "human", "content": f"Question: {state['question']}\n\n"
             f"Answer: {state['generation']}"},
        ]}
    )
    if not answer_check.answers_question:
        print("  Does not answer → rewrite query")
        return "rewrite_query"

    print("  Answer is good → end")
    return "end"


print("Conditional edges defined: should_search_web, check_generation")

7. Build and Run the Agentic RAG Graph

Assemble the state graph with nodes, edges, and conditional routing.

from langgraph.graph import StateGraph, END

# Build the graph
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("web_search", web_search)
workflow.add_node("generate", generate)

# Set entry point
workflow.set_entry_point("retrieve")

# Add edges
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents",
    should_search_web,
    {"web_search": "web_search", "generate": "generate"},
)
workflow.add_edge("web_search", "generate")
workflow.add_conditional_edges(
    "generate",
    check_generation,
    {"end": END, "rewrite_query": "rewrite_query", "generate": "generate"},
)
workflow.add_edge("rewrite_query", "retrieve")

# Compile
app = workflow.compile()
print("Agentic RAG graph compiled successfully")
# Run the agentic RAG pipeline
result = app.invoke({
    "question": "What are the key differences between RLHF and DPO?",
    "documents": [],
    "generation": "",
    "web_search_needed": False,
    "retry_count": 0,
})

print("\n" + "=" * 60)
print("FINAL ANSWER:")
print("=" * 60)
print(result["generation"])

8. Adding Tool Routing

Extend the graph to route queries to different tools based on query type.

class RouteQuery(BaseModel):
    """Route a query to the most appropriate data source."""
    source: Literal["vectorstore", "web_search", "direct"] = Field(
        description="The data source to route the query to"
    )

router_llm = llm.with_structured_output(RouteQuery)

ROUTE_PROMPT = ChatPromptTemplate.from_messages([
    ("system",
     "You are a query router. Route the query to the best data source:\n"
     "- vectorstore: for questions about specific technical topics in our docs\n"
     "- web_search: for questions about recent events or general knowledge\n"
     "- direct: for simple questions the LLM can answer without retrieval"),
    ("human", "{question}"),
])

route_chain = ROUTE_PROMPT | router_llm

# Test routing
queries = [
    "What are the differences between RLHF and DPO?",
    "What is the capital of France?",
    "What happened in the news today?",
]

for q in queries:
    route = route_chain.invoke({"question": q})
    print(f"Query: {q}")
    print(f"  → Route: {route.source}\n")