Advanced Topics • Part 4 of 4
📝Draft

MuZero and Beyond

Learning models for planning without rules

MuZero and Beyond

Traditional model-based RL learns to predict raw observations: given state and action, predict the next state. But do we really need to predict every detail of the world to plan well? MuZero (Schrittwieser et al., 2020) shows that we don’t—we can learn abstract models optimized for planning, even without knowing the rules of the game.

The Evolution: From AlphaGo to MuZero

The journey to MuZero tells a story of progressive abstraction:

AlphaGo (2016): Used Monte Carlo Tree Search with a known game model. Knew the rules of Go perfectly. Could simulate any game position.

AlphaZero (2017): Generalized to chess and shogi. Still needed the rules. Planning worked because it could perfectly simulate game outcomes.

MuZero (2020): Dropped the need for rules entirely. Learned a model from scratch—not to predict observations, but to plan effectively. Achieved superhuman performance in Go, chess, shogi, AND Atari games.

The key insight: you don’t need to predict what the next frame of an Atari game looks like. You just need to predict what matters for making decisions: rewards and values.

Why Abstract Models?

📖Abstract Model

An abstract model learns to predict in a latent space rather than observation space. Instead of predicting “what will I see next?”, it predicts “what will the value and policy look like from here?”

The model doesn’t try to reconstruct pixels—it learns representations useful for planning.

When you plan your commute, you don’t simulate every traffic light and pedestrian. You think in abstract terms: “The highway is probably congested, so I’ll take surface streets. That route usually takes 25 minutes.”

Your mental model is abstract—optimized for making decisions, not for reconstructing reality. MuZero learns this kind of model automatically.

The MuZero Architecture

MuZero learns three neural networks that work together:

Mathematical Details

1. Representation function hh: Maps observations to latent states s0=hθ(o1,...,ot)s^0 = h_\theta(o_1, ..., o_t)

2. Dynamics function gg: Predicts next latent state and reward rk,sk=gθ(sk1,ak)r^k, s^k = g_\theta(s^{k-1}, a^k)

3. Prediction function ff: Predicts policy and value from latent state pk,vk=fθ(sk)p^k, v^k = f_\theta(s^k)

The key insight: the dynamics function gg operates entirely in latent space. It never needs to reconstruct observations—only to produce latent states that support accurate policy and value predictions.

Think of it like this:

  • Representation: “Here’s what I see. Let me encode it into something useful.”
  • Dynamics: “If I take this action, here’s what my internal state becomes, and here’s the reward.”
  • Prediction: “From this internal state, here’s what I should do and how good things are.”

The dynamics function is the learned model—but it predicts internal states, not observations. Those internal states are only trained to be useful for predicting values and policies.

MuZero Architecture
Observation
oto_t
h\xrightarrow{h}
Latent State
s0s^0
\downarrow+ action aa
Next Latent
s1s^1
g\xleftarrow{g}
Reward
r1r^1
f\downarrow f
Policy
p1p^1
Value
v1v^1

Planning with MCTS

MuZero uses Monte Carlo Tree Search (MCTS) to plan. But instead of simulating in the real environment, it simulates in its learned latent space.

Mathematical Details

At each decision point:

  1. Encode the current observation: s0=hθ(ot)s^0 = h_\theta(o_t)
  2. Run MCTS for NN simulations:
    • Traverse the tree using UCB-style selection
    • Expand nodes using the dynamics function gg
    • Evaluate leaf nodes using the prediction function ff
    • Backup values through the tree
  3. Select action proportional to visit counts at the root

The key difference from traditional MCTS: the game rules are replaced by the learned dynamics function gg.

</>Implementation
import torch
import torch.nn as nn
import numpy as np

class MuZeroNetworks(nn.Module):
    """
    Simplified MuZero networks.

    In practice, these would be much larger and use residual blocks.
    """

    def __init__(self, obs_dim, action_dim, latent_dim=128, hidden_dim=256):
        super().__init__()
        self.action_dim = action_dim
        self.latent_dim = latent_dim

        # Representation: observation -> latent state
        self.representation = nn.Sequential(
            nn.Linear(obs_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )

        # Dynamics: (latent, action) -> (next_latent, reward)
        self.dynamics = nn.Sequential(
            nn.Linear(latent_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
        )
        self.dynamics_state = nn.Linear(hidden_dim, latent_dim)
        self.dynamics_reward = nn.Linear(hidden_dim, 1)

        # Prediction: latent -> (policy, value)
        self.prediction_base = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
        )
        self.policy_head = nn.Linear(hidden_dim, action_dim)
        self.value_head = nn.Linear(hidden_dim, 1)

    def initial_inference(self, observation):
        """First step: encode observation and predict policy/value."""
        latent = self.representation(observation)
        pred_features = self.prediction_base(latent)
        policy_logits = self.policy_head(pred_features)
        value = self.value_head(pred_features)
        return latent, policy_logits, value

    def recurrent_inference(self, latent, action_onehot):
        """Subsequent steps: apply dynamics and predict policy/value."""
        # Dynamics
        x = torch.cat([latent, action_onehot], dim=-1)
        features = self.dynamics(x)
        next_latent = self.dynamics_state(features)
        reward = self.dynamics_reward(features)

        # Prediction
        pred_features = self.prediction_base(next_latent)
        policy_logits = self.policy_head(pred_features)
        value = self.value_head(pred_features)

        return next_latent, reward, policy_logits, value


class MCTSNode:
    """A node in the MCTS tree."""

    def __init__(self, prior, latent_state=None):
        self.prior = prior  # P(a) from parent
        self.latent_state = latent_state
        self.visit_count = 0
        self.value_sum = 0
        self.reward = 0
        self.children = {}

    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expanded(self):
        return len(self.children) > 0


def mcts_search(networks, observation, num_simulations=50, gamma=0.99):
    """
    Run MCTS from the current observation.

    Returns action probabilities based on visit counts.
    """
    # Initial inference
    with torch.no_grad():
        obs_tensor = torch.FloatTensor(observation).unsqueeze(0)
        latent, policy_logits, value = networks.initial_inference(obs_tensor)

    # Create root node
    policy = torch.softmax(policy_logits, dim=-1).squeeze().numpy()
    root = MCTSNode(prior=0, latent_state=latent)

    # Initialize children from policy prior
    for a in range(networks.action_dim):
        root.children[a] = MCTSNode(prior=policy[a])

    # Run simulations
    for _ in range(num_simulations):
        node = root
        search_path = [node]
        action_history = []

        # Selection: traverse tree using UCB
        while node.expanded():
            action, child = select_child(node)
            action_history.append(action)
            search_path.append(child)
            node = child

        # Expansion: use dynamics to expand
        parent = search_path[-2]
        action = action_history[-1]

        with torch.no_grad():
            action_onehot = torch.zeros(1, networks.action_dim)
            action_onehot[0, action] = 1
            next_latent, reward, policy_logits, value = networks.recurrent_inference(
                parent.latent_state, action_onehot
            )

        node.latent_state = next_latent
        node.reward = reward.item()

        # Create children for new node
        policy = torch.softmax(policy_logits, dim=-1).squeeze().numpy()
        for a in range(networks.action_dim):
            node.children[a] = MCTSNode(prior=policy[a])

        # Backup
        backup(search_path, value.item(), gamma)

    # Return action probabilities from visit counts
    visits = np.array([root.children[a].visit_count for a in range(networks.action_dim)])
    return visits / visits.sum()


def select_child(node, c_puct=1.0):
    """Select child with highest UCB score."""
    best_score = -float('inf')
    best_action = 0
    best_child = None

    total_visits = sum(child.visit_count for child in node.children.values())

    for action, child in node.children.items():
        # UCB formula
        if child.visit_count == 0:
            q_value = 0
        else:
            q_value = child.value()

        exploration = c_puct * child.prior * np.sqrt(total_visits) / (1 + child.visit_count)
        score = q_value + exploration

        if score > best_score:
            best_score = score
            best_action = action
            best_child = child

    return best_action, best_child


def backup(search_path, value, gamma):
    """Backup value through the search path."""
    for node in reversed(search_path):
        node.value_sum += value
        node.visit_count += 1
        value = node.reward + gamma * value

Training MuZero

Mathematical Details

MuZero is trained to match its predictions to observed outcomes. For each trajectory in the replay buffer:

  1. Unroll the model for KK steps using the dynamics function
  2. Predict policy, value, and reward at each step
  3. Compare to targets:
    • Policy target: MCTS search policy πt\pi_t
    • Value target: nn-step return
    • Reward target: observed reward

Loss function: L=k=0K[lp(pk,πt+k)+lv(vk,zt+k)+lr(rk,ut+k)]L = \sum_{k=0}^{K} \left[ l^p(p^k, \pi_{t+k}) + l^v(v^k, z_{t+k}) + l^r(r^k, u_{t+k}) \right]

where lpl^p is cross-entropy loss for policy, lvl^v and lrl^r are mean squared error for value and reward.

ℹ️The Clever Trick

Notice that MuZero never explicitly trains the latent states to match observations. The latent space is only shaped by the requirement that it supports accurate policy, value, and reward predictions.

This is what makes the model “abstract”—it learns whatever representation is useful for planning, not whatever is useful for reconstruction.

MuZero Results

MuZero achieved remarkable results across diverse domains:

Board Games

Matched AlphaZero in Go, chess, and shogi—without knowing the rules.

Atari Games

Set new state-of-the-art on the Atari-57 benchmark, surpassing model-free methods.

Sample Efficiency

Achieved strong performance with far fewer environment interactions than model-free alternatives.

When to Use Model-Based Methods

Model-based RL isn’t always the right choice. Here’s a framework for deciding:

Use model-based when:

  • Real interactions are expensive, slow, or risky
  • The environment has exploitable structure
  • You need sample efficiency
  • Planning at test time is acceptable

Prefer model-free when:

  • You have unlimited simulation access
  • Environment dynamics are extremely complex or chaotic
  • Real-time decisions are critical (planning is slow)
  • Model errors would compound unacceptably

MuZero-style methods are best when:

  • You want model-based benefits without knowing the rules
  • The observation space is complex but the underlying dynamics are learnable
  • You can afford computational cost at decision time

Beyond MuZero: Current Frontiers

📌Recent Developments

Model-based RL continues to evolve:

Dreamer (Hafner et al., 2020, 2021): Learns world models for continuous control, training policies entirely in imagination.

Decision Transformer (Chen et al., 2021): Treats RL as sequence modeling, learning to predict actions conditioned on desired returns.

IRIS (Micheli et al., 2023): Combines transformers with discrete latent states for efficient world modeling.

Genie (Bruce et al., 2024): Learns generative world models from video that can be prompted to create new environments.

The field is moving toward more expressive, more efficient, and more general world models.

Summary

MuZero represents a paradigm shift in model-based RL:

  • No rules required: Learns models that enable planning without knowing environment rules
  • Abstract representations: Latent dynamics optimized for planning, not observation prediction
  • MCTS planning: Combines learned models with principled search
  • Strong results: Achieves superhuman performance across board games and video games

The key insight: you don’t need to model the world accurately—you need to model what matters for decision-making. This principle underlies not just MuZero but much of the frontier of model-based RL research.

Model-based RL offers a path to sample-efficient learning. As models become more powerful and compute becomes cheaper, expect these methods to play an increasingly important role in real-world RL applications.