Advanced Topics • Part 1 of 4
📝Draft

Learning World Models

Predicting transitions and rewards

Learning World Models

What separates model-based RL from model-free RL is one thing: the model. But what exactly is a model, and how do we learn one?

What is a Model?

📖Environment Model

An environment model consists of two components:

  • Transition model: P^(ss,a)\hat{P}(s'|s, a) - predicts the next state given current state and action
  • Reward model: R^(s,a)\hat{R}(s, a) - predicts the reward for a state-action pair

Together, these allow an agent to simulate what would happen if it took an action, without actually taking it.

Imagine you’re learning to play a new video game. After playing for a while, you start to understand the rules: “If I press jump near a ledge, my character will land on the platform.” This mental model lets you plan ahead without actually trying every action.

That’s exactly what model-based RL does: it learns the rules of the environment so it can plan.

Think about a chess player considering a move. They don’t physically make the move to see what happens. Instead, they simulate it in their head: “If I move my knight here, my opponent will probably respond with…” This is planning with a model.

Model-Free vs Model-Based: The Fundamental Distinction

Model-Free
1.Take action in environment
2.Observe reward and next state
3.Update value function or policy directly
4.Repeat

Example: Q-learning, SARSA, PPO

Model-Based
1.Take action in environment
2.Observe reward and next state
3.Learn/update the model
4.Use model to plan or simulate experiences
5.Update value function or policy

Example: Dyna-Q, MBPO, MuZero

The key insight: model-based methods can reuse each real experience multiple times by using the model to generate additional simulated experiences. This is why they’re more sample-efficient.

The Sample Efficiency Advantage

Why bother learning a model? The answer is sample efficiency.

📌Why Sample Efficiency Matters

Consider training a robot to walk:

  • Model-free approach: The robot must physically attempt thousands of steps, falling over repeatedly. Each fall is wear on the hardware, time spent resetting, and potential damage.

  • Model-based approach: The robot attempts a few hundred real steps, builds a model of its dynamics, then practices thousands of steps in simulation. It only returns to the real robot when it has a promising policy.

In robotics, healthcare, and other real-world domains, sample efficiency isn’t just convenient—it’s essential.

Mathematical Details

The sample efficiency advantage can be quantified. If we take nn real environment steps, model-free methods get nn learning updates. But with a model, we can generate kk simulated experiences from each real experience, giving us nkn \cdot k learning updates from the same real data.

This amplification factor kk can be dramatic. In Dyna-Q, typical values are k=5k = 5 to k=50k = 50 simulated steps per real step.

Of course, this comes with a caveat: simulated experiences are only as good as the model. If the model is wrong, we’re learning from garbage.

Learning a Tabular Model

The simplest approach to model learning works in tabular settings where states and actions are discrete.

The idea is straightforward: just remember what happened. Every time you take action aa in state ss and end up in state ss' with reward rr, record it. After enough observations, you can predict what will happen for any (s,a)(s, a) pair you’ve seen.

Mathematical Details

For a deterministic environment, we simply remember the transitions:

P^(ss,a)=s if we observed (s,a)s\hat{P}(s'|s, a) = s' \text{ if we observed } (s, a) \rightarrow s'

R^(s,a)=r if we observed reward r for (s,a)\hat{R}(s, a) = r \text{ if we observed reward } r \text{ for } (s, a)

For stochastic environments, we estimate transition probabilities from counts:

P^(ss,a)=count(s,a,s)count(s,a)\hat{P}(s'|s, a) = \frac{\text{count}(s, a, s')}{\text{count}(s, a)}

R^(s,a)=visits to (s,a)rcount(s,a)\hat{R}(s, a) = \frac{\sum_{\text{visits to } (s,a)} r}{\text{count}(s, a)}

As we collect more data, these estimates converge to the true dynamics.

</>Implementation
import numpy as np

class TabularModel:
    """Tabular environment model learned from experience."""

    def __init__(self, n_states, n_actions):
        self.n_states = n_states
        self.n_actions = n_actions

        # Count-based transition model
        self.transition_counts = np.zeros((n_states, n_actions, n_states))
        self.reward_sum = np.zeros((n_states, n_actions))
        self.visit_counts = np.zeros((n_states, n_actions))

    def update(self, s, a, r, s_prime):
        """Update model from a single transition."""
        self.transition_counts[s, a, s_prime] += 1
        self.reward_sum[s, a] += r
        self.visit_counts[s, a] += 1

    def predict(self, s, a):
        """Predict next state and reward."""
        if self.visit_counts[s, a] == 0:
            return None, None  # No data for this state-action pair

        # Transition probabilities
        probs = self.transition_counts[s, a] / self.visit_counts[s, a]
        s_prime = np.random.choice(self.n_states, p=probs)

        # Expected reward
        r = self.reward_sum[s, a] / self.visit_counts[s, a]

        return s_prime, r

    def sample_experienced(self):
        """Sample a previously experienced (s, a) pair for planning."""
        experienced = np.where(self.visit_counts > 0)
        if len(experienced[0]) == 0:
            return None
        idx = np.random.randint(len(experienced[0]))
        return experienced[0][idx], experienced[1][idx]

    def get_uncertainty(self, s, a):
        """Return uncertainty (inverse of visit count) for a state-action pair."""
        if self.visit_counts[s, a] == 0:
            return float('inf')
        return 1.0 / self.visit_counts[s, a]

Neural Network Models

For complex environments with continuous states, tabular models won’t work. We need function approximation—typically neural networks.

Instead of storing counts, we train a neural network to predict the next state and reward given the current state and action. The network generalizes across similar states, so we can make predictions even for states we haven’t visited before.

This is powerful but dangerous: the network might make confident predictions about states that are nothing like what it’s seen in training.

Mathematical Details

We train two networks (or a single network with two heads):

Transition model: s^=fθ(s,a)\hat{s}' = f_\theta(s, a) predicting the next state

Reward model: r^=gϕ(s,a)\hat{r} = g_\phi(s, a) predicting the reward

The loss functions are simply prediction errors:

Ltransition=E(s,a,s)D[s^s2]L_{\text{transition}} = \mathbb{E}_{(s,a,s') \sim D}\left[ \| \hat{s}' - s' \|^2 \right]

Lreward=E(s,a,r)D[(r^r)2]L_{\text{reward}} = \mathbb{E}_{(s,a,r) \sim D}\left[ (\hat{r} - r)^2 \right]

For stochastic environments, we might predict a distribution over next states, often using a Gaussian: P^(ss,a)=N(μθ(s,a),σθ(s,a))\hat{P}(s'|s,a) = \mathcal{N}(\mu_\theta(s,a), \sigma_\theta(s,a)).

</>Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeuralModel(nn.Module):
    """Neural network environment model."""

    def __init__(self, state_dim, action_dim, hidden_dim=256):
        super().__init__()

        # Shared encoder
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

        # Transition prediction head (deterministic)
        self.transition_head = nn.Linear(hidden_dim, state_dim)

        # Reward prediction head
        self.reward_head = nn.Linear(hidden_dim, 1)

    def forward(self, state, action):
        """Predict next state and reward."""
        # Concatenate state and action
        x = torch.cat([state, action], dim=-1)
        features = self.encoder(x)

        next_state = self.transition_head(features)
        reward = self.reward_head(features)

        return next_state, reward.squeeze(-1)

    def predict_next_state(self, state, action):
        """Predict just the next state."""
        next_state, _ = self.forward(state, action)
        return next_state


def train_model(model, replay_buffer, optimizer, batch_size=256, epochs=5):
    """Train the model on data from replay buffer."""
    model.train()

    for _ in range(epochs):
        # Sample batch from replay buffer
        states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)

        # Convert to tensors
        states = torch.FloatTensor(states)
        actions = torch.FloatTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)

        # Forward pass
        pred_next_states, pred_rewards = model(states, actions)

        # Compute losses
        transition_loss = F.mse_loss(pred_next_states, next_states)
        reward_loss = F.mse_loss(pred_rewards, rewards)
        loss = transition_loss + reward_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return loss.item()

Model Uncertainty

A critical challenge with learned models is knowing when to trust them.

Think about predicting where a ball will land. If you’ve seen many similar throws, you can predict confidently. But if someone throws the ball in a completely new way, your prediction might be wildly wrong—and you might not even realize it.

Good model-based methods need to track their uncertainty: “I’m confident about this prediction” vs “I’ve never seen anything like this state.”

Ensemble Models for Uncertainty

A practical approach is to train an ensemble of models and use their disagreement as an uncertainty measure.

Mathematical Details

Train MM models {f1,...,fM}\{f_1, ..., f_M\} on different bootstrap samples of the data. For a prediction:

  • Mean prediction: fˉ(s,a)=1Mi=1Mfi(s,a)\bar{f}(s,a) = \frac{1}{M}\sum_{i=1}^M f_i(s,a)
  • Uncertainty: σ(s,a)=std({f1(s,a),...,fM(s,a)})\sigma(s,a) = \text{std}(\{f_1(s,a), ..., f_M(s,a)\})

High disagreement between models indicates high uncertainty—the models haven’t seen enough data from this region to agree.

</>Implementation
class EnsembleModel:
    """Ensemble of neural network models for uncertainty estimation."""

    def __init__(self, n_models, state_dim, action_dim, hidden_dim=256):
        self.models = nn.ModuleList([
            NeuralModel(state_dim, action_dim, hidden_dim)
            for _ in range(n_models)
        ])
        self.n_models = n_models

    def predict_with_uncertainty(self, state, action):
        """Predict next state with uncertainty estimate."""
        predictions = []

        for model in self.models:
            with torch.no_grad():
                next_state, _ = model(state, action)
                predictions.append(next_state)

        # Stack predictions: (n_models, batch_size, state_dim)
        predictions = torch.stack(predictions)

        # Mean prediction
        mean = predictions.mean(dim=0)

        # Uncertainty: standard deviation across models
        uncertainty = predictions.std(dim=0).mean(dim=-1)  # Average over state dims

        return mean, uncertainty

    def should_trust_prediction(self, state, action, threshold=0.5):
        """Check if prediction is reliable based on ensemble disagreement."""
        _, uncertainty = self.predict_with_uncertainty(state, action)
        return uncertainty < threshold

When Models Fail

Consider the weather forecast analogy:

  • Tomorrow’s forecast: usually pretty accurate
  • Next week’s forecast: somewhat reliable
  • Next month’s forecast: basically a guess

The same applies to learned world models. Short-term predictions are reliable; long-term predictions degrade quickly.

ℹ️The Model Bias Problem

If you train a policy purely on model rollouts, any systematic errors in the model will be exploited. The policy might learn to take actions that look good in the model but fail catastrophically in reality.

This is called model exploitation or model bias. Solutions include:

  • Limiting planning horizon
  • Mixing real and simulated experience (Dyna)
  • Penalizing actions with high model uncertainty
  • Using ensembles to avoid confident wrong predictions

Summary

Model learning is the foundation of model-based RL. By learning how the environment works, we gain the ability to simulate experiences and plan ahead—dramatically improving sample efficiency.

But learned models are imperfect. The art of model-based RL lies in using models effectively while being robust to their errors. In the next section, we’ll see how to actually use these models for planning.