Deep Reinforcement Learning • Part 2 of 4
📝Draft

Prioritized Experience Replay

Learning more from important transitions

Prioritized Experience Replay

In standard experience replay, all transitions are sampled with equal probability. But some transitions are more informative than others. A surprising outcome, a rare event, or a transition where our prediction was very wrong all contain more learning signal than routine experiences.

Prioritized Experience Replay (PER) samples transitions based on how much we can learn from them, measured by the TD error. This focuses learning on the most important experiences.

Not All Transitions Are Equal

📖TD Error as Surprise

The TD error δ=r+γmaxaQ(s,a)Q(s,a)\delta = r + \gamma \max_{a'} Q(s', a') - Q(s, a) measures how much our prediction differed from the observed outcome. Large TD errors indicate surprising or poorly-learned transitions.

Imagine playing a video game. Most of your experience consists of:

  • Walking down corridors (boring, predictable)
  • Collecting common items (routine)
  • Taking familiar actions (well-understood)

But occasionally:

  • You discover a hidden treasure (surprising reward)
  • You fall into an unexpected trap (prediction error)
  • You find a shortcut (changes value estimates)

With uniform sampling, you will train on corridor-walking 100 times for every hidden treasure. But the treasure discovery is far more informative. You already know what happens in corridors.

PER fixes this imbalance by sampling surprising transitions more often.

Mathematical Details

The TD error for a transition (s,a,r,s)(s, a, r, s') is:

δ=r+γmaxaQ(s,a;θ)Q(s,a;θ)\delta = r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta)

Large δ|\delta| means:

  • Our prediction was wrong (need to update)
  • The transition contains information we have not fully learned
  • Sampling this transition provides more learning signal

With uniform sampling, each transition has probability:

Puniform(i)=1NP_{\text{uniform}}(i) = \frac{1}{N}

With prioritized sampling:

P(i)=piαkpkαP(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}

where pi=δi+ϵp_i = |\delta_i| + \epsilon is the priority (TD error magnitude plus a small constant).

</>Implementation
import numpy as np

def compare_sampling_strategies():
    """
    Demonstrate why prioritized sampling helps.
    """
    np.random.seed(42)

    # Simulate a buffer with mostly boring transitions and a few important ones
    n_transitions = 10000
    td_errors = np.abs(np.random.randn(n_transitions) * 0.1)  # Mostly small errors

    # Add some high-error transitions (rare but important)
    important_indices = np.random.choice(n_transitions, size=100, replace=False)
    td_errors[important_indices] = np.abs(np.random.randn(100) * 2.0)

    # Uniform sampling
    batch_size = 32
    n_batches = 1000

    uniform_important = 0
    for _ in range(n_batches):
        indices = np.random.choice(n_transitions, batch_size)
        uniform_important += np.sum(np.isin(indices, important_indices))

    # Prioritized sampling
    priorities = td_errors + 1e-6
    probs = priorities / priorities.sum()

    prioritized_important = 0
    for _ in range(n_batches):
        indices = np.random.choice(n_transitions, batch_size, p=probs)
        prioritized_important += np.sum(np.isin(indices, important_indices))

    print("Sampling Strategy Comparison")
    print("=" * 50)
    print(f"Important transitions: {len(important_indices)} / {n_transitions} = {100*len(important_indices)/n_transitions:.1f}%")
    print(f"\nUniform sampling:")
    print(f"  Important samples: {uniform_important} / {n_batches * batch_size}")
    print(f"  Rate: {100*uniform_important/(n_batches*batch_size):.2f}%")
    print(f"\nPrioritized sampling:")
    print(f"  Important samples: {prioritized_important} / {n_batches * batch_size}")
    print(f"  Rate: {100*prioritized_important/(n_batches*batch_size):.2f}%")
    print(f"  Improvement: {prioritized_important/max(1,uniform_important):.1f}x more important samples")

compare_sampling_strategies()

Priority Assignment

There are two common ways to assign priorities:

1. Proportional prioritization: Priority equals TD error magnitude plus a small constant: pi=δi+ϵp_i = |\delta_i| + \epsilon

The sampling probability is proportional to this priority raised to a power: P(i)piαP(i) \propto p_i^\alpha

2. Rank-based prioritization: Priority equals the rank in sorted TD errors: pi=1rank(i)p_i = \frac{1}{\text{rank}(i)}

Rank-based is more robust to outliers. A transition with TD error 1000 does not dominate everything.

The parameter α\alpha controls how much prioritization matters:

  • α=0\alpha = 0: Uniform sampling (no prioritization)
  • α=1\alpha = 1: Full prioritization (sample proportionally to TD error)
  • α=0.6\alpha = 0.6 (typical): Moderate prioritization
Mathematical Details

Proportional prioritization:

P(i)=piαkpkα,pi=δi+ϵP(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}, \quad p_i = |\delta_i| + \epsilon

Rank-based prioritization:

P(i)=piαkpkα,pi=1rank(i)P(i) = \frac{p_i^\alpha}{\sum_k p_k^\alpha}, \quad p_i = \frac{1}{\text{rank}(i)}

where rank(i)\text{rank}(i) is the position of transition ii when sorted by δ|\delta| in decreasing order.

For rank-based, the sum is approximately:

k=1N1kαζ(α)\sum_{k=1}^{N} \frac{1}{k^\alpha} \approx \zeta(\alpha)

where ζ\zeta is the Riemann zeta function.

</>Implementation
import numpy as np

class PrioritizedReplayBuffer:
    """
    Prioritized Experience Replay buffer.

    Uses proportional prioritization based on TD error.
    """

    def __init__(self, capacity, alpha=0.6, epsilon=1e-6):
        """
        Args:
            capacity: Maximum buffer size
            alpha: Prioritization exponent (0=uniform, 1=full prioritization)
            epsilon: Small constant to ensure non-zero priority
        """
        self.capacity = capacity
        self.alpha = alpha
        self.epsilon = epsilon

        self.buffer = []
        self.priorities = np.zeros(capacity, dtype=np.float32)
        self.position = 0

    def push(self, state, action, reward, next_state, done, td_error=None):
        """
        Add transition with priority.

        If td_error is None, use max priority (for new transitions).
        """
        if td_error is None:
            # New transitions get max priority to ensure they're sampled
            priority = self.priorities.max() if self.buffer else 1.0
        else:
            priority = (abs(td_error) + self.epsilon) ** self.alpha

        if len(self.buffer) < self.capacity:
            self.buffer.append((state, action, reward, next_state, done))
        else:
            self.buffer[self.position] = (state, action, reward, next_state, done)

        self.priorities[self.position] = priority
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        """
        Sample batch with prioritized probabilities.

        Args:
            batch_size: Number of transitions to sample
            beta: Importance sampling exponent

        Returns:
            batch: List of transitions
            indices: Buffer indices (for priority updates)
            weights: Importance sampling weights
        """
        n = len(self.buffer)
        priorities = self.priorities[:n]

        # Compute sampling probabilities
        probs = priorities / priorities.sum()

        # Sample indices
        indices = np.random.choice(n, batch_size, p=probs, replace=False)

        # Compute importance sampling weights
        weights = (n * probs[indices]) ** (-beta)
        weights = weights / weights.max()  # Normalize

        batch = [self.buffer[i] for i in indices]

        return batch, indices, weights

    def update_priorities(self, indices, td_errors):
        """Update priorities after learning step."""
        for idx, td_error in zip(indices, td_errors):
            self.priorities[idx] = (abs(td_error) + self.epsilon) ** self.alpha

    def __len__(self):
        return len(self.buffer)


# Example usage
buffer = PrioritizedReplayBuffer(capacity=10000, alpha=0.6)

# Add some transitions with varying TD errors
for i in range(1000):
    td_error = np.random.exponential(0.5)  # Some high, some low
    buffer.push(
        state=np.random.randn(4),
        action=np.random.randint(0, 4),
        reward=np.random.randn(),
        next_state=np.random.randn(4),
        done=False,
        td_error=td_error
    )

# Sample a batch
batch, indices, weights = buffer.sample(batch_size=32, beta=0.4)

print("Prioritized Replay Buffer")
print("=" * 40)
print(f"Buffer size: {len(buffer)}")
print(f"Sampled indices: {indices[:5]}...")
print(f"Importance weights: {weights[:5]}...")
print(f"Weight range: [{weights.min():.3f}, {weights.max():.3f}]")

Importance Sampling Correction

📖Importance Sampling

Importance sampling corrects for the bias introduced by non-uniform sampling. When we sample important transitions more often, we must down-weight their gradient contribution to maintain unbiased updates.

Prioritized sampling changes the distribution we learn from. Instead of learning from the true experience distribution, we learn from a biased distribution that over-represents surprising transitions.

This can cause problems. If we always sample the same surprising transitions, we might overfit to them. The solution is importance sampling: multiply each gradient by a weight that corrects for the sampling bias.

Transitions sampled with high probability get lower weights. Transitions sampled with low probability get higher weights. This rebalances the gradients toward what we would get with uniform sampling while still benefiting from focused learning.

Mathematical Details

With prioritized sampling, the expected gradient is:

EiP[Li]=iP(i)Li\mathbb{E}_{i \sim P}[\nabla L_i] = \sum_i P(i) \nabla L_i

This is biased because P(i)1NP(i) \neq \frac{1}{N}.

To correct, we use importance sampling weights:

wi=(1NP(i))βw_i = \left( \frac{1}{N \cdot P(i)} \right)^\beta

The corrected gradient:

EiP[wiLi]=iP(i)1(NP(i))βLi\mathbb{E}_{i \sim P}[w_i \nabla L_i] = \sum_i P(i) \cdot \frac{1}{(N \cdot P(i))^\beta} \cdot \nabla L_i

For β=1\beta = 1, this exactly recovers the uniform distribution gradient. For β<1\beta < 1, we have partial correction, trading off bias for variance reduction.

Annealing β\beta: We start with β=0.4\beta = 0.4 (more bias, less variance) and anneal toward β=1\beta = 1 (no bias) during training. This allows aggressive prioritization early on while converging to correct updates.

</>Implementation
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import copy

class DQNWithPER:
    """
    DQN agent with Prioritized Experience Replay.
    """

    def __init__(
        self,
        state_dim,
        n_actions,
        buffer_size=100000,
        batch_size=32,
        gamma=0.99,
        lr=1e-4,
        alpha=0.6,
        beta_start=0.4,
        beta_frames=100000,
    ):
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.gamma = gamma
        self.alpha = alpha
        self.beta_start = beta_start
        self.beta_frames = beta_frames

        # Networks
        self.online_net = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        self.target_net = copy.deepcopy(self.online_net)

        self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)

        # Prioritized replay buffer
        self.buffer = PrioritizedReplayBuffer(capacity=buffer_size, alpha=alpha)

        self.step_count = 0

    def get_beta(self):
        """Anneal beta from beta_start to 1.0."""
        progress = min(1.0, self.step_count / self.beta_frames)
        return self.beta_start + progress * (1.0 - self.beta_start)

    def store(self, state, action, reward, next_state, done):
        """Store transition with max priority (will be updated after first use)."""
        self.buffer.push(state, action, reward, next_state, done)

    def train_step(self):
        if len(self.buffer) < self.batch_size:
            return None

        # Sample with priorities
        batch, indices, weights = self.buffer.sample(self.batch_size, beta=self.get_beta())

        # Unpack batch
        states, actions, rewards, next_states, dones = zip(*batch)
        states_t = torch.FloatTensor(np.array(states))
        actions_t = torch.LongTensor(actions)
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(np.array(next_states))
        dones_t = torch.FloatTensor(dones)
        weights_t = torch.FloatTensor(weights)

        # Compute current Q-values
        current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # Compute target Q-values (Double DQN style)
        with torch.no_grad():
            best_actions = self.online_net(next_states_t).argmax(dim=1)
            next_q = self.target_net(next_states_t).gather(1, best_actions.unsqueeze(1)).squeeze(1)
            target_q = rewards_t + self.gamma * next_q * (1 - dones_t)

        # Compute TD errors for priority update
        td_errors = (target_q - current_q).detach().numpy()

        # Weighted loss
        losses = (current_q - target_q) ** 2
        weighted_loss = (weights_t * losses).mean()

        # Optimize
        self.optimizer.zero_grad()
        weighted_loss.backward()
        self.optimizer.step()

        # Update priorities
        self.buffer.update_priorities(indices, td_errors)

        self.step_count += 1

        return weighted_loss.item()


# Example training loop
agent = DQNWithPER(state_dim=4, n_actions=2)

print("DQN with Prioritized Experience Replay")
print("=" * 50)

for i in range(1000):
    state = np.random.randn(4)
    action = np.random.randint(0, 2)
    reward = np.random.randn()
    next_state = np.random.randn(4)
    done = np.random.random() < 0.05

    agent.store(state, action, reward, next_state, done)
    loss = agent.train_step()

    if i % 200 == 0 and loss is not None:
        print(f"Step {i}: Loss = {loss:.4f}, Beta = {agent.get_beta():.3f}")

Sum Tree for Efficient Sampling

Naive prioritized sampling requires O(N)O(N) time to compute probabilities. With millions of transitions, this is too slow.

A Sum Tree is a binary tree data structure that enables O(logN)O(\log N) sampling:

  • Leaf nodes store priorities
  • Internal nodes store the sum of their children
  • The root stores the total priority

To sample, we:

  1. Generate a random number in [0,total priority][0, \text{total priority}]
  2. Traverse the tree, going left or right based on cumulative sums
  3. Reach a leaf in O(logN)O(\log N) time

Updates are also O(logN)O(\log N): change a leaf and update ancestors.

</>Implementation
import numpy as np

class SumTree:
    """
    Sum Tree data structure for O(log n) prioritized sampling.

    Leaf nodes store priorities. Internal nodes store sums of children.
    """

    def __init__(self, capacity):
        self.capacity = capacity
        self.tree = np.zeros(2 * capacity - 1)  # Binary tree
        self.data = np.zeros(capacity, dtype=object)  # Leaf data
        self.write_position = 0
        self.n_entries = 0

    def _propagate(self, idx, change):
        """Propagate priority change up the tree."""
        parent = (idx - 1) // 2
        self.tree[parent] += change
        if parent != 0:
            self._propagate(parent, change)

    def _retrieve(self, idx, s):
        """Find leaf index for cumulative sum s."""
        left = 2 * idx + 1
        right = left + 1

        if left >= len(self.tree):
            return idx

        if s <= self.tree[left]:
            return self._retrieve(left, s)
        else:
            return self._retrieve(right, s - self.tree[left])

    def total(self):
        """Return total priority."""
        return self.tree[0]

    def add(self, priority, data):
        """Add data with priority."""
        idx = self.write_position + self.capacity - 1

        self.data[self.write_position] = data
        self.update(idx, priority)

        self.write_position = (self.write_position + 1) % self.capacity
        self.n_entries = min(self.n_entries + 1, self.capacity)

    def update(self, idx, priority):
        """Update priority at tree index."""
        change = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, change)

    def get(self, s):
        """
        Get leaf index and data for cumulative sum s.

        Returns:
            idx: Tree index
            priority: Priority value
            data: Stored data
        """
        idx = self._retrieve(0, s)
        data_idx = idx - self.capacity + 1
        return idx, self.tree[idx], self.data[data_idx]


class EfficientPrioritizedBuffer:
    """
    Prioritized replay buffer with O(log n) operations using Sum Tree.
    """

    def __init__(self, capacity, alpha=0.6, epsilon=1e-6):
        self.capacity = capacity
        self.alpha = alpha
        self.epsilon = epsilon

        self.tree = SumTree(capacity)
        self.max_priority = 1.0

    def push(self, transition, td_error=None):
        """Add transition."""
        if td_error is None:
            priority = self.max_priority
        else:
            priority = (abs(td_error) + self.epsilon) ** self.alpha
            self.max_priority = max(self.max_priority, priority)

        self.tree.add(priority, transition)

    def sample(self, batch_size, beta=0.4):
        """Sample batch in O(batch_size * log n) time."""
        batch = []
        indices = []
        priorities = []

        segment = self.tree.total() / batch_size

        for i in range(batch_size):
            # Sample from each segment for better coverage
            a = segment * i
            b = segment * (i + 1)
            s = np.random.uniform(a, b)

            idx, priority, data = self.tree.get(s)
            batch.append(data)
            indices.append(idx)
            priorities.append(priority)

        # Compute importance sampling weights
        priorities = np.array(priorities)
        probs = priorities / self.tree.total()
        weights = (self.tree.n_entries * probs) ** (-beta)
        weights = weights / weights.max()

        return batch, indices, weights

    def update_priorities(self, indices, td_errors):
        """Update priorities after learning."""
        for idx, td_error in zip(indices, td_errors):
            priority = (abs(td_error) + self.epsilon) ** self.alpha
            self.max_priority = max(self.max_priority, priority)
            self.tree.update(idx, priority)

    def __len__(self):
        return self.tree.n_entries


# Compare performance
import time

def benchmark_buffers():
    """Compare naive vs efficient prioritized replay."""
    n_transitions = 100000
    batch_size = 32

    # Naive buffer
    naive = PrioritizedReplayBuffer(capacity=n_transitions)
    for i in range(n_transitions):
        naive.push(np.zeros(4), 0, 0, np.zeros(4), False, np.random.rand())

    start = time.time()
    for _ in range(100):
        naive.sample(batch_size)
    naive_time = time.time() - start

    # Efficient buffer
    efficient = EfficientPrioritizedBuffer(capacity=n_transitions)
    for i in range(n_transitions):
        efficient.push((np.zeros(4), 0, 0, np.zeros(4), False), np.random.rand())

    start = time.time()
    for _ in range(100):
        efficient.sample(batch_size)
    efficient_time = time.time() - start

    print("Buffer Performance Comparison")
    print("=" * 40)
    print(f"Buffer size: {n_transitions:,}")
    print(f"Naive (O(n)): {naive_time:.3f}s for 100 samples")
    print(f"Sum Tree (O(log n)): {efficient_time:.3f}s for 100 samples")
    print(f"Speedup: {naive_time/efficient_time:.1f}x")

benchmark_buffers()

Summary

Prioritized Experience Replay improves learning efficiency by sampling important transitions more often:

  1. TD error as priority: Transitions with large prediction errors are more informative
  2. Proportional or rank-based: Two ways to assign priorities
  3. Importance sampling: Corrects the bias from non-uniform sampling
  4. Beta annealing: Start biased, end unbiased
  5. Sum Tree: Enables efficient O(logN)O(\log N) sampling

Key hyperparameters:

  • α\alpha: Prioritization exponent (0.6 typical)
  • β\beta: Importance sampling exponent (anneal from 0.4 to 1.0)
  • ϵ\epsilon: Small constant for numerical stability
ℹ️Note

PER adds complexity and hyperparameters. For simple problems, uniform replay may be sufficient. PER shines when experience importance varies greatly, such as in sparse reward environments or when some state-action pairs are rarely visited.