Fine-tuning RAG Components: Embeddings, Retrievers, and Generators

Domain-adaptive training for each RAG stage — from contrastive embedding fine-tuning to retrieval-aware LLM training with RAFT

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. Generating Synthetic Training Data
  3. Mining Hard Negatives
  4. Fine-Tuning Embeddings with Sentence Transformers
  5. Evaluating Embedding Fine-Tuning
  6. Fine-Tuning a Cross-Encoder Reranker
  7. RAFT: Retrieval Augmented Fine Tuning
  8. Decision Guide: When to Fine-Tune

1. Setup & Installation

!pip install -q sentence-transformers datasets langchain langchain-openai langchain-community faiss-cpu transformers
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here"  # Uncomment and set

2. Generating Synthetic Training Data

Use an LLM to generate query-document pairs from your corpus for training.

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

# Sample domain documents
domain_passages = [
    "Metformin is a first-line medication for type 2 diabetes. It works by decreasing glucose production in the liver and increasing insulin sensitivity.",
    "Common side effects of metformin include nausea, diarrhea, and abdominal pain. Lactic acidosis is a rare but serious side effect.",
    "Insulin therapy is used when oral medications fail to control blood glucose. Types include rapid-acting, short-acting, intermediate, and long-acting.",
    "HbA1c measures average blood glucose over 2-3 months. A target of less than 7% is recommended for most adults with diabetes.",
    "Diabetic retinopathy is a complication of diabetes that affects the eyes. Regular eye exams are recommended for early detection.",
]

# Generate synthetic queries
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)
prompt = ChatPromptTemplate.from_template(
    "Given the following passage, generate {n} diverse questions "
    "that this passage would answer. Return only the questions, one per line.\n\n"
    "Passage: {passage}\n\nQuestions:"
)
chain = prompt | llm | StrOutputParser()

training_pairs = []
for passage in domain_passages:
    response = chain.invoke({"passage": passage, "n": 2})
    questions = [q.strip() for q in response.strip().split("\n") if q.strip()]
    for q in questions:
        training_pairs.append({"query": q, "document": passage})

print(f"Generated {len(training_pairs)} training pairs:\n")
for pair in training_pairs[:4]:
    print(f"Q: {pair['query']}")
    print(f"D: {pair['document'][:80]}...\n")

3. Mining Hard Negatives

Find documents that are similar but not relevant — critical for training embedding models.

from sentence_transformers import SentenceTransformer
import numpy as np

# Load a base embedding model
base_model = SentenceTransformer("all-MiniLM-L6-v2")

# Embed all passages
passage_embeddings = base_model.encode(domain_passages)

# For each query, find hard negatives
triplets = []
for pair in training_pairs:
    query_emb = base_model.encode(pair["query"])
    
    # Compute similarity to all passages
    similarities = np.dot(passage_embeddings, query_emb) / (
        np.linalg.norm(passage_embeddings, axis=1) * np.linalg.norm(query_emb)
    )
    
    # Find the positive document index
    pos_idx = domain_passages.index(pair["document"])
    
    # Hard negative: most similar passage that is NOT the positive
    sorted_indices = np.argsort(similarities)[::-1]
    for idx in sorted_indices:
        if idx != pos_idx:
            hard_neg = domain_passages[idx]
            triplets.append({
                "query": pair["query"],
                "positive": pair["document"],
                "negative": hard_neg,
            })
            break

print(f"Created {len(triplets)} triplets (query, positive, hard negative)\n")
print(f"Example:")
print(f"  Query: {triplets[0]['query']}")
print(f"  Positive: {triplets[0]['positive'][:80]}...")
print(f"  Negative: {triplets[0]['negative'][:80]}...")

4. Fine-Tuning Embeddings with Sentence Transformers

Contrastive learning with MultipleNegativesRankingLoss to adapt embeddings to your domain.

from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from datasets import Dataset

# Prepare training dataset
train_dataset = Dataset.from_dict({
    "anchor": [t["query"] for t in triplets],
    "positive": [t["positive"] for t in triplets],
})

print(f"Training dataset: {len(train_dataset)} examples")
print(f"Columns: {train_dataset.column_names}")
print(f"\nExample:")
print(f"  Anchor: {train_dataset[0]['anchor']}")
print(f"  Positive: {train_dataset[0]['positive'][:80]}...")
# Fine-tuning setup (run on GPU for best results)
model = SentenceTransformer("all-MiniLM-L6-v2")

# Contrastive loss: pulls positive pairs together, pushes negatives apart
loss = losses.MultipleNegativesRankingLoss(model)

# Training arguments
args = SentenceTransformerTrainingArguments(
    output_dir="models/domain-embedding-finetuned",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    logging_steps=10,
)

# Note: For a real fine-tuning run, uncomment the trainer.train() call
# trainer = SentenceTransformerTrainer(
#     model=model,
#     args=args,
#     train_dataset=train_dataset,
#     loss=loss,
# )
# trainer.train()
# model.save_pretrained("models/domain-embedding-finetuned/final")

print("Fine-tuning setup complete (uncomment trainer.train() to run)")
print(f"Model: {model.get_sentence_embedding_dimension()} dimensions")
print(f"Loss: MultipleNegativesRankingLoss")
print(f"Epochs: {args.num_train_epochs}, LR: {args.learning_rate}")

5. Evaluating Embedding Fine-Tuning

Measure hit rate: what fraction of queries retrieve the correct document in top-k?

def evaluate_hit_rate(model, queries, positives, corpus, top_k=3):
    """Compute hit rate: fraction of queries where correct doc is in top-k."""
    corpus_embeddings = model.encode(corpus)
    
    hits = 0
    for query, positive in zip(queries, positives):
        query_emb = model.encode(query)
        
        # Compute similarities
        similarities = np.dot(corpus_embeddings, query_emb) / (
            np.linalg.norm(corpus_embeddings, axis=1) * np.linalg.norm(query_emb)
        )
        
        # Get top-k indices
        top_k_indices = np.argsort(similarities)[-top_k:][::-1]
        top_k_docs = [corpus[i] for i in top_k_indices]
        
        if positive in top_k_docs:
            hits += 1
    
    return hits / len(queries)


# Evaluate base model
queries = [t["query"] for t in triplets]
positives = [t["positive"] for t in triplets]

hit_rate = evaluate_hit_rate(base_model, queries, positives, domain_passages, top_k=2)
print(f"Base model hit rate @2: {hit_rate:.2%}")

hit_rate_3 = evaluate_hit_rate(base_model, queries, positives, domain_passages, top_k=3)
print(f"Base model hit rate @3: {hit_rate_3:.2%}")

6. Fine-Tuning a Cross-Encoder Reranker

Train a cross-encoder to better distinguish relevant from irrelevant documents in your domain.

from sentence_transformers import CrossEncoder

# Load base cross-encoder
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

# Prepare labeled data for reranker fine-tuning
reranker_train_data = []
for t in triplets:
    # Positive pair (label=1)
    reranker_train_data.append({
        "query": t["query"],
        "passage": t["positive"],
        "label": 1.0,
    })
    # Negative pair (label=0)
    reranker_train_data.append({
        "query": t["query"],
        "passage": t["negative"],
        "label": 0.0,
    })

print(f"Reranker training data: {len(reranker_train_data)} examples")
print(f"  Positive pairs: {sum(1 for d in reranker_train_data if d['label'] == 1.0)}")
print(f"  Negative pairs: {sum(1 for d in reranker_train_data if d['label'] == 0.0)}")

# Test base reranker on domain data
test_query = "What are the side effects of metformin?"
test_pairs = [(test_query, p) for p in domain_passages]
scores = reranker.predict(test_pairs)

print(f"\nBase reranker scores for: '{test_query}'")
for score, passage in sorted(zip(scores, domain_passages), reverse=True):
    print(f"  {score:.2f}{passage[:80]}...")

7. RAFT: Retrieval Augmented Fine Tuning

Train the generator LLM to be a better “open-book exam taker” — identifying the right document, extracting evidence, and citing sources.

import json
import random


def prepare_raft_training_example(question, oracle_doc, distractor_docs, answer, p_oracle=0.8):
    """Prepare a RAFT training example with oracle + distractors."""
    # With probability p, include the oracle document
    if random.random() < p_oracle:
        # Include oracle + distractors
        all_docs = [oracle_doc] + distractor_docs
        random.shuffle(all_docs)
    else:
        # Only distractors (forces memorization)
        all_docs = distractor_docs
    
    # Format context
    context = "\n\n".join(f"Document {i+1}: {doc}" for i, doc in enumerate(all_docs))
    
    # RAFT answer format with citations
    raft_answer = (
        f"##Reason: Based on the documents, "
        f"##begin_quote## {oracle_doc[:100]} ##end_quote## "
        f"Therefore, {answer}\n"
        f"##Answer: {answer}"
    )
    
    return {
        "instruction": f"Answer the question using the provided documents.\n\n{context}\n\nQuestion: {question}",
        "output": raft_answer,
    }


# Generate RAFT training examples
raft_examples = []
for pair in training_pairs[:5]:
    # Oracle = the positive document
    oracle = pair["document"]
    # Distractors = other passages
    distractors = [p for p in domain_passages if p != oracle][:2]
    
    example = prepare_raft_training_example(
        question=pair["query"],
        oracle_doc=oracle,
        distractor_docs=distractors,
        answer=f"Based on the medical literature, the answer relates to {pair['query'].lower()}",
    )
    raft_examples.append(example)

print(f"Created {len(raft_examples)} RAFT training examples\n")
print("Example instruction (truncated):")
print(raft_examples[0]["instruction"][:300])
print("\nExample output:")
print(raft_examples[0]["output"][:200])

8. Decision Guide: When to Fine-Tune

A practical decision tree for choosing which RAG components to fine-tune.

decision_guide = {
    "Retrieval recall < 80%": {
        "Have labeled pairs (>1k)?": {
            "Yes": "Fine-tune embeddings with contrastive loss",
            "No": "Generate synthetic queries, then fine-tune",
        }
    },
    "Precision low (wrong docs in top-k)": {
        "Action": "Fine-tune cross-encoder reranker on domain hard negatives",
    },
    "LLM ignores context / hallucinates": {
        "Action": "RAFT generator training (oracle + distractors + CoT)",
    },
    "General performance OK": {
        "Action": "Improve chunking strategy or prompt engineering first",
    },
}

print("RAG Fine-Tuning Decision Guide")
print("=" * 60)
for problem, solution in decision_guide.items():
    print(f"\nProblem: {problem}")
    if isinstance(solution, dict) and "Action" in solution:
        print(f"  → {solution['Action']}")
    else:
        for condition, action in solution.items():
            if isinstance(action, dict):
                for answer, recommendation in action.items():
                    print(f"  {condition} {answer}{recommendation}")
            else:
                print(f"  → {action}")

print("\n" + "=" * 60)
print("Key takeaway: Start with prompt engineering and chunking.")
print("Fine-tune only when you've exhausted cheaper interventions.")