Learning World Models
What separates model-based RL from model-free RL is one thing: the model. But what exactly is a model, and how do we learn one?
What is a Model?
An environment model consists of two components:
- Transition model: - predicts the next state given current state and action
- Reward model: - predicts the reward for a state-action pair
Together, these allow an agent to simulate what would happen if it took an action, without actually taking it.
Imagine you’re learning to play a new video game. After playing for a while, you start to understand the rules: “If I press jump near a ledge, my character will land on the platform.” This mental model lets you plan ahead without actually trying every action.
That’s exactly what model-based RL does: it learns the rules of the environment so it can plan.
Think about a chess player considering a move. They don’t physically make the move to see what happens. Instead, they simulate it in their head: “If I move my knight here, my opponent will probably respond with…” This is planning with a model.
Model-Free vs Model-Based: The Fundamental Distinction
Example: Q-learning, SARSA, PPO
Example: Dyna-Q, MBPO, MuZero
The key insight: model-based methods can reuse each real experience multiple times by using the model to generate additional simulated experiences. This is why they’re more sample-efficient.
The Sample Efficiency Advantage
Why bother learning a model? The answer is sample efficiency.
Consider training a robot to walk:
-
Model-free approach: The robot must physically attempt thousands of steps, falling over repeatedly. Each fall is wear on the hardware, time spent resetting, and potential damage.
-
Model-based approach: The robot attempts a few hundred real steps, builds a model of its dynamics, then practices thousands of steps in simulation. It only returns to the real robot when it has a promising policy.
In robotics, healthcare, and other real-world domains, sample efficiency isn’t just convenient—it’s essential.
The sample efficiency advantage can be quantified. If we take real environment steps, model-free methods get learning updates. But with a model, we can generate simulated experiences from each real experience, giving us learning updates from the same real data.
This amplification factor can be dramatic. In Dyna-Q, typical values are to simulated steps per real step.
Of course, this comes with a caveat: simulated experiences are only as good as the model. If the model is wrong, we’re learning from garbage.
Learning a Tabular Model
The simplest approach to model learning works in tabular settings where states and actions are discrete.
The idea is straightforward: just remember what happened. Every time you take action in state and end up in state with reward , record it. After enough observations, you can predict what will happen for any pair you’ve seen.
For a deterministic environment, we simply remember the transitions:
For stochastic environments, we estimate transition probabilities from counts:
As we collect more data, these estimates converge to the true dynamics.
import numpy as np
class TabularModel:
"""Tabular environment model learned from experience."""
def __init__(self, n_states, n_actions):
self.n_states = n_states
self.n_actions = n_actions
# Count-based transition model
self.transition_counts = np.zeros((n_states, n_actions, n_states))
self.reward_sum = np.zeros((n_states, n_actions))
self.visit_counts = np.zeros((n_states, n_actions))
def update(self, s, a, r, s_prime):
"""Update model from a single transition."""
self.transition_counts[s, a, s_prime] += 1
self.reward_sum[s, a] += r
self.visit_counts[s, a] += 1
def predict(self, s, a):
"""Predict next state and reward."""
if self.visit_counts[s, a] == 0:
return None, None # No data for this state-action pair
# Transition probabilities
probs = self.transition_counts[s, a] / self.visit_counts[s, a]
s_prime = np.random.choice(self.n_states, p=probs)
# Expected reward
r = self.reward_sum[s, a] / self.visit_counts[s, a]
return s_prime, r
def sample_experienced(self):
"""Sample a previously experienced (s, a) pair for planning."""
experienced = np.where(self.visit_counts > 0)
if len(experienced[0]) == 0:
return None
idx = np.random.randint(len(experienced[0]))
return experienced[0][idx], experienced[1][idx]
def get_uncertainty(self, s, a):
"""Return uncertainty (inverse of visit count) for a state-action pair."""
if self.visit_counts[s, a] == 0:
return float('inf')
return 1.0 / self.visit_counts[s, a]Neural Network Models
For complex environments with continuous states, tabular models won’t work. We need function approximation—typically neural networks.
Instead of storing counts, we train a neural network to predict the next state and reward given the current state and action. The network generalizes across similar states, so we can make predictions even for states we haven’t visited before.
This is powerful but dangerous: the network might make confident predictions about states that are nothing like what it’s seen in training.
We train two networks (or a single network with two heads):
Transition model: predicting the next state
Reward model: predicting the reward
The loss functions are simply prediction errors:
For stochastic environments, we might predict a distribution over next states, often using a Gaussian: .
import torch
import torch.nn as nn
import torch.nn.functional as F
class NeuralModel(nn.Module):
"""Neural network environment model."""
def __init__(self, state_dim, action_dim, hidden_dim=256):
super().__init__()
# Shared encoder
self.encoder = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
# Transition prediction head (deterministic)
self.transition_head = nn.Linear(hidden_dim, state_dim)
# Reward prediction head
self.reward_head = nn.Linear(hidden_dim, 1)
def forward(self, state, action):
"""Predict next state and reward."""
# Concatenate state and action
x = torch.cat([state, action], dim=-1)
features = self.encoder(x)
next_state = self.transition_head(features)
reward = self.reward_head(features)
return next_state, reward.squeeze(-1)
def predict_next_state(self, state, action):
"""Predict just the next state."""
next_state, _ = self.forward(state, action)
return next_state
def train_model(model, replay_buffer, optimizer, batch_size=256, epochs=5):
"""Train the model on data from replay buffer."""
model.train()
for _ in range(epochs):
# Sample batch from replay buffer
states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
# Convert to tensors
states = torch.FloatTensor(states)
actions = torch.FloatTensor(actions)
rewards = torch.FloatTensor(rewards)
next_states = torch.FloatTensor(next_states)
# Forward pass
pred_next_states, pred_rewards = model(states, actions)
# Compute losses
transition_loss = F.mse_loss(pred_next_states, next_states)
reward_loss = F.mse_loss(pred_rewards, rewards)
loss = transition_loss + reward_loss
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()Model Uncertainty
A critical challenge with learned models is knowing when to trust them.
A model trained on data from some states may make confident but wrong predictions for states it hasn’t seen. This is especially dangerous because the agent might use these wrong predictions to plan, leading to catastrophic actions.
Think about predicting where a ball will land. If you’ve seen many similar throws, you can predict confidently. But if someone throws the ball in a completely new way, your prediction might be wildly wrong—and you might not even realize it.
Good model-based methods need to track their uncertainty: “I’m confident about this prediction” vs “I’ve never seen anything like this state.”
Ensemble Models for Uncertainty
A practical approach is to train an ensemble of models and use their disagreement as an uncertainty measure.
Train models on different bootstrap samples of the data. For a prediction:
- Mean prediction:
- Uncertainty:
High disagreement between models indicates high uncertainty—the models haven’t seen enough data from this region to agree.
class EnsembleModel:
"""Ensemble of neural network models for uncertainty estimation."""
def __init__(self, n_models, state_dim, action_dim, hidden_dim=256):
self.models = nn.ModuleList([
NeuralModel(state_dim, action_dim, hidden_dim)
for _ in range(n_models)
])
self.n_models = n_models
def predict_with_uncertainty(self, state, action):
"""Predict next state with uncertainty estimate."""
predictions = []
for model in self.models:
with torch.no_grad():
next_state, _ = model(state, action)
predictions.append(next_state)
# Stack predictions: (n_models, batch_size, state_dim)
predictions = torch.stack(predictions)
# Mean prediction
mean = predictions.mean(dim=0)
# Uncertainty: standard deviation across models
uncertainty = predictions.std(dim=0).mean(dim=-1) # Average over state dims
return mean, uncertainty
def should_trust_prediction(self, state, action, threshold=0.5):
"""Check if prediction is reliable based on ensemble disagreement."""
_, uncertainty = self.predict_with_uncertainty(state, action)
return uncertainty < thresholdWhen Models Fail
A single-step prediction error of 1% might seem small. But if you plan 20 steps ahead, that error compounds: after 20 steps, you might be predicting a state that’s nothing like reality.
This is why model-based methods must be careful about how far into the future they plan with a learned model.
Consider the weather forecast analogy:
- Tomorrow’s forecast: usually pretty accurate
- Next week’s forecast: somewhat reliable
- Next month’s forecast: basically a guess
The same applies to learned world models. Short-term predictions are reliable; long-term predictions degrade quickly.
If you train a policy purely on model rollouts, any systematic errors in the model will be exploited. The policy might learn to take actions that look good in the model but fail catastrophically in reality.
This is called model exploitation or model bias. Solutions include:
- Limiting planning horizon
- Mixing real and simulated experience (Dyna)
- Penalizing actions with high model uncertainty
- Using ensembles to avoid confident wrong predictions
Summary
Model learning is the foundation of model-based RL. By learning how the environment works, we gain the ability to simulate experiences and plan ahead—dramatically improving sample efficiency.
But learned models are imperfect. The art of model-based RL lies in using models effectively while being robust to their errors. In the next section, we’ll see how to actually use these models for planning.