!pip install -q sentence-transformers datasets langchain langchain-openai langchain-community faiss-cpu transformersFine-tuning RAG Components: Embeddings, Retrievers, and Generators
Domain-adaptive training for each RAG stage — from contrastive embedding fine-tuning to retrieval-aware LLM training with RAFT
Table of Contents
1. Setup & Installation
import os
# os.environ["OPENAI_API_KEY"] = "your-api-key-here" # Uncomment and set2. Generating Synthetic Training Data
Use an LLM to generate query-document pairs from your corpus for training.
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
# Sample domain documents
domain_passages = [
"Metformin is a first-line medication for type 2 diabetes. It works by decreasing glucose production in the liver and increasing insulin sensitivity.",
"Common side effects of metformin include nausea, diarrhea, and abdominal pain. Lactic acidosis is a rare but serious side effect.",
"Insulin therapy is used when oral medications fail to control blood glucose. Types include rapid-acting, short-acting, intermediate, and long-acting.",
"HbA1c measures average blood glucose over 2-3 months. A target of less than 7% is recommended for most adults with diabetes.",
"Diabetic retinopathy is a complication of diabetes that affects the eyes. Regular eye exams are recommended for early detection.",
]
# Generate synthetic queries
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.7)
prompt = ChatPromptTemplate.from_template(
"Given the following passage, generate {n} diverse questions "
"that this passage would answer. Return only the questions, one per line.\n\n"
"Passage: {passage}\n\nQuestions:"
)
chain = prompt | llm | StrOutputParser()
training_pairs = []
for passage in domain_passages:
response = chain.invoke({"passage": passage, "n": 2})
questions = [q.strip() for q in response.strip().split("\n") if q.strip()]
for q in questions:
training_pairs.append({"query": q, "document": passage})
print(f"Generated {len(training_pairs)} training pairs:\n")
for pair in training_pairs[:4]:
print(f"Q: {pair['query']}")
print(f"D: {pair['document'][:80]}...\n")3. Mining Hard Negatives
Find documents that are similar but not relevant — critical for training embedding models.
from sentence_transformers import SentenceTransformer
import numpy as np
# Load a base embedding model
base_model = SentenceTransformer("all-MiniLM-L6-v2")
# Embed all passages
passage_embeddings = base_model.encode(domain_passages)
# For each query, find hard negatives
triplets = []
for pair in training_pairs:
query_emb = base_model.encode(pair["query"])
# Compute similarity to all passages
similarities = np.dot(passage_embeddings, query_emb) / (
np.linalg.norm(passage_embeddings, axis=1) * np.linalg.norm(query_emb)
)
# Find the positive document index
pos_idx = domain_passages.index(pair["document"])
# Hard negative: most similar passage that is NOT the positive
sorted_indices = np.argsort(similarities)[::-1]
for idx in sorted_indices:
if idx != pos_idx:
hard_neg = domain_passages[idx]
triplets.append({
"query": pair["query"],
"positive": pair["document"],
"negative": hard_neg,
})
break
print(f"Created {len(triplets)} triplets (query, positive, hard negative)\n")
print(f"Example:")
print(f" Query: {triplets[0]['query']}")
print(f" Positive: {triplets[0]['positive'][:80]}...")
print(f" Negative: {triplets[0]['negative'][:80]}...")4. Fine-Tuning Embeddings with Sentence Transformers
Contrastive learning with MultipleNegativesRankingLoss to adapt embeddings to your domain.
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
from sentence_transformers.trainer import SentenceTransformerTrainer
from datasets import Dataset
# Prepare training dataset
train_dataset = Dataset.from_dict({
"anchor": [t["query"] for t in triplets],
"positive": [t["positive"] for t in triplets],
})
print(f"Training dataset: {len(train_dataset)} examples")
print(f"Columns: {train_dataset.column_names}")
print(f"\nExample:")
print(f" Anchor: {train_dataset[0]['anchor']}")
print(f" Positive: {train_dataset[0]['positive'][:80]}...")# Fine-tuning setup (run on GPU for best results)
model = SentenceTransformer("all-MiniLM-L6-v2")
# Contrastive loss: pulls positive pairs together, pushes negatives apart
loss = losses.MultipleNegativesRankingLoss(model)
# Training arguments
args = SentenceTransformerTrainingArguments(
output_dir="models/domain-embedding-finetuned",
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=2e-5,
warmup_ratio=0.1,
logging_steps=10,
)
# Note: For a real fine-tuning run, uncomment the trainer.train() call
# trainer = SentenceTransformerTrainer(
# model=model,
# args=args,
# train_dataset=train_dataset,
# loss=loss,
# )
# trainer.train()
# model.save_pretrained("models/domain-embedding-finetuned/final")
print("Fine-tuning setup complete (uncomment trainer.train() to run)")
print(f"Model: {model.get_sentence_embedding_dimension()} dimensions")
print(f"Loss: MultipleNegativesRankingLoss")
print(f"Epochs: {args.num_train_epochs}, LR: {args.learning_rate}")5. Evaluating Embedding Fine-Tuning
Measure hit rate: what fraction of queries retrieve the correct document in top-k?
def evaluate_hit_rate(model, queries, positives, corpus, top_k=3):
"""Compute hit rate: fraction of queries where correct doc is in top-k."""
corpus_embeddings = model.encode(corpus)
hits = 0
for query, positive in zip(queries, positives):
query_emb = model.encode(query)
# Compute similarities
similarities = np.dot(corpus_embeddings, query_emb) / (
np.linalg.norm(corpus_embeddings, axis=1) * np.linalg.norm(query_emb)
)
# Get top-k indices
top_k_indices = np.argsort(similarities)[-top_k:][::-1]
top_k_docs = [corpus[i] for i in top_k_indices]
if positive in top_k_docs:
hits += 1
return hits / len(queries)
# Evaluate base model
queries = [t["query"] for t in triplets]
positives = [t["positive"] for t in triplets]
hit_rate = evaluate_hit_rate(base_model, queries, positives, domain_passages, top_k=2)
print(f"Base model hit rate @2: {hit_rate:.2%}")
hit_rate_3 = evaluate_hit_rate(base_model, queries, positives, domain_passages, top_k=3)
print(f"Base model hit rate @3: {hit_rate_3:.2%}")6. Fine-Tuning a Cross-Encoder Reranker
Train a cross-encoder to better distinguish relevant from irrelevant documents in your domain.
from sentence_transformers import CrossEncoder
# Load base cross-encoder
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")
# Prepare labeled data for reranker fine-tuning
reranker_train_data = []
for t in triplets:
# Positive pair (label=1)
reranker_train_data.append({
"query": t["query"],
"passage": t["positive"],
"label": 1.0,
})
# Negative pair (label=0)
reranker_train_data.append({
"query": t["query"],
"passage": t["negative"],
"label": 0.0,
})
print(f"Reranker training data: {len(reranker_train_data)} examples")
print(f" Positive pairs: {sum(1 for d in reranker_train_data if d['label'] == 1.0)}")
print(f" Negative pairs: {sum(1 for d in reranker_train_data if d['label'] == 0.0)}")
# Test base reranker on domain data
test_query = "What are the side effects of metformin?"
test_pairs = [(test_query, p) for p in domain_passages]
scores = reranker.predict(test_pairs)
print(f"\nBase reranker scores for: '{test_query}'")
for score, passage in sorted(zip(scores, domain_passages), reverse=True):
print(f" {score:.2f} — {passage[:80]}...")7. RAFT: Retrieval Augmented Fine Tuning
Train the generator LLM to be a better “open-book exam taker” — identifying the right document, extracting evidence, and citing sources.
import json
import random
def prepare_raft_training_example(question, oracle_doc, distractor_docs, answer, p_oracle=0.8):
"""Prepare a RAFT training example with oracle + distractors."""
# With probability p, include the oracle document
if random.random() < p_oracle:
# Include oracle + distractors
all_docs = [oracle_doc] + distractor_docs
random.shuffle(all_docs)
else:
# Only distractors (forces memorization)
all_docs = distractor_docs
# Format context
context = "\n\n".join(f"Document {i+1}: {doc}" for i, doc in enumerate(all_docs))
# RAFT answer format with citations
raft_answer = (
f"##Reason: Based on the documents, "
f"##begin_quote## {oracle_doc[:100]} ##end_quote## "
f"Therefore, {answer}\n"
f"##Answer: {answer}"
)
return {
"instruction": f"Answer the question using the provided documents.\n\n{context}\n\nQuestion: {question}",
"output": raft_answer,
}
# Generate RAFT training examples
raft_examples = []
for pair in training_pairs[:5]:
# Oracle = the positive document
oracle = pair["document"]
# Distractors = other passages
distractors = [p for p in domain_passages if p != oracle][:2]
example = prepare_raft_training_example(
question=pair["query"],
oracle_doc=oracle,
distractor_docs=distractors,
answer=f"Based on the medical literature, the answer relates to {pair['query'].lower()}",
)
raft_examples.append(example)
print(f"Created {len(raft_examples)} RAFT training examples\n")
print("Example instruction (truncated):")
print(raft_examples[0]["instruction"][:300])
print("\nExample output:")
print(raft_examples[0]["output"][:200])8. Decision Guide: When to Fine-Tune
A practical decision tree for choosing which RAG components to fine-tune.
decision_guide = {
"Retrieval recall < 80%": {
"Have labeled pairs (>1k)?": {
"Yes": "Fine-tune embeddings with contrastive loss",
"No": "Generate synthetic queries, then fine-tune",
}
},
"Precision low (wrong docs in top-k)": {
"Action": "Fine-tune cross-encoder reranker on domain hard negatives",
},
"LLM ignores context / hallucinates": {
"Action": "RAFT generator training (oracle + distractors + CoT)",
},
"General performance OK": {
"Action": "Improve chunking strategy or prompt engineering first",
},
}
print("RAG Fine-Tuning Decision Guide")
print("=" * 60)
for problem, solution in decision_guide.items():
print(f"\nProblem: {problem}")
if isinstance(solution, dict) and "Action" in solution:
print(f" → {solution['Action']}")
else:
for condition, action in solution.items():
if isinstance(action, dict):
for answer, recommendation in action.items():
print(f" {condition} {answer} → {recommendation}")
else:
print(f" → {action}")
print("\n" + "=" * 60)
print("Key takeaway: Start with prompt engineering and chunking.")
print("Fine-tune only when you've exhausted cheaper interventions.")