Embedding Models and Reranking for RAG

Selecting, fine-tuning, and combining embedding models with cross-encoder rerankers for production retrieval pipelines

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. Bi-Encoders vs Cross-Encoders
  3. Embedding Models in Practice
  4. Building a Vector Store
  5. Cross-Encoder Reranking
  6. Retrieve-and-Rerank Pipeline
  7. Hybrid Search: Dense + Sparse
  8. Comparing Retrieval Strategies

1. Setup & Installation

!pip install -q langchain langchain-openai langchain-community sentence-transformers faiss-cpu rank-bm25 numpy
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"  # Uncomment and set

2. Bi-Encoders vs Cross-Encoders

  • Bi-Encoder: Encodes query and document independently into vectors. Fast, scalable.
  • Cross-Encoder: Encodes query+document jointly. Slower but more accurate.
  • Standard pattern: Bi-encoder retrieves top-k, cross-encoder reranks them.
import numpy as np

# Demonstrate bi-encoder: encode independently, compare with cosine similarity
from langchain_openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings(model="text-embedding-3-small")

query = "What is the attention mechanism?"
documents = [
    "The attention mechanism allows the model to focus on relevant parts of the input.",
    "Berlin is the capital of Germany.",
    "Transformers use self-attention to capture long-range dependencies.",
]

query_vec = embeddings.embed_query(query)
doc_vecs = embeddings.embed_documents(documents)

# Cosine similarity
def cosine_sim(a, b):
    a, b = np.array(a), np.array(b)
    return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))

print("Bi-Encoder Similarity Scores:")
for i, doc in enumerate(documents):
    score = cosine_sim(query_vec, doc_vecs[i])
    print(f"  {score:.4f}{doc[:60]}...")
# Cross-encoder: encode query+document jointly
from sentence_transformers import CrossEncoder

cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

pairs = [(query, doc) for doc in documents]
scores = cross_encoder.predict(pairs)

print("Cross-Encoder Relevance Scores:")
for i, doc in enumerate(documents):
    print(f"  {scores[i]:.4f}{doc[:60]}...")

3. Embedding Models in Practice

Using OpenAI and local embedding models with LangChain.

# OpenAI Embeddings
query_embedding = embeddings.embed_query("What is RLHF?")
print(f"OpenAI text-embedding-3-small dimensions: {len(query_embedding)}")
print(f"First 5 values: {query_embedding[:5]}")

# Batch embedding
batch_embeddings = embeddings.embed_documents(documents)
print(f"\nBatch embedded {len(batch_embeddings)} documents")
print(f"Each vector: {len(batch_embeddings[0])} dimensions")

4. Building a Vector Store

Create a FAISS index from sample documents for retrieval.

from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document

# Sample corpus
corpus = [
    Document(page_content="RLHF uses a reward model trained on human preferences to align language models via reinforcement learning with PPO."),
    Document(page_content="DPO eliminates the need for a reward model by directly optimizing the policy on preference data using a classification loss."),
    Document(page_content="The attention mechanism in transformers computes weighted sums of value vectors, using query-key dot products as weights."),
    Document(page_content="LoRA reduces fine-tuning costs by adding low-rank decomposition matrices to transformer layers, training only a fraction of parameters."),
    Document(page_content="Quantization compresses model weights from 16-bit to 4-bit integers, reducing memory by 4x with minimal quality loss."),
    Document(page_content="vLLM achieves high throughput for LLM serving through PagedAttention, which manages KV cache memory like virtual memory pages."),
    Document(page_content="RAG retrieves relevant documents from a knowledge base and provides them as context to the LLM for grounded generation."),
    Document(page_content="Chain-of-thought prompting improves LLM reasoning by instructing the model to show its step-by-step reasoning process."),
    Document(page_content="Constitutional AI trains models to be helpful, harmless, and honest by using a set of principles for self-improvement."),
    Document(page_content="Mixture of Experts (MoE) architectures route different tokens to different expert networks, allowing larger model capacity."),
]

vectorstore = FAISS.from_documents(corpus, embeddings)
print(f"Built FAISS index with {len(corpus)} documents")

# Similarity search
results = vectorstore.similarity_search_with_score("How does RLHF work?", k=5)
print("\nTop-5 results for 'How does RLHF work?':")
for doc, score in results:
    print(f"  {score:.4f}{doc.page_content[:80]}...")

5. Cross-Encoder Reranking

Use a cross-encoder to rerank the top-k results from the bi-encoder retrieval.

query = "How does RLHF work?"

# Retrieve top-10 with bi-encoder
bi_encoder_results = vectorstore.similarity_search(query, k=10)

# Rerank with cross-encoder
passages = [doc.page_content for doc in bi_encoder_results]
pairs = [(query, p) for p in passages]
ce_scores = cross_encoder.predict(pairs)

# Sort by cross-encoder score
reranked = sorted(zip(ce_scores, passages), key=lambda x: x[0], reverse=True)

print("Before reranking (bi-encoder order):")
for i, doc in enumerate(bi_encoder_results[:5]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

print("\nAfter reranking (cross-encoder order):")
for i, (score, text) in enumerate(reranked[:5]):
    print(f"  {i+1}. [{score:.2f}] {text[:80]}...")

6. Retrieve-and-Rerank Pipeline

Build a complete retrieve-wider, rerank-narrower pipeline.

def retrieve_and_rerank(query: str, vectorstore, cross_encoder, retrieve_k=10, rerank_k=3):
    """Retrieve top-k with bi-encoder, rerank with cross-encoder, return top-n."""
    # Stage 1: Retrieve candidates
    candidates = vectorstore.similarity_search(query, k=retrieve_k)
    
    # Stage 2: Rerank
    passages = [doc.page_content for doc in candidates]
    pairs = [(query, p) for p in passages]
    scores = cross_encoder.predict(pairs)
    
    # Sort and return top-n
    scored = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
    return [(score, doc) for score, doc in scored[:rerank_k]]


# Test the pipeline
queries = [
    "How does RLHF work?",
    "What is model quantization?",
    "How does RAG improve LLM answers?",
]

for q in queries:
    results = retrieve_and_rerank(q, vectorstore, cross_encoder)
    print(f"\nQuery: {q}")
    for score, doc in results:
        print(f"  [{score:.2f}] {doc.page_content[:80]}...")

7. Hybrid Search: Dense + Sparse

Combine dense (semantic) and sparse (BM25 keyword) retrieval with Reciprocal Rank Fusion.

from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever

# Sparse retriever (BM25)
bm25_retriever = BM25Retriever.from_documents(corpus, k=5)

# Dense retriever (FAISS)
dense_retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

# Hybrid: combine with Reciprocal Rank Fusion
hybrid_retriever = EnsembleRetriever(
    retrievers=[bm25_retriever, dense_retriever],
    weights=[0.3, 0.7],  # Dense-weighted
)

# Test: keyword-heavy query (BM25 excels)
keyword_query = "PPO reinforcement learning reward model"
hybrid_results = hybrid_retriever.invoke(keyword_query)
print(f"Hybrid results for '{keyword_query}':")
for i, doc in enumerate(hybrid_results[:5]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

# Test: semantic query (dense excels)
semantic_query = "How to make language models follow instructions better?"
hybrid_results = hybrid_retriever.invoke(semantic_query)
print(f"\nHybrid results for '{semantic_query}':")
for i, doc in enumerate(hybrid_results[:5]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

8. Comparing Retrieval Strategies

Run the same query through all strategies and compare.

test_query = "How does RLHF work?"

print(f"Query: '{test_query}'\n")

# Dense only
dense_results = dense_retriever.invoke(test_query)
print("Dense (Bi-Encoder) Only:")
for i, doc in enumerate(dense_results[:3]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

# Sparse only
sparse_results = bm25_retriever.invoke(test_query)
print("\nSparse (BM25) Only:")
for i, doc in enumerate(sparse_results[:3]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

# Hybrid
hybrid_results = hybrid_retriever.invoke(test_query)
print("\nHybrid (Dense + Sparse):")
for i, doc in enumerate(hybrid_results[:3]):
    print(f"  {i+1}. {doc.page_content[:80]}...")

# Hybrid + Rerank
hybrid_docs = hybrid_retriever.invoke(test_query)
passages = [doc.page_content for doc in hybrid_docs]
pairs = [(test_query, p) for p in passages]
ce_scores = cross_encoder.predict(pairs)
reranked = sorted(zip(ce_scores, hybrid_docs), key=lambda x: x[0], reverse=True)
print("\nHybrid + Rerank:")
for i, (score, doc) in enumerate(reranked[:3]):
    print(f"  {i+1}. [{score:.2f}] {doc.page_content[:80]}...")