Deep Reinforcement Learning • Part 4 of 4
📝Draft

Rainbow: Combining Improvements

The sum is greater than its parts

Rainbow: Combining Improvements

We’ve explored several DQN improvements in isolation: Double DQN reduces overestimation, Prioritized Experience Replay focuses on important transitions, and Dueling Networks separate state value from action advantage. Each provides meaningful gains over vanilla DQN.

But what happens when we combine them all?

Rainbow answers this question definitively. The 2017 DeepMind paper integrated six orthogonal improvements into a single agent, demonstrating that the combination achieves far more than any individual component. Rainbow became the new state-of-the-art on Atari, crushing both vanilla DQN and each improvement in isolation.

The Six Components

📖Rainbow DQN

Rainbow combines six DQN improvements:

  1. Double DQN: Decouple action selection from value estimation
  2. Prioritized Experience Replay: Sample important transitions more often
  3. Dueling Networks: Separate value and advantage streams
  4. Multi-step Learning: Use n-step returns instead of single-step TD
  5. Distributional RL: Learn the distribution of returns, not just the mean
  6. Noisy Networks: Replace epsilon-greedy with parametric noise for exploration

Think of each improvement as addressing a different weakness of DQN:

ProblemSolution
Q-values are systematically too highDouble DQN
Wasting time on uninformative samplesPrioritized Replay
Not separating “good state” from “good action”Dueling Networks
Slow credit assignment over long episodesMulti-step Learning
Ignoring uncertainty in value estimatesDistributional RL
Crude epsilon-greedy explorationNoisy Networks

Rainbow doesn’t introduce new ideas. It simply asks: “What if we used all the good ideas at once?”

Component 1: Double DQN (Review)

Standard DQN uses the target network for both selecting and evaluating the best action: y=r+γmaxaQ(s,a;θ)y = r + \gamma \max_{a'} Q(s', a'; \theta^-)

This creates overestimation: noisy Q-values cause us to consistently pick overestimated actions.

Double DQN fix: Use the online network to select the action, target network to evaluate it: y=r+γQ(s,argmaxaQ(s,a;θ);θ)y = r + \gamma Q(s', \arg\max_{a'} Q(s', a'; \theta); \theta^-)

Mathematical Details

The overestimation bound from the original paper shows that with noisy estimates: E[maxaQ(s,a)]maxaE[Q(s,a)]\mathbb{E}[\max_a Q(s, a)] \geq \max_a \mathbb{E}[Q(s, a)]

Double DQN breaks the correlation between selection and evaluation, reducing this bias.

Component 2: Prioritized Experience Replay (Review)

Uniform random sampling wastes time replaying transitions we’ve already learned from. Prioritized replay samples proportionally to TD error:

P(i)δiα+ϵP(i) \propto |\delta_i|^\alpha + \epsilon

High-error transitions are sampled more often, accelerating learning. Importance sampling weights correct for the bias this introduces:

wi=(1NP(i))βw_i = \left( \frac{1}{N \cdot P(i)} \right)^\beta

Mathematical Details

The priority is typically: pi=δiα+ϵp_i = |\delta_i|^\alpha + \epsilon

where δi=r+γQ(s,a;θ)Q(s,a;θ)\delta_i = r + \gamma Q(s', a^*; \theta^-) - Q(s, a; \theta) is the TD error.

Rainbow uses α=0.5\alpha = 0.5 and anneals β\beta from 0.4 to 1.0 over training.

Component 3: Dueling Networks (Review)

Instead of learning Q directly, decompose into value and advantage: Q(s,a)=V(s)+A(s,a)1AaA(s,a)Q(s, a) = V(s) + A(s, a) - \frac{1}{|A|}\sum_{a'} A(s, a')

This lets the network learn state value even when actions don’t matter much.

Component 4: Multi-step Learning

📖n-step Returns

Instead of bootstrapping after one step, accumulate rewards for n steps before bootstrapping: Gt(n)=k=0n1γkrt+k+1+γnmaxaQ(st+n,a)G_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k r_{t+k+1} + \gamma^n \max_{a'} Q(s_{t+n}, a')

Standard TD uses 1-step returns: Gt(1)=rt+1+γmaxaQ(st+1,a)G_t^{(1)} = r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a')

This heavily relies on the Q-estimate for bootstrapping. If Q is inaccurate (early in training), we propagate errors.

Multi-step returns use actual rewards for more steps before bootstrapping: Gt(3)=rt+1+γrt+2+γ2rt+3+γ3maxaQ(st+3,a)G_t^{(3)} = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \gamma^3 \max_{a'} Q(s_{t+3}, a')

Benefits:

  • Faster credit assignment: rewards propagate n times faster
  • Less reliance on potentially inaccurate Q-estimates
  • Often lower variance when n is moderate (3-5 steps)

Trade-off:

  • More bias if the policy changes during the n steps
  • More complex to implement with replay buffers
  • Can increase variance if n is too large

Rainbow uses n=3, which provides a good balance.

Mathematical Details

The n-step return is: Gt(n)=k=0n1γkrt+k+1+γnV(st+n)G_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k r_{t+k+1} + \gamma^n V(s_{t+n})

For Q-learning with Double DQN: Gt(n)=k=0n1γkrt+k+1+γnQ(st+n,argmaxaQ(st+n,a;θ);θ)G_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k r_{t+k+1} + \gamma^n Q(s_{t+n}, \arg\max_{a'} Q(s_{t+n}, a'; \theta); \theta^-)

Why n=3?

The Rainbow paper found n=3 gave the best results. Smaller values don’t propagate credit fast enough. Larger values introduce too much off-policy bias since the replay buffer contains old transitions.

Handling episode boundaries:

If the episode terminates before n steps, we truncate the return: Gt(n)=k=0Tt1γkrt+k+1G_t^{(n)} = \sum_{k=0}^{T-t-1} \gamma^k r_{t+k+1}

where T is the terminal timestep.

</>Implementation
import torch
import numpy as np
from collections import deque
from typing import List, Tuple, Dict, Optional

class NStepBuffer:
    """
    Buffer for computing n-step returns.

    Stores the last n transitions and computes n-step returns
    when the buffer is full or an episode ends.
    """

    def __init__(self, n_steps: int = 3, gamma: float = 0.99):
        self.n_steps = n_steps
        self.gamma = gamma
        self.buffer: deque = deque(maxlen=n_steps)

    def append(self, state, action, reward, done) -> Optional[Tuple]:
        """
        Add a transition to the n-step buffer.

        Returns a complete n-step transition when ready, None otherwise.
        """
        self.buffer.append((state, action, reward, done))

        # If episode ended, flush remaining transitions
        if done:
            return self._flush_buffer()

        # If buffer is full, compute n-step return
        if len(self.buffer) == self.n_steps:
            return self._compute_n_step_return()

        return None

    def _compute_n_step_return(self) -> Tuple:
        """Compute n-step return from full buffer."""

        # Get the first transition
        state, action, _, _ = self.buffer[0]

        # Compute discounted sum of rewards
        n_step_reward = 0.0
        for i, (_, _, r, d) in enumerate(self.buffer):
            n_step_reward += (self.gamma ** i) * r
            if d:  # Episode ended early
                return (state, action, n_step_reward, None, True, i + 1)

        # Get the state n steps ahead
        last_state = self.buffer[-1][0]  # This would be s_{t+n}

        return (state, action, n_step_reward, last_state, False, self.n_steps)

    def _flush_buffer(self) -> List[Tuple]:
        """Flush buffer at episode end, computing partial n-step returns."""

        transitions = []
        while len(self.buffer) > 0:
            state, action, _, _ = self.buffer[0]

            # Compute return for remaining steps
            n_step_reward = 0.0
            for i, (_, _, r, d) in enumerate(self.buffer):
                n_step_reward += (self.gamma ** i) * r

            transitions.append((state, action, n_step_reward, None, True, len(self.buffer)))
            self.buffer.popleft()

        return transitions

    def reset(self):
        """Clear the buffer."""
        self.buffer.clear()


class NStepReplayBuffer:
    """
    Replay buffer with n-step return computation.

    Stores transitions with pre-computed n-step returns.
    """

    def __init__(self, capacity: int, n_steps: int = 3, gamma: float = 0.99):
        self.capacity = capacity
        self.n_steps = n_steps
        self.gamma = gamma

        self.n_step_buffer = NStepBuffer(n_steps, gamma)

        # Main storage
        self.states = []
        self.actions = []
        self.n_step_rewards = []  # Accumulated n-step rewards
        self.next_states = []     # State n steps ahead
        self.dones = []
        self.actual_n = []        # Actual number of steps (may be less than n)

        self.position = 0
        self.size = 0

    def push(self, state, action, reward, next_state, done):
        """Add a transition, computing n-step return when ready."""

        result = self.n_step_buffer.append(state, action, reward, done)

        if result is not None:
            if isinstance(result, list):  # Flushed at episode end
                for trans in result:
                    self._store_transition(*trans)
            else:
                # Need to also store the next state (n steps ahead)
                state, action, n_step_reward, _, is_done, actual_n = result
                self._store_transition(state, action, n_step_reward,
                                       next_state, is_done, actual_n)

        if done:
            self.n_step_buffer.reset()

    def _store_transition(self, state, action, n_step_reward,
                          next_state, done, actual_n):
        """Store a complete n-step transition."""

        if self.size < self.capacity:
            self.states.append(state)
            self.actions.append(action)
            self.n_step_rewards.append(n_step_reward)
            self.next_states.append(next_state)
            self.dones.append(done)
            self.actual_n.append(actual_n)
            self.size += 1
        else:
            idx = self.position % self.capacity
            self.states[idx] = state
            self.actions[idx] = action
            self.n_step_rewards[idx] = n_step_reward
            self.next_states[idx] = next_state
            self.dones[idx] = done
            self.actual_n[idx] = actual_n

        self.position += 1

    def sample(self, batch_size: int) -> Dict[str, torch.Tensor]:
        """Sample a batch of n-step transitions."""

        indices = np.random.choice(self.size, batch_size, replace=False)

        # Filter out terminal states from next_states (they'll be None)
        next_states = []
        for idx in indices:
            ns = self.next_states[idx]
            if ns is not None:
                next_states.append(ns)
            else:
                # Use zeros for terminal states (won't be used anyway)
                next_states.append(np.zeros_like(self.states[idx]))

        return {
            'states': torch.FloatTensor(np.array([self.states[i] for i in indices])),
            'actions': torch.LongTensor([self.actions[i] for i in indices]),
            'n_step_rewards': torch.FloatTensor([self.n_step_rewards[i] for i in indices]),
            'next_states': torch.FloatTensor(np.array(next_states)),
            'dones': torch.FloatTensor([self.dones[i] for i in indices]),
            'actual_n': torch.LongTensor([self.actual_n[i] for i in indices])
        }

    def __len__(self):
        return self.size


def compute_n_step_target(batch: Dict, online_net, target_net, gamma: float):
    """
    Compute n-step Double DQN targets.

    Target = n_step_reward + gamma^n * Q_target(s_{t+n}, argmax_a Q_online(s_{t+n}, a))
    """
    n_step_rewards = batch['n_step_rewards']
    next_states = batch['next_states']
    dones = batch['dones']
    actual_n = batch['actual_n']

    with torch.no_grad():
        # Double DQN: online selects, target evaluates
        next_q_online = online_net(next_states)
        best_actions = next_q_online.argmax(dim=1)

        next_q_target = target_net(next_states)
        next_q = next_q_target.gather(1, best_actions.unsqueeze(1)).squeeze(1)

        # Compute targets with variable discount based on actual n
        gamma_n = gamma ** actual_n.float()
        targets = n_step_rewards + gamma_n * next_q * (1 - dones)

    return targets

Component 5: Distributional RL

📖Distributional RL

Instead of learning expected returns Q(s,a)=E[Gt]Q(s, a) = \mathbb{E}[G_t], learn the full distribution of returns Z(s,a)Z(s, a) where Q(s,a)=E[Z(s,a)]Q(s, a) = \mathbb{E}[Z(s, a)].

Standard Q-learning treats returns as a single number. But returns are actually random variables. Two state-action pairs might have the same expected return but very different risk profiles:

Example:

  • Action A: Always returns exactly 10
  • Action B: Returns 0 or 20 with equal probability

Both have Q(s,a)=10Q(s, a) = 10, but they’re fundamentally different! Distributional RL captures this difference.

C51 Algorithm (used in Rainbow):

  • Model the return distribution with 51 atoms (discrete bins)
  • Each atom represents a possible return value
  • The network outputs probabilities for each atom
  • Use a categorical cross-entropy loss

Benefits:

  • Richer learning signal than scalar TD error
  • Better gradient properties
  • More stable optimization
  • Captures multimodality in returns
Mathematical Details

C51 parameterization:

Represent the distribution using 51 fixed atoms: zi=Vmin+iVmaxVminN1,i=0,...,N1z_i = V_{\min} + i \cdot \frac{V_{\max} - V_{\min}}{N-1}, \quad i = 0, ..., N-1

where N=51N = 51, Vmin=10V_{\min} = -10, Vmax=10V_{\max} = 10 for Atari.

The network outputs logits (s,a)RN\ell(s, a) \in \mathbb{R}^{N} which are converted to probabilities: pi(s,a)=ei(s,a)jej(s,a)p_i(s, a) = \frac{e^{\ell_i(s, a)}}{\sum_j e^{\ell_j(s, a)}}

Distributional Bellman equation:

The distribution update is: Z(s,a)=Dr+γZ(s,a)Z(s, a) \stackrel{D}{=} r + \gamma Z(s', a')

where =D\stackrel{D}{=} means equality in distribution.

Projection:

The target distribution may not align with our fixed atoms. We project it:

For target atom z^j=r+γzj\hat{z}_j = r + \gamma z_j:

  1. Clip to [Vmin,Vmax][V_{\min}, V_{\max}]
  2. Distribute probability to neighboring atoms proportionally

The loss is the cross-entropy between projected target and predicted distribution: L=imilogpi(s,a)L = -\sum_i m_i \log p_i(s, a)

where mim_i is the projected target probability for atom ii.

</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistributionalDQN(nn.Module):
    """
    C51 Distributional DQN.

    Outputs a probability distribution over returns for each action.
    """

    def __init__(self, state_dim: int, n_actions: int,
                 n_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0):
        super().__init__()

        self.n_actions = n_actions
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max

        # Support for the distribution
        self.register_buffer(
            'support',
            torch.linspace(v_min, v_max, n_atoms)
        )
        self.delta_z = (v_max - v_min) / (n_atoms - 1)

        # Network outputs n_actions * n_atoms values
        self.network = nn.Sequential(
            nn.Linear(state_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions * n_atoms)
        )

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """
        Returns log-probabilities over atoms for each action.

        Shape: [batch_size, n_actions, n_atoms]
        """
        batch_size = state.size(0)

        # Get raw outputs
        logits = self.network(state)

        # Reshape: [batch, n_actions * n_atoms] -> [batch, n_actions, n_atoms]
        logits = logits.view(batch_size, self.n_actions, self.n_atoms)

        # Apply log-softmax over atoms dimension
        log_probs = F.log_softmax(logits, dim=2)

        return log_probs

    def get_q_values(self, state: torch.Tensor) -> torch.Tensor:
        """Compute Q-values as expected value under distribution."""

        log_probs = self.forward(state)
        probs = log_probs.exp()  # [batch, n_actions, n_atoms]

        # Q(s,a) = sum_i p_i * z_i
        q_values = (probs * self.support).sum(dim=2)  # [batch, n_actions]

        return q_values


def project_distribution(next_dist: torch.Tensor,
                        rewards: torch.Tensor,
                        dones: torch.Tensor,
                        gamma: float,
                        support: torch.Tensor,
                        v_min: float,
                        v_max: float,
                        n_atoms: int) -> torch.Tensor:
    """
    Project the target distribution onto the fixed support.

    This handles the Bellman update: Z' = r + gamma * Z

    Args:
        next_dist: Probabilities over atoms for next state [batch, n_atoms]
        rewards: Rewards [batch]
        dones: Done flags [batch]
        gamma: Discount factor
        support: The fixed atoms [n_atoms]
        v_min, v_max: Support bounds
        n_atoms: Number of atoms

    Returns:
        Projected distribution [batch, n_atoms]
    """
    batch_size = rewards.size(0)
    delta_z = (v_max - v_min) / (n_atoms - 1)

    # Compute target support: T_z = r + gamma * z (clipped)
    # Shape: [batch, n_atoms]
    target_support = rewards.unsqueeze(1) + gamma * (1 - dones.unsqueeze(1)) * support.unsqueeze(0)
    target_support = target_support.clamp(v_min, v_max)

    # Compute the projection
    # b = (T_z - v_min) / delta_z gives the float index
    b = (target_support - v_min) / delta_z

    # Lower and upper atom indices
    lower = b.floor().long()
    upper = b.ceil().long()

    # Handle edge case where b is exactly an integer
    lower = lower.clamp(0, n_atoms - 1)
    upper = upper.clamp(0, n_atoms - 1)

    # Distribute probability proportionally
    # m_l = p * (u - b), m_u = p * (b - l)
    m = torch.zeros(batch_size, n_atoms, device=rewards.device)

    # Upper and lower proportions
    upper_prop = (b - lower.float())
    lower_prop = 1 - upper_prop

    # Add probability mass to lower atoms
    m.scatter_add_(1, lower, next_dist * lower_prop)
    # Add probability mass to upper atoms
    m.scatter_add_(1, upper, next_dist * upper_prop)

    return m


def compute_distributional_loss(online_net: DistributionalDQN,
                                target_net: DistributionalDQN,
                                batch: dict,
                                gamma: float) -> torch.Tensor:
    """
    Compute the distributional RL loss (KL divergence).
    """
    states = batch['states']
    actions = batch['actions']
    rewards = batch['rewards']
    next_states = batch['next_states']
    dones = batch['dones']

    batch_size = states.size(0)

    # Get current distribution (log probabilities)
    log_probs = online_net(states)  # [batch, n_actions, n_atoms]

    # Select distribution for taken action
    actions_expanded = actions.unsqueeze(1).unsqueeze(2).expand(-1, -1, online_net.n_atoms)
    current_log_probs = log_probs.gather(1, actions_expanded).squeeze(1)  # [batch, n_atoms]

    with torch.no_grad():
        # Double DQN: online network selects best action
        next_q_values = online_net.get_q_values(next_states)
        best_actions = next_q_values.argmax(dim=1)

        # Target network evaluates distribution
        next_log_probs = target_net(next_states)
        best_actions_expanded = best_actions.unsqueeze(1).unsqueeze(2).expand(-1, -1, target_net.n_atoms)
        next_dist = next_log_probs.gather(1, best_actions_expanded).squeeze(1).exp()  # [batch, n_atoms]

        # Project target distribution
        target_dist = project_distribution(
            next_dist, rewards, dones, gamma,
            online_net.support, online_net.v_min, online_net.v_max, online_net.n_atoms
        )

    # Cross-entropy loss: -sum(target * log(current))
    loss = -(target_dist * current_log_probs).sum(dim=1).mean()

    return loss

Component 6: Noisy Networks

📖Noisy Networks

Replace epsilon-greedy exploration with learned parametric noise added to network weights. The network learns how much to explore in each state.

Epsilon-greedy exploration is crude: random actions with probability epsilon, regardless of state. This has problems:

  • Same exploration rate in all states
  • Random actions may not be informative
  • Need to tune epsilon schedule

Noisy Networks add learnable noise to the weights: y=(W+σϵ)x+(b+σbϵb)y = (W + \sigma \odot \epsilon) x + (b + \sigma_b \odot \epsilon_b)

where:

  • W,bW, b are the mean weights/biases
  • σ,σb\sigma, \sigma_b are learnable noise scales
  • ϵ,ϵb\epsilon, \epsilon_b are random noise samples

Key insight: The network learns WHERE to explore. In well-understood states, σ\sigma shrinks to zero. In uncertain states, σ\sigma stays large.

Factorized Gaussian noise (used in Rainbow):

  • Full noise requires O(input * output) random numbers per layer
  • Factorized noise uses O(input + output) random numbers
  • Negligible performance difference, large computational savings
Mathematical Details

Factorized Gaussian noise:

Instead of sampling independent noise for each weight: ϵijN(0,1)\epsilon_{ij} \sim \mathcal{N}(0, 1)

Use factorized noise: ϵij=f(ϵi)f(ϵj)\epsilon_{ij} = f(\epsilon_i) f(\epsilon_j)

where ϵiN(0,1)\epsilon_i \sim \mathcal{N}(0, 1) and f(x)=sign(x)xf(x) = \text{sign}(x)\sqrt{|x|}.

This reduces the number of random samples from p×qp \times q to p+qp + q for a layer with pp inputs and qq outputs.

Noisy linear layer: y=(μW+σWϵW)x+μb+σbϵby = (\mu^W + \sigma^W \odot \epsilon^W) x + \mu^b + \sigma^b \odot \epsilon^b

The parameters μW,μb,σW,σb\mu^W, \mu^b, \sigma^W, \sigma^b are all learned.

Initialization:

  • μ\mu initialized uniformly in [1p,1p][-\frac{1}{\sqrt{p}}, \frac{1}{\sqrt{p}}]
  • σ\sigma initialized to 0.5p\frac{0.5}{\sqrt{p}}
</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class NoisyLinear(nn.Module):
    """
    Noisy linear layer with factorized Gaussian noise.

    Replaces standard linear layer with learnable noise for exploration.
    """

    def __init__(self, in_features: int, out_features: int, sigma_init: float = 0.5):
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.sigma_init = sigma_init

        # Learnable parameters: mean weights and biases
        self.weight_mu = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_mu = nn.Parameter(torch.empty(out_features))

        # Learnable parameters: noise scales
        self.weight_sigma = nn.Parameter(torch.empty(out_features, in_features))
        self.bias_sigma = nn.Parameter(torch.empty(out_features))

        # Buffers for noise (not parameters, but saved in state_dict)
        self.register_buffer('weight_epsilon', torch.empty(out_features, in_features))
        self.register_buffer('bias_epsilon', torch.empty(out_features))

        self.reset_parameters()
        self.reset_noise()

    def reset_parameters(self):
        """Initialize parameters."""
        mu_range = 1 / math.sqrt(self.in_features)
        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.bias_mu.data.uniform_(-mu_range, mu_range)

        sigma_init = self.sigma_init / math.sqrt(self.in_features)
        self.weight_sigma.data.fill_(sigma_init)
        self.bias_sigma.data.fill_(sigma_init)

    def reset_noise(self):
        """Sample new noise."""
        epsilon_in = self._factorized_noise(self.in_features)
        epsilon_out = self._factorized_noise(self.out_features)

        # Outer product for factorized noise
        self.weight_epsilon.copy_(epsilon_out.outer(epsilon_in))
        self.bias_epsilon.copy_(epsilon_out)

    def _factorized_noise(self, size: int) -> torch.Tensor:
        """Generate factorized Gaussian noise: f(x) = sign(x) * sqrt(|x|)"""
        x = torch.randn(size, device=self.weight_mu.device)
        return x.sign() * x.abs().sqrt()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass with noisy weights."""
        if self.training:
            weight = self.weight_mu + self.weight_sigma * self.weight_epsilon
            bias = self.bias_mu + self.bias_sigma * self.bias_epsilon
        else:
            # No noise during evaluation
            weight = self.weight_mu
            bias = self.bias_mu

        return F.linear(x, weight, bias)


class NoisyDQN(nn.Module):
    """DQN with noisy layers for exploration."""

    def __init__(self, state_dim: int, n_actions: int, hidden_dim: int = 128):
        super().__init__()

        # Standard layers for feature extraction
        self.feature_layer = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # Noisy layers for value prediction
        self.noisy1 = NoisyLinear(hidden_dim, hidden_dim)
        self.noisy2 = NoisyLinear(hidden_dim, n_actions)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        features = self.feature_layer(state)
        x = F.relu(self.noisy1(features))
        q_values = self.noisy2(x)
        return q_values

    def reset_noise(self):
        """Reset noise in all noisy layers."""
        self.noisy1.reset_noise()
        self.noisy2.reset_noise()


def demonstrate_noisy_exploration():
    """Show how noise affects action selection."""

    net = NoisyDQN(state_dim=4, n_actions=3)
    state = torch.randn(1, 4)

    # Multiple forward passes with same state but different noise
    print("Q-values with different noise samples:")
    for i in range(5):
        net.reset_noise()
        q_values = net(state)
        action = q_values.argmax().item()
        print(f"  Trial {i+1}: Q = {q_values.detach().numpy().round(3)}, Action = {action}")

    # Evaluation mode (no noise)
    net.eval()
    q_values_eval = net(state)
    print(f"\nEvaluation mode (no noise): Q = {q_values_eval.detach().numpy().round(3)}")

Putting It All Together: Rainbow

Rainbow combines all six components:

  1. Architecture: Dueling network with noisy layers
  2. Target computation: Double DQN with n-step returns
  3. Value representation: Distributional (C51)
  4. Replay: Prioritized with importance sampling

The components are largely orthogonal, they address different aspects of the algorithm:

ComponentWhat it changes
Double DQNTarget computation
PERWhich samples to learn from
DuelingNetwork architecture (Q = V + A)
Multi-stepReward accumulation
DistributionalWhat we predict (distribution vs scalar)
Noisy NetsHow we explore
</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, Tuple

class RainbowNetwork(nn.Module):
    """
    Rainbow network combining:
    - Dueling architecture
    - Noisy layers
    - Distributional output (C51)
    """

    def __init__(self, state_dim: int, n_actions: int,
                 n_atoms: int = 51, v_min: float = -10.0, v_max: float = 10.0,
                 hidden_dim: int = 128):
        super().__init__()

        self.n_actions = n_actions
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max

        # Support for distributional RL
        self.register_buffer('support', torch.linspace(v_min, v_max, n_atoms))

        # Shared feature extraction (standard layers)
        self.feature_layer = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU()
        )

        # Value stream with noisy layers
        self.value_hidden = NoisyLinear(hidden_dim, hidden_dim)
        self.value_output = NoisyLinear(hidden_dim, n_atoms)

        # Advantage stream with noisy layers
        self.advantage_hidden = NoisyLinear(hidden_dim, hidden_dim)
        self.advantage_output = NoisyLinear(hidden_dim, n_actions * n_atoms)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        """
        Returns log-probabilities over atoms for each action.
        Combines dueling architecture with distributional output.
        """
        batch_size = state.size(0)
        features = self.feature_layer(state)

        # Value stream: [batch, n_atoms]
        value = F.relu(self.value_hidden(features))
        value = self.value_output(value).view(batch_size, 1, self.n_atoms)

        # Advantage stream: [batch, n_actions, n_atoms]
        advantage = F.relu(self.advantage_hidden(features))
        advantage = self.advantage_output(advantage).view(batch_size, self.n_actions, self.n_atoms)

        # Dueling combination for distributions
        # Q(s,a) distribution = V(s) + A(s,a) - mean(A)
        q_atoms = value + advantage - advantage.mean(dim=1, keepdim=True)

        # Apply log-softmax over atoms
        log_probs = F.log_softmax(q_atoms, dim=2)

        return log_probs

    def get_q_values(self, state: torch.Tensor) -> torch.Tensor:
        """Compute Q-values as expected value of distribution."""
        log_probs = self.forward(state)
        probs = log_probs.exp()
        q_values = (probs * self.support).sum(dim=2)
        return q_values

    def reset_noise(self):
        """Reset noise in all noisy layers."""
        self.value_hidden.reset_noise()
        self.value_output.reset_noise()
        self.advantage_hidden.reset_noise()
        self.advantage_output.reset_noise()


class RainbowAgent:
    """
    Full Rainbow agent combining:
    1. Double DQN (target computation)
    2. Prioritized Experience Replay
    3. Dueling Networks
    4. Multi-step Learning
    5. Distributional RL (C51)
    6. Noisy Networks
    """

    def __init__(self,
                 state_dim: int,
                 n_actions: int,
                 # Distributional parameters
                 n_atoms: int = 51,
                 v_min: float = -10.0,
                 v_max: float = 10.0,
                 # Multi-step parameters
                 n_steps: int = 3,
                 # General parameters
                 gamma: float = 0.99,
                 lr: float = 6.25e-5,
                 tau: float = 0.005,
                 # PER parameters
                 buffer_size: int = 100000,
                 alpha: float = 0.5,
                 beta_start: float = 0.4,
                 beta_frames: int = 100000):

        self.n_actions = n_actions
        self.n_atoms = n_atoms
        self.v_min = v_min
        self.v_max = v_max
        self.n_steps = n_steps
        self.gamma = gamma
        self.tau = tau

        # PER beta annealing
        self.beta_start = beta_start
        self.beta_frames = beta_frames
        self.frame_count = 0

        # Networks
        self.online_net = RainbowNetwork(
            state_dim, n_actions, n_atoms, v_min, v_max
        )
        self.target_net = RainbowNetwork(
            state_dim, n_actions, n_atoms, v_min, v_max
        )
        self.target_net.load_state_dict(self.online_net.state_dict())

        self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=lr)

        # N-step buffer for computing returns
        self.n_step_buffer = NStepBuffer(n_steps, gamma)

        # Prioritized replay buffer (simplified version)
        self.buffer_size = buffer_size
        self.alpha = alpha
        self.buffer = []
        self.priorities = []
        self.position = 0

    def select_action(self, state: torch.Tensor) -> int:
        """
        Select action using noisy network (no epsilon-greedy needed).
        """
        with torch.no_grad():
            q_values = self.online_net.get_q_values(state.unsqueeze(0))
            return q_values.argmax(dim=1).item()

    def push_transition(self, state, action, reward, next_state, done):
        """Add transition to n-step buffer, then to replay buffer."""

        result = self.n_step_buffer.append(state, action, reward, done)

        if result is not None:
            if isinstance(result, list):
                for trans in result:
                    self._add_to_buffer(trans, next_state)
            else:
                self._add_to_buffer(result, next_state)

        if done:
            self.n_step_buffer.reset()

    def _add_to_buffer(self, transition, next_state):
        """Add n-step transition to prioritized buffer."""

        state, action, n_step_reward, _, is_done, actual_n = transition

        # Use max priority for new transitions
        max_priority = max(self.priorities) if self.priorities else 1.0

        if len(self.buffer) < self.buffer_size:
            self.buffer.append((state, action, n_step_reward, next_state, is_done, actual_n))
            self.priorities.append(max_priority)
        else:
            idx = self.position % self.buffer_size
            self.buffer[idx] = (state, action, n_step_reward, next_state, is_done, actual_n)
            self.priorities[idx] = max_priority

        self.position += 1

    def sample_batch(self, batch_size: int) -> Tuple[Dict, np.ndarray, torch.Tensor]:
        """Sample batch with prioritized replay."""

        # Compute sampling probabilities
        priorities = np.array(self.priorities[:len(self.buffer)])
        probs = priorities ** self.alpha
        probs /= probs.sum()

        # Sample indices
        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=False)

        # Compute importance sampling weights
        self.frame_count += 1
        beta = min(1.0, self.beta_start + self.frame_count * (1.0 - self.beta_start) / self.beta_frames)
        weights = (len(self.buffer) * probs[indices]) ** (-beta)
        weights /= weights.max()
        weights = torch.FloatTensor(weights)

        # Extract batch
        batch = {
            'states': torch.FloatTensor(np.array([self.buffer[i][0] for i in indices])),
            'actions': torch.LongTensor([self.buffer[i][1] for i in indices]),
            'n_step_rewards': torch.FloatTensor([self.buffer[i][2] for i in indices]),
            'next_states': torch.FloatTensor(np.array([
                self.buffer[i][3] if self.buffer[i][3] is not None
                else np.zeros_like(self.buffer[i][0])
                for i in indices
            ])),
            'dones': torch.FloatTensor([self.buffer[i][4] for i in indices]),
            'actual_n': torch.LongTensor([self.buffer[i][5] for i in indices])
        }

        return batch, indices, weights

    def compute_loss(self, batch: Dict, weights: torch.Tensor) -> Tuple[torch.Tensor, np.ndarray]:
        """
        Compute Rainbow loss combining:
        - Distributional RL (C51)
        - Double DQN target selection
        - N-step returns
        - Importance sampling weights from PER
        """
        states = batch['states']
        actions = batch['actions']
        n_step_rewards = batch['n_step_rewards']
        next_states = batch['next_states']
        dones = batch['dones']
        actual_n = batch['actual_n']

        batch_size = states.size(0)

        # Reset noise for this forward pass
        self.online_net.reset_noise()
        self.target_net.reset_noise()

        # Get current distribution
        log_probs = self.online_net(states)
        actions_expanded = actions.unsqueeze(1).unsqueeze(2).expand(-1, -1, self.n_atoms)
        current_log_probs = log_probs.gather(1, actions_expanded).squeeze(1)

        with torch.no_grad():
            # Double DQN: online selects, target evaluates
            next_q_values = self.online_net.get_q_values(next_states)
            best_actions = next_q_values.argmax(dim=1)

            # Get target distribution for best action
            next_log_probs = self.target_net(next_states)
            best_expanded = best_actions.unsqueeze(1).unsqueeze(2).expand(-1, -1, self.n_atoms)
            next_dist = next_log_probs.gather(1, best_expanded).squeeze(1).exp()

            # Project with n-step discount
            gamma_n = self.gamma ** actual_n.float()
            target_dist = self._project_distribution(
                next_dist, n_step_rewards, dones, gamma_n
            )

        # Cross-entropy loss with importance sampling weights
        elementwise_loss = -(target_dist * current_log_probs).sum(dim=1)
        loss = (weights * elementwise_loss).mean()

        # TD errors for priority update
        td_errors = elementwise_loss.detach().cpu().numpy()

        return loss, td_errors

    def _project_distribution(self, next_dist, rewards, dones, gamma_n):
        """Project target distribution onto support."""

        batch_size = rewards.size(0)
        delta_z = (self.v_max - self.v_min) / (self.n_atoms - 1)

        # Target support: T_z = r + gamma^n * z
        support = self.online_net.support
        target_support = rewards.unsqueeze(1) + gamma_n.unsqueeze(1) * (1 - dones.unsqueeze(1)) * support

        # Clip to valid range
        target_support = target_support.clamp(self.v_min, self.v_max)

        # Compute projection indices
        b = (target_support - self.v_min) / delta_z
        lower = b.floor().long().clamp(0, self.n_atoms - 1)
        upper = b.ceil().long().clamp(0, self.n_atoms - 1)

        # Distribute probability
        m = torch.zeros(batch_size, self.n_atoms, device=rewards.device)
        offset = torch.arange(0, batch_size, device=rewards.device).unsqueeze(1) * self.n_atoms

        m.view(-1).index_add_(
            0, (lower + offset).view(-1),
            (next_dist * (upper.float() - b)).view(-1)
        )
        m.view(-1).index_add_(
            0, (upper + offset).view(-1),
            (next_dist * (b - lower.float())).view(-1)
        )

        return m

    def update_priorities(self, indices: np.ndarray, td_errors: np.ndarray):
        """Update priorities in replay buffer."""
        for idx, td_error in zip(indices, td_errors):
            self.priorities[idx] = abs(td_error) + 1e-6

    def soft_update_target(self):
        """Soft update target network."""
        for target_param, online_param in zip(
            self.target_net.parameters(),
            self.online_net.parameters()
        ):
            target_param.data.copy_(
                self.tau * online_param.data + (1 - self.tau) * target_param.data
            )

    def train_step(self, batch_size: int = 32) -> float:
        """Perform one training step."""

        if len(self.buffer) < batch_size:
            return 0.0

        # Sample batch
        batch, indices, weights = self.sample_batch(batch_size)

        # Compute loss
        loss, td_errors = self.compute_loss(batch, weights)

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.online_net.parameters(), 10.0)
        self.optimizer.step()

        # Update priorities
        self.update_priorities(indices, td_errors)

        # Soft update target
        self.soft_update_target()

        return loss.item()

Rainbow Hyperparameters

Rainbow uses specific hyperparameters tuned on Atari:

ParameterValueNotes
Learning rate6.25e-5Lower than standard DQN
Discount (gamma)0.99Standard
n-step3Multi-step returns
Atoms51C51 distribution
V_min, V_max-10, 10Support bounds
PER alpha0.5Prioritization exponent
PER beta0.4 to 1.0Importance sampling
Target updateEvery 8000 framesHard update
Replay size1M transitionsLarge buffer
Batch size32Standard

Key differences from vanilla DQN:

  • Lower learning rate (for distributional RL stability)
  • Hard target updates (instead of soft)
  • Larger replay buffer
  • No epsilon-greedy (noisy nets handle exploration)

Ablation Study: Which Components Matter Most?

The Rainbow paper performed ablation studies, removing one component at a time:

Most impactful (removing hurts most):

  1. Distributional RL: Largest drop in performance
  2. Multi-step learning: Second largest impact
  3. Prioritized replay: Significant impact

Moderate impact: 4. Noisy networks: Modest improvement 5. Dueling: Modest improvement

Still helps but less critical: 6. Double DQN: Smallest individual impact (but still helps!)

Key insight: The components interact. Removing any one component hurts, but some combinations are more impactful than others. Distributional + n-step is particularly powerful.

📌Relative Importance

From the ablation study on Atari (median human-normalized score):

ConfigurationMedian Score
Full Rainbow223%
Remove Distributional151%
Remove Multi-step163%
Remove Priority178%
Remove Noisy194%
Remove Dueling197%
Remove Double210%
Vanilla DQN68%

This shows Rainbow achieves over 3x the performance of DQN, and distributional RL contributes the most individually.

When to Use Rainbow

💡Tip

Rainbow is powerful but complex. Consider your needs:

Use full Rainbow when:

  • Maximizing sample efficiency is critical
  • You have engineering resources for implementation
  • Working on well-studied domains (Atari, similar games)

Use simplified combinations when:

  • Double DQN + PER gives 80% of benefits with 20% of complexity
  • Add Dueling for another easy win
  • Multi-step is easy if you’re already using PER

Skip Rainbow when:

  • Simple environments (try vanilla DQN first)
  • Rapid prototyping (complexity slows iteration)
  • Limited compute (Rainbow needs more memory and computation)
</>Implementation
def create_rainbow_variant(complexity: str = "full"):
    """
    Create Rainbow variants of different complexity.

    Args:
        complexity: "minimal", "medium", or "full"
    """

    if complexity == "minimal":
        # Double DQN + PER only
        print("Minimal Rainbow: Double DQN + PER")
        print("- Easiest to implement")
        print("- Good sample efficiency improvement")
        print("- Works well for most applications")

        return {
            'double_dqn': True,
            'per': True,
            'dueling': False,
            'n_step': 1,
            'distributional': False,
            'noisy_nets': False
        }

    elif complexity == "medium":
        # Add Dueling + n-step
        print("Medium Rainbow: Double DQN + PER + Dueling + n-step")
        print("- Significant improvement over minimal")
        print("- Dueling is simple architectural change")
        print("- n-step requires buffer modification")

        return {
            'double_dqn': True,
            'per': True,
            'dueling': True,
            'n_step': 3,
            'distributional': False,
            'noisy_nets': False
        }

    else:  # full
        print("Full Rainbow: All six components")
        print("- Maximum sample efficiency")
        print("- Most complex to implement")
        print("- Requires careful hyperparameter tuning")

        return {
            'double_dqn': True,
            'per': True,
            'dueling': True,
            'n_step': 3,
            'distributional': True,
            'noisy_nets': True
        }


# Decision helper
def recommend_rainbow_variant(
    sample_efficiency_critical: bool,
    implementation_time_limited: bool,
    env_complexity: str  # "simple", "medium", "atari"
):
    """Recommend which Rainbow variant to use."""

    if env_complexity == "simple":
        return "Start with vanilla DQN. Add Double DQN if overestimation is a problem."

    if implementation_time_limited:
        return "Use minimal Rainbow (Double DQN + PER). Quick to implement, solid gains."

    if sample_efficiency_critical and env_complexity == "atari":
        return "Use full Rainbow. The complexity is worth it for Atari-scale problems."

    return "Use medium Rainbow. Good balance of complexity and performance."

Key Takeaways

ℹ️Note

Rainbow in a nutshell:

  1. Six orthogonal improvements: Each addresses a different DQN weakness

    • Double DQN: Overestimation
    • PER: Sample efficiency
    • Dueling: Architecture inductive bias
    • Multi-step: Credit assignment
    • Distributional: Richer learning signal
    • Noisy nets: Better exploration
  2. Sum greater than parts: Combined performance far exceeds individual components

  3. Ablation insights: Distributional and multi-step matter most, but all help

  4. Practical guidance: Start simple (Double DQN + PER), add complexity as needed

  5. Not the end: Rainbow was state-of-the-art in 2017. Research continues with IQN, R2D2, Agent57, and beyond.

Looking Ahead

Rainbow represents the pinnacle of value-based deep RL improvements on DQN. But there’s another approach entirely: instead of learning values and deriving policies, we can learn policies directly.

Policy gradient methods offer different trade-offs:

  • Natural for continuous action spaces
  • Can learn stochastic policies
  • Different stability and sample efficiency characteristics

That’s the domain we’ll explore next, starting with the fundamentals of policy gradient methods.