Prioritized Experience Replay
In standard experience replay, all transitions are sampled with equal probability. But some transitions are more informative than others. A surprising outcome, a rare event, or a transition where our prediction was very wrong all contain more learning signal than routine experiences.
Prioritized Experience Replay (PER) samples transitions based on how much we can learn from them, measured by the TD error. This focuses learning on the most important experiences.
Not All Transitions Are Equal
The TD error measures how much our prediction differed from the observed outcome. Large TD errors indicate surprising or poorly-learned transitions.
Imagine playing a video game. Most of your experience consists of:
- Walking down corridors (boring, predictable)
- Collecting common items (routine)
- Taking familiar actions (well-understood)
But occasionally:
- You discover a hidden treasure (surprising reward)
- You fall into an unexpected trap (prediction error)
- You find a shortcut (changes value estimates)
With uniform sampling, you will train on corridor-walking 100 times for every hidden treasure. But the treasure discovery is far more informative. You already know what happens in corridors.
PER fixes this imbalance by sampling surprising transitions more often.
The TD error for a transition is:
Large means:
- Our prediction was wrong (need to update)
- The transition contains information we have not fully learned
- Sampling this transition provides more learning signal
With uniform sampling, each transition has probability:
With prioritized sampling:
where is the priority (TD error magnitude plus a small constant).
import numpy as np
def compare_sampling_strategies():
"""
Demonstrate why prioritized sampling helps.
"""
np.random.seed(42)
# Simulate a buffer with mostly boring transitions and a few important ones
n_transitions = 10000
td_errors = np.abs(np.random.randn(n_transitions) * 0.1) # Mostly small errors
# Add some high-error transitions (rare but important)
important_indices = np.random.choice(n_transitions, size=100, replace=False)
td_errors[important_indices] = np.abs(np.random.randn(100) * 2.0)
# Uniform sampling
batch_size = 32
n_batches = 1000
uniform_important = 0
for _ in range(n_batches):
indices = np.random.choice(n_transitions, batch_size)
uniform_important += np.sum(np.isin(indices, important_indices))
# Prioritized sampling
priorities = td_errors + 1e-6
probs = priorities / priorities.sum()
prioritized_important = 0
for _ in range(n_batches):
indices = np.random.choice(n_transitions, batch_size, p=probs)
prioritized_important += np.sum(np.isin(indices, important_indices))
print("Sampling Strategy Comparison")
print("=" * 50)
print(f"Important transitions: {len(important_indices)} / {n_transitions} = {100*len(important_indices)/n_transitions:.1f}%")
print(f"\nUniform sampling:")
print(f" Important samples: {uniform_important} / {n_batches * batch_size}")
print(f" Rate: {100*uniform_important/(n_batches*batch_size):.2f}%")
print(f"\nPrioritized sampling:")
print(f" Important samples: {prioritized_important} / {n_batches * batch_size}")
print(f" Rate: {100*prioritized_important/(n_batches*batch_size):.2f}%")
print(f" Improvement: {prioritized_important/max(1,uniform_important):.1f}x more important samples")
compare_sampling_strategies()Priority Assignment
There are two common ways to assign priorities:
1. Proportional prioritization: Priority equals TD error magnitude plus a small constant:
The sampling probability is proportional to this priority raised to a power:
2. Rank-based prioritization: Priority equals the rank in sorted TD errors:
Rank-based is more robust to outliers. A transition with TD error 1000 does not dominate everything.
The parameter controls how much prioritization matters:
- : Uniform sampling (no prioritization)
- : Full prioritization (sample proportionally to TD error)
- (typical): Moderate prioritization
Proportional prioritization:
Rank-based prioritization:
where is the position of transition when sorted by in decreasing order.
For rank-based, the sum is approximately:
where is the Riemann zeta function.
import numpy as np
class PrioritizedReplayBuffer:
"""
Prioritized Experience Replay buffer.
Uses proportional prioritization based on TD error.
"""
def __init__(self, capacity, alpha=0.6, epsilon=1e-6):
"""
Args:
capacity: Maximum buffer size
alpha: Prioritization exponent (0=uniform, 1=full prioritization)
epsilon: Small constant to ensure non-zero priority
"""
self.capacity = capacity
self.alpha = alpha
self.epsilon = epsilon
self.buffer = []
self.priorities = np.zeros(capacity, dtype=np.float32)
self.position = 0
def push(self, state, action, reward, next_state, done, td_error=None):
"""
Add transition with priority.
If td_error is None, use max priority (for new transitions).
"""
if td_error is None:
# New transitions get max priority to ensure they're sampled
priority = self.priorities.max() if self.buffer else 1.0
else:
priority = (abs(td_error) + self.epsilon) ** self.alpha
if len(self.buffer) < self.capacity:
self.buffer.append((state, action, reward, next_state, done))
else:
self.buffer[self.position] = (state, action, reward, next_state, done)
self.priorities[self.position] = priority
self.position = (self.position + 1) % self.capacity
def sample(self, batch_size, beta=0.4):
"""
Sample batch with prioritized probabilities.
Args:
batch_size: Number of transitions to sample
beta: Importance sampling exponent
Returns:
batch: List of transitions
indices: Buffer indices (for priority updates)
weights: Importance sampling weights
"""
n = len(self.buffer)
priorities = self.priorities[:n]
# Compute sampling probabilities
probs = priorities / priorities.sum()
# Sample indices
indices = np.random.choice(n, batch_size, p=probs, replace=False)
# Compute importance sampling weights
weights = (n * probs[indices]) ** (-beta)
weights = weights / weights.max() # Normalize
batch = [self.buffer[i] for i in indices]
return batch, indices, weights
def update_priorities(self, indices, td_errors):
"""Update priorities after learning step."""
for idx, td_error in zip(indices, td_errors):
self.priorities[idx] = (abs(td_error) + self.epsilon) ** self.alpha
def __len__(self):
return len(self.buffer)
# Example usage
buffer = PrioritizedReplayBuffer(capacity=10000, alpha=0.6)
# Add some transitions with varying TD errors
for i in range(1000):
td_error = np.random.exponential(0.5) # Some high, some low
buffer.push(
state=np.random.randn(4),
action=np.random.randint(0, 4),
reward=np.random.randn(),
next_state=np.random.randn(4),
done=False,
td_error=td_error
)
# Sample a batch
batch, indices, weights = buffer.sample(batch_size=32, beta=0.4)
print("Prioritized Replay Buffer")
print("=" * 40)
print(f"Buffer size: {len(buffer)}")
print(f"Sampled indices: {indices[:5]}...")
print(f"Importance weights: {weights[:5]}...")
print(f"Weight range: [{weights.min():.3f}, {weights.max():.3f}]")Importance Sampling Correction
Importance sampling corrects for the bias introduced by non-uniform sampling. When we sample important transitions more often, we must down-weight their gradient contribution to maintain unbiased updates.
Prioritized sampling changes the distribution we learn from. Instead of learning from the true experience distribution, we learn from a biased distribution that over-represents surprising transitions.
This can cause problems. If we always sample the same surprising transitions, we might overfit to them. The solution is importance sampling: multiply each gradient by a weight that corrects for the sampling bias.
Transitions sampled with high probability get lower weights. Transitions sampled with low probability get higher weights. This rebalances the gradients toward what we would get with uniform sampling while still benefiting from focused learning.
With prioritized sampling, the expected gradient is:
This is biased because .
To correct, we use importance sampling weights:
The corrected gradient:
For , this exactly recovers the uniform distribution gradient. For , we have partial correction, trading off bias for variance reduction.
Annealing : We start with (more bias, less variance) and anneal toward (no bias) during training. This allows aggressive prioritization early on while converging to correct updates.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import copy
class DQNWithPER:
"""
DQN agent with Prioritized Experience Replay.
"""
def __init__(
self,
state_dim,
n_actions,
buffer_size=100000,
batch_size=32,
gamma=0.99,
lr=1e-4,
alpha=0.6,
beta_start=0.4,
beta_frames=100000,
):
self.n_actions = n_actions
self.batch_size = batch_size
self.gamma = gamma
self.alpha = alpha
self.beta_start = beta_start
self.beta_frames = beta_frames
# Networks
self.online_net = nn.Sequential(
nn.Linear(state_dim, 128),
nn.ReLU(),
nn.Linear(128, 128),
nn.ReLU(),
nn.Linear(128, n_actions)
)
self.target_net = copy.deepcopy(self.online_net)
self.optimizer = optim.Adam(self.online_net.parameters(), lr=lr)
# Prioritized replay buffer
self.buffer = PrioritizedReplayBuffer(capacity=buffer_size, alpha=alpha)
self.step_count = 0
def get_beta(self):
"""Anneal beta from beta_start to 1.0."""
progress = min(1.0, self.step_count / self.beta_frames)
return self.beta_start + progress * (1.0 - self.beta_start)
def store(self, state, action, reward, next_state, done):
"""Store transition with max priority (will be updated after first use)."""
self.buffer.push(state, action, reward, next_state, done)
def train_step(self):
if len(self.buffer) < self.batch_size:
return None
# Sample with priorities
batch, indices, weights = self.buffer.sample(self.batch_size, beta=self.get_beta())
# Unpack batch
states, actions, rewards, next_states, dones = zip(*batch)
states_t = torch.FloatTensor(np.array(states))
actions_t = torch.LongTensor(actions)
rewards_t = torch.FloatTensor(rewards)
next_states_t = torch.FloatTensor(np.array(next_states))
dones_t = torch.FloatTensor(dones)
weights_t = torch.FloatTensor(weights)
# Compute current Q-values
current_q = self.online_net(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)
# Compute target Q-values (Double DQN style)
with torch.no_grad():
best_actions = self.online_net(next_states_t).argmax(dim=1)
next_q = self.target_net(next_states_t).gather(1, best_actions.unsqueeze(1)).squeeze(1)
target_q = rewards_t + self.gamma * next_q * (1 - dones_t)
# Compute TD errors for priority update
td_errors = (target_q - current_q).detach().numpy()
# Weighted loss
losses = (current_q - target_q) ** 2
weighted_loss = (weights_t * losses).mean()
# Optimize
self.optimizer.zero_grad()
weighted_loss.backward()
self.optimizer.step()
# Update priorities
self.buffer.update_priorities(indices, td_errors)
self.step_count += 1
return weighted_loss.item()
# Example training loop
agent = DQNWithPER(state_dim=4, n_actions=2)
print("DQN with Prioritized Experience Replay")
print("=" * 50)
for i in range(1000):
state = np.random.randn(4)
action = np.random.randint(0, 2)
reward = np.random.randn()
next_state = np.random.randn(4)
done = np.random.random() < 0.05
agent.store(state, action, reward, next_state, done)
loss = agent.train_step()
if i % 200 == 0 and loss is not None:
print(f"Step {i}: Loss = {loss:.4f}, Beta = {agent.get_beta():.3f}")Sum Tree for Efficient Sampling
Naive prioritized sampling requires time to compute probabilities. With millions of transitions, this is too slow.
A Sum Tree is a binary tree data structure that enables sampling:
- Leaf nodes store priorities
- Internal nodes store the sum of their children
- The root stores the total priority
To sample, we:
- Generate a random number in
- Traverse the tree, going left or right based on cumulative sums
- Reach a leaf in time
Updates are also : change a leaf and update ancestors.
import numpy as np
class SumTree:
"""
Sum Tree data structure for O(log n) prioritized sampling.
Leaf nodes store priorities. Internal nodes store sums of children.
"""
def __init__(self, capacity):
self.capacity = capacity
self.tree = np.zeros(2 * capacity - 1) # Binary tree
self.data = np.zeros(capacity, dtype=object) # Leaf data
self.write_position = 0
self.n_entries = 0
def _propagate(self, idx, change):
"""Propagate priority change up the tree."""
parent = (idx - 1) // 2
self.tree[parent] += change
if parent != 0:
self._propagate(parent, change)
def _retrieve(self, idx, s):
"""Find leaf index for cumulative sum s."""
left = 2 * idx + 1
right = left + 1
if left >= len(self.tree):
return idx
if s <= self.tree[left]:
return self._retrieve(left, s)
else:
return self._retrieve(right, s - self.tree[left])
def total(self):
"""Return total priority."""
return self.tree[0]
def add(self, priority, data):
"""Add data with priority."""
idx = self.write_position + self.capacity - 1
self.data[self.write_position] = data
self.update(idx, priority)
self.write_position = (self.write_position + 1) % self.capacity
self.n_entries = min(self.n_entries + 1, self.capacity)
def update(self, idx, priority):
"""Update priority at tree index."""
change = priority - self.tree[idx]
self.tree[idx] = priority
self._propagate(idx, change)
def get(self, s):
"""
Get leaf index and data for cumulative sum s.
Returns:
idx: Tree index
priority: Priority value
data: Stored data
"""
idx = self._retrieve(0, s)
data_idx = idx - self.capacity + 1
return idx, self.tree[idx], self.data[data_idx]
class EfficientPrioritizedBuffer:
"""
Prioritized replay buffer with O(log n) operations using Sum Tree.
"""
def __init__(self, capacity, alpha=0.6, epsilon=1e-6):
self.capacity = capacity
self.alpha = alpha
self.epsilon = epsilon
self.tree = SumTree(capacity)
self.max_priority = 1.0
def push(self, transition, td_error=None):
"""Add transition."""
if td_error is None:
priority = self.max_priority
else:
priority = (abs(td_error) + self.epsilon) ** self.alpha
self.max_priority = max(self.max_priority, priority)
self.tree.add(priority, transition)
def sample(self, batch_size, beta=0.4):
"""Sample batch in O(batch_size * log n) time."""
batch = []
indices = []
priorities = []
segment = self.tree.total() / batch_size
for i in range(batch_size):
# Sample from each segment for better coverage
a = segment * i
b = segment * (i + 1)
s = np.random.uniform(a, b)
idx, priority, data = self.tree.get(s)
batch.append(data)
indices.append(idx)
priorities.append(priority)
# Compute importance sampling weights
priorities = np.array(priorities)
probs = priorities / self.tree.total()
weights = (self.tree.n_entries * probs) ** (-beta)
weights = weights / weights.max()
return batch, indices, weights
def update_priorities(self, indices, td_errors):
"""Update priorities after learning."""
for idx, td_error in zip(indices, td_errors):
priority = (abs(td_error) + self.epsilon) ** self.alpha
self.max_priority = max(self.max_priority, priority)
self.tree.update(idx, priority)
def __len__(self):
return self.tree.n_entries
# Compare performance
import time
def benchmark_buffers():
"""Compare naive vs efficient prioritized replay."""
n_transitions = 100000
batch_size = 32
# Naive buffer
naive = PrioritizedReplayBuffer(capacity=n_transitions)
for i in range(n_transitions):
naive.push(np.zeros(4), 0, 0, np.zeros(4), False, np.random.rand())
start = time.time()
for _ in range(100):
naive.sample(batch_size)
naive_time = time.time() - start
# Efficient buffer
efficient = EfficientPrioritizedBuffer(capacity=n_transitions)
for i in range(n_transitions):
efficient.push((np.zeros(4), 0, 0, np.zeros(4), False), np.random.rand())
start = time.time()
for _ in range(100):
efficient.sample(batch_size)
efficient_time = time.time() - start
print("Buffer Performance Comparison")
print("=" * 40)
print(f"Buffer size: {n_transitions:,}")
print(f"Naive (O(n)): {naive_time:.3f}s for 100 samples")
print(f"Sum Tree (O(log n)): {efficient_time:.3f}s for 100 samples")
print(f"Speedup: {naive_time/efficient_time:.1f}x")
benchmark_buffers()Summary
Prioritized Experience Replay improves learning efficiency by sampling important transitions more often:
- TD error as priority: Transitions with large prediction errors are more informative
- Proportional or rank-based: Two ways to assign priorities
- Importance sampling: Corrects the bias from non-uniform sampling
- Beta annealing: Start biased, end unbiased
- Sum Tree: Enables efficient sampling
Key hyperparameters:
- : Prioritization exponent (0.6 typical)
- : Importance sampling exponent (anneal from 0.4 to 1.0)
- : Small constant for numerical stability
PER adds complexity and hyperparameters. For simple problems, uniform replay may be sufficient. PER shines when experience importance varies greatly, such as in sparse reward environments or when some state-action pairs are rarely visited.