Experience Replay
Neural networks learn best from independent, identically distributed (i.i.d.) samples. But reinforcement learning generates highly correlated data: consecutive states are nearly identical, and the agent spends long stretches in similar parts of the state space.
Experience replay solves this by storing transitions in a buffer and sampling randomly for training. This simple technique transforms unstable neural network training into reliable learning.
The Correlation Problem
Sample correlation occurs when consecutive training samples are not independent. In RL, this happens because states at time and are nearly identical, and the agent’s trajectory visits similar states in sequence.
Imagine learning to play a video game. In one episode, you spend 10 minutes walking through a corridor before finding the exit. If you train your network on each frame as you go:
- Frame 1: Corridor, left side
- Frame 2: Corridor, left side (almost identical)
- Frame 3: Corridor, left side
- … (1000 more corridor frames)
- Frame 1001: Exit door!
Your network will receive 1000 gradient updates about “being in a corridor.” It will become very confident about corridor Q-values, while completely forgetting what it learned about other game areas.
This is catastrophic forgetting: the network overfits to recent experience and loses previously learned knowledge.
The updates are also correlated: frame and frame are nearly identical, so the gradients point in almost the same direction. This violates the assumption of stochastic gradient descent that samples are i.i.d.
Standard SGD assumes samples are drawn i.i.d. from some distribution:
But in online RL, samples come from a trajectory:
These samples are correlated in two ways:
- Temporal correlation: is very similar to
- Policy correlation: The distribution of states depends on the current policy
The sample covariance is high:
This correlation causes the gradient variance to increase, potentially leading to divergence or slow convergence.
The Replay Buffer
A replay buffer (or experience replay memory) is a data structure that stores past transitions . During training, transitions are sampled randomly from the buffer rather than used immediately.
The solution is elegant: instead of learning from each transition immediately, we store it in a buffer. When it is time to train, we sample a random batch from the buffer.
Think of it like a studying strategy. Instead of cramming the same topic for hours, you:
- Keep flashcards with different topics
- Shuffle them randomly
- Study a random mix each session
This prevents overfitting to any single topic and keeps all knowledge fresh.
The replay buffer works the same way:
- Store each transition (state, action, reward, next_state, done) as an “experience”
- When training, sample a random batch of 32-64 experiences
- The batch contains diverse states from different times and situations
The replay buffer is typically implemented as a circular buffer with capacity :
where indicates whether is terminal.
During training, we sample a minibatch uniformly:
where is the batch size (typically 32).
The expected gradient becomes:
This expectation is over the replay buffer distribution, which approximates i.i.d. samples from the agent’s experience.
import numpy as np
from collections import deque
import random
class ReplayBuffer:
"""
Circular buffer for storing and sampling transitions.
Stores transitions as (state, action, reward, next_state, done).
When full, oldest transitions are overwritten.
"""
def __init__(self, capacity):
"""
Args:
capacity: Maximum number of transitions to store
"""
self.buffer = deque(maxlen=capacity)
def push(self, state, action, reward, next_state, done):
"""Add a transition to the buffer."""
self.buffer.append((state, action, reward, next_state, done))
def sample(self, batch_size):
"""
Sample a random batch of transitions.
Args:
batch_size: Number of transitions to sample
Returns:
Tuple of (states, actions, rewards, next_states, dones)
Each is a numpy array
"""
batch = random.sample(self.buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return (
np.array(states, dtype=np.float32),
np.array(actions, dtype=np.int64),
np.array(rewards, dtype=np.float32),
np.array(next_states, dtype=np.float32),
np.array(dones, dtype=np.float32)
)
def __len__(self):
return len(self.buffer)
def is_ready(self, batch_size):
"""Check if buffer has enough samples for a batch."""
return len(self.buffer) >= batch_size
# Example usage
buffer = ReplayBuffer(capacity=10000)
# Simulate adding experiences
for i in range(100):
state = np.random.randn(4)
action = np.random.randint(0, 2)
reward = np.random.randn()
next_state = np.random.randn(4)
done = np.random.random() < 0.1
buffer.push(state, action, reward, next_state, done)
print(f"Buffer size: {len(buffer)}")
# Sample a batch
states, actions, rewards, next_states, dones = buffer.sample(32)
print(f"\nBatch shapes:")
print(f" States: {states.shape}")
print(f" Actions: {actions.shape}")
print(f" Rewards: {rewards.shape}")
print(f" Next states: {next_states.shape}")
print(f" Dones: {dones.shape}")Benefits of Experience Replay
Experience replay provides multiple benefits beyond breaking correlations:
1. Data Efficiency Each experience can be used multiple times. Instead of learning from a transition once and discarding it, we can sample it again and again. A transition might be sampled 10 times before being evicted from the buffer.
2. Diverse Batches A single batch might contain:
- A transition from 5 minutes ago
- A transition from the current episode
- A transition from a completely different game situation
This diversity forces the network to learn general patterns, not just recent ones.
3. Stable Learning Random sampling creates a relatively stationary distribution. Even as the policy changes, old experiences in the buffer provide a stabilizing anchor.
4. Forgetting Prevention The buffer contains experiences from the agent’s entire history (up to capacity). Even if the current policy never visits certain states, the network still trains on those states.
Data efficiency: With a buffer of size and uniform sampling, each transition is sampled on average:
For DQN with a buffer of 1 million and 10 million updates at batch size 32:
Distribution stability: Let be the state distribution at time . With online learning:
With experience replay:
The replay distribution changes slowly, even as changes rapidly with the policy.
import numpy as np
import matplotlib.pyplot as plt
def demonstrate_data_efficiency():
"""
Compare data efficiency: online vs replay.
"""
n_transitions = 10000
buffer_size = 5000
batch_size = 32
n_updates = 20000
# Track how many times each transition is used
online_uses = np.zeros(n_transitions)
replay_uses = np.zeros(n_transitions)
# Online learning: each transition used exactly once
for i in range(min(n_transitions, n_updates)):
online_uses[i] = 1
# Replay: sample from buffer
buffer = ReplayBuffer(capacity=buffer_size)
for i in range(n_transitions):
# Add transition to buffer
buffer.push(np.zeros(4), 0, 0, np.zeros(4), False)
# Simulate training (after warmup)
if len(buffer) >= batch_size:
# Sample batch - track indices
indices = random.sample(range(len(buffer)), batch_size)
for idx in indices:
# Map buffer index to transition index
actual_idx = max(0, i - len(buffer) + 1) + idx
if actual_idx < n_transitions:
replay_uses[actual_idx] += 1
print("Data Efficiency Comparison")
print("=" * 40)
print(f"Online learning:")
print(f" Mean uses per transition: {online_uses.mean():.2f}")
print(f" Max uses: {online_uses.max():.0f}")
print(f"Replay buffer:")
print(f" Mean uses per transition: {replay_uses.mean():.2f}")
print(f" Max uses: {replay_uses.max():.0f}")
demonstrate_data_efficiency()Buffer Size Considerations
The buffer size is a crucial hyperparameter. Too small, and you lose the benefits of diverse sampling. Too large, and you waste memory and train on outdated experiences.
Too small (e.g., 1,000 transitions):
- Buffer fills quickly and old experiences are evicted
- Less diversity in sampled batches
- Risk of catastrophic forgetting returns
Too large (e.g., 10 million transitions):
- Requires significant memory (especially for image observations)
- Old experiences may be from a very different policy
- Learning from outdated transitions can slow convergence
Just right (typically 100K - 1M for Atari):
- Good balance of diversity and relevance
- Old enough for decorrelation
- New enough to be useful for current policy
For Atari games with 84x84 grayscale frames, a buffer of 1 million transitions requires:
This is why efficient replay buffer implementations are important.
import numpy as np
class EfficientReplayBuffer:
"""
Memory-efficient replay buffer for image observations.
Stores images as uint8 instead of float32 to save 4x memory.
Converts to float32 only when sampling.
"""
def __init__(self, capacity, state_shape, n_frames=4):
"""
Args:
capacity: Maximum transitions
state_shape: Shape of a single frame (H, W)
n_frames: Number of stacked frames
"""
self.capacity = capacity
self.n_frames = n_frames
# Pre-allocate arrays (uint8 for memory efficiency)
# Store individual frames, reconstruct stacks when sampling
total_frames = capacity + n_frames - 1
self.frames = np.zeros((total_frames, *state_shape), dtype=np.uint8)
self.actions = np.zeros(capacity, dtype=np.int32)
self.rewards = np.zeros(capacity, dtype=np.float32)
self.dones = np.zeros(capacity, dtype=np.bool_)
self.position = 0
self.size = 0
self.frame_position = 0
def push(self, frame, action, reward, done):
"""
Add a transition (storing only the newest frame).
Args:
frame: Single frame (H, W), uint8
action: Action taken
reward: Reward received
done: Episode ended
"""
# Store the frame
self.frames[self.frame_position] = frame
self.actions[self.position] = action
self.rewards[self.position] = reward
self.dones[self.position] = done
# Update positions
self.position = (self.position + 1) % self.capacity
self.frame_position = (self.frame_position + 1) % (self.capacity + self.n_frames - 1)
self.size = min(self.size + 1, self.capacity)
def _get_state(self, idx):
"""Reconstruct stacked state for index."""
# Get n_frames ending at idx
frames = []
for i in range(self.n_frames):
frame_idx = (idx - self.n_frames + 1 + i) % (self.capacity + self.n_frames - 1)
frames.append(self.frames[frame_idx])
return np.stack(frames)
def sample(self, batch_size):
"""Sample batch and convert to float32."""
# Sample valid indices (not first n_frames-1 positions)
valid_range = range(self.n_frames - 1, self.size)
indices = np.random.choice(list(valid_range), batch_size, replace=False)
# Reconstruct states
states = np.array([self._get_state(i) for i in indices], dtype=np.float32) / 255.0
next_states = np.array([self._get_state(i+1) for i in indices], dtype=np.float32) / 255.0
return (
states,
self.actions[indices],
self.rewards[indices],
next_states,
self.dones[indices]
)
def __len__(self):
return self.size
@property
def memory_usage_mb(self):
"""Estimate memory usage in MB."""
return (self.frames.nbytes + self.actions.nbytes +
self.rewards.nbytes + self.dones.nbytes) / (1024**2)
# Compare memory usage
naive_buffer_mb = 1_000_000 * 4 * 84 * 84 * 4 / (1024**2) # float32
efficient_buffer = EfficientReplayBuffer(1_000_000, (84, 84), n_frames=4)
efficient_mb = efficient_buffer.memory_usage_mb
print(f"Memory comparison for 1M transitions:")
print(f" Naive (float32 stacked frames): {naive_buffer_mb:.0f} MB")
print(f" Efficient (uint8 individual frames): {efficient_mb:.0f} MB")
print(f" Savings: {naive_buffer_mb / efficient_mb:.1f}x")Integration with Training
In DQN, experience replay integrates into the training loop as follows:
- Collect experience: Take action in environment, store transition
- Check buffer: Only train when buffer has enough samples
- Sample batch: Draw random batch from buffer
- Compute loss: Use sampled transitions for TD error
- Update network: Gradient descent on the loss
The key insight is that data collection and training are decoupled. We can take many environment steps between training updates, or many training updates between environment steps.
DQN typically does one training step per environment step once the buffer is ready.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class DQNWithReplay:
"""
DQN agent with experience replay.
This demonstrates how replay integrates into training.
Target network is omitted for clarity (see next section).
"""
def __init__(self, state_dim, n_actions, buffer_size=100000,
batch_size=32, gamma=0.99, lr=1e-4):
self.n_actions = n_actions
self.batch_size = batch_size
self.gamma = gamma
# Q-network
self.q_network = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, n_actions)
)
self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
# Replay buffer
self.buffer = ReplayBuffer(capacity=buffer_size)
# Metrics
self.losses = []
def select_action(self, state, epsilon):
"""Epsilon-greedy action selection."""
if np.random.random() < epsilon:
return np.random.randint(self.n_actions)
with torch.no_grad():
state_t = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_network(state_t)
return q_values.argmax(dim=1).item()
def store_transition(self, state, action, reward, next_state, done):
"""Store transition in replay buffer."""
self.buffer.push(state, action, reward, next_state, done)
def train_step(self):
"""
Perform one training step using experience replay.
Returns:
loss value, or None if buffer not ready
"""
# Check if buffer has enough samples
if not self.buffer.is_ready(self.batch_size):
return None
# Sample random batch
states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)
# Convert to tensors
states_t = torch.FloatTensor(states)
actions_t = torch.LongTensor(actions)
rewards_t = torch.FloatTensor(rewards)
next_states_t = torch.FloatTensor(next_states)
dones_t = torch.FloatTensor(dones)
# Compute current Q-values
current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# Compute target Q-values
with torch.no_grad():
next_q = self.q_network(next_states_t).max(dim=1)[0]
target_q = rewards_t + self.gamma * next_q * (1 - dones_t)
# Compute loss
loss = nn.MSELoss()(current_q, target_q)
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
self.losses.append(loss.item())
return loss.item()
# Training loop example
def train_dqn(env, agent, n_episodes=100, warmup_steps=1000):
"""
Train DQN agent with experience replay.
"""
epsilon = 1.0
epsilon_decay = 0.995
epsilon_min = 0.01
episode_rewards = []
steps = 0
for episode in range(n_episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
# Select action
action = agent.select_action(state, epsilon)
# Take action in environment
next_state, reward, done, _ = env.step(action)
# Store in replay buffer
agent.store_transition(state, action, reward, next_state, done)
# Train (after warmup)
if steps >= warmup_steps:
loss = agent.train_step()
state = next_state
episode_reward += reward
steps += 1
# Decay epsilon
epsilon = max(epsilon_min, epsilon * epsilon_decay)
episode_rewards.append(episode_reward)
if (episode + 1) % 10 == 0:
avg_reward = np.mean(episode_rewards[-10:])
print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}, "
f"Epsilon: {epsilon:.3f}, Buffer: {len(agent.buffer)}")
return episode_rewardsSummary
Experience replay is fundamental to making DQN work:
- Breaking correlations: Random sampling provides approximately i.i.d. data for training
- Data efficiency: Each transition can be used multiple times
- Stability: The replay distribution changes slowly, providing a stable learning signal
- Preventing forgetting: Old experiences keep the network’s knowledge fresh
Key design choices:
- Buffer size: 100K - 1M for Atari (balance diversity vs relevance)
- Batch size: 32 is typical
- Warmup: Fill buffer with random actions before training starts
Experience replay addresses one half of the instability problem in neural network Q-learning. The next section covers the other half: target networks that prevent the “chasing a moving target” problem.
Experience replay can be combined with other sampling strategies. Prioritized Experience Replay samples important transitions more often, significantly improving learning efficiency. We will cover this in the DQN Improvements chapter.