Distribution Shift
Why can’t we just run standard Q-learning on offline data? Because Q-learning assumes we can verify Q-values by actually trying actions—but in offline RL, we can’t. This leads to catastrophic overestimation of actions not in the dataset. This problem is called distribution shift.
The Core Problem
The mismatch between the distribution of state-action pairs in the offline dataset (from the behavior policy ) and the distribution that would be visited by the learned policy . When the learned policy selects actions not well-covered by the data, Q-value estimates become unreliable.
Imagine a dataset of driving collected from careful, defensive drivers. They never:
- Drove 100 mph on city streets
- Ran red lights
- Drove the wrong way on highways
Now you train Q-learning on this data. The Q-network has never seen what happens when you do these dangerous things. So what Q-value does it assign to “drive 100 mph in a school zone”?
It has no idea. It might extrapolate and say “well, driving faster often gets you places quicker, so… this must be good!” The Q-value could be wildly optimistic because the network is guessing about something it’s never seen.
And here’s the deadly part: Q-learning picks the action with the highest Q-value. So if any out-of-distribution action has an overestimated Q-value, the policy will select it—even though it’s terrible.
Why Q-Learning Overestimates
Standard Q-learning update:
The problem is . This maximum includes all actions, even those never seen in the dataset.
For in-distribution actions (seen in the data), Q-values are trained on real transitions and converge to meaningful estimates.
For out-of-distribution actions (never seen), Q-values are never corrected by real data. They start at arbitrary initialization values and drift based on bootstrap targets from other unreliable estimates.
Result: Out-of-distribution Q-values are essentially random numbers. And since we take the max, even one overestimated OOD action will be selected. This creates a systematic bias toward selecting overestimated, unsupported actions.
The Extrapolation Error Cascade
It gets worse. When the policy selects an OOD action with overestimated Q-value:
- The transition leads to an unfamiliar state (also OOD)
- From that state, the policy again picks the max Q-value
- That Q-value is also unreliable
- The bootstrap target is wrong
- This wrong target propagates back to affect other Q-values
- The entire Q-function becomes corrupted
This is the extrapolation error cascade: one bad estimate pollutes many others.
Think of it like a house of cards. Each Q-value estimate depends on estimates of future Q-values (via the Bellman backup). In online RL, we can verify each card by actually trying the action. In offline RL, some cards are just guesses—and if any guess is wrong, cards built on top of it collapse.
Visualizing the Problem
Consider a simple 1D continuous action problem. The behavior policy only takes actions in the range , but the action space is .
Action Space: [-3]----[-1]=====[0]=====[ 1]----[ 3]
^ ^ ^ ^
| | | |
OOD Data boundary Data boundary OOD
Q-values after offline training:
- In-distribution [-1, 1]: Reasonable estimates
- Out-of-distribution: Random, often overestimatedWhen the policy takes argmax over Q-values, it might select an action at because the Q-network, never having seen what happens there, assigned an optimistic value.
The deployed policy confidently takes an action that leads to catastrophe.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class QNetwork(nn.Module):
"""Q-network for continuous states, discrete actions."""
def __init__(self, state_dim, n_actions, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, state):
return self.net(state)
def naive_offline_q_learning(dataset, q_net, target_net, optimizer,
batch_size=256, gamma=0.99, iterations=10000):
"""
Naive offline Q-learning (will fail due to distribution shift).
This demonstrates what goes wrong without accounting for OOD actions.
"""
losses = []
for i in range(iterations):
# Sample batch from fixed dataset
states, actions, rewards, next_states, dones = dataset.sample(batch_size)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Current Q-values for taken actions
q_values = q_net(states)
q_taken = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Target: r + gamma * max_a' Q(s', a')
with torch.no_grad():
next_q_values = target_net(next_states)
max_next_q = next_q_values.max(dim=1)[0] # PROBLEM: max over ALL actions
targets = rewards + gamma * (1 - dones) * max_next_q
# Loss
loss = F.mse_loss(q_taken, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# Periodically update target network
if i % 1000 == 0:
target_net.load_state_dict(q_net.state_dict())
return losses
def demonstrate_distribution_shift(q_net, dataset, n_actions):
"""
Show how learned Q-values differ between in-distribution and OOD regions.
"""
# Sample some states from the dataset
states, actions, _, _, _ = dataset.sample(100)
states_tensor = torch.FloatTensor(states)
with torch.no_grad():
q_values = q_net(states_tensor)
# Which actions are in-distribution for these states?
# (Simplistic check based on what actions appear in dataset)
for i, (state, action) in enumerate(zip(states[:5], actions[:5])):
q_vals = q_values[i].numpy()
print(f"\nState {i}:")
print(f" Dataset action: {action} with Q={q_vals[action]:.2f}")
print(f" Argmax action: {q_vals.argmax()} with Q={q_vals.max():.2f}")
# Check if argmax action is in dataset for this state
if q_vals.argmax() != action:
print(f" WARNING: Learned policy prefers OOD action!")
print(f" Q-values: {q_vals}")Why This Doesn’t Happen in Online RL
In online RL, the Q-learning max over actions isn’t a problem because:
- If Q(s, a) is overestimated, the policy will try action a
- The agent sees what actually happens
- The real transition corrects the Q-value
- Overestimates are quickly fixed through experience
The feedback loop between exploration and learning keeps Q-values grounded in reality. In offline RL, this feedback loop is broken—we can’t verify our estimates.
In online RL, the data distribution tracks the policy:
When selects a new action, data is collected to evaluate that action. This ensures Q-values are accurate for actions the policy actually takes.
In offline RL, the data distribution is fixed:
The learned policy may want to take actions never selected by . For those actions, we have no corrective signal—only unreliable extrapolation.
The distribution shift is exactly this gap: wants to go where never went.
Quantifying Distribution Shift
We can measure distribution shift using divergences. The policy distribution over actions given states:
- Behavior policy:
- Learned policy:
The KL divergence measures how much the learned policy differs from the behavior policy. High divergence means the learned policy wants to take actions rarely seen in the data.
Another perspective: the occupancy mismatch. Let be the state distribution under policy . The effective distribution shift is:
If this is high, the policy frequently chooses state-action pairs not covered by the dataset.
def estimate_distribution_shift(policy, behavior_model, dataset, n_samples=1000):
"""
Estimate distribution shift between learned policy and behavior policy.
Args:
policy: Learned policy network
behavior_model: Model of behavior policy (e.g., from behavior cloning)
dataset: Offline dataset
Returns:
KL divergence estimate and fraction of OOD actions
"""
states, _, _, _, _ = dataset.sample(n_samples)
states_tensor = torch.FloatTensor(states)
with torch.no_grad():
# Learned policy probabilities
pi_logits = policy(states_tensor)
pi_probs = F.softmax(pi_logits, dim=-1)
# Behavior policy probabilities (from behavior cloning model)
beta_logits = behavior_model(states_tensor)
beta_probs = F.softmax(beta_logits, dim=-1) + 1e-8 # Avoid log(0)
# KL divergence: sum_a pi(a) * log(pi(a) / beta(a))
kl_per_state = (pi_probs * (torch.log(pi_probs + 1e-8) - torch.log(beta_probs))).sum(dim=-1)
mean_kl = kl_per_state.mean().item()
# Fraction of states where argmax differs
pi_actions = pi_probs.argmax(dim=-1)
beta_actions = beta_probs.argmax(dim=-1)
action_disagreement = (pi_actions != beta_actions).float().mean().item()
print(f"Mean KL divergence: {mean_kl:.4f}")
print(f"Action disagreement rate: {action_disagreement:.2%}")
return mean_kl, action_disagreement
def detect_ood_actions(q_net, dataset, threshold=2.0):
"""
Detect when learned policy selects potentially OOD actions.
Heuristic: if argmax Q-value is much higher than dataset actions' Q-values,
the argmax action is likely OOD and overestimated.
"""
states, actions, _, _, _ = dataset.sample(500)
states_tensor = torch.FloatTensor(states)
actions_tensor = torch.LongTensor(actions)
with torch.no_grad():
q_values = q_net(states_tensor)
q_max = q_values.max(dim=1)[0]
q_dataset = q_values.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)
# Gap between max Q and dataset action Q
gap = q_max - q_dataset
suspicious = (gap > threshold).sum().item()
print(f"States with suspicious Q-gap > {threshold}: {suspicious}/{len(states)}")
print(f"Mean Q-gap: {gap.mean().item():.2f}")
return gapThe Deadly Cycle
Here’s the full deadly cycle in naive offline Q-learning:
- Initialize: Q-values start at random values
- Train: Q-learning fits in-distribution (s, a) pairs reasonably well
- But: Q-values for OOD actions are never corrected, remain at (or drift to) arbitrary values
- Some OOD actions get overestimated (by chance or extrapolation)
- Max operator: Policy selects the overestimated OOD action
- Bootstrap: The overestimated Q propagates to other states
- Cascade: Error spreads, more Q-values become unreliable
- Result: Policy confidently takes terrible actions
The more training, the worse it can get! Unlike online RL where more training helps, in naive offline RL, more training can compound errors.
Summary
Distribution shift is the fundamental challenge of offline RL:
- Learned policies want to take actions not in the dataset
- Q-values for those actions are unreliable extrapolations
- Q-learning’s max operator systematically selects overestimated OOD actions
- Errors cascade through bootstrapping
- The result is policies that look good on paper but fail catastrophically in practice
The solution? We need algorithms that explicitly handle this distribution shift—either by staying close to the data or by being conservative about OOD actions. That’s what we’ll explore in the next section on Conservative Methods.