!pip install -q langchain-openai langgraph langchain-corePlanning and Query Decomposition for Complex Retrieval
Plan-and-execute, sub-question decomposition, and iterative retrieval-and-reasoning
Table of Contents
- Setup
- Plan-and-Execute Pattern
- Sub-Question Decomposition
- Parallel Retrieval
- Iterative Retrieval-then-Reason (IRCoT)
import os
# os.environ["OPENAI_API_KEY"] = "your-key"2. Plan-and-Execute Pattern
The agent first creates a plan of sub-steps, then executes each step sequentially, and re-plans if needed.
from typing import TypedDict, Annotated, Literal
from langgraph.graph import StateGraph, END, START
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
import json
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)
class PlanExecuteState(TypedDict):
messages: Annotated[list, add_messages]
question: str
plan: list[str]
current_step: int
step_results: dict[str, str]
final_answer: str
def plan_step(state: PlanExecuteState) -> dict:
"""Create an ordered list of sub-tasks."""
response = llm.invoke([{
"role": "system",
"content": """Break the user question into 2-4 simple sub-questions that can be answered independently.
Return a JSON list of strings. Example: ["sub-q1", "sub-q2", "sub-q3"]""",
}, {"role": "user", "content": state["question"]}])
try:
plan = json.loads(response.content)
except json.JSONDecodeError:
plan = [state["question"]]
print(f"📋 Plan: {plan}")
return {"plan": plan, "current_step": 0, "step_results": {}}
def execute_step(state: PlanExecuteState) -> dict:
"""Execute the current sub-step."""
step = state["plan"][state["current_step"]]
context = "\n".join(f"- {k}: {v}" for k, v in state["step_results"].items())
response = llm.invoke([{
"role": "system",
"content": f"Answer this sub-question concisely.\nPrevious results:\n{context}",
}, {"role": "user", "content": step}])
updated = {**state["step_results"], step: response.content}
print(f" Step {state['current_step']}: {step[:60]}... → done")
return {"step_results": updated, "current_step": state["current_step"] + 1}
def should_continue(state: PlanExecuteState) -> Literal["execute", "compose"]:
return "execute" if state["current_step"] < len(state["plan"]) else "compose"
def compose_answer(state: PlanExecuteState) -> dict:
"""Synthesize sub-results into a final answer."""
parts = "\n\n".join(f"### {q}\n{a}" for q, a in state["step_results"].items())
response = llm.invoke([{
"role": "system",
"content": "Combine sub-answers into a coherent final answer.",
}, {"role": "user", "content": f"Original: {state['question']}\n\nSub-answers:\n{parts}"}])
return {"final_answer": response.content}
# Build graph
graph = StateGraph(PlanExecuteState)
graph.add_node("plan", plan_step)
graph.add_node("execute", execute_step)
graph.add_node("compose", compose_answer)
graph.add_edge(START, "plan")
graph.add_edge("plan", "execute")
graph.add_conditional_edges("execute", should_continue, {"execute": "execute", "compose": "compose"})
graph.add_edge("compose", END)
plan_execute_app = graph.compile()
print("✅ Plan-and-Execute graph compiled")result = plan_execute_app.invoke({
"messages": [],
"question": "Compare transformer and RNN architectures for sequence modeling, including training speed and memory use.",
"plan": [], "current_step": 0, "step_results": {}, "final_answer": "",
})
print("\n" + "="*60)
print(result["final_answer"])3. Sub-Question Decomposition
For quick decomposition without a full graph — just break and fan out.
import asyncio
async def decompose_and_retrieve(question: str) -> dict:
"""Decompose a complex question and retrieve answers in parallel."""
# 1. Decompose
response = llm.invoke([{
"role": "system",
"content": "Break this into 2-4 independent sub-questions. Return JSON list of strings.",
}, {"role": "user", "content": question}])
try:
sub_questions = json.loads(response.content)
except json.JSONDecodeError:
sub_questions = [question]
print(f"Sub-questions: {sub_questions}")
# 2. Retrieve in parallel
async def answer_sub(sq: str) -> tuple[str, str]:
resp = await llm.ainvoke([
{"role": "system", "content": "Answer this sub-question concisely."},
{"role": "user", "content": sq},
])
return sq, resp.content
results = await asyncio.gather(*[answer_sub(sq) for sq in sub_questions])
return dict(results)
# Run it
results = await decompose_and_retrieve(
"What are the key differences between FAISS and Pinecone for vector search?"
)
for q, a in results.items():
print(f"\n❓ {q}\n💬 {a[:200]}...")4. Parallel Retrieval
Fan-out multiple retrieval calls simultaneously for speed.
import time
import asyncio
async def mock_vector_search(query: str) -> str:
await asyncio.sleep(0.5) # Simulate latency
return f"[vector] Relevant chunk for: {query}"
async def mock_web_search(query: str) -> str:
await asyncio.sleep(0.8)
return f"[web] Recent results for: {query}"
async def mock_sql_query(query: str) -> str:
await asyncio.sleep(0.3)
return f"[sql] Data: 1234 rows for: {query}"
async def parallel_retrieve(query: str) -> dict:
"""Fan-out to vector, web, and SQL simultaneously."""
start = time.time()
results = await asyncio.gather(
mock_vector_search(query),
mock_web_search(query),
mock_sql_query(query),
)
elapsed = time.time() - start
print(f"⚡ All 3 sources retrieved in {elapsed:.2f}s (parallel)")
return {"vector": results[0], "web": results[1], "sql": results[2]}
results = await parallel_retrieve("transformer architecture performance")
for source, result in results.items():
print(f" {source}: {result}")5. Iterative Retrieval-then-Reason (IRCoT)
Interleave chain-of-thought with retrieval — each reasoning step can trigger a new retrieval round.
def iterative_retrieval_reason(question: str, max_rounds: int = 3) -> str:
"""IRCoT-style: interleave reasoning with retrieval."""
reasoning_chain = []
retrieved_so_far = []
for i in range(max_rounds):
# Reason step
context = "\n".join(retrieved_so_far) or "(no context yet)"
chain_so_far = "\n".join(reasoning_chain) or "(start)"
reason_resp = llm.invoke([{
"role": "system",
"content": """Think step by step. Based on your reasoning, decide if you need more info.
If YES, write SEARCH: <query>.
If NO, write ANSWER: <your answer>.""",
}, {
"role": "user",
"content": f"Question: {question}\nContext: {context}\nReasoning so far: {chain_so_far}",
}])
thought = reason_resp.content.strip()
reasoning_chain.append(f"Round {i+1}: {thought[:200]}")
print(f"\n🔄 Round {i+1}: {thought[:150]}...")
if thought.startswith("ANSWER:"):
return thought[7:].strip()
if "SEARCH:" in thought:
query = thought.split("SEARCH:")[-1].strip()
# Simulate retrieval
fake_result = f"[Retrieved for '{query}']: Relevant information about {query}."
retrieved_so_far.append(fake_result)
print(f" 📥 {fake_result}")
# Final synthesis if max rounds reached
return reasoning_chain[-1] if reasoning_chain else "Unable to answer."
answer = iterative_retrieval_reason("How does the attention mechanism in transformers differ from LSTM gating?")
print(f"\n✅ Final answer: {answer[:300]}...")