Centralized Training, Decentralized Execution
The fundamental challenge of multi-agent RL is that agents must act independently (they can’t share their observations in real-time) but would benefit from coordinating. Centralized Training with Decentralized Execution (CTDE) solves this elegantly: share information during training to learn better policies, then deploy agents that act based only on their local observations.
The CTDE Paradigm
Centralized Training, Decentralized Execution: During training, agents have access to additional information (other agents’ observations, actions, or a global state). During deployment, each agent acts using only its local observation. The extra training information helps learn better decentralized policies.
Think of a soccer team. During practice, the coach can give feedback to everyone, players discuss strategy openly, and they can watch video replay from all camera angles. During the actual game, each player makes split-second decisions based only on what they personally see and hear.
CTDE: practice together, play independently.
The key insight is that what you know during training doesn’t need to match what you know during deployment. Training is a privileged time for learning; deployment is when you apply what you’ve learned with limited information.
Why Centralized Training Helps
- Non-stationarity: other agents keep changing
- Credit assignment: who’s responsible for team success?
- Coordination: how to learn complementary roles?
- Partial observability: can’t see what others see
- Stationarity: train on joint experiences
- Credit: centralized critic assigns credit
- Coordination: shared information during learning
- Observability: global state available for training
Multi-Agent Actor-Critic (MAAC)
A foundational CTDE method: each agent has a decentralized actor (policy) and a centralized critic (value function).
Decentralized Actor: Agent ‘s policy depends only on local observation .
Centralized Critic: The critic sees the global state and all agents’ actions.
Policy gradient for agent :
The critic is centralized (sees everything), but the actor is decentralized (only uses local observation). During deployment, we discard the critic and keep only the actors.
The actor is the player on the field. The critic is the coach watching from the sidelines with full visibility. During training, the coach provides feedback (“that was a good move because it set up your teammate”). During the game, the player acts on their own—but they’ve internalized the coach’s lessons.
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class Actor(nn.Module):
"""Decentralized actor - uses only local observation."""
def __init__(self, obs_dim, n_actions, hidden_dim=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, obs):
return torch.softmax(self.net(obs), dim=-1)
def get_action(self, obs, deterministic=False):
probs = self.forward(obs)
if deterministic:
return torch.argmax(probs, dim=-1)
dist = torch.distributions.Categorical(probs)
return dist.sample()
class CentralizedCritic(nn.Module):
"""Centralized critic - sees global state and all actions."""
def __init__(self, state_dim, n_agents, n_actions, hidden_dim=256):
super().__init__()
# Input: global state + one-hot actions for all agents
input_dim = state_dim + n_agents * n_actions
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
self.n_agents = n_agents
self.n_actions = n_actions
def forward(self, state, actions):
"""
Args:
state: Global state [batch, state_dim]
actions: List of action indices for each agent
"""
# Convert actions to one-hot
batch_size = state.shape[0]
actions_onehot = torch.zeros(batch_size, self.n_agents * self.n_actions)
for i, action in enumerate(actions):
offset = i * self.n_actions
actions_onehot[:, offset + action] = 1
x = torch.cat([state, actions_onehot], dim=-1)
return self.net(x)
class MADDPG:
"""
Multi-Agent DDPG with centralized critic.
Simplified version of the MADDPG algorithm.
"""
def __init__(self, n_agents, obs_dims, state_dim, n_actions,
actor_lr=1e-3, critic_lr=1e-3, gamma=0.99, tau=0.01):
self.n_agents = n_agents
self.gamma = gamma
self.tau = tau
# Each agent has its own actor (decentralized)
self.actors = [Actor(obs_dims[i], n_actions) for i in range(n_agents)]
self.actor_targets = [Actor(obs_dims[i], n_actions) for i in range(n_agents)]
# Shared centralized critic (could also be per-agent)
self.critic = CentralizedCritic(state_dim, n_agents, n_actions)
self.critic_target = CentralizedCritic(state_dim, n_agents, n_actions)
# Copy weights to targets
for i in range(n_agents):
self.actor_targets[i].load_state_dict(self.actors[i].state_dict())
self.critic_target.load_state_dict(self.critic.state_dict())
# Optimizers
self.actor_optimizers = [
optim.Adam(actor.parameters(), lr=actor_lr)
for actor in self.actors
]
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=critic_lr)
def select_actions(self, observations, explore=True):
"""Select actions for all agents (decentralized)."""
actions = []
for i, obs in enumerate(observations):
obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
with torch.no_grad():
probs = self.actors[i](obs_tensor)
if explore:
dist = torch.distributions.Categorical(probs)
action = dist.sample().item()
else:
action = torch.argmax(probs).item()
actions.append(action)
return actions
def update(self, batch):
"""Update critics and actors using CTDE."""
states, observations, actions, rewards, next_states, next_observations, dones = batch
states = torch.FloatTensor(states)
next_states = torch.FloatTensor(next_states)
rewards = torch.FloatTensor(rewards)
dones = torch.FloatTensor(dones)
# --- Update Critic ---
# Get target actions from target actors
target_actions = []
for i in range(self.n_agents):
next_obs = torch.FloatTensor(next_observations[i])
with torch.no_grad():
target_action = self.actor_targets[i].get_action(next_obs)
target_actions.append(target_action)
# Compute target Q
with torch.no_grad():
target_q = self.critic_target(next_states, target_actions)
# Use mean reward across agents for cooperative setting
mean_reward = rewards.mean(dim=1, keepdim=True)
y = mean_reward + self.gamma * (1 - dones.unsqueeze(1)) * target_q
# Current Q
current_q = self.critic(states, actions)
critic_loss = nn.MSELoss()(current_q, y)
self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()
# --- Update Actors ---
for i in range(self.n_agents):
obs = torch.FloatTensor(observations[i])
probs = self.actors[i](obs)
dist = torch.distributions.Categorical(probs)
sampled_actions = dist.sample()
# For policy gradient, need Q-value of current joint action
current_actions = list(actions)
current_actions[i] = sampled_actions
q_value = self.critic(states, current_actions)
actor_loss = -(dist.log_prob(sampled_actions) * q_value.detach()).mean()
self.actor_optimizers[i].zero_grad()
actor_loss.backward()
self.actor_optimizers[i].step()
# --- Soft update targets ---
self._soft_update()
def _soft_update(self):
"""Soft update target networks."""
for i in range(self.n_agents):
for param, target_param in zip(
self.actors[i].parameters(),
self.actor_targets[i].parameters()
):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)
for param, target_param in zip(
self.critic.parameters(),
self.critic_target.parameters()
):
target_param.data.copy_(
self.tau * param.data + (1 - self.tau) * target_param.data
)Value Decomposition: QMIX
For cooperative games, we often want a single team Q-value that decomposes into individual agent utilities. QMIX learns to combine individual Q-values into a joint Q-value while ensuring that greedy action selection can be done independently.
QMIX maintains:
- Individual Q-values for each agent
- A mixing network that combines them:
The key constraint is monotonicity:
This ensures that if each agent greedily maximizes its , the joint action maximizes :
This allows decentralized execution: each agent just picks its best individual action.
Think of the mixing network as combining “how good does this action look to me?” from each agent into “how good is this action for the team?” The monotonicity constraint means that if an action looks better to any agent, it should look at least as good (never worse) for the team.
This is reasonable for fully cooperative games: agents should never be penalized for individually beneficial actions.
class QMIXAgent(nn.Module):
"""Individual agent network for QMIX."""
def __init__(self, obs_dim, n_actions, hidden_dim=64):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, obs):
return self.net(obs)
class QMIXMixer(nn.Module):
"""
Mixing network for QMIX.
Combines individual Q-values into Q_tot with monotonicity constraint.
"""
def __init__(self, n_agents, state_dim, embed_dim=32):
super().__init__()
self.n_agents = n_agents
# Hypernetworks: state -> weights for mixing
self.hyper_w1 = nn.Linear(state_dim, n_agents * embed_dim)
self.hyper_w2 = nn.Linear(state_dim, embed_dim)
self.hyper_b1 = nn.Linear(state_dim, embed_dim)
self.hyper_b2 = nn.Sequential(
nn.Linear(state_dim, embed_dim),
nn.ReLU(),
nn.Linear(embed_dim, 1)
)
self.embed_dim = embed_dim
def forward(self, q_values, state):
"""
Args:
q_values: [batch, n_agents] - individual Q-values
state: [batch, state_dim] - global state
Returns:
Q_tot: [batch, 1]
"""
batch_size = q_values.shape[0]
# Generate mixing weights (abs ensures monotonicity)
w1 = torch.abs(self.hyper_w1(state)).view(batch_size, self.n_agents, self.embed_dim)
w2 = torch.abs(self.hyper_w2(state)).view(batch_size, self.embed_dim, 1)
b1 = self.hyper_b1(state).view(batch_size, 1, self.embed_dim)
b2 = self.hyper_b2(state).view(batch_size, 1, 1)
# Forward through mixing network
q_values = q_values.view(batch_size, 1, self.n_agents)
hidden = torch.relu(torch.bmm(q_values, w1) + b1)
q_tot = torch.bmm(hidden, w2) + b2
return q_tot.squeeze(-1)
class QMIX:
"""QMIX algorithm for cooperative multi-agent RL."""
def __init__(self, n_agents, obs_dims, state_dim, n_actions,
lr=1e-3, gamma=0.99):
self.n_agents = n_agents
self.n_actions = n_actions
self.gamma = gamma
# Individual Q-networks
self.agents = nn.ModuleList([
QMIXAgent(obs_dims[i], n_actions) for i in range(n_agents)
])
self.target_agents = nn.ModuleList([
QMIXAgent(obs_dims[i], n_actions) for i in range(n_agents)
])
# Mixing networks
self.mixer = QMIXMixer(n_agents, state_dim)
self.target_mixer = QMIXMixer(n_agents, state_dim)
# Copy to targets
self.target_agents.load_state_dict(self.agents.state_dict())
self.target_mixer.load_state_dict(self.mixer.state_dict())
# Single optimizer for all parameters
params = list(self.agents.parameters()) + list(self.mixer.parameters())
self.optimizer = optim.Adam(params, lr=lr)
def select_actions(self, observations, epsilon=0.1):
"""Decentralized action selection."""
actions = []
for i, obs in enumerate(observations):
obs_tensor = torch.FloatTensor(obs).unsqueeze(0)
with torch.no_grad():
q_values = self.agents[i](obs_tensor)
if np.random.random() < epsilon:
action = np.random.randint(self.n_actions)
else:
action = torch.argmax(q_values).item()
actions.append(action)
return actions
def update(self, batch):
"""Update Q-networks using QMIX loss."""
states, observations, actions, rewards, next_states, next_observations, dones = batch
states = torch.FloatTensor(states)
next_states = torch.FloatTensor(next_states)
rewards = torch.FloatTensor(rewards).mean(dim=1) # Team reward
dones = torch.FloatTensor(dones)
actions = torch.LongTensor(actions)
# Get current Q-values
q_values = []
for i in range(self.n_agents):
obs = torch.FloatTensor(observations[i])
q = self.agents[i](obs)
q_taken = q.gather(1, actions[:, i:i+1]).squeeze(1)
q_values.append(q_taken)
q_values = torch.stack(q_values, dim=1)
# Get Q_tot
q_tot = self.mixer(q_values, states)
# Get target Q-values
with torch.no_grad():
target_q_values = []
for i in range(self.n_agents):
next_obs = torch.FloatTensor(next_observations[i])
target_q = self.target_agents[i](next_obs)
target_q_max = target_q.max(dim=1)[0]
target_q_values.append(target_q_max)
target_q_values = torch.stack(target_q_values, dim=1)
target_q_tot = self.target_mixer(target_q_values, next_states)
y = rewards + self.gamma * (1 - dones) * target_q_tot.squeeze(1)
# QMIX loss
loss = nn.MSELoss()(q_tot.squeeze(1), y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()Self-Play for Competitive Games
For zero-sum competitive games, a powerful approach is self-play: train an agent against copies of itself.
Self-play creates an automatic curriculum:
- Initially, the agent plays against a weak version of itself
- As it improves, so does its opponent (it’s the same network!)
- The agent continuously faces challenges at its current skill level
- Over time, both the agent and its “opponent” become stronger
This is how AlphaGo, AlphaZero, and many game-playing AIs achieved superhuman performance.
In self-play, the agent trains against its own policy. The experience comes from games where both players use (possibly different snapshots of) the same policy:
where is often a past version of (to add diversity and stability).
The goal is to find a Nash equilibrium: a policy that can’t be exploited by any opponent, including itself.
class SelfPlayTrainer:
"""
Self-play training for two-player zero-sum games.
"""
def __init__(self, agent, opponent_pool_size=10):
self.agent = agent # The main agent being trained
self.opponent_pool = [] # Past versions for diversity
self.opponent_pool_size = opponent_pool_size
def get_opponent(self):
"""Select opponent - either current agent or past version."""
if len(self.opponent_pool) == 0 or np.random.random() < 0.5:
# Play against current self
return self.agent
else:
# Play against random past version
return np.random.choice(self.opponent_pool)
def save_checkpoint(self):
"""Save current agent to opponent pool."""
# Deep copy the agent
import copy
checkpoint = copy.deepcopy(self.agent)
if len(self.opponent_pool) >= self.opponent_pool_size:
# Remove oldest
self.opponent_pool.pop(0)
self.opponent_pool.append(checkpoint)
def train_episode(self, env):
"""Play one game of self-play and update."""
opponent = self.get_opponent()
state = env.reset()
episode_data = []
done = False
while not done:
# Agent plays as player 0
obs_agent = env.get_observation(0)
action_agent = self.agent.select_action(obs_agent)
# Opponent plays as player 1
obs_opp = env.get_observation(1)
with torch.no_grad():
action_opp = opponent.select_action(obs_opp)
# Environment step
next_state, rewards, done, _ = env.step([action_agent, action_opp])
episode_data.append({
'obs': obs_agent,
'action': action_agent,
'reward': rewards[0], # Agent's reward
'next_obs': env.get_observation(0),
'done': done
})
state = next_state
# Update agent with episode data
self.agent.update(episode_data)
# Periodically save checkpoint
return episode_dataPopulation-Based Training
For even more diversity, use a population of agents that train and compete against each other.
OpenAI Five (Dota 2): Trained a population of agents. Used matchmaking to pair agents of similar skill. Periodically updated the population based on performance.
AlphaStar (StarCraft II): Maintained a “league” of diverse agents. Main agents trained against the league. Exploiter agents specifically targeted weaknesses. This created a diverse pool of strategies and prevented any single exploitable weakness.
These systems achieved superhuman performance by avoiding the “arms race” problem where agents co-evolve narrow counter-strategies.
A single self-play agent might develop blind spots—strategies it never learned to counter because it never used them. A population maintains diversity:
- Different agents develop different strategies
- Training against the population means learning to handle diverse opponents
- The resulting policy is more robust and general
Summary
Centralized Training with Decentralized Execution (CTDE) is the dominant paradigm in modern multi-agent RL:
- Centralized critics provide stable learning signals using global information
- Decentralized actors ensure deployable policies that need only local observations
- Value decomposition (QMIX) enables efficient credit assignment in cooperative games
- Self-play creates automatic curricula for competitive games
- Population-based training provides diversity and robustness
The key insight: exploit all available information during training, but design systems that work with limited information during deployment. This separation of concerns makes multi-agent RL tractable while maintaining practical deployability.