Hybrid and Corrective RAG Architectures

Self-RAG, CRAG, Adaptive RAG, and query routing — building RAG systems that know when to retrieve, when to skip, and when to self-correct

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. The Problem with Standard RAG
  3. Self-RAG: Reflection Tokens
  4. Self-RAG Implementation with LangGraph
  5. CRAG: Corrective Retrieval Augmented Generation
  6. CRAG Implementation with LangGraph
  7. Adaptive RAG: Routing by Query Complexity
  8. Comparing Architectures

1. Setup & Installation

!pip install -q langchain langchain-openai langchain-community langgraph langchain-text-splitters faiss-cpu tavily-python
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"  # Uncomment and set
# os.environ["TAVILY_API_KEY"] = "your-tavily-key-here"  # Uncomment for web search

2. The Problem with Standard RAG

Standard RAG retrieves every time (even when unnecessary) and trusts every document (even irrelevant ones).

from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

# Sample knowledge base
docs = [
    Document(page_content="Metformin is the first-line treatment for type 2 diabetes. It reduces glucose production in the liver."),
    Document(page_content="Common side effects of metformin include nausea, diarrhea, and stomach pain."),
    Document(page_content="HbA1c levels below 7% indicate good diabetes control."),
    Document(page_content="Insulin therapy is needed when oral medications fail to control blood sugar."),
    Document(page_content="Regular exercise helps improve insulin sensitivity and blood glucose control."),
]

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
vectorstore = FAISS.from_documents(docs, embeddings)
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

# Standard RAG failure modes
failures = {
    "Unnecessary retrieval": "'What is 2+2?' → retrieves 5 documents (wasted latency)",
    "Irrelevant retrieval": "Query about quantum computing → retrieves cooking recipes",
    "Unfaithful generation": "Correct docs retrieved but LLM fabricates details",
}

print("Standard RAG Failure Modes:")
for mode, example in failures.items():
    print(f"  {mode}: {example}")

3. Self-RAG: Reflection Tokens

Self-RAG (Asai et al., 2023) teaches the LLM to decide: Should I retrieve? Are docs relevant? Is my answer grounded?

# Self-RAG reflection tokens
reflection_tokens = {
    "[Retrieve]": {
        "input": "Question (+ generation)",
        "output": "yes, no, continue",
        "purpose": "Decides whether to retrieve documents",
    },
    "[ISREL]": {
        "input": "Question + document",
        "output": "relevant, irrelevant",
        "purpose": "Grades document relevance",
    },
    "[ISSUP]": {
        "input": "Question + document + generation",
        "output": "fully supported, partial, no support",
        "purpose": "Checks if generation is grounded",
    },
    "[ISUSE]": {
        "input": "Question + generation",
        "output": "Score 1-5",
        "purpose": "Rates overall answer utility",
    },
}

print("Self-RAG Reflection Tokens:")
print("=" * 60)
for token, details in reflection_tokens.items():
    print(f"\n{token}:")
    for key, value in details.items():
        print(f"  {key}: {value}")

4. Self-RAG Implementation with LangGraph

Approximate Self-RAG logic using LLM-as-judge with LangGraph.

from typing import TypedDict, Literal
from langgraph.graph import StateGraph, END
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser


class RAGState(TypedDict):
    question: str
    documents: list[str]
    generation: str
    retries: int


# Node: Decide whether retrieval is needed
def route_question(state: RAGState) -> Literal["retrieve", "generate_direct"]:
    prompt = ChatPromptTemplate.from_template(
        "Does this question require external knowledge retrieval, "
        "or can it be answered from general knowledge?\n\n"
        "Question: {question}\n\n"
        "Answer with ONLY 'retrieve' or 'generate_direct'."
    )
    chain = prompt | llm | StrOutputParser()
    decision = chain.invoke({"question": state["question"]}).strip().lower()
    result = "retrieve" if "retrieve" in decision else "generate_direct"
    print(f"  Route decision: {result}")
    return result


def retrieve(state: RAGState) -> RAGState:
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
    docs = retriever.invoke(state["question"])
    print(f"  Retrieved {len(docs)} documents")
    return {**state, "documents": [d.page_content for d in docs]}


def grade_documents(state: RAGState) -> RAGState:
    prompt = ChatPromptTemplate.from_template(
        "Is this document relevant?\nQuestion: {question}\nDocument: {document}\n"
        "Answer ONLY 'relevant' or 'irrelevant'."
    )
    chain = prompt | llm | StrOutputParser()
    relevant = []
    for doc in state["documents"]:
        grade = chain.invoke({"question": state["question"], "document": doc})
        if "relevant" in grade.strip().lower() and "irrelevant" not in grade.strip().lower():
            relevant.append(doc)
    print(f"  {len(relevant)}/{len(state['documents'])} documents relevant")
    return {**state, "documents": relevant}


def generate(state: RAGState) -> RAGState:
    context = "\n\n".join(state["documents"]) if state["documents"] else "No context available."
    prompt = ChatPromptTemplate.from_template(
        "Answer using ONLY the context.\nContext:\n{context}\nQuestion: {question}"
    )
    chain = prompt | llm | StrOutputParser()
    answer = chain.invoke({"context": context, "question": state["question"]})
    return {**state, "generation": answer}


def generate_direct(state: RAGState) -> RAGState:
    prompt = ChatPromptTemplate.from_template("Answer concisely: {question}")
    chain = prompt | llm | StrOutputParser()
    return {**state, "generation": chain.invoke({"question": state["question"]})}


def check_hallucination(state: RAGState) -> Literal["supported", "not_supported"]:
    if not state["documents"] or state.get("retries", 0) >= 2:
        return "supported"
    prompt = ChatPromptTemplate.from_template(
        "Is this answer supported by the documents?\n"
        "Documents:\n{context}\nAnswer: {generation}\n"
        "Respond ONLY 'supported' or 'not_supported'."
    )
    chain = prompt | llm | StrOutputParser()
    result = chain.invoke({
        "context": "\n".join(state["documents"]),
        "generation": state["generation"],
    })
    decision = "supported" if "supported" in result.strip().lower() and "not" not in result.strip().lower() else "not_supported"
    print(f"  Hallucination check: {decision}")
    return decision


def route_after_grading(state: RAGState) -> Literal["generate", "generate_direct"]:
    return "generate" if state["documents"] else "generate_direct"


print("Self-RAG nodes defined")
# Build Self-RAG graph
workflow = StateGraph(RAGState)
workflow.add_node("retrieve", retrieve)
workflow.add_node("grade_documents", grade_documents)
workflow.add_node("generate", generate)
workflow.add_node("generate_direct", generate_direct)

workflow.set_conditional_entry_point(
    route_question,
    {"retrieve": "retrieve", "generate_direct": "generate_direct"},
)
workflow.add_edge("retrieve", "grade_documents")
workflow.add_conditional_edges(
    "grade_documents", route_after_grading,
    {"generate": "generate", "generate_direct": "generate_direct"},
)
workflow.add_conditional_edges(
    "generate", check_hallucination,
    {"supported": END, "not_supported": "generate"},
)
workflow.add_edge("generate_direct", END)

self_rag = workflow.compile()

# Test with domain question
print("\n--- Domain Question ---")
result = self_rag.invoke({"question": "What are the side effects of metformin?", "documents": [], "generation": "", "retries": 0})
print(f"\nAnswer: {result['generation']}")

# Test with general knowledge question
print("\n--- General Question ---")
result = self_rag.invoke({"question": "What is the capital of France?", "documents": [], "generation": "", "retries": 0})
print(f"\nAnswer: {result['generation']}")

5. CRAG: Corrective Retrieval Augmented Generation

CRAG (Yan et al., 2024) evaluates retrieval quality and triggers corrective actions: refine, supplement with web search, or replace entirely.

# CRAG three-action framework
crag_actions = {
    "Correct (high confidence)": {
        "action": "Knowledge Refinement",
        "description": "Decompose docs into knowledge strips, filter irrelevant ones",
    },
    "Ambiguous (medium confidence)": {
        "action": "Refine + Web Search",
        "description": "Keep refined local docs AND supplement with web search",
    },
    "Incorrect (low confidence)": {
        "action": "Web Search Only",
        "description": "Discard all retrieved docs, query the web",
    },
}

print("CRAG Action Framework:")
print("=" * 60)
for confidence, details in crag_actions.items():
    print(f"\n{confidence}:")
    print(f"  Action: {details['action']}")
    print(f"  Description: {details['description']}")

6. CRAG Implementation with LangGraph

class CRAGState(TypedDict):
    question: str
    documents: list[str]
    confidence: str
    generation: str


def crag_retrieve(state: CRAGState) -> CRAGState:
    retriever = vectorstore.as_retriever(search_kwargs={"k": 3})
    results = retriever.invoke(state["question"])
    return {**state, "documents": [d.page_content for d in results]}


def evaluate_retrieval(state: CRAGState) -> CRAGState:
    prompt = ChatPromptTemplate.from_template(
        "Rate retrieval quality as 'correct', 'ambiguous', or 'incorrect'.\n\n"
        "Question: {question}\nDocuments:\n{documents}\n\n"
        "Respond with ONLY one word."
    )
    chain = prompt | llm | StrOutputParser()
    confidence = chain.invoke({
        "question": state["question"],
        "documents": "\n---\n".join(state["documents"]),
    }).strip().lower()
    if confidence not in ("correct", "ambiguous", "incorrect"):
        confidence = "ambiguous"
    print(f"  Retrieval confidence: {confidence}")
    return {**state, "confidence": confidence}


def route_on_confidence(state: CRAGState) -> Literal["refine", "web_search", "web_only"]:
    c = state["confidence"]
    if c == "correct": return "refine"
    if c == "ambiguous": return "web_search"
    return "web_only"


def refine_knowledge(state: CRAGState) -> CRAGState:
    prompt = ChatPromptTemplate.from_template(
        "Extract ONLY relevant sentences for the question.\n"
        "Question: {question}\nDocuments:\n{documents}\nRelevant extracts:"
    )
    chain = prompt | llm | StrOutputParser()
    refined = chain.invoke({
        "question": state["question"],
        "documents": "\n---\n".join(state["documents"]),
    })
    print("  Knowledge refined")
    return {**state, "documents": [refined]}


def web_search_supplement(state: CRAGState) -> CRAGState:
    print("  Web search supplementing (placeholder)")
    web_result = f"Web search result for: {state['question']}"
    return {**state, "documents": state["documents"] + [web_result]}


def web_search_replace(state: CRAGState) -> CRAGState:
    print("  Web search replacing all docs (placeholder)")
    web_result = f"Web search result for: {state['question']}"
    return {**state, "documents": [web_result]}


def crag_generate(state: CRAGState) -> CRAGState:
    context = "\n\n".join(state["documents"])
    prompt = ChatPromptTemplate.from_template(
        "Answer using the context.\nContext:\n{context}\nQuestion: {question}"
    )
    chain = prompt | llm | StrOutputParser()
    return {**state, "generation": chain.invoke({"context": context, "question": state["question"]})}


# Build CRAG graph
crag_workflow = StateGraph(CRAGState)
crag_workflow.add_node("retrieve", crag_retrieve)
crag_workflow.add_node("evaluate", evaluate_retrieval)
crag_workflow.add_node("refine", refine_knowledge)
crag_workflow.add_node("web_search", web_search_supplement)
crag_workflow.add_node("web_only", web_search_replace)
crag_workflow.add_node("generate", crag_generate)

crag_workflow.set_entry_point("retrieve")
crag_workflow.add_edge("retrieve", "evaluate")
crag_workflow.add_conditional_edges(
    "evaluate", route_on_confidence,
    {"refine": "refine", "web_search": "web_search", "web_only": "web_only"},
)
crag_workflow.add_edge("refine", "generate")
crag_workflow.add_edge("web_search", "generate")
crag_workflow.add_edge("web_only", "generate")
crag_workflow.add_edge("generate", END)

crag_app = crag_workflow.compile()

# Test CRAG
result = crag_app.invoke({"question": "What are the side effects of metformin?", "documents": [], "confidence": "", "generation": ""})
print(f"\nAnswer: {result['generation']}")

7. Adaptive RAG: Routing by Query Complexity

Adaptive RAG (Jeong et al., 2024) classifies queries by complexity: simple (no retrieval), medium (single-step RAG), complex (multi-step).

from pydantic import BaseModel, Field


class ComplexityLevel(BaseModel):
    level: Literal["simple", "medium", "complex"] = Field(
        description="Query complexity: simple (LLM only), medium (single RAG), complex (multi-step RAG)"
    )
    reasoning: str = Field(description="Why this complexity level")


classifier = llm.with_structured_output(ComplexityLevel)
classify_prompt = ChatPromptTemplate.from_template(
    "Classify query complexity:\n"
    "- simple: general knowledge, no specialized retrieval needed\n"
    "- medium: needs single retrieval pass from documents\n"
    "- complex: needs multi-step retrieval, comparison, or reasoning\n\n"
    "Query: {query}"
)

classify_chain = classify_prompt | classifier

test_queries = [
    "What is 2 + 2?",
    "What are the side effects of metformin?",
    "Compare metformin and insulin therapy for type 2 diabetes, including side effects and when to switch",
]

print("Adaptive RAG Query Routing:")
print("=" * 60)
for q in test_queries:
    result = classify_chain.invoke({"query": q})
    print(f"\nQ: {q}")
    print(f"  Level: {result.level}")
    print(f"  Reason: {result.reasoning}")

8. Comparing Architectures

Summary of Self-RAG, CRAG, and Adaptive RAG.

comparison = {
    "Self-RAG": {
        "Key Idea": "LLM generates reflection tokens to control retrieval",
        "Retrieval Decision": "Adaptive (retrieve only when needed)",
        "Quality Control": "ISREL, ISSUP, ISUSE tokens",
        "Fallback": "Re-generate with different parameters",
        "Complexity": "Medium (requires training or LLM-as-judge)",
    },
    "CRAG": {
        "Key Idea": "Lightweight evaluator grades retrieval, triggers corrections",
        "Retrieval Decision": "Always retrieve, then evaluate",
        "Quality Control": "Correct/Ambiguous/Incorrect grading",
        "Fallback": "Web search supplement or replacement",
        "Complexity": "Low (plug-and-play evaluator)",
    },
    "Adaptive RAG": {
        "Key Idea": "Route queries to different strategies by complexity",
        "Retrieval Decision": "Classifier-based routing",
        "Quality Control": "Strategy selection prevents over/under-retrieval",
        "Fallback": "Escalate to more complex strategy",
        "Complexity": "Low (trained classifier or LLM router)",
    },
}

print("Architecture Comparison:")
print("=" * 70)
for arch, details in comparison.items():
    print(f"\n{arch}:")
    for key, value in details.items():
        print(f"  {key}: {value}")

print("\n" + "=" * 70)
print("Recommendation:")
print("  Start with CRAG (simplest to implement).")
print("  Add Adaptive RAG routing if latency matters.")
print("  Use Self-RAG for highest quality with full self-correction.")