MuZero and Beyond
Traditional model-based RL learns to predict raw observations: given state and action, predict the next state. But do we really need to predict every detail of the world to plan well? MuZero (Schrittwieser et al., 2020) shows that we don’t—we can learn abstract models optimized for planning, even without knowing the rules of the game.
The Evolution: From AlphaGo to MuZero
The journey to MuZero tells a story of progressive abstraction:
AlphaGo (2016): Used Monte Carlo Tree Search with a known game model. Knew the rules of Go perfectly. Could simulate any game position.
AlphaZero (2017): Generalized to chess and shogi. Still needed the rules. Planning worked because it could perfectly simulate game outcomes.
MuZero (2020): Dropped the need for rules entirely. Learned a model from scratch—not to predict observations, but to plan effectively. Achieved superhuman performance in Go, chess, shogi, AND Atari games.
The key insight: you don’t need to predict what the next frame of an Atari game looks like. You just need to predict what matters for making decisions: rewards and values.
Why Abstract Models?
An abstract model learns to predict in a latent space rather than observation space. Instead of predicting “what will I see next?”, it predicts “what will the value and policy look like from here?”
The model doesn’t try to reconstruct pixels—it learns representations useful for planning.
When you plan your commute, you don’t simulate every traffic light and pedestrian. You think in abstract terms: “The highway is probably congested, so I’ll take surface streets. That route usually takes 25 minutes.”
Your mental model is abstract—optimized for making decisions, not for reconstructing reality. MuZero learns this kind of model automatically.
Traditional models that predict raw observations face two challenges:
-
Wasted capacity: Most pixels don’t matter for decision-making. The color of the sky in a driving game doesn’t affect whether you should brake.
-
Compounding errors: Small observation errors compound over multi-step planning. After 20 steps, the predicted game state might look nothing like reality.
Abstract models sidestep both problems by only learning what matters for planning.
The MuZero Architecture
MuZero learns three neural networks that work together:
1. Representation function : Maps observations to latent states
2. Dynamics function : Predicts next latent state and reward
3. Prediction function : Predicts policy and value from latent state
The key insight: the dynamics function operates entirely in latent space. It never needs to reconstruct observations—only to produce latent states that support accurate policy and value predictions.
Think of it like this:
- Representation: “Here’s what I see. Let me encode it into something useful.”
- Dynamics: “If I take this action, here’s what my internal state becomes, and here’s the reward.”
- Prediction: “From this internal state, here’s what I should do and how good things are.”
The dynamics function is the learned model—but it predicts internal states, not observations. Those internal states are only trained to be useful for predicting values and policies.
Planning with MCTS
MuZero uses Monte Carlo Tree Search (MCTS) to plan. But instead of simulating in the real environment, it simulates in its learned latent space.
At each decision point:
- Encode the current observation:
- Run MCTS for simulations:
- Traverse the tree using UCB-style selection
- Expand nodes using the dynamics function
- Evaluate leaf nodes using the prediction function
- Backup values through the tree
- Select action proportional to visit counts at the root
The key difference from traditional MCTS: the game rules are replaced by the learned dynamics function .
import torch
import torch.nn as nn
import numpy as np
class MuZeroNetworks(nn.Module):
"""
Simplified MuZero networks.
In practice, these would be much larger and use residual blocks.
"""
def __init__(self, obs_dim, action_dim, latent_dim=128, hidden_dim=256):
super().__init__()
self.action_dim = action_dim
self.latent_dim = latent_dim
# Representation: observation -> latent state
self.representation = nn.Sequential(
nn.Linear(obs_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim)
)
# Dynamics: (latent, action) -> (next_latent, reward)
self.dynamics = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.dynamics_state = nn.Linear(hidden_dim, latent_dim)
self.dynamics_reward = nn.Linear(hidden_dim, 1)
# Prediction: latent -> (policy, value)
self.prediction_base = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
)
self.policy_head = nn.Linear(hidden_dim, action_dim)
self.value_head = nn.Linear(hidden_dim, 1)
def initial_inference(self, observation):
"""First step: encode observation and predict policy/value."""
latent = self.representation(observation)
pred_features = self.prediction_base(latent)
policy_logits = self.policy_head(pred_features)
value = self.value_head(pred_features)
return latent, policy_logits, value
def recurrent_inference(self, latent, action_onehot):
"""Subsequent steps: apply dynamics and predict policy/value."""
# Dynamics
x = torch.cat([latent, action_onehot], dim=-1)
features = self.dynamics(x)
next_latent = self.dynamics_state(features)
reward = self.dynamics_reward(features)
# Prediction
pred_features = self.prediction_base(next_latent)
policy_logits = self.policy_head(pred_features)
value = self.value_head(pred_features)
return next_latent, reward, policy_logits, value
class MCTSNode:
"""A node in the MCTS tree."""
def __init__(self, prior, latent_state=None):
self.prior = prior # P(a) from parent
self.latent_state = latent_state
self.visit_count = 0
self.value_sum = 0
self.reward = 0
self.children = {}
def value(self):
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
def expanded(self):
return len(self.children) > 0
def mcts_search(networks, observation, num_simulations=50, gamma=0.99):
"""
Run MCTS from the current observation.
Returns action probabilities based on visit counts.
"""
# Initial inference
with torch.no_grad():
obs_tensor = torch.FloatTensor(observation).unsqueeze(0)
latent, policy_logits, value = networks.initial_inference(obs_tensor)
# Create root node
policy = torch.softmax(policy_logits, dim=-1).squeeze().numpy()
root = MCTSNode(prior=0, latent_state=latent)
# Initialize children from policy prior
for a in range(networks.action_dim):
root.children[a] = MCTSNode(prior=policy[a])
# Run simulations
for _ in range(num_simulations):
node = root
search_path = [node]
action_history = []
# Selection: traverse tree using UCB
while node.expanded():
action, child = select_child(node)
action_history.append(action)
search_path.append(child)
node = child
# Expansion: use dynamics to expand
parent = search_path[-2]
action = action_history[-1]
with torch.no_grad():
action_onehot = torch.zeros(1, networks.action_dim)
action_onehot[0, action] = 1
next_latent, reward, policy_logits, value = networks.recurrent_inference(
parent.latent_state, action_onehot
)
node.latent_state = next_latent
node.reward = reward.item()
# Create children for new node
policy = torch.softmax(policy_logits, dim=-1).squeeze().numpy()
for a in range(networks.action_dim):
node.children[a] = MCTSNode(prior=policy[a])
# Backup
backup(search_path, value.item(), gamma)
# Return action probabilities from visit counts
visits = np.array([root.children[a].visit_count for a in range(networks.action_dim)])
return visits / visits.sum()
def select_child(node, c_puct=1.0):
"""Select child with highest UCB score."""
best_score = -float('inf')
best_action = 0
best_child = None
total_visits = sum(child.visit_count for child in node.children.values())
for action, child in node.children.items():
# UCB formula
if child.visit_count == 0:
q_value = 0
else:
q_value = child.value()
exploration = c_puct * child.prior * np.sqrt(total_visits) / (1 + child.visit_count)
score = q_value + exploration
if score > best_score:
best_score = score
best_action = action
best_child = child
return best_action, best_child
def backup(search_path, value, gamma):
"""Backup value through the search path."""
for node in reversed(search_path):
node.value_sum += value
node.visit_count += 1
value = node.reward + gamma * valueTraining MuZero
MuZero is trained to match its predictions to observed outcomes. For each trajectory in the replay buffer:
- Unroll the model for steps using the dynamics function
- Predict policy, value, and reward at each step
- Compare to targets:
- Policy target: MCTS search policy
- Value target: -step return
- Reward target: observed reward
Loss function:
where is cross-entropy loss for policy, and are mean squared error for value and reward.
Notice that MuZero never explicitly trains the latent states to match observations. The latent space is only shaped by the requirement that it supports accurate policy, value, and reward predictions.
This is what makes the model “abstract”—it learns whatever representation is useful for planning, not whatever is useful for reconstruction.
MuZero Results
MuZero achieved remarkable results across diverse domains:
Matched AlphaZero in Go, chess, and shogi—without knowing the rules.
Set new state-of-the-art on the Atari-57 benchmark, surpassing model-free methods.
Achieved strong performance with far fewer environment interactions than model-free alternatives.
When to Use Model-Based Methods
Model-based RL isn’t always the right choice. Here’s a framework for deciding:
Use model-based when:
- Real interactions are expensive, slow, or risky
- The environment has exploitable structure
- You need sample efficiency
- Planning at test time is acceptable
Prefer model-free when:
- You have unlimited simulation access
- Environment dynamics are extremely complex or chaotic
- Real-time decisions are critical (planning is slow)
- Model errors would compound unacceptably
MuZero-style methods are best when:
- You want model-based benefits without knowing the rules
- The observation space is complex but the underlying dynamics are learnable
- You can afford computational cost at decision time
Beyond MuZero: Current Frontiers
Model-based RL continues to evolve:
Dreamer (Hafner et al., 2020, 2021): Learns world models for continuous control, training policies entirely in imagination.
Decision Transformer (Chen et al., 2021): Treats RL as sequence modeling, learning to predict actions conditioned on desired returns.
IRIS (Micheli et al., 2023): Combines transformers with discrete latent states for efficient world modeling.
Genie (Bruce et al., 2024): Learns generative world models from video that can be prompted to create new environments.
The field is moving toward more expressive, more efficient, and more general world models.
Summary
MuZero represents a paradigm shift in model-based RL:
- No rules required: Learns models that enable planning without knowing environment rules
- Abstract representations: Latent dynamics optimized for planning, not observation prediction
- MCTS planning: Combines learned models with principled search
- Strong results: Achieves superhuman performance across board games and video games
The key insight: you don’t need to model the world accurately—you need to model what matters for decision-making. This principle underlies not just MuZero but much of the frontier of model-based RL research.
Model-based RL offers a path to sample-efficient learning. As models become more powerful and compute becomes cheaper, expect these methods to play an increasingly important role in real-world RL applications.