Deep Reinforcement Learning • Part 3 of 4
📝Draft

Target Networks

Stabilizing training with frozen targets

Target Networks

Experience replay breaks the correlation between consecutive samples. But there is another source of instability in neural network Q-learning: the target we are trying to match keeps moving.

Target networks solve this by maintaining a separate, slowly-updated copy of the Q-network. The target network provides a stable reference point, allowing the online network to make progress without chasing its own tail.

The Moving Target Problem

📖Moving Target Problem

In Q-learning, the target includes our own Q-value estimate: y=r+γmaxaQ(s,a)y = r + \gamma \max_{a'} Q(s', a'). When we update QQ, the target changes. We are simultaneously defining the goal and trying to reach it, like trying to hit a moving target.

Consider what happens in Q-learning when we update the network:

  1. Compute target: y=r+γmaxaQ(s,a;θ)y = r + \gamma \max_{a'} Q(s', a'; \theta)
  2. Update θ\theta to make Q(s,a;θ)Q(s, a; \theta) closer to yy
  3. But wait, yy depends on θ\theta! So yy has now changed
  4. The target we were aiming for has moved

This creates a feedback loop. The network tries to match a target, but every update changes that target. It is like trying to catch your own shadow: every step you take moves the shadow too.

In the worst case, this leads to:

  • Oscillation: Q-values swing back and forth
  • Divergence: Q-values explode to infinity
  • Instability: Training loss spikes unpredictably
Mathematical Details

In standard Q-learning with function approximation, the loss is:

L(θ)=(r+γmaxaQ(s,a;θ)Q(s,a;θ))2L(\theta) = \left( r + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta) \right)^2

The gradient includes two terms:

θL=2δ(θQ(s,a;θ)γθmaxaQ(s,a;θ))\nabla_\theta L = -2 \delta \left( \nabla_\theta Q(s, a; \theta) - \gamma \nabla_\theta \max_{a'} Q(s', a'; \theta) \right)

where δ\delta is the TD error.

The second term represents how the target changes with θ\theta. This creates a coupling between the current estimate and its target. Changes to θ\theta affect both sides of the equation simultaneously.

This coupling can create positive feedback loops. If Q(s,a)Q(s, a) is too high, and ss leads to ss', then Q(s,a)Q(s', a') may also be too high, creating an even higher target for Q(s,a)Q(s, a).

The Target Network Solution

📖Target Network

A target network is a separate copy of the Q-network with parameters θ\theta^-. It is used to compute TD targets but is updated less frequently than the online network. This provides a stable target during training.

The solution is simple: use two copies of the network.

Online network (θ\theta):

  • Used to select actions
  • Updated every step
  • Computes current Q-values for loss

Target network (θ\theta^-):

  • Used only for computing targets
  • Updated rarely (e.g., every 10,000 steps)
  • Provides stable targets

Now the target does not move every step:

y=r+γmaxaQ(s,a;θ)y = r + \gamma \max_{a'} Q(s', a'; \theta^-)

The target stays fixed until we explicitly update θ\theta^-. This gives the online network thousands of steps to make progress toward a stable goal before the goal moves.

Think of it like this: instead of chasing a moving target, you freeze the target in place, take many shots at it, then move it to a new location.

Mathematical Details

With a target network, the DQN loss becomes:

L(θ)=E(s,a,r,s)D[(r+γmaxaQ(s,a;θ)Q(s,a;θ))2]L(\theta) = \mathbb{E}_{(s,a,r,s') \sim \mathcal{D}} \left[ \left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2 \right]

Note that θ\theta^- is treated as a constant when computing gradients:

θL=2E[δθQ(s,a;θ)]\nabla_\theta L = -2 \mathbb{E} \left[ \delta \cdot \nabla_\theta Q(s, a; \theta) \right]

where δ=r+γmaxaQ(s,a;θ)Q(s,a;θ)\delta = r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta).

The gradient no longer includes the target’s dependence on parameters. This is the semi-gradient approach, but now the target is held fixed not just within one update but across thousands of updates.

</>Implementation
import torch
import torch.nn as nn
import copy

def create_target_network(online_network):
    """
    Create a target network as a copy of the online network.

    The target network has the same architecture and initial weights,
    but gradients are disabled.
    """
    target_network = copy.deepcopy(online_network)

    # Disable gradients for target network (it's never trained directly)
    for param in target_network.parameters():
        param.requires_grad = False

    return target_network


class DQNWithTargetNetwork(nn.Module):
    """
    DQN with separate online and target networks.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128):
        super().__init__()

        # Online network (trained)
        self.online_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

        # Target network (copied from online, updated periodically)
        self.target_network = create_target_network(self.online_network)

        self.update_count = 0

    def forward(self, state):
        """Get Q-values from online network."""
        return self.online_network(state)

    def get_target_q(self, state):
        """Get Q-values from target network (for computing TD targets)."""
        with torch.no_grad():
            return self.target_network(state)

    def update_target(self):
        """Copy online network weights to target network."""
        self.target_network.load_state_dict(self.online_network.state_dict())


# Example usage
dqn = DQNWithTargetNetwork(state_dim=4, n_actions=2)

# Online and target start with same weights
state = torch.randn(1, 4)
print(f"Initial Q-values match: {torch.allclose(dqn(state), dqn.get_target_q(state))}")

# After training, online network changes but target doesn't
optimizer = torch.optim.Adam(dqn.online_network.parameters())
for _ in range(100):
    optimizer.zero_grad()
    loss = dqn(torch.randn(32, 4)).mean()  # Dummy loss
    loss.backward()
    optimizer.step()

print(f"After training, Q-values match: {torch.allclose(dqn(state), dqn.get_target_q(state))}")

# Update target network
dqn.update_target()
print(f"After update, Q-values match: {torch.allclose(dqn(state), dqn.get_target_q(state))}")

Hard vs Soft Updates

📖Hard Update

A hard update copies all weights from the online network to the target network at once, every CC steps.

📖Soft Update

A soft update (Polyak averaging) blends the target network weights with the online network weights every step: θτθ+(1τ)θ\theta^- \leftarrow \tau \theta + (1 - \tau) \theta^- where τ\tau is a small constant.

Hard Updates (used in original DQN):

  • Update every 10,000 steps: θθ\theta^- \leftarrow \theta
  • Target is completely stable between updates
  • Sudden jumps when target is updated
  • Simple to implement and understand

Soft Updates (used in DDPG, SAC, and newer algorithms):

  • Update every step: θ0.005θ+0.995θ\theta^- \leftarrow 0.005 \theta + 0.995 \theta^-
  • Target changes smoothly over time
  • No sudden jumps
  • More hyperparameters (τ\tau)

Both approaches work. Hard updates are simpler and were used in the original DQN paper. Soft updates are more popular in modern algorithms because they avoid sudden changes in the target.

Mathematical Details

Hard update (every CC steps):

θθ\theta^- \leftarrow \theta

The target network is an exact copy of the online network from CC steps ago.

Soft update (every step):

θτθ+(1τ)θ\theta^- \leftarrow \tau \theta + (1 - \tau) \theta^-

With soft updates, the target network is an exponential moving average of the online network. After many updates, the contribution of an old online network state decays as:

Contribution of θt at time T:(1τ)Tt\text{Contribution of } \theta_t \text{ at time } T: (1-\tau)^{T-t}

For τ=0.005\tau = 0.005, after 1000 steps, the contribution is (0.995)10000.007(0.995)^{1000} \approx 0.007, so old weights are mostly forgotten.

</>Implementation
import torch
import torch.nn as nn

class TargetNetworkMixin:
    """
    Mixin class providing target network functionality.

    Add to your DQN class via multiple inheritance.
    """

    def hard_update_target(self):
        """
        Hard update: copy all weights at once.

        Called periodically (e.g., every 10,000 steps).
        """
        self.target_network.load_state_dict(self.online_network.state_dict())

    def soft_update_target(self, tau=0.005):
        """
        Soft update: blend weights with exponential moving average.

        Called every step.

        Args:
            tau: Blending factor (0.001 - 0.01 typical)
        """
        for target_param, online_param in zip(
            self.target_network.parameters(),
            self.online_network.parameters()
        ):
            target_param.data.copy_(
                tau * online_param.data + (1.0 - tau) * target_param.data
            )


class DQNAgent(TargetNetworkMixin):
    """
    Complete DQN agent with target network.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128,
                 update_mode='hard', tau=0.005, target_update_freq=10000):
        self.update_mode = update_mode
        self.tau = tau
        self.target_update_freq = target_update_freq
        self.step_count = 0

        # Create networks
        self.online_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
        self.target_network = create_target_network(self.online_network)

    def update_target_if_needed(self):
        """Update target network based on update mode."""
        self.step_count += 1

        if self.update_mode == 'hard':
            if self.step_count % self.target_update_freq == 0:
                self.hard_update_target()
                return True
        elif self.update_mode == 'soft':
            self.soft_update_target(self.tau)
            return True

        return False


# Compare update modes
def compare_update_modes():
    """Visualize difference between hard and soft updates."""
    n_steps = 1000

    # Initialize agents
    hard_agent = DQNAgent(4, 2, update_mode='hard', target_update_freq=200)
    soft_agent = DQNAgent(4, 2, update_mode='soft', tau=0.01)

    # Track target network changes
    state = torch.randn(1, 4)

    hard_history = []
    soft_history = []

    for step in range(n_steps):
        # Simulate online network change
        with torch.no_grad():
            hard_agent.online_network[0].weight.data += 0.01 * torch.randn_like(
                hard_agent.online_network[0].weight.data
            )
            soft_agent.online_network[0].weight.data += 0.01 * torch.randn_like(
                soft_agent.online_network[0].weight.data
            )

        # Update targets
        hard_agent.update_target_if_needed()
        soft_agent.update_target_if_needed()

        # Record target Q-values
        hard_history.append(hard_agent.target_network(state).mean().item())
        soft_history.append(soft_agent.target_network(state).mean().item())

    print("Target network Q-value evolution:")
    print(f"  Hard update (every 200 steps):")
    print(f"    Start: {hard_history[0]:.4f}")
    print(f"    End: {hard_history[-1]:.4f}")
    print(f"    Max change between steps: {max(abs(hard_history[i+1] - hard_history[i]) for i in range(len(hard_history)-1)):.4f}")
    print(f"  Soft update (tau=0.01):")
    print(f"    Start: {soft_history[0]:.4f}")
    print(f"    End: {soft_history[-1]:.4f}")
    print(f"    Max change between steps: {max(abs(soft_history[i+1] - soft_history[i]) for i in range(len(soft_history)-1)):.4f}")

compare_update_modes()

Why Target Networks Work

Target networks work by breaking the feedback loop between Q-values and their targets:

Without target network:

  1. Q-value increases for state A
  2. Target for state B (which leads to A) increases
  3. Q-value for B increases
  4. Target for state C (which leads to B) increases
  5. This chain of increases can propagate and amplify

With target network:

  1. Q-value increases for state A
  2. Target for state B stays the same (uses frozen target network)
  3. Q-value for B increases toward stable target
  4. After 10,000 steps, target network is updated
  5. Now targets shift, but online network has had time to stabilize

The key insight is temporal separation: the target network reflects the network from many steps ago, breaking the immediate feedback loop.

Mathematical Details

Consider the sequence of updates in Q-learning. Without a target network, updates are:

θt+1=θt+αθL(θt)\theta_{t+1} = \theta_t + \alpha \nabla_\theta L(\theta_t)

where LL depends on Q(;θt)Q(\cdot; \theta_t) through the target.

The Jacobian of the target with respect to parameters:

yθ=γθmaxaQ(s,a;θ)\frac{\partial y}{\partial \theta} = \gamma \nabla_\theta \max_{a'} Q(s', a'; \theta)

can have large eigenvalues, especially when states are highly connected. This leads to unstable updates.

With a target network, the Jacobian is zero:

yθ=γθmaxaQ(s,a;θ)=0\frac{\partial y}{\partial \theta} = \gamma \nabla_\theta \max_{a'} Q(s', a'; \theta^-) = 0

since θ\theta^- is treated as constant.

The update becomes a supervised learning problem where we regress toward fixed targets. This is much more stable, as the dynamics are governed by the loss landscape rather than the complex interdependencies of Q-values.

Update Frequency Considerations

How often should we update the target network?

Too frequent (e.g., every 10 steps):

  • Target is almost the same as online network
  • Little benefit from having a separate target
  • Instability returns

Too infrequent (e.g., every 1 million steps):

  • Target becomes very stale
  • Online network may overshoot
  • Learning becomes very slow

Just right (10,000 - 50,000 steps for Atari):

  • Target is stable enough for meaningful progress
  • Regular updates prevent target from becoming too stale
  • Balances stability and learning speed

The original DQN paper used 10,000 steps, corresponding to 40,000 frames (with frame skipping). This means the target reflects the network from about 10 minutes of gameplay ago.

</>Implementation
def analyze_update_frequency():
    """
    Analyze the effect of different target update frequencies.
    """
    frequencies = [10, 100, 1000, 10000, 100000]

    print("Target Update Frequency Analysis")
    print("=" * 50)

    for freq in frequencies:
        # How many updates between target syncs?
        updates_between_syncs = freq

        # How stale is the target?
        # Assuming 1 update per environment step
        staleness_seconds = freq / 60  # Assuming 60 FPS

        print(f"\nUpdate every {freq:,} steps:")
        print(f"  Updates between syncs: {updates_between_syncs:,}")
        print(f"  Target staleness: ~{staleness_seconds:.1f} seconds of gameplay")

        # Stability vs learning speed tradeoff
        if freq < 100:
            print(f"  Assessment: Too frequent - limited stability benefit")
        elif freq < 5000:
            print(f"  Assessment: Moderate - some stability, fast adaptation")
        elif freq < 50000:
            print(f"  Assessment: Good balance (DQN default range)")
        else:
            print(f"  Assessment: Very stable but slow to adapt")

analyze_update_frequency()

Integration with DQN Training

</>Implementation
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from collections import deque
import random
import copy

class CompleteDQN:
    """
    Complete DQN with experience replay and target network.
    """

    def __init__(self, state_dim, n_actions, hidden_dim=128,
                 buffer_size=100000, batch_size=32, gamma=0.99,
                 lr=1e-4, target_update_freq=1000):

        self.n_actions = n_actions
        self.batch_size = batch_size
        self.gamma = gamma
        self.target_update_freq = target_update_freq

        # Networks
        self.online_network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
        self.target_network = copy.deepcopy(self.online_network)
        for p in self.target_network.parameters():
            p.requires_grad = False

        self.optimizer = optim.Adam(self.online_network.parameters(), lr=lr)

        # Replay buffer
        self.buffer = deque(maxlen=buffer_size)

        self.step_count = 0

    def select_action(self, state, epsilon):
        if np.random.random() < epsilon:
            return np.random.randint(self.n_actions)
        with torch.no_grad():
            q_values = self.online_network(torch.FloatTensor(state).unsqueeze(0))
            return q_values.argmax(dim=1).item()

    def store_transition(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def train_step(self):
        if len(self.buffer) < self.batch_size:
            return None

        # Sample batch
        batch = random.sample(self.buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states_t = torch.FloatTensor(np.array(states))
        actions_t = torch.LongTensor(actions)
        rewards_t = torch.FloatTensor(rewards)
        next_states_t = torch.FloatTensor(np.array(next_states))
        dones_t = torch.FloatTensor(dones)

        # Current Q-values (from online network)
        current_q = self.online_network(states_t).gather(1, actions_t.unsqueeze(1)).squeeze(1)

        # Target Q-values (from target network)
        with torch.no_grad():
            next_q = self.target_network(next_states_t).max(dim=1)[0]
            target_q = rewards_t + self.gamma * next_q * (1 - dones_t)

        # Loss and update
        loss = nn.MSELoss()(current_q, target_q)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update target network periodically
        self.step_count += 1
        if self.step_count % self.target_update_freq == 0:
            self.target_network.load_state_dict(self.online_network.state_dict())

        return loss.item()


# Example training
print("Complete DQN Training Example")
print("=" * 40)

dqn = CompleteDQN(state_dim=4, n_actions=2, target_update_freq=100)

# Simulate training
for i in range(1000):
    state = np.random.randn(4)
    action = dqn.select_action(state, epsilon=0.1)
    next_state = np.random.randn(4)
    reward = np.random.randn()
    done = np.random.random() < 0.05

    dqn.store_transition(state, action, reward, next_state, done)
    loss = dqn.train_step()

    if i % 200 == 0 and loss is not None:
        print(f"Step {i}: Loss = {loss:.4f}, Buffer = {len(dqn.buffer)}")

Summary

Target networks stabilize DQN training by providing a fixed target:

  1. Moving target problem: Without target networks, updating Q-values changes the target simultaneously, causing instability

  2. Solution: Maintain a separate target network θ\theta^- that is updated less frequently

  3. Hard updates: Copy weights every CC steps (original DQN)

  4. Soft updates: Blend weights every step with τ\tau (modern algorithms)

  5. Why it works: Breaks the feedback loop between Q-values and targets, turning RL into supervised learning toward fixed targets

Key hyperparameters:

  • Update frequency: 10,000 steps for Atari (with hard updates)
  • Tau: 0.005 is typical for soft updates

Together with experience replay (from the previous section), target networks make deep Q-learning stable enough to learn complex tasks from high-dimensional observations.

ℹ️Note

The combination of experience replay and target networks transforms unstable neural network Q-learning into reliable DQN. These two innovations, simple in hindsight, were the key breakthroughs that launched deep reinforcement learning.