!pip install -q torch transformers datasets unsloth peft accelerate bitsandbytesTraining LLMs with Mixture of Experts
A deep dive into MoE architectures — from routing mechanics to building expert layers from scratch
Table of Contents
1. Setup & Installation
Install the required libraries for building and experimenting with MoE models.
2. Dense vs Sparse
| Property | Dense Model | Sparse (MoE) Model |
|---|---|---|
| Parameters Used | All parameters for every token | Only top-K experts per token |
| Training Speed | Slower per step | Faster per step (fewer active params) |
| Inference FLOPs | Proportional to total params | Proportional to active params only |
| Memory | Total params × bytes | Total params × bytes (all loaded) |
| Scaling | Linear cost increase | Sub-linear compute increase |
Key insight: MoE models store more knowledge (more total parameters) while keeping compute cost low (only a subset is active per token).
3. MoE Architecture
A Mixture of Experts layer has two main components:
- Router (Gate): A learned linear layer that decides which experts process each token
- Experts: N independent feed-forward networks (FFNs)
How it works:
- The router computes a probability distribution over all experts for each token
- The top-K experts are selected based on highest probabilities
- Each selected expert processes the token independently
- Outputs are combined as a weighted sum:
y = \sum_{i \in \text{Top-K}} g_i \cdot E_i(x)
where g_i is the normalized gate weight and E_i(x) is expert i’s output.
Where MoE layers are placed:
- MoE replaces the FFN (feed-forward) layer in transformer blocks
- Attention layers remain dense (shared across all tokens)
- Typically applied to every other layer or every layer
4. Top-K Router Implementation
Implement a Top-K gating mechanism that routes tokens to the best experts.
import torch
import torch.nn as nn
import torch.nn.functional as F
class TopKRouter(nn.Module):
"""Top-K router that selects the best experts for each token."""
def __init__(self, hidden_dim, num_experts, top_k=2):
super().__init__()
self.top_k = top_k
self.num_experts = num_experts
# Linear gate: projects hidden states to expert scores
self.gate = nn.Linear(hidden_dim, num_experts, bias=False)
def forward(self, x):
# x: (batch, seq_len, hidden_dim)
logits = self.gate(x) # (batch, seq_len, num_experts)
probs = F.softmax(logits, dim=-1) # Softmax over experts
# Select top-k experts
top_k_probs, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
# Normalize selected probabilities to sum to 1
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
return top_k_probs, top_k_indices, probs
# Test the router
router = TopKRouter(hidden_dim=512, num_experts=8, top_k=2)
x = torch.randn(2, 10, 512) # batch=2, seq_len=10, hidden_dim=512
top_k_probs, top_k_indices, all_probs = router(x)
print(f"Input shape: {x.shape}")
print(f"Top-K probabilities shape: {top_k_probs.shape}")
print(f"Top-K indices shape: {top_k_indices.shape}")
print(f"\nSample routing for first token:")
print(f" Selected experts: {top_k_indices[0, 0].tolist()}")
print(f" Weights: {top_k_probs[0, 0].tolist()}")5. Load Balancing Loss
Without balancing, routers tend to collapse — sending all tokens to a few experts. The auxiliary load balancing loss encourages uniform expert utilization.
def load_balancing_loss(router_probs, top_k_indices, num_experts):
"""
Compute the auxiliary load balancing loss.
Loss = N * sum(f_i * P_i) for i in experts
where:
- f_i = fraction of tokens routed to expert i
- P_i = average router probability for expert i
- N = number of experts
"""
# router_probs: (batch, seq_len, num_experts) - full probability distribution
# top_k_indices: (batch, seq_len, top_k) - selected expert indices
batch_size, seq_len, _ = router_probs.shape
total_tokens = batch_size * seq_len
# f_i: fraction of tokens assigned to each expert
flat_indices = top_k_indices.reshape(-1) # Flatten all selections
expert_counts = torch.bincount(flat_indices, minlength=num_experts).float()
fraction_tokens = expert_counts / total_tokens # f_i
# P_i: average router probability for each expert
fraction_probs = router_probs.mean(dim=[0, 1]) # Average over batch and seq
# Load balancing loss: N * sum(f_i * P_i)
loss = num_experts * torch.sum(fraction_tokens * fraction_probs)
return loss
# Test load balancing loss
num_experts = 8
loss = load_balancing_loss(all_probs, top_k_indices, num_experts)
print(f"Load balancing loss: {loss.item():.4f}")
print(f"Perfect balance would give loss ≈ {num_experts * (1/num_experts) * (1/num_experts) * num_experts:.4f}")6. Building a MoE Layer from Scratch
Combine the router with multiple expert FFNs using a SwiGLU activation.
class Expert(nn.Module):
"""Single expert using SwiGLU activation."""
def __init__(self, hidden_dim, intermediate_dim):
super().__init__()
self.w1 = nn.Linear(hidden_dim, intermediate_dim, bias=False)
self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, intermediate_dim, bias=False)
def forward(self, x):
# SwiGLU: w2(SiLU(w1(x)) * w3(x))
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class MoELayer(nn.Module):
"""Mixture of Experts layer with top-k routing."""
def __init__(self, hidden_dim, intermediate_dim, num_experts, top_k=2):
super().__init__()
self.num_experts = num_experts
self.top_k = top_k
# Router
self.gate = TopKRouter(hidden_dim, num_experts, top_k)
# Expert list
self.experts = nn.ModuleList([
Expert(hidden_dim, intermediate_dim)
for _ in range(num_experts)
])
self.aux_loss = 0.0 # Store auxiliary loss
def forward(self, x):
batch, seq_len, hidden_dim = x.shape
# Route tokens to experts
top_k_probs, top_k_indices, all_probs = self.gate(x)
# Compute load balancing loss
self.aux_loss = load_balancing_loss(all_probs, top_k_indices, self.num_experts)
# Initialize output
output = torch.zeros_like(x)
# Process each expert
for i, expert in enumerate(self.experts):
# Find tokens routed to this expert
mask = (top_k_indices == i).any(dim=-1) # (batch, seq_len)
if not mask.any():
continue
# Get the weight for this expert
expert_weight = torch.where(
top_k_indices == i,
top_k_probs,
torch.zeros_like(top_k_probs)
).sum(dim=-1) # (batch, seq_len)
# Compute expert output and weight it
expert_out = expert(x) # (batch, seq_len, hidden_dim)
output += expert_weight.unsqueeze(-1) * expert_out * mask.unsqueeze(-1).float()
return output
# Test MoE layer
moe = MoELayer(hidden_dim=512, intermediate_dim=1024, num_experts=8, top_k=2)
x = torch.randn(2, 10, 512)
output = moe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Aux loss: {moe.aux_loss.item():.4f}")
print(f"Total params: {sum(p.numel() for p in moe.parameters()):,}")7. MoE Transformer Block
Integrate the MoE layer into a transformer block, replacing the standard FFN.
class MoETransformerBlock(nn.Module):
"""Transformer block with MoE replacing the FFN."""
def __init__(self, hidden_dim, num_heads, intermediate_dim, num_experts, top_k=2):
super().__init__()
# Pre-attention layer norm
self.attn_norm = nn.RMSNorm(hidden_dim)
# Multi-head self-attention (dense, shared)
self.attention = nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
batch_first=True,
)
# Pre-FFN layer norm
self.ffn_norm = nn.RMSNorm(hidden_dim)
# MoE replaces the standard FFN
self.moe = MoELayer(hidden_dim, intermediate_dim, num_experts, top_k)
def forward(self, x):
# Self-attention with residual
normed = self.attn_norm(x)
attn_out, _ = self.attention(normed, normed, normed)
x = x + attn_out
# MoE FFN with residual
normed = self.ffn_norm(x)
moe_out = self.moe(normed)
x = x + moe_out
return x
# Test MoE Transformer Block
block = MoETransformerBlock(
hidden_dim=512,
num_heads=8,
intermediate_dim=1024,
num_experts=8,
top_k=2,
)
x = torch.randn(2, 10, 512)
output = block(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Block params: {sum(p.numel() for p in block.parameters()):,}")
print(f"MoE aux loss: {block.moe.aux_loss.item():.4f}")8. Notable MoE Models
| Model | Year | Total Params | Active Params | Experts | Top-K | Key Innovation |
|---|---|---|---|---|---|---|
| Switch Transformer | 2021 | 1.6T | ~1B | 2048 | 1 | Simplified top-1 routing |
| Mixtral 8x7B | 2023 | 46.7B | 12.9B | 8 | 2 | Open-source MoE, strong quality |
| DeepSeekMoE | 2024 | 16B | 2.8B | 64 | 6 | Fine-grained experts + shared experts |
| DeepSeek-V2/V3 | 2024 | 236B/671B | 21B/37B | 160/257 | 6/8 | Multi-head latent attention |
| OLMoE | 2024 | 6.9B | 1.3B | 64 | 8 | Fully open (data + code + weights) |
| Qwen3-30B-A3B | 2025 | 30B | 3B | 128 | 8 | Efficient hybrid thinking model |
9. Sparse Upcycling
Sparse upcycling converts a pre-trained dense model into a MoE model, avoiding training from scratch.
Process:
- Copy FFN weights: Each expert is initialized as a copy of the original FFN weights
- Add random router: A randomly initialized gating network is added
- Continue training: Fine-tune the MoE model, allowing experts to specialize
Advantages:
- Reuses expensive pre-training compute
- Experts start from a strong baseline
- Faster convergence than training MoE from scratch
Typical recipe:
Dense 7B model (FFN) → Copy FFN × 8 experts → Add router → MoE 8x7B
Active params: ~12.9B Total params: ~46.7B
Note: The router needs careful warm-up — start with a higher learning rate for the router and lower for experts to avoid catastrophic forgetting.
10. Small MoE Configuration
Define and estimate the parameter count for a small MoE model.
# Small MoE configuration
config = {
"hidden_dim": 2048,
"intermediate_dim": 5504,
"num_experts": 8,
"top_k": 2,
"num_layers": 24,
"num_heads": 16,
"vocab_size": 32000,
"max_seq_len": 4096,
}
# Parameter count estimate
hidden = config["hidden_dim"]
inter = config["intermediate_dim"]
n_exp = config["num_experts"]
n_layers = config["num_layers"]
n_heads = config["num_heads"]
vocab = config["vocab_size"]
# Per-layer parameters
attn_params = 4 * hidden * hidden # Q, K, V, O projections
expert_params = 3 * hidden * inter # w1, w2, w3 (SwiGLU)
moe_params = n_exp * expert_params # All experts
router_params = hidden * n_exp # Gate
norm_params = 2 * hidden # Two RMSNorm layers
layer_params = attn_params + moe_params + router_params + norm_params
# Total
embedding_params = vocab * hidden # Token embedding
total_params = n_layers * layer_params + embedding_params
active_params = n_layers * (attn_params + config["top_k"] * expert_params + router_params + norm_params) + embedding_params
print("MoE Model Configuration")
print("=" * 40)
for k, v in config.items():
print(f" {k}: {v}")
print()
print(f"Per-layer breakdown:")
print(f" Attention: {attn_params:>12,} params")
print(f" MoE ({n_exp} experts): {moe_params:>12,} params")
print(f" Router: {router_params:>12,} params")
print(f" Norms: {norm_params:>12,} params")
print(f" Layer total: {layer_params:>12,} params")
print()
print(f"Total parameters: {total_params:>12,} ({total_params/1e9:.2f}B)")
print(f"Active parameters: {active_params:>12,} ({active_params/1e9:.2f}B)")
print(f"Active ratio: {active_params/total_params:.1%}")