Putting It Together
We have covered the three pillars of DQN: a convolutional architecture for visual input, experience replay for breaking correlations, and target networks for stable learning. Now we combine them into the complete algorithm that learned to play 49 Atari games at superhuman level.
The Complete DQN Algorithm
DQN follows a simple loop:
- Observe: Get the current state (stacked frames)
- Act: Choose action with epsilon-greedy policy
- Store: Save transition to replay buffer
- Sample: Draw random batch from buffer
- Learn: Update network to minimize TD error
- Sync: Periodically update target network
The magic is not in any single step but in how they work together. Experience replay breaks the correlation between consecutive experiences. The target network provides a stable learning signal. The CNN extracts useful features from pixels. Together, they overcome the deadly triad.
DQN Algorithm:
Initialize replay buffer with capacity
Initialize action-value function with random weights
Initialize target action-value function with weights
For episode :
- Initialize state (preprocessed starting frame stack)
- For :
- With probability select random action
- Otherwise select
- Execute action , observe reward and next state
- Store transition in
- Sample random minibatch of transitions from
- For each transition, compute target:
- If terminal:
- Otherwise:
- Perform gradient descent on
- Every steps:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from collections import deque
import random
import copy
class DQN(nn.Module):
"""
DQN convolutional network for Atari.
"""
def __init__(self, n_actions, in_channels=4):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
)
self.fc = nn.Sequential(
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, n_actions),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
return self.fc(x)
class ReplayBuffer:
"""Experience replay buffer."""
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32),
)
def __len__(self):
return len(self.buffer)
class DQNAgent:
"""
Complete DQN agent with all components.
"""
def __init__(
self,
n_actions,
state_shape=(4, 84, 84),
buffer_size=1000000,
batch_size=32,
gamma=0.99,
lr=1e-4,
target_update_freq=10000,
epsilon_start=1.0,
epsilon_end=0.1,
epsilon_decay_steps=1000000,
):
self.n_actions = n_actions
self.batch_size = batch_size
self.gamma = gamma
self.target_update_freq = target_update_freq
# Epsilon schedule
self.epsilon_start = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay_steps = epsilon_decay_steps
# Networks
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.q_network = DQN(n_actions).to(self.device)
self.target_network = copy.deepcopy(self.q_network)
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
# Replay buffer
self.buffer = ReplayBuffer(buffer_size)
# Counters
self.step_count = 0
self.episode_count = 0
def get_epsilon(self):
"""Linear epsilon decay schedule."""
progress = min(1.0, self.step_count / self.epsilon_decay_steps)
return self.epsilon_start + progress * (self.epsilon_end - self.epsilon_start)
def select_action(self, state):
"""Epsilon-greedy action selection."""
epsilon = self.get_epsilon()
if np.random.random() < epsilon:
return np.random.randint(self.n_actions)
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
q_values = self.q_network(state_t)
return q_values.argmax(dim=1).item()
def store_transition(self, state, action, reward, next_state, done):
"""Store transition in replay buffer."""
self.buffer.push(state, action, reward, next_state, done)
def train_step(self):
"""Perform one training step."""
if len(self.buffer) < self.batch_size:
return None
# Sample batch
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
# Convert to tensors
states_t = torch.FloatTensor(states).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
rewards_t = torch.FloatTensor(rewards).to(self.device)
next_states_t = torch.FloatTensor(next_states).to(self.device)
dones_t = torch.FloatTensor(dones).to(self.device)
# Current Q-values
current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# Target Q-values
with torch.no_grad():
next_q = self.target_network(next_states_t).max(dim=1)[0]
target_q = rewards_t + self.gamma * next_q * (1 - dones_t)
# Huber loss (more stable than MSE)
loss = F.smooth_l1_loss(current_q, target_q)
# Optimize
self.optimizer.zero_grad()
loss.backward()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 10)
self.optimizer.step()
# Update target network
self.step_count += 1
if self.step_count % self.target_update_freq == 0:
self.target_network.load_state_dict(self.q_network.state_dict())
return loss.item()Epsilon Decay Schedule
Epsilon decay gradually reduces the exploration rate from a high value (e.g., 1.0) to a low value (e.g., 0.1) over training. This balances exploration early on with exploitation of learned knowledge later.
At the start of training, the Q-network outputs random values. Following its greedy policy would be no better than random. We use high exploration () to gather diverse experience.
As training progresses, the Q-values become more accurate. We can trust them more and explore less. By the end, we still keep some exploration () to avoid getting stuck.
The original DQN paper used linear decay:
- Start at
- Linearly decay to over 1 million steps
- Stay at thereafter
Linear decay schedule:
For DQN default values:
- Decay steps = 1,000,000
At step 500,000:
class EpsilonSchedule:
"""
Epsilon decay schedules for exploration.
"""
@staticmethod
def linear(step, start=1.0, end=0.1, decay_steps=1000000):
"""Linear decay from start to end."""
progress = min(1.0, step / decay_steps)
return start + progress * (end - start)
@staticmethod
def exponential(step, start=1.0, end=0.1, decay_rate=0.99999):
"""Exponential decay."""
return max(end, start * (decay_rate ** step))
@staticmethod
def piecewise(step, schedule):
"""
Piecewise linear schedule.
Args:
schedule: List of (step, epsilon) tuples
"""
for i, (s, e) in enumerate(schedule):
if step < s:
if i == 0:
return e
prev_s, prev_e = schedule[i-1]
progress = (step - prev_s) / (s - prev_s)
return prev_e + progress * (e - prev_e)
return schedule[-1][1]
# Example schedules
steps = [0, 250000, 500000, 750000, 1000000, 1500000]
print("Epsilon Schedules:")
print("-" * 50)
for step in steps:
linear = EpsilonSchedule.linear(step)
exp = EpsilonSchedule.exponential(step)
print(f"Step {step:>10}: Linear = {linear:.3f}, Exponential = {exp:.3f}")Hyperparameters from the Nature Paper
The DQN Nature paper used carefully tuned hyperparameters. Remarkably, the same settings worked across all 49 Atari games:
Network:
- 3 convolutional layers (32, 64, 64 filters)
- 1 hidden fully connected layer (512 units)
- ReLU activations
Training:
- Replay buffer size: 1,000,000 transitions
- Minibatch size: 32
- Discount factor: 0.99
- Learning rate: 0.00025 (RMSprop)
- Target network update frequency: 10,000 steps
Exploration:
- Initial epsilon: 1.0
- Final epsilon: 0.1
- Epsilon decay over: 1,000,000 steps
Preprocessing:
- 84x84 grayscale frames
- 4 stacked frames
- Frame skip: 4 (action repeated for 4 frames)
- Reward clipping: [-1, +1]
class DQNConfig:
"""
DQN hyperparameters from the Nature paper.
"""
# Network architecture
CONV_FILTERS = [32, 64, 64]
CONV_KERNELS = [8, 4, 3]
CONV_STRIDES = [4, 2, 1]
FC_HIDDEN = 512
# Training
REPLAY_BUFFER_SIZE = 1_000_000
BATCH_SIZE = 32
GAMMA = 0.99
LEARNING_RATE = 0.00025
TARGET_UPDATE_FREQ = 10_000
TRAIN_START = 50_000 # Start training after this many steps
# Exploration
EPSILON_START = 1.0
EPSILON_END = 0.1
EPSILON_DECAY_STEPS = 1_000_000
# Preprocessing
FRAME_HEIGHT = 84
FRAME_WIDTH = 84
FRAME_STACK = 4
FRAME_SKIP = 4
# Optimizer (original paper used RMSprop)
OPTIMIZER = 'rmsprop'
RMS_EPSILON = 0.01
RMS_ALPHA = 0.95 # Decay rate
# Gradient clipping
GRAD_CLIP = 10
# Training duration
TOTAL_FRAMES = 50_000_000 # 50 million frames
def create_dqn_from_config(n_actions, config=DQNConfig):
"""Create DQN agent from configuration."""
agent = DQNAgent(
n_actions=n_actions,
buffer_size=config.REPLAY_BUFFER_SIZE,
batch_size=config.BATCH_SIZE,
gamma=config.GAMMA,
lr=config.LEARNING_RATE,
target_update_freq=config.TARGET_UPDATE_FREQ,
epsilon_start=config.EPSILON_START,
epsilon_end=config.EPSILON_END,
epsilon_decay_steps=config.EPSILON_DECAY_STEPS,
)
return agent
# Print configuration summary
print("DQN Nature Paper Configuration")
print("=" * 50)
for attr in dir(DQNConfig):
if not attr.startswith('_'):
value = getattr(DQNConfig, attr)
if not callable(value):
print(f" {attr}: {value}")The Training Loop
def train_dqn_atari(env, agent, config=DQNConfig, verbose=True):
"""
Complete DQN training loop for Atari.
Args:
env: Preprocessed Atari environment
agent: DQNAgent instance
config: Hyperparameter configuration
verbose: Print progress
Returns:
Dictionary with training statistics
"""
stats = {
'episode_rewards': [],
'episode_lengths': [],
'losses': [],
'epsilons': [],
}
total_steps = 0
episode = 0
while total_steps < config.TOTAL_FRAMES:
episode += 1
state = env.reset()
episode_reward = 0
episode_length = 0
done = False
while not done:
# Select and execute action
action = agent.select_action(state)
next_state, reward, done, info = env.step(action)
# Store transition
agent.store_transition(state, action, reward, next_state, done)
# Train (after warmup)
if total_steps >= config.TRAIN_START:
loss = agent.train_step()
if loss is not None:
stats['losses'].append(loss)
state = next_state
episode_reward += reward
episode_length += 1
total_steps += 1
# Periodic logging
if total_steps % 100000 == 0 and verbose:
avg_reward = np.mean(stats['episode_rewards'][-100:]) if stats['episode_rewards'] else 0
avg_loss = np.mean(stats['losses'][-1000:]) if stats['losses'] else 0
epsilon = agent.get_epsilon()
print(f"Step {total_steps:,} | "
f"Episodes: {episode} | "
f"Avg Reward: {avg_reward:.2f} | "
f"Avg Loss: {avg_loss:.4f} | "
f"Epsilon: {epsilon:.3f}")
# Episode finished
stats['episode_rewards'].append(episode_reward)
stats['episode_lengths'].append(episode_length)
stats['epsilons'].append(agent.get_epsilon())
return stats
def evaluate_dqn(env, agent, n_episodes=10, render=False):
"""
Evaluate trained DQN agent.
Uses greedy policy (epsilon=0).
"""
rewards = []
for episode in range(n_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
# Greedy action selection
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
q_values = agent.q_network(state_t)
action = q_values.argmax(dim=1).item()
state, reward, done, _ = env.step(action)
episode_reward += reward
if render:
env.render()
rewards.append(episode_reward)
print(f"Evaluation over {n_episodes} episodes:")
print(f" Mean reward: {np.mean(rewards):.2f}")
print(f" Std reward: {np.std(rewards):.2f}")
print(f" Min reward: {np.min(rewards):.2f}")
print(f" Max reward: {np.max(rewards):.2f}")
return rewardsWhat Learning Looks Like
Training DQN on Atari takes a long time. With the standard settings:
- 50 million frames of gameplay
- About 10-14 hours on a modern GPU
- ~200 episodes of Breakout before seeing improvement
- Superhuman performance on many games after full training
The learning curve is often bumpy:
- Early on, performance improves slowly as the network learns basic patterns
- Middle training shows rapid improvement as strategies crystallize
- Late training plateaus as the agent approaches optimal play
Different games show different patterns:
- Pong: Learns relatively quickly (a few hundred episodes)
- Breakout: Shows the famous “tunnel” strategy discovery
- Montezuma’s Revenge: Fails almost completely (sparse rewards)
def plot_training_progress(stats, window=100):
"""
Visualize DQN training progress.
"""
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Episode rewards
ax = axes[0, 0]
rewards = stats['episode_rewards']
ax.plot(rewards, alpha=0.3, label='Raw')
if len(rewards) >= window:
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
ax.plot(range(window-1, len(rewards)), smoothed, label=f'{window}-episode average')
ax.set_xlabel('Episode')
ax.set_ylabel('Reward')
ax.set_title('Episode Rewards')
ax.legend()
# Episode lengths
ax = axes[0, 1]
lengths = stats['episode_lengths']
ax.plot(lengths, alpha=0.3)
if len(lengths) >= window:
smoothed = np.convolve(lengths, np.ones(window)/window, mode='valid')
ax.plot(range(window-1, len(lengths)), smoothed)
ax.set_xlabel('Episode')
ax.set_ylabel('Steps')
ax.set_title('Episode Lengths')
# Training loss
ax = axes[1, 0]
losses = stats['losses']
# Sample for plotting (too many points)
sample_rate = max(1, len(losses) // 10000)
sampled_losses = losses[::sample_rate]
ax.plot(sampled_losses, alpha=0.5)
ax.set_xlabel('Training Step (sampled)')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
# Epsilon decay
ax = axes[1, 1]
epsilons = stats['epsilons']
ax.plot(epsilons)
ax.set_xlabel('Episode')
ax.set_ylabel('Epsilon')
ax.set_title('Exploration Rate')
plt.tight_layout()
return fig
# Example: what training stats might look like
def simulate_training_stats():
"""Generate example training statistics."""
n_episodes = 1000
# Simulate learning curve (reward increases over time)
base_reward = np.linspace(-21, 21, n_episodes)
noise = np.random.randn(n_episodes) * 5
rewards = base_reward + noise + np.sin(np.arange(n_episodes) / 20) * 3
# Episode lengths tend to increase as agent survives longer
lengths = 100 + np.cumsum(np.random.randn(n_episodes) * 0.5)
lengths = np.clip(lengths, 50, 5000)
# Loss decreases then fluctuates
n_steps = 100000
losses = 10 * np.exp(-np.arange(n_steps) / 20000) + 0.5 + np.random.randn(n_steps) * 0.3
losses = np.clip(losses, 0.1, 20)
# Epsilon decays
epsilons = [max(0.1, 1.0 - i/500) for i in range(n_episodes)]
return {
'episode_rewards': list(rewards),
'episode_lengths': list(lengths),
'losses': list(losses),
'epsilons': epsilons,
}
stats = simulate_training_stats()
print("Simulated Training Statistics:")
print(f" Final avg reward: {np.mean(stats['episode_rewards'][-100:]):.2f}")
print(f" Final avg length: {np.mean(stats['episode_lengths'][-100:]):.0f}")
print(f" Final loss: {np.mean(stats['losses'][-1000:]):.4f}")Summary
The complete DQN algorithm combines:
- CNN architecture for processing visual observations
- Experience replay for breaking sample correlations
- Target network for stable learning targets
- Epsilon-greedy exploration with decay schedule
- Careful hyperparameter choices that generalize across games
Key implementation details:
- Huber loss instead of MSE for robustness
- Gradient clipping to prevent exploding gradients
- Warmup period before training starts
- Frame preprocessing (grayscale, resize, stack)
Training characteristics:
- Slow but steady: Millions of frames for good performance
- Same settings work across games: Remarkable generality
- Some games are hard: Sparse rewards (Montezuma) remain challenging
DQN was a breakthrough that showed deep learning and reinforcement learning could work together. The next chapter explores improvements that pushed performance even further.
DQN was published in 2013 (NIPS) and 2015 (Nature). It demonstrated superhuman performance on 29 of 49 Atari games, using the same algorithm and hyperparameters for all games. This generality was as remarkable as the raw performance.