Advanced Topics • Part 3 of 3
📝Draft

Conservative Methods

Staying close to the data

Conservative Methods

If the problem is overestimating out-of-distribution actions, the solution is to be conservative: actively push down Q-values for OOD actions, and trust only what we’ve actually seen in the data.

The Conservative Approach

📖Conservative Offline RL

Algorithms that explicitly penalize or avoid out-of-distribution actions, ensuring the learned policy stays close to the behavior demonstrated in the dataset. The key principle: it’s better to be pessimistic about the unknown than to be confidently wrong.

Think of it like restaurant reviews. If you’ve never tried a dish, you shouldn’t assume it’s 5 stars. Conservative methods say: “If I haven’t seen it in the data, I’ll assume it’s worse than what I have seen.”

This pessimism keeps the learned policy close to behaviors actually supported by the data—behaviors we know work reasonably well.

Behavior Cloning: The Simplest Baseline

Before diving into sophisticated methods, let’s start with the simplest approach: just imitate the data.

📖Behavior Cloning

Supervised learning on the offline dataset: train a policy to predict the action taken by the behavior policy given each state. No RL at all—just imitation.

Mathematical Details

Behavior cloning minimizes:

LBC(θ)=E(s,a)D[logπθ(as)]L_{BC}(\theta) = \mathbb{E}_{(s, a) \sim D}\left[ -\log \pi_\theta(a|s) \right]

This is just cross-entropy loss for action prediction. The resulting policy imitates the behavior policy πβ\pi_\beta that collected the data.

</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

class BCPolicy(nn.Module):
    """Behavior cloning policy network."""

    def __init__(self, state_dim, n_actions, hidden_dim=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

    def forward(self, state):
        return self.net(state)

    def get_action(self, state, deterministic=True):
        with torch.no_grad():
            logits = self.forward(torch.FloatTensor(state).unsqueeze(0))
            if deterministic:
                return logits.argmax(dim=-1).item()
            else:
                probs = F.softmax(logits, dim=-1)
                return torch.multinomial(probs, 1).item()


def train_behavior_cloning(dataset, policy, optimizer, epochs=100, batch_size=256):
    """Train behavior cloning policy."""
    policy.train()

    for epoch in range(epochs):
        total_loss = 0
        n_batches = 0

        # Shuffle and iterate through dataset
        for _ in range(len(dataset.transitions) // batch_size):
            states, actions, _, _, _ = dataset.sample(batch_size)

            states = torch.FloatTensor(states)
            actions = torch.LongTensor(actions)

            # Cross-entropy loss
            logits = policy(states)
            loss = F.cross_entropy(logits, actions)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            n_batches += 1

        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {total_loss / n_batches:.4f}")

    return policy

Why behavior cloning works (sometimes): It stays perfectly in-distribution. The policy only outputs actions that appear in the dataset for similar states.

Why behavior cloning fails (often): It can only be as good as the behavior policy. If the data came from mediocre behavior, BC learns mediocre behavior. No improvement is possible.

This is where RL methods shine—they can potentially improve over the behavior policy by stitching together good parts of different trajectories.

Conservative Q-Learning (CQL)

Conservative Q-Learning (CQL) is the landmark algorithm for offline RL. It adds a penalty that explicitly pushes down Q-values for OOD actions while maintaining accuracy for in-distribution actions.

Mathematical Details

CQL adds a regularization term to the standard Q-learning objective:

LCQL=LTD+αEsD[logaexp(Q(s,a))EaD(as)[Q(s,a)]]L_{CQL} = L_{TD} + \alpha \cdot \mathbb{E}_{s \sim D}\left[ \log \sum_a \exp(Q(s,a)) - \mathbb{E}_{a \sim D(a|s)}[Q(s,a)] \right]

Let’s break down the CQL penalty:

  • logaexp(Q(s,a))\log \sum_a \exp(Q(s,a)): This is like a soft-max over all Q-values. It pushes down the Q-values of all actions.

  • EaD(as)[Q(s,a)]\mathbb{E}_{a \sim D(a|s)}[Q(s,a)]: This pulls up the Q-values of actions seen in the dataset.

Net effect: Q-values for OOD actions get pushed down; Q-values for dataset actions stay accurate. The policy will prefer dataset actions.

The CQL penalty creates a “conservative cushion”:

  • Dataset actions: Q-values trained normally on real transitions
  • OOD actions: Q-values pushed down by the penalty

When the policy takes argmax, it selects from dataset actions (where Q-values are accurate) rather than OOD actions (where Q-values are artificially lowered).

This doesn’t mean the policy exactly copies the behavior policy—it can still select the best dataset actions. It just won’t hallucinate that unseen actions are better.

</>Implementation
class CQL:
    """
    Conservative Q-Learning for offline RL.

    Adds a penalty that pushes down Q-values for OOD actions.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=256,
                 lr=3e-4, gamma=0.99, alpha=1.0, tau=0.005):
        self.n_actions = n_actions
        self.gamma = gamma
        self.alpha = alpha  # CQL regularization strength
        self.tau = tau

        # Q-networks
        self.q_net1 = QNetwork(state_dim, n_actions, hidden_dim)
        self.q_net2 = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_q1 = QNetwork(state_dim, n_actions, hidden_dim)
        self.target_q2 = QNetwork(state_dim, n_actions, hidden_dim)

        # Copy to targets
        self.target_q1.load_state_dict(self.q_net1.state_dict())
        self.target_q2.load_state_dict(self.q_net2.state_dict())

        self.optimizer = optim.Adam(
            list(self.q_net1.parameters()) + list(self.q_net2.parameters()),
            lr=lr
        )

    def compute_cql_penalty(self, q_values, actions):
        """
        Compute the CQL conservative penalty.

        Args:
            q_values: Q-values for all actions [batch, n_actions]
            actions: Actions from dataset [batch]

        Returns:
            Conservative penalty value
        """
        # Log-sum-exp over all actions (pushes down all Q-values)
        logsumexp = torch.logsumexp(q_values, dim=1, keepdim=True)

        # Q-values for dataset actions (pulls up dataset Q-values)
        q_dataset = q_values.gather(1, actions.unsqueeze(1))

        # Penalty: logsumexp - dataset actions
        penalty = (logsumexp - q_dataset).mean()
        return penalty

    def update(self, batch):
        """CQL update step."""
        states, actions, rewards, next_states, dones = batch

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Current Q-values
        q1 = self.q_net1(states)
        q2 = self.q_net2(states)
        q1_taken = q1.gather(1, actions.unsqueeze(1)).squeeze(1)
        q2_taken = q2.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target Q-values (min of two targets for double Q-learning)
        with torch.no_grad():
            next_q1 = self.target_q1(next_states)
            next_q2 = self.target_q2(next_states)
            next_q = torch.min(next_q1, next_q2)
            max_next_q = next_q.max(dim=1)[0]
            targets = rewards + self.gamma * (1 - dones) * max_next_q

        # TD loss
        td_loss1 = F.mse_loss(q1_taken, targets)
        td_loss2 = F.mse_loss(q2_taken, targets)
        td_loss = td_loss1 + td_loss2

        # CQL penalty
        cql_penalty1 = self.compute_cql_penalty(q1, actions)
        cql_penalty2 = self.compute_cql_penalty(q2, actions)
        cql_penalty = cql_penalty1 + cql_penalty2

        # Total loss
        loss = td_loss + self.alpha * cql_penalty

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Soft update targets
        self._soft_update()

        return {
            'loss': loss.item(),
            'td_loss': td_loss.item(),
            'cql_penalty': cql_penalty.item()
        }

    def _soft_update(self):
        """Soft update target networks."""
        for param, target_param in zip(self.q_net1.parameters(), self.target_q1.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
        for param, target_param in zip(self.q_net2.parameters(), self.target_q2.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

    def get_action(self, state):
        """Select action using Q-network (deterministic)."""
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q1 = self.q_net1(state_tensor)
            q2 = self.q_net2(state_tensor)
            q = torch.min(q1, q2)
            return q.argmax(dim=-1).item()

Balancing Conservatism and Optimality

The α\alpha parameter in CQL controls the conservatism-optimality tradeoff:

  • High α\alpha: Very conservative. Policy stays very close to behavior policy. Safe but may not improve.
  • Low α\alpha: Less conservative. Policy can deviate more from data. May improve but risks selecting OOD actions.

The right α\alpha depends on:

  • Dataset coverage (better coverage = can be less conservative)
  • Behavior policy quality (poor behavior = need more freedom to improve)
  • Safety requirements (safety-critical = more conservative)
Mathematical Details

Under certain assumptions, CQL provides a lower bound on the true Q-values:

QCQL(s,a)Qπ(s,a)Q_{CQL}(s, a) \leq Q^\pi(s, a)

This means CQL’s policy is evaluated pessimistically. If the pessimistic evaluation says the policy is good, it’s actually at least that good (probably better) in reality.

This is the theoretical foundation for safe deployment: we underestimate performance during training, so real performance should exceed expectations.

Batch-Constrained Q-Learning (BCQ)

Another approach: explicitly restrict the policy to only consider actions similar to those in the dataset.

Mathematical Details

BCQ learns a generative model G(s)G(s) of the behavior policy, then only considers actions that GG would produce:

π(s)=argmaxa:G(s) would output aQ(s,a)\pi(s) = \arg\max_{a : G(s) \text{ would output } a} Q(s, a)

In practice, BCQ generates candidate actions from GG and picks the one with highest Q-value. This ensures the policy never selects actions far from the data.

</>Implementation
class BCQ:
    """
    Batch-Constrained Q-Learning.

    Restricts policy to actions supported by the behavior policy.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=256, threshold=0.3):
        self.threshold = threshold  # Action similarity threshold

        # Q-network
        self.q_net = QNetwork(state_dim, n_actions, hidden_dim)

        # Behavior cloning model (generative model of behavior policy)
        self.bc_model = BCPolicy(state_dim, n_actions, hidden_dim)

        self.q_optimizer = optim.Adam(self.q_net.parameters(), lr=3e-4)
        self.bc_optimizer = optim.Adam(self.bc_model.parameters(), lr=3e-4)

    def train_bc(self, dataset, epochs=50, batch_size=256):
        """Pre-train behavior cloning model."""
        for epoch in range(epochs):
            states, actions, _, _, _ = dataset.sample(batch_size * 10)
            states = torch.FloatTensor(states)
            actions = torch.LongTensor(actions)

            logits = self.bc_model(states)
            loss = F.cross_entropy(logits, actions)

            self.bc_optimizer.zero_grad()
            loss.backward()
            self.bc_optimizer.step()

    def get_action_mask(self, states):
        """
        Get mask of allowed actions based on behavior policy.

        Actions with probability below threshold are masked out.
        """
        with torch.no_grad():
            logits = self.bc_model(states)
            probs = F.softmax(logits, dim=-1)
            mask = (probs >= self.threshold).float()

            # Ensure at least one action is allowed
            max_probs = probs.max(dim=-1, keepdim=True)[0]
            mask = mask + (probs >= max_probs * 0.9).float()
            mask = (mask > 0).float()

        return mask

    def get_action(self, state):
        """Select action: best Q among allowed actions."""
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_net(state_tensor)
            mask = self.get_action_mask(state_tensor)

            # Mask out disallowed actions with very negative value
            masked_q = q_values - 1e8 * (1 - mask)
            return masked_q.argmax(dim=-1).item()

    def update(self, batch, gamma=0.99):
        """BCQ update step."""
        states, actions, rewards, next_states, dones = batch

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Current Q-values
        q_values = self.q_net(states)
        q_taken = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Target: max over ALLOWED actions only
        with torch.no_grad():
            next_q = self.q_net(next_states)
            next_mask = self.get_action_mask(next_states)
            masked_next_q = next_q - 1e8 * (1 - next_mask)
            max_next_q = masked_next_q.max(dim=1)[0]
            targets = rewards + gamma * (1 - dones) * max_next_q

        loss = F.mse_loss(q_taken, targets)

        self.q_optimizer.zero_grad()
        loss.backward()
        self.q_optimizer.step()

        return loss.item()

Decision Transformer: RL as Sequence Modeling

A recent paradigm shift: treat offline RL as sequence modeling. Decision Transformer uses a transformer to predict actions given past states, actions, and desired future returns.

Instead of learning Q-values, Decision Transformer learns: “What action would lead to return R from state S?”

At test time, you condition on a high desired return, and the model outputs actions that historically led to high returns. No explicit Q-values, no OOD action problem—just sequence prediction.

Mathematical Details

Decision Transformer models trajectories as sequences:

τ=(R1,s1,a1,R2,s2,a2,...,RT,sT,aT)\tau = (R_1, s_1, a_1, R_2, s_2, a_2, ..., R_T, s_T, a_T)

where Rt=t=tTrtR_t = \sum_{t'=t}^T r_{t'} is return-to-go (sum of future rewards).

The model is trained to predict ata_t given (R1,s1,a1,...,Rt,st)(R_1, s_1, a_1, ..., R_t, s_t). At test time, set R1R_1 to a high target return, and the model outputs actions to achieve it.

ℹ️Preview: Connection to RLHF

Decision Transformer foreshadows how language models are trained with RL. In RLHF, we also use offline data (human preferences) and sequence models (transformers). The next chapter on RLHF will build on these ideas.

Comparing Conservative Methods

MethodApproachProsCons
Behavior CloningImitate dataSimple, safeCan’t improve over data
CQLPenalize OOD Q-valuesPrincipled, flexibleHyperparameter sensitive
BCQRestrict action spaceIntuitive, effectiveRequires good BC model
Decision TransformerSequence modelingSimple, scalableNeeds trajectory data

Summary

Conservative methods are the key to making offline RL work:

  • Behavior cloning imitates the data but can’t improve
  • CQL explicitly penalizes OOD actions in the Q-function
  • BCQ restricts the policy to dataset-supported actions
  • Decision Transformer reframes the problem as sequence modeling

The common theme: stay close to the data. In offline RL, overconfidence in unseen actions is the enemy. Conservative methods embrace pessimism—and that pessimism enables safe, practical offline learning.

These techniques are the foundation for training AI systems from human data, including the RLHF methods used to train modern language models.