Planning and Query Decomposition for Complex Retrieval

Plan-and-execute, sub-question decomposition, and iterative retrieval-and-reasoning

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup
  2. Plan-and-Execute Pattern
  3. Sub-Question Decomposition
  4. Parallel Retrieval
  5. Iterative Retrieval-then-Reason (IRCoT)
!pip install -q langchain-openai langgraph langchain-core
import os
# os.environ["OPENAI_API_KEY"] = "your-key"

2. Plan-and-Execute Pattern

The agent first creates a plan of sub-steps, then executes each step sequentially, and re-plans if needed.

from typing import TypedDict, Annotated, Literal
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
import json

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)


class PlanExecuteState(TypedDict):
    messages: Annotated[list, add_messages]
    question: str
    plan: list[str]
    current_step: int
    step_results: dict[str, str]
    final_answer: str


def plan_step(state: PlanExecuteState) -> dict:
    """Create an ordered list of sub-tasks."""
    response = llm.invoke([{
        "role": "system",
        "content": """Break the user question into 2-4 simple sub-questions that can be answered independently.
Return a JSON list of strings. Example: ["sub-q1", "sub-q2", "sub-q3"]""",
    }, {"role": "user", "content": state["question"]}])

    try:
        plan = json.loads(response.content)
    except json.JSONDecodeError:
        plan = [state["question"]]

    print(f"📋 Plan: {plan}")
    return {"plan": plan, "current_step": 0, "step_results": {}}


def execute_step(state: PlanExecuteState) -> dict:
    """Execute the current sub-step."""
    step = state["plan"][state["current_step"]]
    context = "\n".join(f"- {k}: {v}" for k, v in state["step_results"].items())

    response = llm.invoke([{
        "role": "system",
        "content": f"Answer this sub-question concisely.\nPrevious results:\n{context}",
    }, {"role": "user", "content": step}])

    updated = {**state["step_results"], step: response.content}
    print(f"  Step {state['current_step']}: {step[:60]}... → done")
    return {"step_results": updated, "current_step": state["current_step"] + 1}


def should_continue(state: PlanExecuteState) -> Literal["execute", "compose"]:
    return "execute" if state["current_step"] < len(state["plan"]) else "compose"


def compose_answer(state: PlanExecuteState) -> dict:
    """Synthesize sub-results into a final answer."""
    parts = "\n\n".join(f"### {q}\n{a}" for q, a in state["step_results"].items())
    response = llm.invoke([{
        "role": "system",
        "content": "Combine sub-answers into a coherent final answer.",
    }, {"role": "user", "content": f"Original: {state['question']}\n\nSub-answers:\n{parts}"}])
    return {"final_answer": response.content}


# Build graph
graph = StateGraph(PlanExecuteState)
graph.add_node("plan", plan_step)
graph.add_node("execute", execute_step)
graph.add_node("compose", compose_answer)

graph.add_edge(START, "plan")
graph.add_edge("plan", "execute")
graph.add_conditional_edges("execute", should_continue, {"execute": "execute", "compose": "compose"})
graph.add_edge("compose", END)

plan_execute_app = graph.compile()
print("✅ Plan-and-Execute graph compiled")
result = plan_execute_app.invoke({
    "messages": [],
    "question": "Compare transformer and RNN architectures for sequence modeling, including training speed and memory use.",
    "plan": [], "current_step": 0, "step_results": {}, "final_answer": "",
})

print("\n" + "="*60)
print(result["final_answer"])

3. Sub-Question Decomposition

For quick decomposition without a full graph — just break and fan out.

import asyncio


async def decompose_and_retrieve(question: str) -> dict:
    """Decompose a complex question and retrieve answers in parallel."""
    # 1. Decompose
    response = llm.invoke([{
        "role": "system",
        "content": "Break this into 2-4 independent sub-questions. Return JSON list of strings.",
    }, {"role": "user", "content": question}])

    try:
        sub_questions = json.loads(response.content)
    except json.JSONDecodeError:
        sub_questions = [question]

    print(f"Sub-questions: {sub_questions}")

    # 2. Retrieve in parallel
    async def answer_sub(sq: str) -> tuple[str, str]:
        resp = await llm.ainvoke([
            {"role": "system", "content": "Answer this sub-question concisely."},
            {"role": "user", "content": sq},
        ])
        return sq, resp.content

    results = await asyncio.gather(*[answer_sub(sq) for sq in sub_questions])
    return dict(results)


# Run it
results = await decompose_and_retrieve(
    "What are the key differences between FAISS and Pinecone for vector search?"
)
for q, a in results.items():
    print(f"\n{q}\n💬 {a[:200]}...")

4. Parallel Retrieval

Fan-out multiple retrieval calls simultaneously for speed.

import time
import asyncio


async def mock_vector_search(query: str) -> str:
    await asyncio.sleep(0.5)  # Simulate latency
    return f"[vector] Relevant chunk for: {query}"


async def mock_web_search(query: str) -> str:
    await asyncio.sleep(0.8)
    return f"[web] Recent results for: {query}"


async def mock_sql_query(query: str) -> str:
    await asyncio.sleep(0.3)
    return f"[sql] Data: 1234 rows for: {query}"


async def parallel_retrieve(query: str) -> dict:
    """Fan-out to vector, web, and SQL simultaneously."""
    start = time.time()
    results = await asyncio.gather(
        mock_vector_search(query),
        mock_web_search(query),
        mock_sql_query(query),
    )
    elapsed = time.time() - start
    print(f"⚡ All 3 sources retrieved in {elapsed:.2f}s (parallel)")
    return {"vector": results[0], "web": results[1], "sql": results[2]}


results = await parallel_retrieve("transformer architecture performance")
for source, result in results.items():
    print(f"  {source}: {result}")

5. Iterative Retrieval-then-Reason (IRCoT)

Interleave chain-of-thought with retrieval — each reasoning step can trigger a new retrieval round.

def iterative_retrieval_reason(question: str, max_rounds: int = 3) -> str:
    """IRCoT-style: interleave reasoning with retrieval."""
    reasoning_chain = []
    retrieved_so_far = []

    for i in range(max_rounds):
        # Reason step
        context = "\n".join(retrieved_so_far) or "(no context yet)"
        chain_so_far = "\n".join(reasoning_chain) or "(start)"

        reason_resp = llm.invoke([{
            "role": "system",
            "content": """Think step by step. Based on your reasoning, decide if you need more info.
If YES, write SEARCH: <query>.
If NO, write ANSWER: <your answer>.""",
        }, {
            "role": "user",
            "content": f"Question: {question}\nContext: {context}\nReasoning so far: {chain_so_far}",
        }])

        thought = reason_resp.content.strip()
        reasoning_chain.append(f"Round {i+1}: {thought[:200]}")
        print(f"\n🔄 Round {i+1}: {thought[:150]}...")

        if thought.startswith("ANSWER:"):
            return thought[7:].strip()

        if "SEARCH:" in thought:
            query = thought.split("SEARCH:")[-1].strip()
            # Simulate retrieval
            fake_result = f"[Retrieved for '{query}']: Relevant information about {query}."
            retrieved_so_far.append(fake_result)
            print(f"  📥 {fake_result}")

    # Final synthesis if max rounds reached
    return reasoning_chain[-1] if reasoning_chain else "Unable to answer."


answer = iterative_retrieval_reason("How does the attention mechanism in transformers differ from LSTM gating?")
print(f"\n✅ Final answer: {answer[:300]}...")