Training LLMs for Reasoning

A deep dive into techniques for building language models that can think step by step

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. The Reasoning Training Landscape
  3. Chain-of-Thought SFT (Distillation)
  4. GRPO for Math Reasoning
  5. The DeepSeek-R1 Recipe
  6. Self-Improvement & Rejection Sampling
  7. Reward Design for Reasoning
  8. Building a Dual-Mode Model

1. Setup & Installation

!pip install -q trl transformers datasets peft accelerate bitsandbytes

2. The Reasoning Training Landscape

There are three main paths to training LLMs for reasoning:

Path Method Data Source Pros Cons
Distillation SFT on CoT traces Teacher model (e.g., DeepSeek-R1) Simple, reliable Bounded by teacher quality
RL from Scratch GRPO / PPO with reward Self-generated solutions Can exceed teacher Unstable, reward hacking
Hybrid Distillation + RL Both Best of both worlds More complex pipeline

The choice depends on your budget, base model quality, and target domain.

3. Chain-of-Thought SFT (Distillation)

Fine-tune a smaller model on chain-of-thought reasoning traces from a stronger teacher model. This is the simplest and most reliable approach.

from trl import SFTTrainer, SFTConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig
from datasets import load_dataset

# Load model and tokenizer
model_name = "Qwen/Qwen2.5-1.5B"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load chain-of-thought dataset (distilled from DeepSeek-R1)
dataset = load_dataset("bespokelabs/Bespoke-Stratos-17k", split="train")

# LoRA configuration for efficient fine-tuning
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM"
)

# SFT training configuration
training_config = SFTConfig(
    output_dir="./cot-sft-model",
    max_seq_length=4096,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=2e-5,
    logging_steps=10,
    save_strategy="epoch",
)

# Create trainer and train
trainer = SFTTrainer(
    model=model,
    args=training_config,
    train_dataset=dataset,
    peft_config=peft_config,
)

# trainer.train()

4. GRPO for Math Reasoning

Group Relative Policy Optimization (GRPO) trains the model using reinforcement learning with verifiable rewards. For math, we can check if the final answer is correct.

import re
from trl import GRPOTrainer, GRPOConfig


def accuracy_reward(completions, **kwargs):
    """Reward based on whether the extracted answer matches the ground truth."""
    solutions = kwargs.get("answer", [])
    rewards = []
    for completion, solution in zip(completions, solutions):
        content = completion[0]["content"]
        # Try to extract answer from <answer>...</answer> tags
        match = re.search(r"<answer>(.*?)</answer>", content)
        if not match:
            # Try to extract from \boxed{...}
            match = re.search(r"\\boxed\{(.*?)\}", content)
        if match and match.group(1).strip() == solution.strip():
            rewards.append(1.0)
        else:
            rewards.append(0.0)
    return rewards


def format_reward(completions, **kwargs):
    """Reward for following the expected format with <think> and <answer> tags."""
    rewards = []
    for completion in completions:
        content = completion[0]["content"]
        has_think = "<think>" in content and "</think>" in content
        has_answer = "<answer>" in content and "</answer>" in content
        if has_think and has_answer:
            rewards.append(1.0)
        elif has_answer:
            rewards.append(0.5)
        else:
            rewards.append(0.0)
    return rewards


# GRPO configuration
grpo_config = GRPOConfig(
    output_dir="./grpo-math-model",
    num_generations=8,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=1,
    learning_rate=5e-6,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
)

# Load math dataset
math_dataset = load_dataset("trl-lib/DeepMath-103K", split="train")

# Create GRPO trainer
grpo_trainer = GRPOTrainer(
    model=model_name,
    args=grpo_config,
    train_dataset=math_dataset,
    reward_funcs=[accuracy_reward, format_reward],
)

# grpo_trainer.train()

5. The DeepSeek-R1 Recipe

DeepSeek-R1 used a sophisticated 6-stage pipeline to achieve strong reasoning performance:

Stage Name Description
1 R1-Zero Pure RL (GRPO) on base model — emergent CoT, but messy formatting
2 Cold Start SFT SFT on thousands of curated long-CoT examples to bootstrap reasoning
3 RL Phase 2 GRPO with rule-based rewards on math, code, and reasoning tasks
4 Rejection Sampling Generate many solutions, keep only verified-correct ones
5 Final SFT SFT on rejection-sampled data + general instruction data
6 Distillation Distill R1 reasoning into smaller models (1.5B–70B)

Key insight: The combination of RL and distillation produces models that can reason well at any scale.

6. Self-Improvement & Rejection Sampling

Generate multiple candidate solutions and filter to keep only the ones with verified-correct answers. This creates high-quality training data from the model itself.

import re
from transformers import pipeline


def generate_and_filter(model, tokenizer, problems, n_samples=16):
    """Generate multiple solutions per problem and keep only correct ones."""
    generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
    filtered_dataset = []

    for problem in problems:
        question = problem["question"]
        ground_truth = problem["answer"]

        # Generate n_samples candidate solutions
        outputs = generator(
            question,
            max_new_tokens=2048,
            num_return_sequences=n_samples,
            temperature=0.7,
            do_sample=True,
        )

        for output in outputs:
            text = output["generated_text"]
            # Extract answer from \boxed{...}
            match = re.search(r"\\boxed\{(.*?)\}", text)
            if match and match.group(1).strip() == ground_truth.strip():
                filtered_dataset.append({
                    "question": question,
                    "response": text,
                    "answer": ground_truth,
                })
                break  # Keep one correct solution per problem

    return filtered_dataset


# Example usage (commented out - requires GPU and model):
# problems = [{"question": "What is 2+2?", "answer": "4"}]
# filtered = generate_and_filter(model, tokenizer, problems, n_samples=16)
# print(f"Filtered {len(filtered)} correct solutions from {len(problems)} problems")

7. Reward Design for Reasoning

Effective reward functions are critical for RL-based reasoning training. Combine multiple reward signals for robust learning.

import re
import subprocess


def outcome_reward(completion, ground_truth):
    """Binary reward: 1.0 if the final answer matches, 0.0 otherwise."""
    match = re.search(r"\\boxed\{(.*?)\}", completion)
    if match and match.group(1).strip() == ground_truth.strip():
        return 1.0
    return 0.0


def code_execution_reward(completion):
    """Reward based on whether generated code executes successfully."""
    code_match = re.search(r"```python\n(.*?)```", completion, re.DOTALL)
    if not code_match:
        return 0.0
    code = code_match.group(1)
    try:
        result = subprocess.run(
            ["python", "-c", code],
            capture_output=True, text=True, timeout=10
        )
        return 1.0 if result.returncode == 0 else 0.0
    except (subprocess.TimeoutExpired, Exception):
        return 0.0


# Multi-reward composition
reward_weights = {
    "accuracy": 0.5,
    "format": 0.2,
    "code_execution": 0.3,
}


def combined_reward(completion, ground_truth):
    """Weighted combination of multiple reward signals."""
    acc = outcome_reward(completion, ground_truth)
    fmt = 1.0 if "<think>" in completion and "</think>" in completion else 0.0
    code = code_execution_reward(completion)

    total = (
        reward_weights["accuracy"] * acc +
        reward_weights["format"] * fmt +
        reward_weights["code_execution"] * code
    )
    return total


# Example
sample = "<think>Let me solve this.</think> The answer is \\boxed{42}"
print(f"Combined reward: {combined_reward(sample, '42')}")

8. Building a Dual-Mode Model

A dual-mode model can operate in two modes: - /think mode: Model shows its chain-of-thought reasoning (slower, more accurate) - /no_think mode: Model gives a direct answer (faster, for simple queries)

This approach, used in models like SmolLM3, lets users control the reasoning behavior at inference time.

from transformers import AutoModelForCausalLM, AutoTokenizer


def generate_with_mode(model, tokenizer, prompt, mode="think", max_tokens=1024):
    """Generate a response in either think or no_think mode."""
    if mode == "think":
        system_prompt = (
            "You are a helpful assistant. Think step by step before answering. "
            "Show your reasoning inside <think>...</think> tags, then give your "
            "final answer."
        )
    else:
        system_prompt = (
            "You are a helpful assistant. Give a direct, concise answer "
            "without showing your reasoning process."
        )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt},
    ]

    inputs = tokenizer.apply_chat_template(
        messages, return_tensors="pt", add_generation_prompt=True
    )

    outputs = model.generate(
        inputs,
        max_new_tokens=max_tokens,
        temperature=0.7,
        do_sample=True,
    )

    response = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
    return response


# Example usage (requires a model loaded):
# /think mode — shows reasoning
# response = generate_with_mode(model, tokenizer, "What is 15% of 240?", mode="think")
# print("Think mode:", response)

# /no_think mode — direct answer
# response = generate_with_mode(model, tokenizer, "What is 15% of 240?", mode="no_think")
# print("No-think mode:", response)