Training LLMs with Mixture of Experts

A deep dive into MoE architectures — from routing mechanics to building expert layers from scratch

Open In Colab

📖 Read the full article


Table of Contents

  1. Setup & Installation
  2. Dense vs Sparse
  3. MoE Architecture
  4. Top-K Router Implementation
  5. Load Balancing Loss
  6. Building a MoE Layer from Scratch
  7. MoE Transformer Block
  8. Notable MoE Models
  9. Sparse Upcycling
  10. Small MoE Configuration

1. Setup & Installation

Install the required libraries for building and experimenting with MoE models.

!pip install -q torch transformers datasets unsloth peft accelerate bitsandbytes

2. Dense vs Sparse

Property Dense Model Sparse (MoE) Model
Parameters Used All parameters for every token Only top-K experts per token
Training Speed Slower per step Faster per step (fewer active params)
Inference FLOPs Proportional to total params Proportional to active params only
Memory Total params × bytes Total params × bytes (all loaded)
Scaling Linear cost increase Sub-linear compute increase

Key insight: MoE models store more knowledge (more total parameters) while keeping compute cost low (only a subset is active per token).

3. MoE Architecture

A Mixture of Experts layer has two main components:

  1. Router (Gate): A learned linear layer that decides which experts process each token
  2. Experts: N independent feed-forward networks (FFNs)

How it works:

  1. The router computes a probability distribution over all experts for each token
  2. The top-K experts are selected based on highest probabilities
  3. Each selected expert processes the token independently
  4. Outputs are combined as a weighted sum:

y = \sum_{i \in \text{Top-K}} g_i \cdot E_i(x)

where g_i is the normalized gate weight and E_i(x) is expert i’s output.

Where MoE layers are placed:

  • MoE replaces the FFN (feed-forward) layer in transformer blocks
  • Attention layers remain dense (shared across all tokens)
  • Typically applied to every other layer or every layer

4. Top-K Router Implementation

Implement a Top-K gating mechanism that routes tokens to the best experts.

import torch
import torch.nn as nn
import torch.nn.functional as F


class TopKRouter(nn.Module):
    """Top-K router that selects the best experts for each token."""

    def __init__(self, hidden_dim, num_experts, top_k=2):
        super().__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        # Linear gate: projects hidden states to expert scores
        self.gate = nn.Linear(hidden_dim, num_experts, bias=False)

    def forward(self, x):
        # x: (batch, seq_len, hidden_dim)
        logits = self.gate(x)  # (batch, seq_len, num_experts)
        probs = F.softmax(logits, dim=-1)  # Softmax over experts

        # Select top-k experts
        top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1)

        # Normalize selected probabilities to sum to 1
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        return top_k_probs, top_k_indices, probs


# Test the router
router = TopKRouter(hidden_dim=512, num_experts=8, top_k=2)
x = torch.randn(2, 10, 512)  # batch=2, seq_len=10, hidden_dim=512
top_k_probs, top_k_indices, all_probs = router(x)

print(f"Input shape: {x.shape}")
print(f"Top-K probabilities shape: {top_k_probs.shape}")
print(f"Top-K indices shape: {top_k_indices.shape}")
print(f"\nSample routing for first token:")
print(f"  Selected experts: {top_k_indices[0, 0].tolist()}")
print(f"  Weights: {top_k_probs[0, 0].tolist()}")

5. Load Balancing Loss

Without balancing, routers tend to collapse — sending all tokens to a few experts. The auxiliary load balancing loss encourages uniform expert utilization.

def load_balancing_loss(router_probs, top_k_indices, num_experts):
    """
    Compute the auxiliary load balancing loss.

    Loss = N * sum(f_i * P_i) for i in experts

    where:
    - f_i = fraction of tokens routed to expert i
    - P_i = average router probability for expert i
    - N = number of experts
    """
    # router_probs: (batch, seq_len, num_experts) - full probability distribution
    # top_k_indices: (batch, seq_len, top_k) - selected expert indices

    batch_size, seq_len, _ = router_probs.shape
    total_tokens = batch_size * seq_len

    # f_i: fraction of tokens assigned to each expert
    flat_indices = top_k_indices.reshape(-1)  # Flatten all selections
    expert_counts = torch.bincount(flat_indices, minlength=num_experts).float()
    fraction_tokens = expert_counts / total_tokens  # f_i

    # P_i: average router probability for each expert
    fraction_probs = router_probs.mean(dim=[0, 1])  # Average over batch and seq

    # Load balancing loss: N * sum(f_i * P_i)
    loss = num_experts * torch.sum(fraction_tokens * fraction_probs)

    return loss


# Test load balancing loss
num_experts = 8
loss = load_balancing_loss(all_probs, top_k_indices, num_experts)
print(f"Load balancing loss: {loss.item():.4f}")
print(f"Perfect balance would give loss ≈ {num_experts * (1/num_experts) * (1/num_experts) * num_experts:.4f}")

6. Building a MoE Layer from Scratch

Combine the router with multiple expert FFNs using a SwiGLU activation.

class Expert(nn.Module):
    """Single expert using SwiGLU activation."""

    def __init__(self, hidden_dim, intermediate_dim):
        super().__init__()
        self.w1 = nn.Linear(hidden_dim, intermediate_dim, bias=False)
        self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, intermediate_dim, bias=False)

    def forward(self, x):
        # SwiGLU: w2(SiLU(w1(x)) * w3(x))
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class MoELayer(nn.Module):
    """Mixture of Experts layer with top-k routing."""

    def __init__(self, hidden_dim, intermediate_dim, num_experts, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k

        # Router
        self.gate = TopKRouter(hidden_dim, num_experts, top_k)

        # Expert list
        self.experts = nn.ModuleList([
            Expert(hidden_dim, intermediate_dim)
            for _ in range(num_experts)
        ])

        self.aux_loss = 0.0  # Store auxiliary loss

    def forward(self, x):
        batch, seq_len, hidden_dim = x.shape

        # Route tokens to experts
        top_k_probs, top_k_indices, all_probs = self.gate(x)

        # Compute load balancing loss
        self.aux_loss = load_balancing_loss(all_probs, top_k_indices, self.num_experts)

        # Initialize output
        output = torch.zeros_like(x)

        # Process each expert
        for i, expert in enumerate(self.experts):
            # Find tokens routed to this expert
            mask = (top_k_indices == i).any(dim=-1)  # (batch, seq_len)
            if not mask.any():
                continue

            # Get the weight for this expert
            expert_weight = torch.where(
                top_k_indices == i,
                top_k_probs,
                torch.zeros_like(top_k_probs)
            ).sum(dim=-1)  # (batch, seq_len)

            # Compute expert output and weight it
            expert_out = expert(x)  # (batch, seq_len, hidden_dim)
            output += expert_weight.unsqueeze(-1) * expert_out * mask.unsqueeze(-1).float()

        return output


# Test MoE layer
moe = MoELayer(hidden_dim=512, intermediate_dim=1024, num_experts=8, top_k=2)
x = torch.randn(2, 10, 512)
output = moe(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Aux loss:     {moe.aux_loss.item():.4f}")
print(f"Total params: {sum(p.numel() for p in moe.parameters()):,}")

7. MoE Transformer Block

Integrate the MoE layer into a transformer block, replacing the standard FFN.

class MoETransformerBlock(nn.Module):
    """Transformer block with MoE replacing the FFN."""

    def __init__(self, hidden_dim, num_heads, intermediate_dim, num_experts, top_k=2):
        super().__init__()
        # Pre-attention layer norm
        self.attn_norm = nn.RMSNorm(hidden_dim)

        # Multi-head self-attention (dense, shared)
        self.attention = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            batch_first=True,
        )

        # Pre-FFN layer norm
        self.ffn_norm = nn.RMSNorm(hidden_dim)

        # MoE replaces the standard FFN
        self.moe = MoELayer(hidden_dim, intermediate_dim, num_experts, top_k)

    def forward(self, x):
        # Self-attention with residual
        normed = self.attn_norm(x)
        attn_out, _ = self.attention(normed, normed, normed)
        x = x + attn_out

        # MoE FFN with residual
        normed = self.ffn_norm(x)
        moe_out = self.moe(normed)
        x = x + moe_out

        return x


# Test MoE Transformer Block
block = MoETransformerBlock(
    hidden_dim=512,
    num_heads=8,
    intermediate_dim=1024,
    num_experts=8,
    top_k=2,
)

x = torch.randn(2, 10, 512)
output = block(x)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Block params: {sum(p.numel() for p in block.parameters()):,}")
print(f"MoE aux loss: {block.moe.aux_loss.item():.4f}")

8. Notable MoE Models

Model Year Total Params Active Params Experts Top-K Key Innovation
Switch Transformer 2021 1.6T ~1B 2048 1 Simplified top-1 routing
Mixtral 8x7B 2023 46.7B 12.9B 8 2 Open-source MoE, strong quality
DeepSeekMoE 2024 16B 2.8B 64 6 Fine-grained experts + shared experts
DeepSeek-V2/V3 2024 236B/671B 21B/37B 160/257 6/8 Multi-head latent attention
OLMoE 2024 6.9B 1.3B 64 8 Fully open (data + code + weights)
Qwen3-30B-A3B 2025 30B 3B 128 8 Efficient hybrid thinking model

9. Sparse Upcycling

Sparse upcycling converts a pre-trained dense model into a MoE model, avoiding training from scratch.

Process:

  1. Copy FFN weights: Each expert is initialized as a copy of the original FFN weights
  2. Add random router: A randomly initialized gating network is added
  3. Continue training: Fine-tune the MoE model, allowing experts to specialize

Advantages:

  • Reuses expensive pre-training compute
  • Experts start from a strong baseline
  • Faster convergence than training MoE from scratch

Typical recipe:

Dense 7B model (FFN) → Copy FFN × 8 experts → Add router → MoE 8x7B
Active params: ~12.9B   Total params: ~46.7B

Note: The router needs careful warm-up — start with a higher learning rate for the router and lower for experts to avoid catastrophic forgetting.

10. Small MoE Configuration

Define and estimate the parameter count for a small MoE model.

# Small MoE configuration
config = {
    "hidden_dim": 2048,
    "intermediate_dim": 5504,
    "num_experts": 8,
    "top_k": 2,
    "num_layers": 24,
    "num_heads": 16,
    "vocab_size": 32000,
    "max_seq_len": 4096,
}

# Parameter count estimate
hidden = config["hidden_dim"]
inter = config["intermediate_dim"]
n_exp = config["num_experts"]
n_layers = config["num_layers"]
n_heads = config["num_heads"]
vocab = config["vocab_size"]

# Per-layer parameters
attn_params = 4 * hidden * hidden  # Q, K, V, O projections
expert_params = 3 * hidden * inter  # w1, w2, w3 (SwiGLU)
moe_params = n_exp * expert_params  # All experts
router_params = hidden * n_exp  # Gate
norm_params = 2 * hidden  # Two RMSNorm layers
layer_params = attn_params + moe_params + router_params + norm_params

# Total
embedding_params = vocab * hidden  # Token embedding
total_params = n_layers * layer_params + embedding_params
active_params = n_layers * (attn_params + config["top_k"] * expert_params + router_params + norm_params) + embedding_params

print("MoE Model Configuration")
print("=" * 40)
for k, v in config.items():
    print(f"  {k}: {v}")
print()
print(f"Per-layer breakdown:")
print(f"  Attention:     {attn_params:>12,} params")
print(f"  MoE ({n_exp} experts): {moe_params:>12,} params")
print(f"  Router:        {router_params:>12,} params")
print(f"  Norms:         {norm_params:>12,} params")
print(f"  Layer total:   {layer_params:>12,} params")
print()
print(f"Total parameters:  {total_params:>12,} ({total_params/1e9:.2f}B)")
print(f"Active parameters: {active_params:>12,} ({active_params/1e9:.2f}B)")
print(f"Active ratio:      {active_params/total_params:.1%}")