Conservative Methods
If the problem is overestimating out-of-distribution actions, the solution is to be conservative: actively push down Q-values for OOD actions, and trust only what we’ve actually seen in the data.
The Conservative Approach
Algorithms that explicitly penalize or avoid out-of-distribution actions, ensuring the learned policy stays close to the behavior demonstrated in the dataset. The key principle: it’s better to be pessimistic about the unknown than to be confidently wrong.
Think of it like restaurant reviews. If you’ve never tried a dish, you shouldn’t assume it’s 5 stars. Conservative methods say: “If I haven’t seen it in the data, I’ll assume it’s worse than what I have seen.”
This pessimism keeps the learned policy close to behaviors actually supported by the data—behaviors we know work reasonably well.
Behavior Cloning: The Simplest Baseline
Before diving into sophisticated methods, let’s start with the simplest approach: just imitate the data.
Supervised learning on the offline dataset: train a policy to predict the action taken by the behavior policy given each state. No RL at all—just imitation.
Behavior cloning minimizes:
This is just cross-entropy loss for action prediction. The resulting policy imitates the behavior policy that collected the data.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
class BCPolicy(nn.Module):
"""Behavior cloning policy network."""
def __init__(self, state_dim, n_actions, hidden_dim=256):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, state):
return self.net(state)
def get_action(self, state, deterministic=True):
with torch.no_grad():
logits = self.forward(torch.FloatTensor(state).unsqueeze(0))
if deterministic:
return logits.argmax(dim=-1).item()
else:
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, 1).item()
def train_behavior_cloning(dataset, policy, optimizer, epochs=100, batch_size=256):
"""Train behavior cloning policy."""
policy.train()
for epoch in range(epochs):
total_loss = 0
n_batches = 0
# Shuffle and iterate through dataset
for _ in range(len(dataset.transitions) // batch_size):
states, actions, _, _, _ = dataset.sample(batch_size)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
# Cross-entropy loss
logits = policy(states)
loss = F.cross_entropy(logits, actions)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
n_batches += 1
if epoch % 10 == 0:
print(f"Epoch {epoch}, Loss: {total_loss / n_batches:.4f}")
return policyWhy behavior cloning works (sometimes): It stays perfectly in-distribution. The policy only outputs actions that appear in the dataset for similar states.
Why behavior cloning fails (often): It can only be as good as the behavior policy. If the data came from mediocre behavior, BC learns mediocre behavior. No improvement is possible.
This is where RL methods shine—they can potentially improve over the behavior policy by stitching together good parts of different trajectories.
Conservative Q-Learning (CQL)
Conservative Q-Learning (CQL) is the landmark algorithm for offline RL. It adds a penalty that explicitly pushes down Q-values for OOD actions while maintaining accuracy for in-distribution actions.
CQL adds a regularization term to the standard Q-learning objective:
Let’s break down the CQL penalty:
-
: This is like a soft-max over all Q-values. It pushes down the Q-values of all actions.
-
: This pulls up the Q-values of actions seen in the dataset.
Net effect: Q-values for OOD actions get pushed down; Q-values for dataset actions stay accurate. The policy will prefer dataset actions.
The CQL penalty creates a “conservative cushion”:
- Dataset actions: Q-values trained normally on real transitions
- OOD actions: Q-values pushed down by the penalty
When the policy takes argmax, it selects from dataset actions (where Q-values are accurate) rather than OOD actions (where Q-values are artificially lowered).
This doesn’t mean the policy exactly copies the behavior policy—it can still select the best dataset actions. It just won’t hallucinate that unseen actions are better.
class CQL:
"""
Conservative Q-Learning for offline RL.
Adds a penalty that pushes down Q-values for OOD actions.
"""
def __init__(self, state_dim, n_actions, hidden_dim=256,
lr=3e-4, gamma=0.99, alpha=1.0, tau=0.005):
self.n_actions = n_actions
self.gamma = gamma
self.alpha = alpha # CQL regularization strength
self.tau = tau
# Q-networks
self.q_net1 = QNetwork(state_dim, n_actions, hidden_dim)
self.q_net2 = QNetwork(state_dim, n_actions, hidden_dim)
self.target_q1 = QNetwork(state_dim, n_actions, hidden_dim)
self.target_q2 = QNetwork(state_dim, n_actions, hidden_dim)
# Copy to targets
self.target_q1.load_state_dict(self.q_net1.state_dict())
self.target_q2.load_state_dict(self.q_net2.state_dict())
self.optimizer = optim.Adam(
list(self.q_net1.parameters()) + list(self.q_net2.parameters()),
lr=lr
)
def compute_cql_penalty(self, q_values, actions):
"""
Compute the CQL conservative penalty.
Args:
q_values: Q-values for all actions [batch, n_actions]
actions: Actions from dataset [batch]
Returns:
Conservative penalty value
"""
# Log-sum-exp over all actions (pushes down all Q-values)
logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)
# Q-values for dataset actions (pulls up dataset Q-values)
q_dataset = q_values.gather(1, actions.unsqueeze(1))
# Penalty: logsumexp - dataset actions
penalty = (logsumexp - q_dataset).mean()
return penalty
def update(self, batch):
"""CQL update step."""
states, actions, rewards, next_states, dones = batch
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Current Q-values
q1 = self.q_net1(states)
q2 = self.q_net2(states)
q1_taken = q1.gather(1, actions.unsqueeze(1)).squeeze(1)
q2_taken = q2.gather(1, actions.unsqueeze(1)).squeeze(1)
# Target Q-values (min of two targets for double Q-learning)
with torch.no_grad():
next_q1 = self.target_q1(next_states)
next_q2 = self.target_q2(next_states)
next_q = torch.min(next_q1, next_q2)
max_next_q = next_q.max(dim=1)[0]
targets = rewards + self.gamma * (1 - dones) * max_next_q
# TD loss
td_loss1 = F.mse_loss(q1_taken, targets)
td_loss2 = F.mse_loss(q2_taken, targets)
td_loss = td_loss1 + td_loss2
# CQL penalty
cql_penalty1 = self.compute_cql_penalty(q1, actions)
cql_penalty2 = self.compute_cql_penalty(q2, actions)
cql_penalty = cql_penalty1 + cql_penalty2
# Total loss
loss = td_loss + self.alpha * cql_penalty
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Soft update targets
self._soft_update()
return {
'loss': loss.item(),
'td_loss': td_loss.item(),
'cql_penalty': cql_penalty.item()
}
def _soft_update(self):
"""Soft update target networks."""
for param, target_param in zip(self.q_net1.parameters(), self.target_q1.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
for param, target_param in zip(self.q_net2.parameters(), self.target_q2.parameters()):
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
def get_action(self, state):
"""Select action using Q-network (deterministic)."""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q1 = self.q_net1(state_tensor)
q2 = self.q_net2(state_tensor)
q = torch.min(q1, q2)
return q.argmax(dim=-1).item()Balancing Conservatism and Optimality
The parameter in CQL controls the conservatism-optimality tradeoff:
- High : Very conservative. Policy stays very close to behavior policy. Safe but may not improve.
- Low : Less conservative. Policy can deviate more from data. May improve but risks selecting OOD actions.
The right depends on:
- Dataset coverage (better coverage = can be less conservative)
- Behavior policy quality (poor behavior = need more freedom to improve)
- Safety requirements (safety-critical = more conservative)
Under certain assumptions, CQL provides a lower bound on the true Q-values:
This means CQL’s policy is evaluated pessimistically. If the pessimistic evaluation says the policy is good, it’s actually at least that good (probably better) in reality.
This is the theoretical foundation for safe deployment: we underestimate performance during training, so real performance should exceed expectations.
Batch-Constrained Q-Learning (BCQ)
Another approach: explicitly restrict the policy to only consider actions similar to those in the dataset.
BCQ learns a generative model of the behavior policy, then only considers actions that would produce:
In practice, BCQ generates candidate actions from and picks the one with highest Q-value. This ensures the policy never selects actions far from the data.
class BCQ:
"""
Batch-Constrained Q-Learning.
Restricts policy to actions supported by the behavior policy.
"""
def __init__(self, state_dim, n_actions, hidden_dim=256, threshold=0.3):
self.threshold = threshold # Action similarity threshold
# Q-network
self.q_net = QNetwork(state_dim, n_actions, hidden_dim)
# Behavior cloning model (generative model of behavior policy)
self.bc_model = BCPolicy(state_dim, n_actions, hidden_dim)
self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=3e-4)
self.bc_optimizer = optim.Adam(self.bc_model.parameters(), lr=3e-4)
def train_bc(self, dataset, epochs=50, batch_size=256):
"""Pre-train behavior cloning model."""
for epoch in range(epochs):
states, actions, _, _, _ = dataset.sample(batch_size * 10)
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
logits = self.bc_model(states)
loss = F.cross_entropy(logits, actions)
self.bc_optimizer.zero_grad()
loss.backward()
self.bc_optimizer.step()
def get_action_mask(self, states):
"""
Get mask of allowed actions based on behavior policy.
Actions with probability below threshold are masked out.
"""
with torch.no_grad():
logits = self.bc_model(states)
probs = F.softmax(logits, dim=-1)
mask = (probs >= self.threshold).float()
# Ensure at least one action is allowed
max_probs = probs.max(dim=-1, keepdim=True)[0]
mask = mask + (probs >= max_probs * 0.9).float()
mask = (mask > 0).float()
return mask
def get_action(self, state):
"""Select action: best Q among allowed actions."""
with torch.no_grad():
state_tensor = torch.FloatTensor(state).unsqueeze(0)
q_values = self.q_net(state_tensor)
mask = self.get_action_mask(state_tensor)
# Mask out disallowed actions with very negative value
masked_q = q_values - 1e8 * (1 - mask)
return masked_q.argmax(dim=-1).item()
def update(self, batch, gamma=0.99):
"""BCQ update step."""
states, actions, rewards, next_states, dones = batch
states = torch.FloatTensor(states)
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
dones = torch.FloatTensor(dones)
# Current Q-values
q_values = self.q_net(states)
q_taken = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)
# Target: max over ALLOWED actions only
with torch.no_grad():
next_q = self.q_net(next_states)
next_mask = self.get_action_mask(next_states)
masked_next_q = next_q - 1e8 * (1 - next_mask)
max_next_q = masked_next_q.max(dim=1)[0]
targets = rewards + gamma * (1 - dones) * max_next_q
loss = F.mse_loss(q_taken, targets)
self.q_optimizer.zero_grad()
loss.backward()
self.q_optimizer.step()
return loss.item()Decision Transformer: RL as Sequence Modeling
A recent paradigm shift: treat offline RL as sequence modeling. Decision Transformer uses a transformer to predict actions given past states, actions, and desired future returns.
Instead of learning Q-values, Decision Transformer learns: “What action would lead to return R from state S?”
At test time, you condition on a high desired return, and the model outputs actions that historically led to high returns. No explicit Q-values, no OOD action problem—just sequence prediction.
Decision Transformer models trajectories as sequences:
where is return-to-go (sum of future rewards).
The model is trained to predict given . At test time, set to a high target return, and the model outputs actions to achieve it.
Decision Transformer foreshadows how language models are trained with RL. In RLHF, we also use offline data (human preferences) and sequence models (transformers). The next chapter on RLHF will build on these ideas.
Comparing Conservative Methods
| Method | Approach | Pros | Cons |
|---|---|---|---|
| Behavior Cloning | Imitate data | Simple, safe | Can’t improve over data |
| CQL | Penalize OOD Q-values | Principled, flexible | Hyperparameter sensitive |
| BCQ | Restrict action space | Intuitive, effective | Requires good BC model |
| Decision Transformer | Sequence modeling | Simple, scalable | Needs trajectory data |
Summary
Conservative methods are the key to making offline RL work:
- Behavior cloning imitates the data but can’t improve
- CQL explicitly penalizes OOD actions in the Q-function
- BCQ restricts the policy to dataset-supported actions
- Decision Transformer reframes the problem as sequence modeling
The common theme: stay close to the data. In offline RL, overconfidence in unseen actions is the enemy. Conservative methods embrace pessimism—and that pessimism enables safe, practical offline learning.
These techniques are the foundation for training AI systems from human data, including the RLHF methods used to train modern language models.