!pip install -q langchain langchain-openai langchain-community langgraph langchain-text-splitters faiss-cpu tavily-python pydanticAgentic RAG: When Retrieval Needs Reasoning
Building RAG agents that plan queries, route to tools, self-reflect, and iteratively refine answers with LangGraph and LlamaIndex Workflows
Table of Contents
1. Setup & Installation
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here" # Uncomment and set
# os.environ["TAVILY_API_KEY"] = "your-tavily-key-here" # Uncomment for web search2. Why Standard RAG Breaks Down
Standard RAG follows a fixed linear pipeline: embed → retrieve → generate. No feedback loop, no self-correction.
# Standard RAG failure modes
failure_modes = {
"Irrelevant retrieval": "Query about 'Python decorators' retrieves docs about 'Python snakes'",
"Incomplete retrieval": "Multi-part question, only first part answered",
"Wrong data source": "Question needs SQL data but system only searches vector store",
"Stale data": "'What's the latest version?' retrieves outdated chunk",
"Hallucination from noise": "LLM generates confidently from marginally relevant chunks",
}
print("Standard RAG Failure Modes:")
print("=" * 60)
for mode, example in failure_modes.items():
print(f"\n{mode}:")
print(f" Example: {example}")3. Define the Agent State
The state flows through the agentic RAG graph, carrying the question, documents, generation, and control flags.
from typing import Literal
from typing_extensions import TypedDict
class AgentState(TypedDict):
"""State that flows through the agentic RAG graph."""
question: str
documents: list[str]
generation: str
web_search_needed: bool
retry_count: int
print("AgentState fields:", list(AgentState.__annotations__.keys()))4. Build the Retriever
Create a FAISS vector store with sample documents for retrieval.
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
# Sample documents
sample_docs = [
Document(page_content="RLHF (Reinforcement Learning from Human Feedback) is a technique for aligning LLMs with human preferences. It involves training a reward model on human comparisons, then using PPO to optimize the language model against that reward."),
Document(page_content="DPO (Direct Preference Optimization) simplifies RLHF by directly optimizing the language model on preference data without needing a separate reward model. DPO reformulates the RLHF objective as a classification loss."),
Document(page_content="vLLM is a high-throughput inference engine for LLMs that uses PagedAttention to efficiently manage KV cache memory. It achieves 2-4x higher throughput compared to HuggingFace Transformers."),
Document(page_content="Ollama is a tool for running LLMs locally. It supports models like Llama, Mistral, and Gemma. Ollama wraps llama.cpp and provides a simple API for local inference."),
Document(page_content="The attention mechanism allows transformers to weigh the importance of different tokens. Self-attention computes queries, keys, and values for each position in the sequence."),
]
# Build vector store
splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
chunks = splitter.split_documents(sample_docs)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_documents(chunks, embeddings)
retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
print(f"Indexed {len(chunks)} chunks")
# Test retrieval
test_results = retriever.invoke("What are the differences between RLHF and DPO?")
print(f"Retrieved {len(test_results)} documents for test query")5. Define Graph Nodes
Each node is a function that processes the agent state: retrieve, grade documents, rewrite query, web search, and generate.
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# --- Node: Retrieve ---
def retrieve(state: AgentState) -> AgentState:
"""Retrieve documents from vector store."""
question = state["question"]
documents = retriever.invoke(question)
return {
**state,
"documents": [doc.page_content for doc in documents],
}
# --- Node: Grade Documents ---
class RelevanceGrade(BaseModel):
"""Binary relevance grade for a retrieved document."""
is_relevant: bool = Field(
description="Whether the document is relevant to the question"
)
grader_llm = llm.with_structured_output(RelevanceGrade)
GRADE_PROMPT = ChatPromptTemplate.from_messages([
("system", "You are a grader assessing whether a retrieved document "
"is relevant to a user question. Answer with is_relevant=true or false."),
("human", "Document:\n{document}\n\nQuestion: {question}"),
])
grader_chain = GRADE_PROMPT | grader_llm
def grade_documents(state: AgentState) -> AgentState:
"""Grade each retrieved document for relevance."""
question = state["question"]
documents = state["documents"]
relevant_docs = []
for doc in documents:
grade = grader_chain.invoke(
{"document": doc, "question": question}
)
if grade.is_relevant:
relevant_docs.append(doc)
print(f" Doc relevant: {grade.is_relevant} — {doc[:60]}...")
return {
**state,
"documents": relevant_docs,
"web_search_needed": len(relevant_docs) == 0,
}
print("Nodes defined: retrieve, grade_documents")# --- Node: Rewrite Query ---
REWRITE_PROMPT = ChatPromptTemplate.from_messages([
("system", "You are a query rewriter. Given a question that did not "
"retrieve good results, rewrite it to be more specific and "
"search-friendly. Return only the rewritten question."),
("human", "Original question: {question}"),
])
rewrite_chain = REWRITE_PROMPT | llm | StrOutputParser()
def rewrite_query(state: AgentState) -> AgentState:
"""Rewrite the query for better retrieval."""
new_question = rewrite_chain.invoke(
{"question": state["question"]}
)
print(f" Rewritten: {new_question}")
return {
**state,
"question": new_question,
"retry_count": state.get("retry_count", 0) + 1,
}
# --- Node: Web Search (placeholder) ---
def web_search(state: AgentState) -> AgentState:
"""Supplement retrieval with web search results."""
# In production, use TavilySearchResults:
# from langchain_community.tools.tavily_search import TavilySearchResults
# web_search_tool = TavilySearchResults(max_results=3)
# results = web_search_tool.invoke({"query": state["question"]})
print(" Web search triggered (placeholder)")
web_docs = [f"Web result for: {state['question']}"]
return {
**state,
"documents": state["documents"] + web_docs,
}
# --- Node: Generate ---
RAG_PROMPT = ChatPromptTemplate.from_messages([
("system", "You are an assistant answering questions based on "
"provided context. Answer only from the context. If the context "
"is insufficient, say so."),
("human", "Context:\n{context}\n\nQuestion: {question}"),
])
generate_chain = RAG_PROMPT | llm | StrOutputParser()
def generate(state: AgentState) -> AgentState:
"""Generate an answer from retrieved documents."""
context = "\n\n".join(state["documents"])
generation = generate_chain.invoke(
{"context": context, "question": state["question"]}
)
return {**state, "generation": generation}
print("Nodes defined: rewrite_query, web_search, generate")6. Define Conditional Edges
Conditional edges route the graph based on document relevance, hallucination checks, and answer quality.
class GradeHallucination(BaseModel):
"""Check if generation is grounded in documents."""
is_grounded: bool = Field(
description="Whether the answer is grounded in the provided documents"
)
class GradeAnswer(BaseModel):
"""Check if generation answers the question."""
answers_question: bool = Field(
description="Whether the answer addresses the user's question"
)
hallucination_grader = llm.with_structured_output(GradeHallucination)
answer_grader = llm.with_structured_output(GradeAnswer)
def should_search_web(state: AgentState) -> Literal["web_search", "generate"]:
"""Route based on document relevance."""
if state["web_search_needed"]:
print(" → Routing to web_search")
return "web_search"
print(" → Routing to generate")
return "generate"
def check_generation(state: AgentState) -> Literal["end", "rewrite_query", "generate"]:
"""Check if generation is grounded and answers the question."""
if state.get("retry_count", 0) >= 3:
print(" Max retries reached → end")
return "end"
context = "\n\n".join(state["documents"])
hallucination = hallucination_grader.invoke(
{"messages": [
{"role": "system", "content": "Check if the answer is grounded "
"in the provided documents."},
{"role": "human", "content": f"Documents:\n{context}\n\n"
f"Answer: {state['generation']}"},
]}
)
if not hallucination.is_grounded:
print(" Not grounded → re-generate")
return "generate"
answer_check = answer_grader.invoke(
{"messages": [
{"role": "system", "content": "Check if the answer addresses "
"the user's question."},
{"role": "human", "content": f"Question: {state['question']}\n\n"
f"Answer: {state['generation']}"},
]}
)
if not answer_check.answers_question:
print(" Does not answer → rewrite query")
return "rewrite_query"
print(" Answer is good → end")
return "end"
print("Conditional edges defined: should_search_web, check_generation")7. Build and Run the Agentic RAG Graph
Assemble the state graph with nodes, edges, and conditional routing.
from langgraph.graph import StateGraph, END
# Build the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("rewrite_query", rewrite_query)
workflow.add_node("web_search", web_search)
workflow.add_node("generate", generate)
# Set entry point
workflow.set_entry_point("retrieve")
# Add edges
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
"grade_documents",
should_search_web,
{"web_search": "web_search", "generate": "generate"},
)
workflow.add_edge("web_search", "generate")
workflow.add_conditional_edges(
"generate",
check_generation,
{"end": END, "rewrite_query": "rewrite_query", "generate": "generate"},
)
workflow.add_edge("rewrite_query", "retrieve")
# Compile
app = workflow.compile()
print("Agentic RAG graph compiled successfully")# Run the agentic RAG pipeline
result = app.invoke({
"question": "What are the key differences between RLHF and DPO?",
"documents": [],
"generation": "",
"web_search_needed": False,
"retry_count": 0,
})
print("\n" + "=" * 60)
print("FINAL ANSWER:")
print("=" * 60)
print(result["generation"])8. Adding Tool Routing
Extend the graph to route queries to different tools based on query type.
class RouteQuery(BaseModel):
"""Route a query to the most appropriate data source."""
source: Literal["vectorstore", "web_search", "direct"] = Field(
description="The data source to route the query to"
)
router_llm = llm.with_structured_output(RouteQuery)
ROUTE_PROMPT = ChatPromptTemplate.from_messages([
("system",
"You are a query router. Route the query to the best data source:\n"
"- vectorstore: for questions about specific technical topics in our docs\n"
"- web_search: for questions about recent events or general knowledge\n"
"- direct: for simple questions the LLM can answer without retrieval"),
("human", "{question}"),
])
route_chain = ROUTE_PROMPT | router_llm
# Test routing
queries = [
"What are the differences between RLHF and DPO?",
"What is the capital of France?",
"What happened in the news today?",
]
for q in queries:
route = route_chain.invoke({"question": q})
print(f"Query: {q}")
print(f" → Route: {route.source}\n")