Advanced Topics • Part 3 of 3
📝Draft

Centralized Training, Decentralized Execution

Sharing information during training only

Centralized Training, Decentralized Execution

The fundamental challenge of multi-agent RL is that agents must act independently (they can’t share their observations in real-time) but would benefit from coordinating. Centralized Training with Decentralized Execution (CTDE) solves this elegantly: share information during training to learn better policies, then deploy agents that act based only on their local observations.

The CTDE Paradigm

📖CTDE

Centralized Training, Decentralized Execution: During training, agents have access to additional information (other agents’ observations, actions, or a global state). During deployment, each agent acts using only its local observation. The extra training information helps learn better decentralized policies.

Think of a soccer team. During practice, the coach can give feedback to everyone, players discuss strategy openly, and they can watch video replay from all camera angles. During the actual game, each player makes split-second decisions based only on what they personally see and hear.

CTDE: practice together, play independently.

The key insight is that what you know during training doesn’t need to match what you know during deployment. Training is a privileged time for learning; deployment is when you apply what you’ve learned with limited information.

Why Centralized Training Helps

Independent Training Problems
  • Non-stationarity: other agents keep changing
  • Credit assignment: who’s responsible for team success?
  • Coordination: how to learn complementary roles?
  • Partial observability: can’t see what others see
CTDE Solutions
  • Stationarity: train on joint experiences
  • Credit: centralized critic assigns credit
  • Coordination: shared information during learning
  • Observability: global state available for training

Multi-Agent Actor-Critic (MAAC)

A foundational CTDE method: each agent has a decentralized actor (policy) and a centralized critic (value function).

Mathematical Details

Decentralized Actor: Agent ii‘s policy πθi(aioi)\pi_{\theta_i}(a_i | o_i) depends only on local observation oio_i.

Centralized Critic: The critic Qϕ(s,a1,...,aN)Q_\phi(s, a_1, ..., a_N) sees the global state and all agents’ actions.

Policy gradient for agent ii:

θiJi=E[θilogπθi(aioi)Qϕ(s,a1,...,aN)]\nabla_{\theta_i} J_i = \mathbb{E}\left[ \nabla_{\theta_i} \log \pi_{\theta_i}(a_i | o_i) \cdot Q_\phi(s, a_1, ..., a_N) \right]

The critic is centralized (sees everything), but the actor is decentralized (only uses local observation). During deployment, we discard the critic and keep only the actors.

The actor is the player on the field. The critic is the coach watching from the sidelines with full visibility. During training, the coach provides feedback (“that was a good move because it set up your teammate”). During the game, the player acts on their own—but they’ve internalized the coach’s lessons.

</>Implementation
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class Actor(nn.Module):
    """Decentralized actor - uses only local observation."""

    def __init__(self, obs_dim, n_actions, hidden_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

    def forward(self, obs):
        return torch.softmax(self.net(obs), dim=-1)

    def get_action(self, obs, deterministic=False):
        probs = self.forward(obs)
        if deterministic:
            return torch.argmax(probs, dim=-1)
        dist = torch.distributions.Categorical(probs)
        return dist.sample()


class CentralizedCritic(nn.Module):
    """Centralized critic - sees global state and all actions."""

    def __init__(self, state_dim, n_agents, n_actions, hidden_dim=256):
        super().__init__()
        # Input: global state + one-hot actions for all agents
        input_dim = state_dim + n_agents * n_actions
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.n_agents = n_agents
        self.n_actions = n_actions

    def forward(self, state, actions):
        """
        Args:
            state: Global state [batch, state_dim]
            actions: List of action indices for each agent
        """
        # Convert actions to one-hot
        batch_size = state.shape[0]
        actions_onehot = torch.zeros(batch_size, self.n_agents * self.n_actions)
        for i, action in enumerate(actions):
            offset = i * self.n_actions
            actions_onehot[:, offset + action] = 1

        x = torch.cat([state, actions_onehot], dim=-1)
        return self.net(x)


class MADDPG:
    """
    Multi-Agent DDPG with centralized critic.

    Simplified version of the MADDPG algorithm.
    """

    def __init__(self, n_agents, obs_dims, state_dim, n_actions,
                 actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.01):
        self.n_agents = n_agents
        self.gamma = gamma
        self.tau = tau

        # Each agent has its own actor (decentralized)
        self.actors = [Actor(obs_dims[i], n_actions) for i in range(n_agents)]
        self.actor_targets = [Actor(obs_dims[i], n_actions) for i in range(n_agents)]

        # Shared centralized critic (could also be per-agent)
        self.critic = CentralizedCritic(state_dim, n_agents, n_actions)
        self.critic_target = CentralizedCritic(state_dim, n_agents, n_actions)

        # Copy weights to targets
        for i in range(n_agents):
            self.actor_targets[i].load_state_dict(self.actors[i].state_dict())
        self.critic_target.load_state_dict(self.critic.state_dict())

        # Optimizers
        self.actor_optimizers = [
            optim.Adam(actor.parameters(), lr=actor_lr)
            for actor in self.actors
        ]
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)

    def select_actions(self, observations, explore=True):
        """Select actions for all agents (decentralized)."""
        actions = []
        for i, obs in enumerate(observations):
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
            with torch.no_grad():
                probs = self.actors[i](obs_tensor)
                if explore:
                    dist = torch.distributions.Categorical(probs)
                    action = dist.sample().item()
                else:
                    action = torch.argmax(probs).item()
            actions.append(action)
        return actions

    def update(self, batch):
        """Update critics and actors using CTDE."""
        states, observations, actions, rewards, next_states, next_observations, dones = batch

        states = torch.FloatTensor(states)
        next_states = torch.FloatTensor(next_states)
        rewards = torch.FloatTensor(rewards)
        dones = torch.FloatTensor(dones)

        # --- Update Critic ---
        # Get target actions from target actors
        target_actions = []
        for i in range(self.n_agents):
            next_obs = torch.FloatTensor(next_observations[i])
            with torch.no_grad():
                target_action = self.actor_targets[i].get_action(next_obs)
            target_actions.append(target_action)

        # Compute target Q
        with torch.no_grad():
            target_q = self.critic_target(next_states, target_actions)
            # Use mean reward across agents for cooperative setting
            mean_reward = rewards.mean(dim=1, keepdim=True)
            y = mean_reward + self.gamma * (1 - dones.unsqueeze(1)) * target_q

        # Current Q
        current_q = self.critic(states, actions)
        critic_loss = nn.MSELoss()(current_q, y)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # --- Update Actors ---
        for i in range(self.n_agents):
            obs = torch.FloatTensor(observations[i])
            probs = self.actors[i](obs)
            dist = torch.distributions.Categorical(probs)
            sampled_actions = dist.sample()

            # For policy gradient, need Q-value of current joint action
            current_actions = list(actions)
            current_actions[i] = sampled_actions

            q_value = self.critic(states, current_actions)
            actor_loss = -(dist.log_prob(sampled_actions) * q_value.detach()).mean()

            self.actor_optimizers[i].zero_grad()
            actor_loss.backward()
            self.actor_optimizers[i].step()

        # --- Soft update targets ---
        self._soft_update()

    def _soft_update(self):
        """Soft update target networks."""
        for i in range(self.n_agents):
            for param, target_param in zip(
                self.actors[i].parameters(),
                self.actor_targets[i].parameters()
            ):
                target_param.data.copy_(
                    self.tau * param.data + (1 - self.tau) * target_param.data
                )

        for param, target_param in zip(
            self.critic.parameters(),
            self.critic_target.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )

Value Decomposition: QMIX

For cooperative games, we often want a single team Q-value that decomposes into individual agent utilities. QMIX learns to combine individual Q-values into a joint Q-value while ensuring that greedy action selection can be done independently.

Mathematical Details

QMIX maintains:

  • Individual Q-values Qi(oi,ai)Q_i(o_i, a_i) for each agent
  • A mixing network that combines them: Qtot(s,a1,...,aN)=f(Q1,...,QN,s)Q_{tot}(s, a_1, ..., a_N) = f(Q_1, ..., Q_N, s)

The key constraint is monotonicity:

QtotQi0i\frac{\partial Q_{tot}}{\partial Q_i} \geq 0 \quad \forall i

This ensures that if each agent greedily maximizes its QiQ_i, the joint action maximizes QtotQ_{tot}:

argmaxaQtot(s,a)=(argmaxa1Q1(o1,a1),...,argmaxaNQN(oN,aN))\arg\max_a Q_{tot}(s, a) = (\arg\max_{a_1} Q_1(o_1, a_1), ..., \arg\max_{a_N} Q_N(o_N, a_N))

This allows decentralized execution: each agent just picks its best individual action.

Think of the mixing network as combining “how good does this action look to me?” from each agent into “how good is this action for the team?” The monotonicity constraint means that if an action looks better to any agent, it should look at least as good (never worse) for the team.

This is reasonable for fully cooperative games: agents should never be penalized for individually beneficial actions.

</>Implementation
class QMIXAgent(nn.Module):
    """Individual agent network for QMIX."""

    def __init__(self, obs_dim, n_actions, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )

    def forward(self, obs):
        return self.net(obs)


class QMIXMixer(nn.Module):
    """
    Mixing network for QMIX.

    Combines individual Q-values into Q_tot with monotonicity constraint.
    """

    def __init__(self, n_agents, state_dim, embed_dim=32):
        super().__init__()
        self.n_agents = n_agents

        # Hypernetworks: state -> weights for mixing
        self.hyper_w1 = nn.Linear(state_dim, n_agents * embed_dim)
        self.hyper_w2 = nn.Linear(state_dim, embed_dim)

        self.hyper_b1 = nn.Linear(state_dim, embed_dim)
        self.hyper_b2 = nn.Sequential(
            nn.Linear(state_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, 1)
        )

        self.embed_dim = embed_dim

    def forward(self, q_values, state):
        """
        Args:
            q_values: [batch, n_agents] - individual Q-values
            state: [batch, state_dim] - global state

        Returns:
            Q_tot: [batch, 1]
        """
        batch_size = q_values.shape[0]

        # Generate mixing weights (abs ensures monotonicity)
        w1 = torch.abs(self.hyper_w1(state)).view(batch_size, self.n_agents, self.embed_dim)
        w2 = torch.abs(self.hyper_w2(state)).view(batch_size, self.embed_dim, 1)

        b1 = self.hyper_b1(state).view(batch_size, 1, self.embed_dim)
        b2 = self.hyper_b2(state).view(batch_size, 1, 1)

        # Forward through mixing network
        q_values = q_values.view(batch_size, 1, self.n_agents)
        hidden = torch.relu(torch.bmm(q_values, w1) + b1)
        q_tot = torch.bmm(hidden, w2) + b2

        return q_tot.squeeze(-1)


class QMIX:
    """QMIX algorithm for cooperative multi-agent RL."""

    def __init__(self, n_agents, obs_dims, state_dim, n_actions,
                 lr=1e-3, gamma=0.99):
        self.n_agents = n_agents
        self.n_actions = n_actions
        self.gamma = gamma

        # Individual Q-networks
        self.agents = nn.ModuleList([
            QMIXAgent(obs_dims[i], n_actions) for i in range(n_agents)
        ])
        self.target_agents = nn.ModuleList([
            QMIXAgent(obs_dims[i], n_actions) for i in range(n_agents)
        ])

        # Mixing networks
        self.mixer = QMIXMixer(n_agents, state_dim)
        self.target_mixer = QMIXMixer(n_agents, state_dim)

        # Copy to targets
        self.target_agents.load_state_dict(self.agents.state_dict())
        self.target_mixer.load_state_dict(self.mixer.state_dict())

        # Single optimizer for all parameters
        params = list(self.agents.parameters()) + list(self.mixer.parameters())
        self.optimizer = optim.Adam(params, lr=lr)

    def select_actions(self, observations, epsilon=0.1):
        """Decentralized action selection."""
        actions = []
        for i, obs in enumerate(observations):
            obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
            with torch.no_grad():
                q_values = self.agents[i](obs_tensor)

            if np.random.random() < epsilon:
                action = np.random.randint(self.n_actions)
            else:
                action = torch.argmax(q_values).item()
            actions.append(action)
        return actions

    def update(self, batch):
        """Update Q-networks using QMIX loss."""
        states, observations, actions, rewards, next_states, next_observations, dones = batch

        states = torch.FloatTensor(states)
        next_states = torch.FloatTensor(next_states)
        rewards = torch.FloatTensor(rewards).mean(dim=1)  # Team reward
        dones = torch.FloatTensor(dones)
        actions = torch.LongTensor(actions)

        # Get current Q-values
        q_values = []
        for i in range(self.n_agents):
            obs = torch.FloatTensor(observations[i])
            q = self.agents[i](obs)
            q_taken = q.gather(1, actions[:, i:i+1]).squeeze(1)
            q_values.append(q_taken)
        q_values = torch.stack(q_values, dim=1)

        # Get Q_tot
        q_tot = self.mixer(q_values, states)

        # Get target Q-values
        with torch.no_grad():
            target_q_values = []
            for i in range(self.n_agents):
                next_obs = torch.FloatTensor(next_observations[i])
                target_q = self.target_agents[i](next_obs)
                target_q_max = target_q.max(dim=1)[0]
                target_q_values.append(target_q_max)
            target_q_values = torch.stack(target_q_values, dim=1)

            target_q_tot = self.target_mixer(target_q_values, next_states)
            y = rewards + self.gamma * (1 - dones) * target_q_tot.squeeze(1)

        # QMIX loss
        loss = nn.MSELoss()(q_tot.squeeze(1), y)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

Self-Play for Competitive Games

For zero-sum competitive games, a powerful approach is self-play: train an agent against copies of itself.

Self-play creates an automatic curriculum:

  1. Initially, the agent plays against a weak version of itself
  2. As it improves, so does its opponent (it’s the same network!)
  3. The agent continuously faces challenges at its current skill level
  4. Over time, both the agent and its “opponent” become stronger

This is how AlphaGo, AlphaZero, and many game-playing AIs achieved superhuman performance.

Mathematical Details

In self-play, the agent trains against its own policy. The experience comes from games where both players use (possibly different snapshots of) the same policy:

D={(st,at,rt,st+1):atπθ,atoppπθ}D = \{(s_t, a_t, r_t, s_{t+1}) : a_t \sim \pi_\theta, a_t^{opp} \sim \pi_{\theta'}\}

where θ\theta' is often a past version of θ\theta (to add diversity and stability).

The goal is to find a Nash equilibrium: a policy that can’t be exploited by any opponent, including itself.

</>Implementation
class SelfPlayTrainer:
    """
    Self-play training for two-player zero-sum games.
    """

    def __init__(self, agent, opponent_pool_size=10):
        self.agent = agent  # The main agent being trained
        self.opponent_pool = []  # Past versions for diversity
        self.opponent_pool_size = opponent_pool_size

    def get_opponent(self):
        """Select opponent - either current agent or past version."""
        if len(self.opponent_pool) == 0 or np.random.random() < 0.5:
            # Play against current self
            return self.agent
        else:
            # Play against random past version
            return np.random.choice(self.opponent_pool)

    def save_checkpoint(self):
        """Save current agent to opponent pool."""
        # Deep copy the agent
        import copy
        checkpoint = copy.deepcopy(self.agent)

        if len(self.opponent_pool) >= self.opponent_pool_size:
            # Remove oldest
            self.opponent_pool.pop(0)
        self.opponent_pool.append(checkpoint)

    def train_episode(self, env):
        """Play one game of self-play and update."""
        opponent = self.get_opponent()

        state = env.reset()
        episode_data = []

        done = False
        while not done:
            # Agent plays as player 0
            obs_agent = env.get_observation(0)
            action_agent = self.agent.select_action(obs_agent)

            # Opponent plays as player 1
            obs_opp = env.get_observation(1)
            with torch.no_grad():
                action_opp = opponent.select_action(obs_opp)

            # Environment step
            next_state, rewards, done, _ = env.step([action_agent, action_opp])

            episode_data.append({
                'obs': obs_agent,
                'action': action_agent,
                'reward': rewards[0],  # Agent's reward
                'next_obs': env.get_observation(0),
                'done': done
            })

            state = next_state

        # Update agent with episode data
        self.agent.update(episode_data)

        # Periodically save checkpoint
        return episode_data

Population-Based Training

For even more diversity, use a population of agents that train and compete against each other.

📌OpenAI Five and AlphaStar

OpenAI Five (Dota 2): Trained a population of agents. Used matchmaking to pair agents of similar skill. Periodically updated the population based on performance.

AlphaStar (StarCraft II): Maintained a “league” of diverse agents. Main agents trained against the league. Exploiter agents specifically targeted weaknesses. This created a diverse pool of strategies and prevented any single exploitable weakness.

These systems achieved superhuman performance by avoiding the “arms race” problem where agents co-evolve narrow counter-strategies.

A single self-play agent might develop blind spots—strategies it never learned to counter because it never used them. A population maintains diversity:

  • Different agents develop different strategies
  • Training against the population means learning to handle diverse opponents
  • The resulting policy is more robust and general

Summary

Centralized Training with Decentralized Execution (CTDE) is the dominant paradigm in modern multi-agent RL:

  • Centralized critics provide stable learning signals using global information
  • Decentralized actors ensure deployable policies that need only local observations
  • Value decomposition (QMIX) enables efficient credit assignment in cooperative games
  • Self-play creates automatic curricula for competitive games
  • Population-based training provides diversity and robustness

The key insight: exploit all available information during training, but design systems that work with limited information during deployment. This separation of concerns makes multi-agent RL tractable while maintaining practical deployability.