Deep Reinforcement Learning • Part 2 of 4
📝Draft

Experience Replay

Breaking correlations through random sampling

Experience Replay

Neural networks learn best from independent, identically distributed (i.i.d.) samples. But reinforcement learning generates highly correlated data: consecutive states are nearly identical, and the agent spends long stretches in similar parts of the state space.

Experience replay solves this by storing transitions in a buffer and sampling randomly for training. This simple technique transforms unstable neural network training into reliable learning.

The Correlation Problem

📖Sample Correlation

Sample correlation occurs when consecutive training samples are not independent. In RL, this happens because states at time tt and t+1t+1 are nearly identical, and the agent’s trajectory visits similar states in sequence.

Imagine learning to play a video game. In one episode, you spend 10 minutes walking through a corridor before finding the exit. If you train your network on each frame as you go:

  1. Frame 1: Corridor, left side
  2. Frame 2: Corridor, left side (almost identical)
  3. Frame 3: Corridor, left side
  4. … (1000 more corridor frames)
  5. Frame 1001: Exit door!

Your network will receive 1000 gradient updates about “being in a corridor.” It will become very confident about corridor Q-values, while completely forgetting what it learned about other game areas.

This is catastrophic forgetting: the network overfits to recent experience and loses previously learned knowledge.

The updates are also correlated: frame tt and frame t+1t+1 are nearly identical, so the gradients point in almost the same direction. This violates the assumption of stochastic gradient descent that samples are i.i.d.

Mathematical Details

Standard SGD assumes samples are drawn i.i.d. from some distribution:

(s,a,r,s)Di.i.d.(s, a, r, s') \sim \mathcal{D}_{\text{i.i.d.}}

But in online RL, samples come from a trajectory:

(st,at,rt,st+1),(st+1,at+1,rt+1,st+2),(s_t, a_t, r_t, s_{t+1}), (s_{t+1}, a_{t+1}, r_{t+1}, s_{t+2}), \ldots

These samples are correlated in two ways:

  1. Temporal correlation: st+1s_{t+1} is very similar to sts_t
  2. Policy correlation: The distribution of states depends on the current policy

The sample covariance is high:

Cov[(st,at),(st+1,at+1)]0\text{Cov}[(s_t, a_t), (s_{t+1}, a_{t+1})] \gg 0

This correlation causes the gradient variance to increase, potentially leading to divergence or slow convergence.

The Replay Buffer

📖Replay Buffer

A replay buffer (or experience replay memory) is a data structure that stores past transitions (s,a,r,s,done)(s, a, r, s', \text{done}). During training, transitions are sampled randomly from the buffer rather than used immediately.

The solution is elegant: instead of learning from each transition immediately, we store it in a buffer. When it is time to train, we sample a random batch from the buffer.

Think of it like a studying strategy. Instead of cramming the same topic for hours, you:

  1. Keep flashcards with different topics
  2. Shuffle them randomly
  3. Study a random mix each session

This prevents overfitting to any single topic and keeps all knowledge fresh.

The replay buffer works the same way:

  1. Store each transition (state, action, reward, next_state, done) as an “experience”
  2. When training, sample a random batch of 32-64 experiences
  3. The batch contains diverse states from different times and situations
Mathematical Details

The replay buffer D\mathcal{D} is typically implemented as a circular buffer with capacity NN:

D={(si,ai,ri,si,di)}i=1N\mathcal{D} = \{(s_i, a_i, r_i, s'_i, d_i)\}_{i=1}^{N}

where did_i indicates whether sis'_i is terminal.

During training, we sample a minibatch uniformly:

{(sj,aj,rj,sj,dj)}j=1BUniform(D)\{(s_j, a_j, r_j, s'_j, d_j)\}_{j=1}^{B} \sim \text{Uniform}(\mathcal{D})

where BB is the batch size (typically 32).

The expected gradient becomes:

E(s,a,r,s)D[θL(s,a,r,s;θ)]\mathbb{E}_{(s,a,r,s') \sim \mathcal{D}}[\nabla_\theta L(s, a, r, s'; \theta)]

This expectation is over the replay buffer distribution, which approximates i.i.d. samples from the agent’s experience.

</>Implementation
import numpy as np
from collections import deque
import random

class ReplayBuffer:
    """
    Circular buffer for storing and sampling transitions.

    Stores transitions as (state, action, reward, next_state, done).
    When full, oldest transitions are overwritten.
    """

    def __init__(self, capacity):
        """
        Args:
            capacity: Maximum number of transitions to store
        """
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        """Add a transition to the buffer."""
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        """
        Sample a random batch of transitions.

        Args:
            batch_size: Number of transitions to sample

        Returns:
            Tuple of (states, actions, rewards, next_states, dones)
            Each is a numpy array
        """
        batch = random.sample(self.buffer, batch_size)

        states, actions, rewards, next_states, dones = zip(*batch)

        return (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
            np.array(next_states, dtype=np.float32),
            np.array(dones, dtype=np.float32)
        )

    def __len__(self):
        return len(self.buffer)

    def is_ready(self, batch_size):
        """Check if buffer has enough samples for a batch."""
        return len(self.buffer) >= batch_size


# Example usage
buffer = ReplayBuffer(capacity=10000)

# Simulate adding experiences
for i in range(100):
    state = np.random.randn(4)
    action = np.random.randint(0, 2)
    reward = np.random.randn()
    next_state = np.random.randn(4)
    done = np.random.random() < 0.1

    buffer.push(state, action, reward, next_state, done)

print(f"Buffer size: {len(buffer)}")

# Sample a batch
states, actions, rewards, next_states, dones = buffer.sample(32)
print(f"\nBatch shapes:")
print(f"  States: {states.shape}")
print(f"  Actions: {actions.shape}")
print(f"  Rewards: {rewards.shape}")
print(f"  Next states: {next_states.shape}")
print(f"  Dones: {dones.shape}")

Benefits of Experience Replay

Experience replay provides multiple benefits beyond breaking correlations:

1. Data Efficiency Each experience can be used multiple times. Instead of learning from a transition once and discarding it, we can sample it again and again. A transition might be sampled 10 times before being evicted from the buffer.

2. Diverse Batches A single batch might contain:

  • A transition from 5 minutes ago
  • A transition from the current episode
  • A transition from a completely different game situation

This diversity forces the network to learn general patterns, not just recent ones.

3. Stable Learning Random sampling creates a relatively stationary distribution. Even as the policy changes, old experiences in the buffer provide a stabilizing anchor.

4. Forgetting Prevention The buffer contains experiences from the agent’s entire history (up to capacity). Even if the current policy never visits certain states, the network still trains on those states.

Mathematical Details

Data efficiency: With a buffer of size NN and uniform sampling, each transition is sampled on average:

Expected uses=Updates×Batch sizeN\text{Expected uses} = \frac{\text{Updates} \times \text{Batch size}}{N}

For DQN with a buffer of 1 million and 10 million updates at batch size 32:

Expected uses=107×32106=320 times\text{Expected uses} = \frac{10^7 \times 32}{10^6} = 320 \text{ times}

Distribution stability: Let μt\mu_t be the state distribution at time tt. With online learning:

Training distribution=μt\text{Training distribution} = \mu_t

With experience replay:

Training distribution1Ni=tNtμi\text{Training distribution} \approx \frac{1}{N}\sum_{i=t-N}^{t} \mu_i

The replay distribution changes slowly, even as μt\mu_t changes rapidly with the policy.

</>Implementation
import numpy as np
import matplotlib.pyplot as plt

def demonstrate_data_efficiency():
    """
    Compare data efficiency: online vs replay.
    """
    n_transitions = 10000
    buffer_size = 5000
    batch_size = 32
    n_updates = 20000

    # Track how many times each transition is used
    online_uses = np.zeros(n_transitions)
    replay_uses = np.zeros(n_transitions)

    # Online learning: each transition used exactly once
    for i in range(min(n_transitions, n_updates)):
        online_uses[i] = 1

    # Replay: sample from buffer
    buffer = ReplayBuffer(capacity=buffer_size)

    for i in range(n_transitions):
        # Add transition to buffer
        buffer.push(np.zeros(4), 0, 0, np.zeros(4), False)

        # Simulate training (after warmup)
        if len(buffer) >= batch_size:
            # Sample batch - track indices
            indices = random.sample(range(len(buffer)), batch_size)
            for idx in indices:
                # Map buffer index to transition index
                actual_idx = max(0, i - len(buffer) + 1) + idx
                if actual_idx < n_transitions:
                    replay_uses[actual_idx] += 1

    print("Data Efficiency Comparison")
    print("=" * 40)
    print(f"Online learning:")
    print(f"  Mean uses per transition: {online_uses.mean():.2f}")
    print(f"  Max uses: {online_uses.max():.0f}")
    print(f"Replay buffer:")
    print(f"  Mean uses per transition: {replay_uses.mean():.2f}")
    print(f"  Max uses: {replay_uses.max():.0f}")

demonstrate_data_efficiency()

Buffer Size Considerations

The buffer size is a crucial hyperparameter. Too small, and you lose the benefits of diverse sampling. Too large, and you waste memory and train on outdated experiences.

Too small (e.g., 1,000 transitions):

  • Buffer fills quickly and old experiences are evicted
  • Less diversity in sampled batches
  • Risk of catastrophic forgetting returns

Too large (e.g., 10 million transitions):

  • Requires significant memory (especially for image observations)
  • Old experiences may be from a very different policy
  • Learning from outdated transitions can slow convergence

Just right (typically 100K - 1M for Atari):

  • Good balance of diversity and relevance
  • Old enough for decorrelation
  • New enough to be useful for current policy

For Atari games with 84x84 grayscale frames, a buffer of 1 million transitions requires: 1M×84×84×4 bytes28 GB1\text{M} \times 84 \times 84 \times 4 \text{ bytes} \approx 28 \text{ GB}

This is why efficient replay buffer implementations are important.

</>Implementation
import numpy as np

class EfficientReplayBuffer:
    """
    Memory-efficient replay buffer for image observations.

    Stores images as uint8 instead of float32 to save 4x memory.
    Converts to float32 only when sampling.
    """

    def __init__(self, capacity, state_shape, n_frames=4):
        """
        Args:
            capacity: Maximum transitions
            state_shape: Shape of a single frame (H, W)
            n_frames: Number of stacked frames
        """
        self.capacity = capacity
        self.n_frames = n_frames

        # Pre-allocate arrays (uint8 for memory efficiency)
        # Store individual frames, reconstruct stacks when sampling
        total_frames = capacity + n_frames - 1
        self.frames = np.zeros((total_frames, *state_shape), dtype=np.uint8)
        self.actions = np.zeros(capacity, dtype=np.int32)
        self.rewards = np.zeros(capacity, dtype=np.float32)
        self.dones = np.zeros(capacity, dtype=np.bool_)

        self.position = 0
        self.size = 0
        self.frame_position = 0

    def push(self, frame, action, reward, done):
        """
        Add a transition (storing only the newest frame).

        Args:
            frame: Single frame (H, W), uint8
            action: Action taken
            reward: Reward received
            done: Episode ended
        """
        # Store the frame
        self.frames[self.frame_position] = frame
        self.actions[self.position] = action
        self.rewards[self.position] = reward
        self.dones[self.position] = done

        # Update positions
        self.position = (self.position + 1) % self.capacity
        self.frame_position = (self.frame_position + 1) % (self.capacity + self.n_frames - 1)
        self.size = min(self.size + 1, self.capacity)

    def _get_state(self, idx):
        """Reconstruct stacked state for index."""
        # Get n_frames ending at idx
        frames = []
        for i in range(self.n_frames):
            frame_idx = (idx - self.n_frames + 1 + i) % (self.capacity + self.n_frames - 1)
            frames.append(self.frames[frame_idx])
        return np.stack(frames)

    def sample(self, batch_size):
        """Sample batch and convert to float32."""
        # Sample valid indices (not first n_frames-1 positions)
        valid_range = range(self.n_frames - 1, self.size)
        indices = np.random.choice(list(valid_range), batch_size, replace=False)

        # Reconstruct states
        states = np.array([self._get_state(i) for i in indices], dtype=np.float32) / 255.0
        next_states = np.array([self._get_state(i+1) for i in indices], dtype=np.float32) / 255.0

        return (
            states,
            self.actions[indices],
            self.rewards[indices],
            next_states,
            self.dones[indices]
        )

    def __len__(self):
        return self.size

    @property
    def memory_usage_mb(self):
        """Estimate memory usage in MB."""
        return (self.frames.nbytes + self.actions.nbytes +
                self.rewards.nbytes + self.dones.nbytes) / (1024**2)


# Compare memory usage
naive_buffer_mb = 1_000_000 * 4 * 84 * 84 * 4 / (1024**2)  # float32
efficient_buffer = EfficientReplayBuffer(1_000_000, (84, 84), n_frames=4)
efficient_mb = efficient_buffer.memory_usage_mb

print(f"Memory comparison for 1M transitions:")
print(f"  Naive (float32 stacked frames): {naive_buffer_mb:.0f} MB")
print(f"  Efficient (uint8 individual frames): {efficient_mb:.0f} MB")
print(f"  Savings: {naive_buffer_mb / efficient_mb:.1f}x")

Integration with Training

In DQN, experience replay integrates into the training loop as follows:

  1. Collect experience: Take action in environment, store transition
  2. Check buffer: Only train when buffer has enough samples
  3. Sample batch: Draw random batch from buffer
  4. Compute loss: Use sampled transitions for TD error
  5. Update network: Gradient descent on the loss

The key insight is that data collection and training are decoupled. We can take many environment steps between training updates, or many training updates between environment steps.

DQN typically does one training step per environment step once the buffer is ready.

</>Implementation
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class DQNWithReplay:
    """
    DQN agent with experience replay.

    This demonstrates how replay integrates into training.
    Target network is omitted for clarity (see next section).
    """

    def __init__(self, state_dim, n_actions, buffer_size=100000,
                 batch_size=32, gamma=0.99, lr=1e-4):
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.gamma = gamma

        # Q-network
        self.q_network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

        # Replay buffer
        self.buffer = ReplayBuffer(capacity=buffer_size)

        # Metrics
        self.losses = []

    def select_action(self, state, epsilon):
        """Epsilon-greedy action selection."""
        if np.random.random() < epsilon:
            return np.random.randint(self.n_actions)

        with torch.no_grad():
            state_t = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_network(state_t)
            return q_values.argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        """Store transition in replay buffer."""
        self.buffer.push(state, action, reward, next_state, done)

    def train_step(self):
        """
        Perform one training step using experience replay.

        Returns:
            loss value, or None if buffer not ready
        """
        # Check if buffer has enough samples
        if not self.buffer.is_ready(self.batch_size):
            return None

        # Sample random batch
        states, actions, rewards, next_states, dones = self.buffer.sample(self.batch_size)

        # Convert to tensors
        states_t = torch.FloatTensor(states)
        actions_t = torch.LongTensor(actions)
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(next_states)
        dones_t = torch.FloatTensor(dones)

        # Compute current Q-values
        current_q = self.q_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # Compute target Q-values
        with torch.no_grad():
            next_q = self.q_network(next_states_t).max(dim=1)[0]
            target_q = rewards_t + self.gamma * next_q * (1 - dones_t)

        # Compute loss
        loss = nn.MSELoss()(current_q, target_q)

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.losses.append(loss.item())
        return loss.item()


# Training loop example
def train_dqn(env, agent, n_episodes=100, warmup_steps=1000):
    """
    Train DQN agent with experience replay.
    """
    epsilon = 1.0
    epsilon_decay = 0.995
    epsilon_min = 0.01

    episode_rewards = []
    steps = 0

    for episode in range(n_episodes):
        state = env.reset()
        episode_reward = 0
        done = False

        while not done:
            # Select action
            action = agent.select_action(state, epsilon)

            # Take action in environment
            next_state, reward, done, _ = env.step(action)

            # Store in replay buffer
            agent.store_transition(state, action, reward, next_state, done)

            # Train (after warmup)
            if steps >= warmup_steps:
                loss = agent.train_step()

            state = next_state
            episode_reward += reward
            steps += 1

        # Decay epsilon
        epsilon = max(epsilon_min, epsilon * epsilon_decay)

        episode_rewards.append(episode_reward)

        if (episode + 1) % 10 == 0:
            avg_reward = np.mean(episode_rewards[-10:])
            print(f"Episode {episode+1}, Avg Reward: {avg_reward:.2f}, "
                  f"Epsilon: {epsilon:.3f}, Buffer: {len(agent.buffer)}")

    return episode_rewards

Summary

Experience replay is fundamental to making DQN work:

  1. Breaking correlations: Random sampling provides approximately i.i.d. data for training
  2. Data efficiency: Each transition can be used multiple times
  3. Stability: The replay distribution changes slowly, providing a stable learning signal
  4. Preventing forgetting: Old experiences keep the network’s knowledge fresh

Key design choices:

  • Buffer size: 100K - 1M for Atari (balance diversity vs relevance)
  • Batch size: 32 is typical
  • Warmup: Fill buffer with random actions before training starts

Experience replay addresses one half of the instability problem in neural network Q-learning. The next section covers the other half: target networks that prevent the “chasing a moving target” problem.

💡Tip

Experience replay can be combined with other sampling strategies. Prioritized Experience Replay samples important transitions more often, significantly improving learning efficiency. We will cover this in the DQN Improvements chapter.