The DQN Architecture
In 2013, DeepMind published a paper that changed the trajectory of artificial intelligence. They demonstrated that a single neural network architecture could learn to play dozens of Atari games from raw pixels, some at superhuman level. The key was not just using a neural network for Q-learning, but designing an architecture suited for visual processing and stabilizing training with novel techniques.
This section covers the neural network architecture at the heart of DQN: a convolutional neural network that transforms raw game frames into action values.
From Q-Table to Q-Network
A Q-network is a neural network that takes a state as input and outputs Q-values for all actions. Instead of storing Q-values in a table, we learn a function that computes them from the state representation.
In tabular Q-learning, we maintain a table with an entry for every state-action pair. To get the Q-value, we simply look it up:
Q-value for (state, action) = Q_table[state][action]With a Q-network, we replace the table with a neural network forward pass:
Q-values for all actions = network(state)
Q-value for specific action = network(state)[action]This has several advantages:
- Generalization: Similar states get similar Q-values automatically
- Scalability: Works for continuous or high-dimensional states
- Efficiency: Computes Q-values for all actions in one forward pass
The architecture of the network determines what kinds of state representations it can handle effectively.
A Q-network is a parameterized function:
implemented as:
where the network takes a state and outputs a vector of Q-values, one for each action. The parameters include all weights and biases of the network.
For a specific action , we index into the output:
This vectorized output is more efficient than computing separately for each action, since we often need which requires all values.
import torch
import torch.nn as nn
import torch.nn.functional as F
class QNetwork(nn.Module):
"""
Basic Q-network for low-dimensional state spaces.
Takes a state vector and outputs Q-values for all actions.
"""
def __init__(self, state_dim, n_actions, hidden_dims=[64, 64]):
"""
Args:
state_dim: Dimension of state vector
n_actions: Number of possible actions
hidden_dims: List of hidden layer dimensions
"""
super().__init__()
layers = []
prev_dim = state_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, n_actions))
self.network = nn.Sequential(*layers)
def forward(self, state):
"""
Args:
state: Tensor of shape (batch_size, state_dim)
Returns:
Q-values of shape (batch_size, n_actions)
"""
return self.network(state)
# Example: CartPole
cartpole_net = QNetwork(state_dim=4, n_actions=2)
print("CartPole Q-Network:")
print(cartpole_net)
# Test forward pass
state = torch.randn(1, 4)
q_values = cartpole_net(state)
print(f"\nInput shape: {state.shape}")
print(f"Output shape: {q_values.shape}")
print(f"Q-values: {q_values.detach().numpy()}")Convolutional Architecture for Pixels
Atari games present the agent with raw pixel observations. The screen is 210x160 pixels with RGB color, totaling over 100,000 values per frame. A fully connected network would require millions of parameters just for the first layer.
Convolutional Neural Networks (CNNs) are designed for image data. They exploit two key properties:
- Local patterns: Useful features (edges, objects) are local patterns that can be detected by small filters
- Translation invariance: The same pattern should be detected regardless of where it appears on screen
The DQN architecture uses three convolutional layers to progressively extract higher-level features:
- Layer 1: Detects simple patterns (edges, gradients)
- Layer 2: Combines patterns into shapes (paddles, balls)
- Layer 3: Detects complex objects and spatial relationships
After the convolutional layers, fully connected layers combine these features to estimate action values.
The DQN architecture from the Nature paper processes 84x84 grayscale images through three convolutional layers:
Layer 1: 32 filters of size 8x8 with stride 4
Layer 2: 64 filters of size 4x4 with stride 2
Layer 3: 64 filters of size 3x3 with stride 1
This is followed by:
- A fully connected layer with 512 units
- An output layer with one unit per action
Total parameters: approximately 1.7 million, far fewer than a fully connected architecture would require.
import torch
import torch.nn as nn
class DQNNetwork(nn.Module):
"""
The DQN convolutional architecture from the Nature paper.
Input: 4 stacked 84x84 grayscale frames (batch, 4, 84, 84)
Output: Q-values for each action (batch, n_actions)
This is the architecture that learned to play 49 Atari games.
"""
def __init__(self, n_actions, in_channels=4):
"""
Args:
n_actions: Number of possible actions
in_channels: Number of stacked frames (default: 4)
"""
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=8, stride=4)
self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)
# Fully connected layers
self.fc1 = nn.Linear(64 * 7 * 7, 512)
self.fc2 = nn.Linear(512, n_actions)
def forward(self, x):
"""
Forward pass through the network.
Args:
x: Tensor of shape (batch, 4, 84, 84)
Pixel values should be in [0, 1]
Returns:
Q-values of shape (batch, n_actions)
"""
# Convolutional layers with ReLU activation
x = F.relu(self.conv1(x)) # -> (batch, 32, 20, 20)
x = F.relu(self.conv2(x)) # -> (batch, 64, 9, 9)
x = F.relu(self.conv3(x)) # -> (batch, 64, 7, 7)
# Flatten for fully connected layers
x = x.view(x.size(0), -1) # -> (batch, 3136)
# Fully connected layers
x = F.relu(self.fc1(x)) # -> (batch, 512)
q_values = self.fc2(x) # -> (batch, n_actions)
return q_values
# Create network for a game with 4 actions
network = DQNNetwork(n_actions=4)
# Count parameters
n_params = sum(p.numel() for p in network.parameters())
print(f"Total parameters: {n_params:,}")
# Detailed parameter count
print("\nParameters by layer:")
for name, param in network.named_parameters():
print(f" {name}: {param.numel():,}")
# Test forward pass
batch = torch.randn(8, 4, 84, 84) # 8 samples of 4 stacked frames
q_values = network(batch)
print(f"\nInput shape: {batch.shape}")
print(f"Output shape: {q_values.shape}")Frame Stacking for Temporal Information
Frame stacking is a technique where multiple consecutive frames are stacked as channels in the network input. This provides temporal information, allowing the network to perceive motion and velocity from static images.
A single frame of Pong shows the ball and paddles, but not their velocities. You cannot tell if the ball is moving left or right from one static image. Yet velocity is crucial for deciding how to move.
Frame stacking solves this by providing the last 4 frames as input:
- Frame t-3: Ball at position (50, 100)
- Frame t-2: Ball at position (52, 98)
- Frame t-1: Ball at position (54, 96)
- Frame t: Ball at position (56, 94)
Now the network can “see” that the ball is moving right and up. The velocity is implicit in the differences between frames.
The 4 frames are stacked as channels (like RGB, but with 4 channels instead of 3). The convolutional filters learn to detect motion patterns across these channels.
With frame stacking, the input to the network is:
where each is an 84x84 grayscale frame. The resulting tensor has shape:
or for a batch of samples:
The first convolutional layer applies filters of shape , where the 4 corresponds to the number of stacked frames. Each filter learns to detect spatio-temporal patterns across all 4 frames simultaneously.
import numpy as np
from collections import deque
class FrameStack:
"""
Maintains a stack of the last n frames.
Used to provide temporal information to the network.
"""
def __init__(self, n_frames=4, frame_shape=(84, 84)):
"""
Args:
n_frames: Number of frames to stack
frame_shape: Shape of each frame (height, width)
"""
self.n_frames = n_frames
self.frame_shape = frame_shape
self.frames = deque(maxlen=n_frames)
def reset(self, initial_frame):
"""
Reset stack with initial frame repeated n times.
Called at the start of each episode.
"""
self.frames.clear()
for _ in range(self.n_frames):
self.frames.append(initial_frame)
def push(self, frame):
"""Add a new frame to the stack."""
self.frames.append(frame)
def get(self):
"""
Get the stacked frames as a single array.
Returns:
Array of shape (n_frames, height, width)
"""
return np.array(self.frames)
class AtariPreprocessor:
"""
Preprocessing pipeline for Atari frames.
1. Convert to grayscale
2. Resize to 84x84
3. Stack last 4 frames
"""
def __init__(self, n_frames=4):
self.frame_stack = FrameStack(n_frames=n_frames)
def preprocess_frame(self, frame):
"""
Convert raw Atari frame to 84x84 grayscale.
Args:
frame: RGB frame of shape (210, 160, 3)
Returns:
Grayscale frame of shape (84, 84)
"""
# Convert to grayscale
gray = np.mean(frame, axis=2).astype(np.uint8)
# Resize to 84x84 (simple downsampling)
# In practice, use cv2.resize for better quality
h, w = gray.shape
new_h, new_w = 84, 84
# Simple resize by taking evenly spaced pixels
row_indices = np.linspace(0, h - 1, new_h).astype(int)
col_indices = np.linspace(0, w - 1, new_w).astype(int)
resized = gray[row_indices][:, col_indices]
# Normalize to [0, 1]
normalized = resized.astype(np.float32) / 255.0
return normalized
def reset(self, initial_frame):
"""Reset with initial observation."""
processed = self.preprocess_frame(initial_frame)
self.frame_stack.reset(processed)
return self.frame_stack.get()
def step(self, frame):
"""Process new frame and return stacked state."""
processed = self.preprocess_frame(frame)
self.frame_stack.push(processed)
return self.frame_stack.get()
# Example usage
preprocessor = AtariPreprocessor()
# Simulate an Atari frame (210, 160, 3)
raw_frame = np.random.randint(0, 256, (210, 160, 3), dtype=np.uint8)
# Reset with initial frame
state = preprocessor.reset(raw_frame)
print(f"Initial state shape: {state.shape}")
# Step with new frame
new_frame = np.random.randint(0, 256, (210, 160, 3), dtype=np.uint8)
next_state = preprocessor.step(new_frame)
print(f"Next state shape: {next_state.shape}")
# Check that frames shifted correctly
print(f"\nFrame 2 same as previous frame 3: {np.allclose(next_state[2], state[3])}")Atari Preprocessing Details
The raw Atari output is a 210x160 RGB image, but DQN uses extensive preprocessing:
-
Grayscale conversion: Color is not essential for most games, and this reduces input size by 3x
-
Resize to 84x84: Reduces computation while retaining important detail
-
Frame skipping: The agent takes an action every 4 frames, but the same action is repeated. This speeds up training and matches human reaction time
-
Max over last 2 frames: Some Atari games have flickering sprites. Taking the pixel-wise maximum of the last 2 raw frames before preprocessing eliminates this
-
Reward clipping: All rewards are clipped to 1. This normalizes the learning signal across games with different score scales
These preprocessing steps are crucial for DQN performance. Without them, training is significantly slower or fails entirely.
import numpy as np
class AtariWrapper:
"""
Full Atari preprocessing wrapper.
Implements all preprocessing from the DQN Nature paper:
- Frame skipping (action repeat)
- Max pooling over last 2 frames
- Grayscale conversion
- Resize to 84x84
- Frame stacking
- Reward clipping
"""
def __init__(self, env, frame_skip=4, n_stack=4):
"""
Args:
env: Base Atari environment
frame_skip: Number of times to repeat each action
n_stack: Number of frames to stack
"""
self.env = env
self.frame_skip = frame_skip
self.frame_stack = FrameStack(n_frames=n_stack)
# Buffer for max pooling over last 2 frames
self.frame_buffer = [None, None]
def _get_observation(self):
"""Get max of last 2 frames, preprocessed."""
max_frame = np.maximum(self.frame_buffer[0], self.frame_buffer[1])
return self._preprocess(max_frame)
def _preprocess(self, frame):
"""Convert to grayscale and resize."""
# Grayscale
gray = np.mean(frame, axis=2)
# Resize (simplified - use cv2.resize in practice)
h, w = gray.shape
rows = np.linspace(0, h-1, 84).astype(int)
cols = np.linspace(0, w-1, 84).astype(int)
resized = gray[rows][:, cols]
return resized.astype(np.float32) / 255.0
def reset(self):
"""Reset environment and return initial stacked state."""
frame = self.env.reset()
self.frame_buffer = [frame, frame]
processed = self._get_observation()
self.frame_stack.reset(processed)
return self.frame_stack.get()
def step(self, action):
"""
Take action with frame skipping.
Returns:
state: Stacked frames (n_stack, 84, 84)
reward: Clipped total reward
done: Episode terminated
info: Additional info
"""
total_reward = 0
for i in range(self.frame_skip):
frame, reward, done, info = self.env.step(action)
# Update frame buffer (for max pooling)
self.frame_buffer[i % 2] = frame
# Accumulate reward
total_reward += reward
if done:
break
# Get preprocessed observation
processed = self._get_observation()
self.frame_stack.push(processed)
state = self.frame_stack.get()
# Clip reward to {-1, 0, +1}
clipped_reward = np.sign(total_reward)
return state, clipped_reward, done, info
@property
def action_space(self):
return self.env.action_space
# Note: This is a simplified version. The full implementation would use
# gymnasium's wrappers for better integration:
#
# from gymnasium.wrappers import AtariPreprocessing, FrameStack
# env = gymnasium.make("ALE/Breakout-v5")
# env = AtariPreprocessing(env, grayscale_obs=True, scale_obs=True)
# env = FrameStack(env, 4)Architecture Variations
The DQN architecture described is the one from the Nature paper, but variations exist:
Larger networks: More convolutional layers or larger hidden layers for complex games
Different input sizes: Some implementations use 84x84, others use 80x80 or different resolutions
Batch normalization: Can help stabilize training, though the original DQN did not use it
Different activations: ReLU is standard, but some use LeakyReLU or ELU
For non-visual domains, the convolutional layers are replaced with fully connected layers. The key principles remain:
- Sufficient capacity to represent the Q-function
- Architecture suited to the input modality
- Not so large that it overfits
import torch
import torch.nn as nn
class FlexibleDQN(nn.Module):
"""
Flexible DQN architecture for different input types.
Automatically selects CNN or MLP based on input shape.
"""
def __init__(self, input_shape, n_actions, hidden_dim=512):
"""
Args:
input_shape: Tuple describing input shape
(C, H, W) for images or (D,) for vectors
n_actions: Number of actions
hidden_dim: Size of hidden fully connected layer
"""
super().__init__()
if len(input_shape) == 3:
# Image input: use CNN
self.features = self._make_cnn(input_shape)
# Calculate flattened size
with torch.no_grad():
dummy = torch.zeros(1, *input_shape)
flat_size = self.features(dummy).view(1, -1).size(1)
else:
# Vector input: use MLP
self.features = nn.Identity()
flat_size = input_shape[0]
self.fc = nn.Sequential(
nn.Linear(flat_size, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def _make_cnn(self, input_shape):
"""Create CNN for image input."""
in_channels = input_shape[0]
return nn.Sequential(
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten()
)
def forward(self, x):
features = self.features(x)
if len(features.shape) > 2:
features = features.view(features.size(0), -1)
return self.fc(features)
# Example: Atari (image input)
atari_net = FlexibleDQN(input_shape=(4, 84, 84), n_actions=4)
print("Atari network parameters:", sum(p.numel() for p in atari_net.parameters()))
# Example: CartPole (vector input)
cartpole_net = FlexibleDQN(input_shape=(4,), n_actions=2)
print("CartPole network parameters:", sum(p.numel() for p in cartpole_net.parameters()))Summary
The DQN architecture transforms raw pixels into action values through:
- Convolutional layers that extract spatial features from images
- Fully connected layers that combine features into action values
- Frame stacking that provides temporal information
- Preprocessing that normalizes and simplifies the input
Key architectural choices from the Nature paper:
- 3 convolutional layers (32, 64, 64 filters)
- 1 hidden fully connected layer (512 units)
- 4 stacked 84x84 grayscale frames as input
- Approximately 1.7 million parameters
This architecture, combined with experience replay and target networks (covered in the next sections), enabled the first successful application of deep reinforcement learning to complex visual domains.
The same architecture, with the same hyperparameters, was used to learn 49 different Atari games. This generality was a key contribution of the DQN paper: one algorithm, one architecture, many tasks.