Deep Q-Networks
What You'll Learn
- Understand why tabular methods fail for large or continuous state spaces
- Learn how neural networks can approximate Q-functions
- Identify the deadly triad and why naive deep Q-learning fails
- Master the two key innovations: experience replay and target networks
- Implement a complete DQN agent that learns to balance a pole
When Tables Aren’t Enough
Q-learning is elegant. Store a value for every state-action pair, update them over time, and the optimal policy emerges. But there’s a problem.
In GridWorld, we had maybe 16 states and 4 actions—64 Q-values to store. Easy.
Now consider Atari games. A single frame is 210×160 pixels with 128 possible colors. That’s possible states. There are only about atoms in the observable universe.
Even simpler problems explode quickly:
- 10×10 GridWorld: 400 state-action pairs ✓
- 100×100 GridWorld: 40,000 state-action pairs (manageable)
- Robot with 6 joints, 100 positions each: states (impossible)
- Continuous states: Infinite (truly impossible for tables)
This is the curse of dimensionality: storage and data requirements grow exponentially with state dimensions.
But there’s a deeper problem than just storage.
Even if we could store all those Q-values, we’d never visit most states. Learning Q(s,a) requires experiencing state —if we never see it, we never learn it.
This is where generalization becomes essential. We need to:
- Learn something useful from the states we do see
- Apply that knowledge to states we haven’t seen
A chess position we’ve never encountered might be similar to ones we have. We need to recognize “this feels like a winning position” without having seen this exact board before.
Function Approximation: From Tables to Functions
The solution: instead of storing Q-values in a table, approximate them with a function.
Instead of:
Q[state1][action1] = 0.5
Q[state1][action2] = 0.3
Q[state2][action1] = 0.8
...We learn a function:
Q(state, action) ≈ some_function(state, action; θ)where are learnable parameters. Given any state-action pair, the function outputs an estimated Q-value. States it hasn’t seen get reasonable estimates based on similar states it has seen.
What function should we use? We need something that:
- Can represent complex patterns
- Can be trained from data
- Generalizes to unseen inputs
Enter neural networks.
Mathematical Details
We parameterize the Q-function with a neural network:
where represents the network weights.
Architecture choices:
- State-action input: Network takes as input, outputs a single Q-value
- State input, action outputs (more common): Network takes as input, outputs Q-values for all actions simultaneously
The second approach is more efficient—one forward pass gives us all action values.
For a game like CartPole with 4 state dimensions (position, velocity, angle, angular velocity) and 2 actions (push left, push right):
Input: [position, velocity, angle, angular_velocity]
↓
[Hidden Layer 1 - 64 neurons]
↓
[Hidden Layer 2 - 64 neurons]
↓
Output: [Q(s, left), Q(s, right)]To act greedily: take the action with highest output.
Training the Network
How do we train this network? The same idea as tabular Q-learning, but with gradient descent.
Mathematical Details
Recall the Q-learning update:
The term in brackets is the TD error: how wrong our prediction was.
For neural networks, we minimize the squared TD error:
We compute the gradient and update:
This looks straightforward. But there’s a catch—several catches, actually.
The Deadly Triad: Why Naive Deep Q-Learning Fails
Here’s an uncomfortable truth: combining function approximation with Q-learning often diverges catastrophically.
Try this naive approach:
- Use a neural network for Q
- Collect experience by acting in the environment
- Update the network on each transition using the TD error
Run it. Watch the Q-values explode to infinity. Watch the agent perform terribly.
This isn’t a bug in your code. It’s a fundamental problem called the deadly triad:
- Function approximation: We’re approximating Q, not storing exact values
- Bootstrapping: Our target uses our own Q estimates (the term)
- Off-policy learning: We learn about the greedy policy while acting with exploration
Any two of these are fine. All three together can cause instability.
Mathematical Details
Why does bootstrapping cause problems?
The TD target is:
When we update to make closer to , we also change —the very thing we used to compute .
It’s like chasing a moving target. If the target moves faster than we can catch it, we never converge. Worse, errors can compound: an overestimate in raises the target, causing more overestimation.
In tabular Q-learning, this stabilizes because each state-action has its own independent entry. With function approximation, everything is coupled through shared weights.
The Correlation Problem
There’s another issue: correlated samples.
When we act in an environment, consecutive transitions are highly correlated:
- Step 1: State , action , next state
- Step 2: State , action , next state
- Step 3: State , action , next state
If we train on each transition as it happens, we’re fitting to batches of very similar data. Neural networks overfit badly to this—they forget what they learned earlier while memorizing recent experience.
Imagine studying for an exam by reading only one chapter repeatedly, then switching to another chapter. You’d ace questions about the last chapter and bomb everything else.
Experience Replay: Breaking the Correlation
The first DQN innovation: experience replay.
Instead of training on each transition immediately, we:
- Store transitions in a replay buffer (memory)
- Sample random batches from the buffer for training
This breaks the temporal correlation. A single training batch might contain:
- A transition from episode 1
- A transition from episode 47
- A transition from episode 23
- …
Each batch is diverse—the network sees varied experiences and learns general patterns.
Implementation
import random
from collections import deque
import numpy as np
class ReplayBuffer:
"""
Store transitions and sample random batches for training.
Each transition is (state, action, reward, next_state, done).
"""
def __init__(self, capacity=10000):
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
"""Add a transition to the buffer."""
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
"""Sample a random batch of transitions."""
batch = random.sample(self.buffer, batch_size)
states = np.array([t[0] for t in batch])
actions = np.array([t[1] for t in batch])
rewards = np.array([t[2] for t in batch])
next_states = np.array([t[3] for t in batch])
dones = np.array([t[4] for t in batch])
return states, actions, rewards, next_states, dones
def __len__(self):
return len(self.buffer)Mathematical Details
With experience replay, our loss function becomes:
where is the replay buffer. The expectation is over uniformly sampled transitions, not consecutive ones.
Benefits:
- Decorrelation: Batches contain diverse, unrelated transitions
- Data efficiency: Each transition can be used multiple times
- Stability: Learning from a distribution of experiences, not a single trajectory
Experience replay requires off-policy learning—we must be able to learn from transitions generated by an old policy. This is why Q-learning (off-policy) works with replay, but on-policy methods like SARSA would not.
Target Networks: Stabilizing the Target
Experience replay helps, but we still have the moving target problem. The second DQN innovation: target networks.
The problem: when we update , both our prediction and our target change simultaneously.
The solution: use a separate copy of the network for computing targets.
We maintain two networks:
- Online network (): Used for action selection, updated frequently
- Target network (): Used for computing TD targets, updated slowly
The target becomes:
Now when we update , the target stays fixed (until we sync the networks).
Mathematical Details
The DQN loss with target network:
Note: is treated as a constant—we don’t compute gradients through it.
Update strategies:
- Hard update: Copy every steps
- Soft update: Slowly blend each step
Hard updates (e.g., ) are simpler. Soft updates (e.g., ) are smoother.
Implementation
import torch
import torch.nn as nn
import copy
class QNetwork(nn.Module):
"""Neural network for Q-function approximation."""
def __init__(self, state_dim, action_dim, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
def forward(self, state):
"""Return Q-values for all actions given state."""
return self.net(state)
def create_target_network(online_network):
"""Create a copy of the online network for target computation."""
target_network = copy.deepcopy(online_network)
# Freeze target network - we don't compute gradients for it
for param in target_network.parameters():
param.requires_grad = False
return target_network
def hard_update(online_network, target_network):
"""Copy online network weights to target network."""
target_network.load_state_dict(online_network.state_dict())
def soft_update(online_network, target_network, tau=0.005):
"""Slowly blend online network weights into target network."""
for target_param, online_param in zip(
target_network.parameters(), online_network.parameters()
):
target_param.data.copy_(
tau * online_param.data + (1 - tau) * target_param.data
)For CartPole and simple environments, hard updates every 100-1000 steps work well. For more complex tasks, soft updates with τ = 0.001 to 0.01 provide smoother learning.
The Complete DQN Algorithm
Let’s put it all together.
Mathematical Details
Deep Q-Network (DQN) Algorithm:
-
Initialize replay buffer with capacity
-
Initialize Q-network with random weights
-
Initialize target network
-
For each episode:
- Initialize state
- For each step:
- Select action using ε-greedy:
- Execute , observe reward and next state
- Store transition in
- Sample random minibatch from
- Compute targets:
- Update by gradient descent on
- Every steps:
Implementation
import torch
import torch.nn.functional as F
import numpy as np
class DQNAgent:
"""Complete DQN agent with experience replay and target network."""
def __init__(
self,
state_dim,
action_dim,
hidden_dim=64,
lr=1e-3,
gamma=0.99,
epsilon_start=1.0,
epsilon_end=0.01,
epsilon_decay=0.995,
buffer_size=10000,
batch_size=64,
target_update_freq=100
):
self.action_dim = action_dim
self.gamma = gamma
self.batch_size = batch_size
self.target_update_freq = target_update_freq
# Epsilon for exploration
self.epsilon = epsilon_start
self.epsilon_end = epsilon_end
self.epsilon_decay = epsilon_decay
# Networks
self.q_network = QNetwork(state_dim, action_dim, hidden_dim)
self.target_network = create_target_network(self.q_network)
# Optimizer
self.optimizer = torch.optim.Adam(self.q_network.parameters(), lr=lr)
# Replay buffer
self.replay_buffer = ReplayBuffer(buffer_size)
# Step counter for target updates
self.steps = 0
def select_action(self, state, training=True):
"""Select action using epsilon-greedy policy."""
if training and np.random.random() < self.epsilon:
return np.random.randint(self.action_dim)
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
q_values = self.q_network(state_tensor)
return q_values.argmax().item()
def store_transition(self, state, action, reward, next_state, done):
"""Store a transition in the replay buffer."""
self.replay_buffer.push(state, action, reward, next_state, done)
def train_step(self):
"""Perform one training step on a batch from replay buffer."""
if len(self.replay_buffer) < self.batch_size:
return None # Not enough samples yet
# Sample batch
states, actions, rewards, next_states, dones = \
self.replay_buffer.sample(self.batch_size)
# Convert to tensors
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Current Q-values: Q(s, a)
current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
# Target Q-values: r + γ * max_a' Q(s', a'; θ⁻)
with torch.no_grad():
next_q = self.target_network(next_states).max(dim=1)[0]
target_q = rewards + self.gamma * next_q * (1 - dones)
# Compute loss
loss = F.mse_loss(current_q, target_q)
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network periodically
self.steps += 1
if self.steps % self.target_update_freq == 0:
hard_update(self.q_network, self.target_network)
# Decay epsilon
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
return loss.item()Training Loop
Implementation
import gymnasium as gym
def train_dqn(env_name="CartPole-v1", episodes=500):
"""Train a DQN agent on a Gymnasium environment."""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = DQNAgent(state_dim, action_dim)
episode_rewards = []
for episode in range(episodes):
state, _ = env.reset()
total_reward = 0
done = False
while not done:
# Select and execute action
action = agent.select_action(state)
next_state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# Store and learn
agent.store_transition(state, action, reward, next_state, done)
agent.train_step()
state = next_state
total_reward += reward
episode_rewards.append(total_reward)
# Logging
if (episode + 1) % 20 == 0:
avg_reward = np.mean(episode_rewards[-20:])
print(f"Episode {episode + 1}, Avg Reward: {avg_reward:.1f}, "
f"Epsilon: {agent.epsilon:.3f}")
return agent, episode_rewards
# Train the agent
agent, rewards = train_dqn()DQN in Action
When you run DQN on CartPole, you’ll see a characteristic learning curve:
Early training (episodes 0-50):
- High epsilon means lots of exploration
- Replay buffer filling up
- Q-values start random, slowly improve
- Rewards low and variable
Middle training (episodes 50-200):
- Epsilon decreasing, more exploitation
- Q-values becoming meaningful
- Agent starts balancing longer
- Rewards climbing
Late training (episodes 200+):
- Low epsilon, mostly greedy actions
- Q-values accurate
- Agent consistently balances for full episode
- Rewards plateau at maximum
The transition from chaos to competence is remarkably quick once the Q-network locks in the right patterns.
DQN is sensitive to hyperparameters. If training diverges (Q-values explode, rewards crash), try:
- Lower learning rate (1e-4 instead of 1e-3)
- Larger replay buffer
- More frequent target updates
- Gradient clipping
Advanced Variants: Beyond Vanilla DQN
DQN was a breakthrough, but researchers quickly found ways to improve it.
Deep Dive
Double DQN: Fixing Overestimation
The max operator in Q-learning causes systematic overestimation. If is noisy, picks the action with highest noise, biasing upward.
Double DQN decouples action selection from evaluation:
The online network selects the best action; the target network evaluates it. This reduces overestimation significantly.
# Double DQN target computation
with torch.no_grad():
# Online network selects best action
best_actions = self.q_network(next_states).argmax(dim=1)
# Target network evaluates that action
next_q = self.target_network(next_states).gather(1, best_actions.unsqueeze(1)).squeeze()
target_q = rewards + self.gamma * next_q * (1 - dones)Dueling DQN: Separating Value and Advantage
Some states are good or bad regardless of action. Dueling DQN splits the Q-function:
- : How good is this state overall?
- : How much better is action than average?
The network has two heads sharing early layers. This architecture learns faster because it can update from any action, not just the one taken.
Prioritized Experience Replay
Not all experiences are equally valuable. Prioritized replay samples transitions based on their TD error—we learn more from surprising experiences.
Transitions with high TD error (predictions were wrong) get sampled more often. This focuses learning on what the network doesn’t yet understand.
Historical Context
Deep Dive
DQN was published by Mnih et al. at DeepMind in 2015 in the paper “Human-level control through deep reinforcement learning” (Nature).
Key achievements:
- Single algorithm: Same network architecture and hyperparameters across 49 Atari games
- Raw pixels: Learned directly from screen images (84×84×4 frames)
- Superhuman performance: Beat human experts on many games
- Minimal game-specific knowledge: Only the score as reward
This paper demonstrated that deep learning + RL could tackle complex, high-dimensional problems. It launched the modern era of deep reinforcement learning.
The architecture for Atari:
- 3 convolutional layers processing 4 stacked frames
- 2 fully connected layers
- Output: Q-value for each of 18 possible actions
Summary
Key Takeaways
- Tabular Q-learning fails at scale because we can’t store or visit all state-action pairs
- Function approximation with neural networks lets Q-learning generalize to unseen states
- The deadly triad (function approx + bootstrapping + off-policy) can cause divergence
- Experience replay breaks correlation between samples by storing and randomly sampling transitions
- Target networks stabilize learning by using a slowly-updated copy for computing TD targets
- DQN = Q-learning + neural network + experience replay + target network
- Double DQN fixes overestimation; Dueling DQN separates value and advantage
Exercises
Conceptual Questions
-
Why can’t we use experience replay with on-policy algorithms like SARSA? What property of Q-learning makes replay possible?
-
What would happen if we updated the target network every single step? How would this affect training stability?
-
The max operator causes overestimation. Why doesn’t the min operator cause underestimation? (Hint: think about how we select actions)
-
How does experience replay help with sample efficiency? Explain why DQN can learn from fewer environment interactions than tabular Q-learning.
Coding Challenges
-
Implement DQN for CartPole using the code in this chapter. Tune hyperparameters until you can consistently solve the environment (average reward > 195 over 100 episodes).
-
Add Double DQN to your implementation. Compare learning curves with and without it. Do you see reduced Q-value overestimation?
-
Implement prioritized experience replay:
- Store TD errors with each transition
- Sample proportionally to TD error magnitude
- Update priorities after training
- Compare with uniform sampling
Open-Ended Exploration
- Experiment with network architecture: How does the number of layers and neurons affect learning? What happens with very small networks (8 neurons)? Very deep networks (5+ layers)?