Neural Network Approximators
Linear function approximation requires hand-crafted features. You must decide what aspects of the state matter and how to represent them. For simple problems, this works well. But for complex observations like images, designing good features is practically impossible.
Neural networks solve this by learning their own features. They can discover patterns in raw data that human engineers would never think of. This capability is what makes deep reinforcement learning possible.
The Limitations of Linear Methods
Consider playing Atari from raw pixels. The observation is an 84x84 grayscale image, meaning 7,056 pixel values. What features would you hand-craft?
- “Average brightness”? Not useful for most games.
- “Pixel at position (42, 50)”? Too specific.
- “Edge detector outputs”? Maybe, but which edges matter?
The fundamental problem is that useful features for game playing are high-level concepts:
- “Ball is moving left”
- “Enemy is approaching”
- “Paddle is aligned with ball”
These concepts exist in the pixels, but extracting them requires understanding the game. We need a system that can discover these concepts automatically from experience.
Linear methods cannot do this. If you give them raw pixels as features, they can only learn linear combinations of pixel values. They cannot learn “the ball is to the left of the paddle” because that requires nonlinear reasoning about pixel patterns.
Linear function approximation is limited to functions of the form:
The expressiveness of this model depends entirely on the features . If only includes raw pixel values, we can only represent value functions that are linear in pixels.
Many important relationships are nonlinear. For example, detecting that a ball is at a specific location requires matching a pattern across multiple pixels, which is fundamentally a nonlinear operation:
To represent such functions with linear methods, we would need features that already encode these patterns. But for complex domains, the number of possible patterns is exponential.
Neural Networks as Universal Approximators
A neural network with at least one hidden layer and nonlinear activation functions can approximate any continuous function on a compact domain to arbitrary precision, given enough hidden units. This is the universal approximation theorem.
Neural networks are function approximators that can learn their own features. They consist of layers of simple computations:
- Input layer: Takes the raw observation (e.g., pixel values)
- Hidden layers: Transform the input through learned weights and nonlinear activations
- Output layer: Produces the final value prediction
Each hidden layer learns features of the previous layer’s output. Early layers might learn simple patterns (edges, colors), while later layers combine these into complex concepts (objects, spatial relationships).
The key insight is that the features are learned from data, not hand-crafted. The network discovers what patterns are useful for predicting values.
A neural network with one hidden layer computes:
where:
- are the first layer weights and biases
- is a nonlinear activation (e.g., ReLU: )
- are the output layer weights and biases
The hidden layer output serves as a learned feature vector. The output layer is a linear function of these features, just like linear function approximation, but the features themselves are learned.
For deeper networks:
Each layer transforms its input, progressively building more abstract representations.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleQNetwork(nn.Module):
"""
A simple feedforward neural network for Q-value estimation.
Takes a state as input and outputs Q-values for all actions.
"""
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
# Two hidden layers
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, n_actions)
def forward(self, state):
"""
Forward pass: state -> Q-values for all actions.
Args:
state: Tensor of shape (batch_size, state_dim)
Returns:
Q-values of shape (batch_size, n_actions)
"""
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
q_values = self.fc3(x)
return q_values
def get_action(self, state, epsilon=0.0):
"""Select action using epsilon-greedy policy."""
if np.random.random() < epsilon:
return np.random.randint(self.fc3.out_features)
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = self.forward(state_tensor)
return q_values.argmax(dim=1).item()
# Example: Create network for CartPole
state_dim = 4 # position, velocity, angle, angular velocity
n_actions = 2 # left, right
q_network = SimpleQNetwork(state_dim, n_actions, hidden_dim=64)
# Count parameters
n_params = sum(p.numel() for p in q_network.parameters())
print(f"Network architecture:")
print(q_network)
print(f"\nTotal parameters: {n_params:,}")
# Test forward pass
sample_state = torch.randn(1, state_dim)
q_values = q_network(sample_state)
print(f"\nSample Q-values: {q_values.detach().numpy()}")Gradient Descent with Neural Networks
Training a neural network Q-function follows the same principle as linear function approximation: minimize the TD error using gradient descent.
The difference is that the gradient is now computed through backpropagation. Instead of the simple gradient from linear methods, we have a complex gradient flowing backward through the network layers.
Fortunately, modern deep learning frameworks compute these gradients automatically. We just need to define the loss and call .backward().
For a neural network Q-function , the semi-gradient TD update is:
The gradient is computed via backpropagation through all layers.
Equivalently, we can frame this as minimizing the loss:
where the target is treated as a constant (semi-gradient).
The gradient of this loss is:
import torch
import torch.nn as nn
import torch.optim as optim
class NeuralQLearning:
"""
Q-learning with a neural network function approximator.
WARNING: This naive implementation is unstable!
It demonstrates the concepts but will likely fail to learn.
See the DQN chapter for the stable version.
"""
def __init__(self, state_dim, n_actions, hidden_dim=128, lr=1e-3, gamma=0.99):
self.gamma = gamma
self.n_actions = n_actions
# Q-network
self.q_network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
def get_q_values(self, state):
"""Get Q-values for a state."""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
return self.q_network(state_tensor).squeeze(0).numpy()
def select_action(self, state, epsilon=0.1):
"""Epsilon-greedy action selection."""
if np.random.random() < epsilon:
return np.random.randint(self.n_actions)
return self.get_q_values(state).argmax()
def update(self, state, action, reward, next_state, done):
"""
Single-step Q-learning update.
This is the naive approach - it has stability issues!
"""
# Convert to tensors
state_t = torch.FloatTensor(state).unsqueeze(0)
next_state_t = torch.FloatTensor(next_state).unsqueeze(0)
# Current Q-value for the action taken
current_q = self.q_network(state_t)[0, action]
# Compute target (treating it as constant)
with torch.no_grad():
if done:
target = reward
else:
next_q_values = self.q_network(next_state_t)
target = reward + self.gamma * next_q_values.max()
# TD error loss
loss = (current_q - target) ** 2
# Gradient descent step
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
# Example usage (will likely be unstable)
import numpy as np
agent = NeuralQLearning(state_dim=4, n_actions=2)
# Simulated experience
state = np.array([0.1, 0.2, 0.05, 0.1])
action = 1
reward = 1.0
next_state = np.array([0.15, 0.25, 0.03, 0.08])
done = False
loss = agent.update(state, action, reward, next_state, done)
print(f"TD loss: {loss:.4f}")
print(f"Q-values before: {agent.get_q_values(state)}")Convolutional Networks for Images
A CNN is a neural network architecture designed for image processing. It uses convolutional layers that apply learnable filters across the input, detecting local patterns regardless of their position in the image.
For visual observations like Atari games, we need a network architecture suited for images. Fully connected networks would require too many parameters and would not exploit the structure of images.
Convolutional layers solve this by:
- Parameter sharing: The same filter is applied at every location, dramatically reducing parameters
- Local connectivity: Each neuron only looks at a small region, focusing on local patterns
- Translation invariance: The same pattern is detected regardless of where it appears
A typical CNN for RL:
- Conv layers: Detect patterns (edges, textures, objects)
- Pooling/striding: Reduce spatial dimensions
- Fully connected layers: Combine patterns into action values
The network learns a hierarchy of features: pixels to edges to shapes to objects to action values.
A 2D convolution operation computes:
where is the input image and is a learnable filter (kernel).
For a CNN layer with input and filter :
The output is then passed to the next layer. With stride , the output spatial dimensions are reduced:
import torch
import torch.nn as nn
class AtariQNetwork(nn.Module):
"""
CNN architecture for Atari-like visual inputs.
This is similar to the architecture from the DQN Nature paper.
Input: 4 stacked grayscale frames of 84x84 pixels
Output: Q-values for each action
"""
def __init__(self, n_actions, in_channels=4):
super().__init__()
# Convolutional layers
self.conv = nn.Sequential(
# Input: (batch, 4, 84, 84)
nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
nn.ReLU(),
# Output: (batch, 32, 20, 20)
nn.Conv2d(32, 64, kernel_size=4, stride=2),
nn.ReLU(),
# Output: (batch, 64, 9, 9)
nn.Conv2d(64, 64, kernel_size=3, stride=1),
nn.ReLU(),
# Output: (batch, 64, 7, 7)
)
# Fully connected layers
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(64 * 7 * 7, 512),
nn.ReLU(),
nn.Linear(512, n_actions)
)
def forward(self, x):
"""
Args:
x: Tensor of shape (batch, 4, 84, 84)
4 stacked grayscale frames, pixel values in [0, 1]
Returns:
Q-values of shape (batch, n_actions)
"""
features = self.conv(x)
q_values = self.fc(features)
return q_values
# Example: Create network for Atari
n_actions = 4 # Typical Atari game has ~4-18 actions
q_net = AtariQNetwork(n_actions=n_actions)
# Count parameters
n_params = sum(p.numel() for p in q_net.parameters())
print(f"Total parameters: {n_params:,}")
# Test forward pass
batch_size = 8
frames = torch.randn(batch_size, 4, 84, 84)
q_values = q_net(frames)
print(f"\nInput shape: {frames.shape}")
print(f"Output shape: {q_values.shape}")
print(f"Sample Q-values: {q_values[0].detach().numpy()}")The Deadly Triad Revisited
Neural networks combined with Q-learning exhibit the deadly triad: function approximation, bootstrapping, and off-policy learning. Naive neural network Q-learning is unstable and often diverges.
We saw in the linear approximation section that the combination of function approximation, bootstrapping, and off-policy learning can cause divergence. Neural networks make this worse for several reasons:
- High capacity: Neural networks can fit complex patterns, including spurious correlations
- Correlated updates: Consecutive states in an episode are similar, leading to correlated gradients
- Non-stationary targets: The target changes as the network updates
- Extrapolation errors: The network may predict wildly incorrect values for unseen states
The result is training instability. Q-values oscillate or explode. The agent forgets what it learned. Performance degrades catastrophically.
The instability can be understood through the lens of bootstrapping error amplification.
Consider the TD target:
If overestimates the true value by , the target becomes:
This overestimation propagates and can compound across updates. With neural networks, the error is often correlated across similar states, amplifying the effect.
The feedback loop is:
- Overestimate
- Use this overestimate as target for neighboring states
- Those states become overestimated
- The error propagates and grows
import numpy as np
def demonstrate_instability():
"""
Demonstrate why naive neural Q-learning is unstable.
Key issues:
1. Correlated samples from sequential experience
2. Moving targets as network updates
3. No mechanism to break these correlations
"""
print("=" * 60)
print("Why Naive Neural Q-Learning Fails")
print("=" * 60)
print("\n1. CORRELATED SAMPLES")
print("-" * 40)
print("In a game, consecutive frames are almost identical.")
print("Training on frame_t, frame_t+1, frame_t+2, ...")
print("means gradients are highly correlated.")
print("This violates the i.i.d. assumption of SGD,")
print("causing the network to overfit to recent states.")
print("\n2. NON-STATIONARY TARGETS")
print("-" * 40)
print("Target: y = r + gamma * max_a' Q(s', a'; theta)")
print("As theta updates, the target changes!")
print("We're chasing a moving target.")
print("Like trying to hit a bullseye that moves")
print("every time you throw.")
print("\n3. FEEDBACK LOOPS")
print("-" * 40)
print("Overestimate Q(s, a)")
print(" -> Higher target for states that lead to s")
print(" -> Those states get overestimated")
print(" -> Their predecessors get overestimated")
print(" -> Errors compound exponentially!")
print("\n4. CATASTROPHIC FORGETTING")
print("-" * 40)
print("Agent spends time in one part of state space")
print("Network overfits to those states")
print("Forgets about previously learned states")
print("Performance oscillates wildly")
print("\n" + "=" * 60)
print("SOLUTION: Deep Q-Networks (DQN)")
print("=" * 60)
print("\nDQN addresses these issues with two key innovations:")
print("1. Experience Replay: Store and sample past experiences")
print(" -> Breaks correlation between consecutive updates")
print("2. Target Network: Separate network for computing targets")
print(" -> Stabilizes the target during learning")
print("\nSee the DQN chapter for the full algorithm!")
demonstrate_instability()The Path to DQN
Neural networks give us the power to learn from raw pixels and discover complex features. But they also introduce instability that makes naive Q-learning fail.
The breakthrough came in 2013 when DeepMind introduced Deep Q-Networks (DQN). DQN is still Q-learning with a neural network, but with two crucial additions:
-
Experience Replay: Store transitions in a buffer and sample randomly for training
- Breaks the correlation between consecutive samples
- Each experience can be used multiple times
-
Target Network: Use a separate, slowly-updated network for computing targets
- Stabilizes the target during learning
- Prevents the “chasing a moving target” problem
These two techniques address the deadly triad well enough for DQN to learn superhuman Atari play from raw pixels.
DQN modifies the Q-learning loss to:
where:
- is the replay buffer of past experiences
- is the target network parameters, updated infrequently
The replay buffer provides i.i.d.-like samples (breaking correlation). The target network provides a stable target (breaking the feedback loop).
Together, these make neural network Q-learning stable enough to work in practice.
Summary
Neural networks enable function approximation without hand-crafted features:
- Universal approximation: Neural networks can represent any continuous function
- Learned features: The network discovers useful patterns from raw data
- Convolutional networks: Efficient architecture for image observations
- Gradient-based learning: Backpropagation computes gradients automatically
However, neural networks combined with Q-learning are unstable:
- Correlated samples violate SGD assumptions
- Moving targets create feedback loops
- High capacity enables overfitting to recent experience
The solution is DQN, which adds:
- Experience replay to break sample correlation
- Target networks to stabilize learning
The next chapter presents DQN in full detail, showing how these innovations enabled the deep RL revolution.
The gap between “neural network Q-learning” and “DQN” is two ideas, but those two ideas took years to discover and made the difference between failure and superhuman performance on Atari games.