Advanced Topics • Part 2 of 3
📝Draft

Distribution Shift

The core challenge of offline RL

Distribution Shift

Why can’t we just run standard Q-learning on offline data? Because Q-learning assumes we can verify Q-values by actually trying actions—but in offline RL, we can’t. This leads to catastrophic overestimation of actions not in the dataset. This problem is called distribution shift.

The Core Problem

📖Distribution Shift

The mismatch between the distribution of state-action pairs in the offline dataset (from the behavior policy πβ\pi_\beta) and the distribution that would be visited by the learned policy π\pi. When the learned policy selects actions not well-covered by the data, Q-value estimates become unreliable.

Imagine a dataset of driving collected from careful, defensive drivers. They never:

  • Drove 100 mph on city streets
  • Ran red lights
  • Drove the wrong way on highways

Now you train Q-learning on this data. The Q-network has never seen what happens when you do these dangerous things. So what Q-value does it assign to “drive 100 mph in a school zone”?

It has no idea. It might extrapolate and say “well, driving faster often gets you places quicker, so… this must be good!” The Q-value could be wildly optimistic because the network is guessing about something it’s never seen.

And here’s the deadly part: Q-learning picks the action with the highest Q-value. So if any out-of-distribution action has an overestimated Q-value, the policy will select it—even though it’s terrible.

Why Q-Learning Overestimates

Mathematical Details

Standard Q-learning update:

Q(s,a)Q(s,a)+α[r+γmaxaQ(s,a)Q(s,a)]Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]

The problem is maxaQ(s,a)\max_{a'} Q(s', a'). This maximum includes all actions, even those never seen in the dataset.

For in-distribution actions (seen in the data), Q-values are trained on real transitions and converge to meaningful estimates.

For out-of-distribution actions (never seen), Q-values are never corrected by real data. They start at arbitrary initialization values and drift based on bootstrap targets from other unreliable estimates.

Result: Out-of-distribution Q-values are essentially random numbers. And since we take the max, even one overestimated OOD action will be selected. This creates a systematic bias toward selecting overestimated, unsupported actions.

The Extrapolation Error Cascade

Think of it like a house of cards. Each Q-value estimate depends on estimates of future Q-values (via the Bellman backup). In online RL, we can verify each card by actually trying the action. In offline RL, some cards are just guesses—and if any guess is wrong, cards built on top of it collapse.

Visualizing the Problem

📌Simple Illustration

Consider a simple 1D continuous action problem. The behavior policy only takes actions in the range [1,1][-1, 1], but the action space is [3,3][-3, 3].

Action Space:   [-3]----[-1]=====[0]=====[ 1]----[ 3]
                  ^       ^                ^       ^
                  |       |                |       |
                 OOD   Data boundary    Data boundary   OOD

Q-values after offline training:
  - In-distribution [-1, 1]: Reasonable estimates
  - Out-of-distribution: Random, often overestimated

When the policy takes argmax over Q-values, it might select an action at a=2.5a=2.5 because the Q-network, never having seen what happens there, assigned an optimistic value.

The deployed policy confidently takes an action that leads to catastrophe.

</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class QNetwork(nn.Module):
    """Q-network for continuous states, discrete actions."""

    def __init__(self, state_dim, n_actions, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

    def forward(self, state):
        return self.net(state)


def naive_offline_q_learning(dataset, q_net, target_net, optimizer,
                              batch_size=256, gamma=0.99, iterations=10000):
    """
    Naive offline Q-learning (will fail due to distribution shift).

    This demonstrates what goes wrong without accounting for OOD actions.
    """
    losses = []

    for i in range(iterations):
        # Sample batch from fixed dataset
        states, actions, rewards, next_states, dones = dataset.sample(batch_size)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Current Q-values for taken actions
        q_values = q_net(states)
        q_taken = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target: r + gamma * max_a' Q(s', a')
        with torch.no_grad():
            next_q_values = target_net(next_states)
            max_next_q = next_q_values.max(dim=1)[0]  # PROBLEM: max over ALL actions
            targets = rewards + gamma * (1 - dones) * max_next_q

        # Loss
        loss = F.mse_loss(q_taken, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        # Periodically update target network
        if i % 1000 == 0:
            target_net.load_state_dict(q_net.state_dict())

    return losses


def demonstrate_distribution_shift(q_net, dataset, n_actions):
    """
    Show how learned Q-values differ between in-distribution and OOD regions.
    """
    # Sample some states from the dataset
    states, actions, _, _, _ = dataset.sample(100)
    states_tensor = torch.FloatTensor(states)

    with torch.no_grad():
        q_values = q_net(states_tensor)

    # Which actions are in-distribution for these states?
    # (Simplistic check based on what actions appear in dataset)
    for i, (state, action) in enumerate(zip(states[:5], actions[:5])):
        q_vals = q_values[i].numpy()
        print(f"\nState {i}:")
        print(f"  Dataset action: {action} with Q={q_vals[action]:.2f}")
        print(f"  Argmax action: {q_vals.argmax()} with Q={q_vals.max():.2f}")

        # Check if argmax action is in dataset for this state
        if q_vals.argmax() != action:
            print(f"  WARNING: Learned policy prefers OOD action!")
            print(f"  Q-values: {q_vals}")

Why This Doesn’t Happen in Online RL

In online RL, the Q-learning max over actions isn’t a problem because:

  1. If Q(s, a) is overestimated, the policy will try action a
  2. The agent sees what actually happens
  3. The real transition (s,a,r,s)(s, a, r, s') corrects the Q-value
  4. Overestimates are quickly fixed through experience

The feedback loop between exploration and learning keeps Q-values grounded in reality. In offline RL, this feedback loop is broken—we can’t verify our estimates.

Mathematical Details

In online RL, the data distribution tracks the policy:

Dttrajectories from πθtD_t \sim \text{trajectories from } \pi_{\theta_t}

When π\pi selects a new action, data is collected to evaluate that action. This ensures Q-values are accurate for actions the policy actually takes.

In offline RL, the data distribution is fixed:

Dtrajectories from πβD \sim \text{trajectories from } \pi_\beta

The learned policy π\pi may want to take actions never selected by πβ\pi_\beta. For those actions, we have no corrective signal—only unreliable extrapolation.

The distribution shift is exactly this gap: π\pi wants to go where πβ\pi_\beta never went.

Quantifying Distribution Shift

Mathematical Details

We can measure distribution shift using divergences. The policy distribution over actions given states:

  • Behavior policy: πβ(as)\pi_\beta(a|s)
  • Learned policy: π(as)\pi(a|s)

The KL divergence DKL(ππβ)D_{KL}(\pi || \pi_\beta) measures how much the learned policy differs from the behavior policy. High divergence means the learned policy wants to take actions rarely seen in the data.

Another perspective: the occupancy mismatch. Let dπ(s)d^\pi(s) be the state distribution under policy π\pi. The effective distribution shift is:

shift=Esdπ,aπ(as)[1[(s,a)supp(D)]]\text{shift} = \mathbb{E}_{s \sim d^\pi, a \sim \pi(a|s)}\left[ \mathbb{1}[(s, a) \notin \text{supp}(D)] \right]

If this is high, the policy frequently chooses state-action pairs not covered by the dataset.

</>Implementation
def estimate_distribution_shift(policy, behavior_model, dataset, n_samples=1000):
    """
    Estimate distribution shift between learned policy and behavior policy.

    Args:
        policy: Learned policy network
        behavior_model: Model of behavior policy (e.g., from behavior cloning)
        dataset: Offline dataset

    Returns:
        KL divergence estimate and fraction of OOD actions
    """
    states, _, _, _, _ = dataset.sample(n_samples)
    states_tensor = torch.FloatTensor(states)

    with torch.no_grad():
        # Learned policy probabilities
        pi_logits = policy(states_tensor)
        pi_probs = F.softmax(pi_logits, dim=-1)

        # Behavior policy probabilities (from behavior cloning model)
        beta_logits = behavior_model(states_tensor)
        beta_probs = F.softmax(beta_logits, dim=-1) + 1e-8  # Avoid log(0)

        # KL divergence: sum_a pi(a) * log(pi(a) / beta(a))
        kl_per_state = (pi_probs * (torch.log(pi_probs + 1e-8) - torch.log(beta_probs))).sum(dim=-1)
        mean_kl = kl_per_state.mean().item()

        # Fraction of states where argmax differs
        pi_actions = pi_probs.argmax(dim=-1)
        beta_actions = beta_probs.argmax(dim=-1)
        action_disagreement = (pi_actions != beta_actions).float().mean().item()

    print(f"Mean KL divergence: {mean_kl:.4f}")
    print(f"Action disagreement rate: {action_disagreement:.2%}")

    return mean_kl, action_disagreement


def detect_ood_actions(q_net, dataset, threshold=2.0):
    """
    Detect when learned policy selects potentially OOD actions.

    Heuristic: if argmax Q-value is much higher than dataset actions' Q-values,
    the argmax action is likely OOD and overestimated.
    """
    states, actions, _, _, _ = dataset.sample(500)
    states_tensor = torch.FloatTensor(states)
    actions_tensor = torch.LongTensor(actions)

    with torch.no_grad():
        q_values = q_net(states_tensor)
        q_max = q_values.max(dim=1)[0]
        q_dataset = q_values.gather(1, actions_tensor.unsqueeze(1)).squeeze(1)

        # Gap between max Q and dataset action Q
        gap = q_max - q_dataset
        suspicious = (gap > threshold).sum().item()

    print(f"States with suspicious Q-gap > {threshold}: {suspicious}/{len(states)}")
    print(f"Mean Q-gap: {gap.mean().item():.2f}")

    return gap

The Deadly Cycle

Here’s the full deadly cycle in naive offline Q-learning:

  1. Initialize: Q-values start at random values
  2. Train: Q-learning fits in-distribution (s, a) pairs reasonably well
  3. But: Q-values for OOD actions are never corrected, remain at (or drift to) arbitrary values
  4. Some OOD actions get overestimated (by chance or extrapolation)
  5. Max operator: Policy selects the overestimated OOD action
  6. Bootstrap: The overestimated Q propagates to other states
  7. Cascade: Error spreads, more Q-values become unreliable
  8. Result: Policy confidently takes terrible actions

The more training, the worse it can get! Unlike online RL where more training helps, in naive offline RL, more training can compound errors.

Summary

Distribution shift is the fundamental challenge of offline RL:

  • Learned policies want to take actions not in the dataset
  • Q-values for those actions are unreliable extrapolations
  • Q-learning’s max operator systematically selects overestimated OOD actions
  • Errors cascade through bootstrapping
  • The result is policies that look good on paper but fail catastrophically in practice

The solution? We need algorithms that explicitly handle this distribution shift—either by staying close to the data or by being conservative about OOD actions. That’s what we’ll explore in the next section on Conservative Methods.