PPO and Trust Region Methods
What You'll Learn
- Explain why unconstrained policy updates can be unstable
- Understand the trust region concept and its motivation
- Describe the TRPO algorithm at a high level
- Implement PPO (Proximal Policy Optimization) from scratch
- Explain the clipped objective and why it works
- Apply PPO to challenging environments like LunarLander
The Stability Problem
Actor-critic methods improved on REINFORCE by using a critic for lower variance. But there’s still a lurking danger: unstable updates.
Here’s what can go wrong:
- We collect experience with policy
- We compute a gradient and take a step:
- But if the step is too big, might be completely different
- The new policy could be terrible—performance crashes
- Now we’re collecting bad experience, making things worse
This is especially problematic because neural network policies are sensitive. A small change in parameters can cause large changes in behavior. One bad update can undo hours of training.
Policy collapse is a real problem in practice. You might see training progress nicely for hours, then suddenly the reward drops to near zero. Without safeguards, it can be very hard to recover.
Why Is This Worse for Policies?
In supervised learning, a bad update just means worse predictions for a while. You can recover on the next batch.
In RL, it’s worse:
- Actions affect data: A bad policy collects bad experience
- Feedback loop: Bad experience leads to worse updates
- No ground truth: Unlike supervised learning, we can’t just “look up” the right answer
This is why RL is fundamentally harder to stabilize than supervised learning. We need to be careful about how much the policy changes.
Trust Regions: The Core Idea
The solution is to limit how much the policy can change in each update. This is the trust region concept.
A trust region is like a safe zone around your current policy:
- Inside the zone: Updates are probably safe, won’t break things
- Outside the zone: Updates might be dangerous, could cause collapse
We want to take the best step we can while staying inside the trust region. This way, we improve as fast as possible while maintaining stability.
But what does “distance” mean for policies? We can’t just measure Euclidean distance in parameter space—that doesn’t correspond to how different the policies actually behave.
KL Divergence: Measuring Policy Distance
Mathematical Details
The KL divergence measures how different two probability distributions are:
Key properties:
- always
- only when
- Not symmetric:
For policies, we average over states:
This measures: “On average, how different do these policies act?”
TRPO: Trust Region Policy Optimization
TRPO (Trust Region Policy Optimization) formalizes the trust region idea.
Mathematical Details
TRPO Objective:
The ratio is called the importance sampling ratio (we’ll use for short).
- If : New policy is more likely to take this action
- If : New policy is less likely to take this action
- If : Same probability
The constraint ensures the new policy doesn’t change too much from the old one.
TRPO says: “Maximize improvement, but stay close to the current policy.”
The importance sampling ratio lets us evaluate the new policy using data from the old policy. This is crucial—we collected experience with , but we want to improve .
The catch: TRPO is hard to implement. It requires computing the KL divergence constraint exactly and using second-order optimization (the natural gradient). This makes it slow and complicated.
TRPO was groundbreaking for its theoretical guarantees, but it’s rarely used in practice today. The implementation complexity and computational cost led to simpler alternatives—most notably, PPO.
PPO: Proximal Policy Optimization
PPO achieves similar stability to TRPO but with a much simpler algorithm. It’s the most popular deep RL algorithm today.
PPO’s insight: instead of a hard KL constraint, use a clipped objective that naturally discourages large policy changes.
The objective is designed so that:
- When the policy improves, we get gradient signal
- When the policy changes too much, the gradient disappears
- No second-order optimization needed—just SGD!
The Clipped Objective
Mathematical Details
PPO Clipped Objective:
where:
- is the probability ratio
- is the advantage estimate
- is the clip parameter (typically 0.1 to 0.2)
The takes the lower of:
- The normal policy gradient:
- A clipped version:
Understanding the Clip
Let’s trace through what the clipping does:
Case 1: Advantage is positive (action was good)
- We want to increase probability of this action
- As increases from 1, objective increases (good!)
- But when , we clip: objective stops increasing
- No incentive to push further—the policy won’t change too much
Case 2: Advantage is negative (action was bad)
- We want to decrease probability of this action
- As decreases from 1, objective increases (good—we’re avoiding bad actions!)
- But when , we clip: objective stops increasing
- No incentive to push further—again, bounded change
The clipping acts like automatic training wheels. You can only change the policy by a factor of before the gradient signal vanishes.
Implementation
def compute_ppo_loss(log_probs_old, log_probs_new, advantages, epsilon=0.2):
"""
Compute PPO clipped policy loss.
Args:
log_probs_old: Log probabilities under old policy
log_probs_new: Log probabilities under new policy
advantages: Advantage estimates
epsilon: Clipping parameter
Returns:
PPO clipped loss (to minimize, so it's negative of objective)
"""
# Compute probability ratio
ratio = torch.exp(log_probs_new - log_probs_old)
# Clipped ratio
ratio_clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
# The two terms
term1 = ratio * advantages
term2 = ratio_clipped * advantages
# Take the minimum (pessimistic bound)
loss = -torch.min(term1, term2).mean()
return lossWhy Does Clipping Work?
The intuition behind clipping
Consider what happens during optimization:
- Early in training: ratio is close to 1, no clipping, normal gradients
- Policy starts changing: ratio moves away from 1
- Ratio hits clip boundary: gradient becomes zero for that sample
- Effect: Policy can’t change more than the clip allows per batch
The clip essentially says: “I trust the old policy’s data to tell me about changes up to factor . Beyond that, I shouldn’t extrapolate.”
This is a soft version of TRPO’s hard constraint. Instead of explicitly limiting KL divergence, we limit how much the probability of any action can change.
Generalized Advantage Estimation (GAE)
PPO typically uses GAE for advantage estimation, which we mentioned in the actor-critic chapter.
Mathematical Details
GAE combines n-step returns with exponential weighting:
where is the TD error.
This can be computed recursively:
Parameters:
- : Pure TD, (low variance, some bias)
- : Monte Carlo, (no bias, high variance)
Typical value: balances bias and variance well.
Implementation
def compute_gae(rewards, values, dones, gamma=0.99, gae_lambda=0.95):
"""
Compute Generalized Advantage Estimation.
Args:
rewards: Tensor of rewards [T]
values: Tensor of value estimates [T+1] (includes bootstrap)
dones: Tensor of done flags [T]
gamma: Discount factor
gae_lambda: GAE lambda parameter
Returns:
advantages: Tensor of advantage estimates [T]
returns: Tensor of return targets for value function [T]
"""
T = len(rewards)
advantages = torch.zeros(T)
# Compute GAE backwards
gae = 0
for t in reversed(range(T)):
if dones[t]:
next_value = 0
gae = 0 # Reset GAE at episode boundary
else:
next_value = values[t + 1]
delta = rewards[t] + gamma * next_value - values[t]
gae = delta + gamma * gae_lambda * gae
advantages[t] = gae
# Returns = advantages + values (for value function training)
returns = advantages + values[:-1]
return advantages, returnsComplete PPO Implementation
Let’s put it all together into a complete PPO agent.
Implementation
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import gymnasium as gym
class PPONetwork(nn.Module):
"""Actor-Critic network for PPO."""
def __init__(self, state_dim, n_actions, hidden_dim=64):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh()
)
self.actor = nn.Linear(hidden_dim, n_actions)
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.shared(state)
return self.actor(features), self.critic(features)
def get_action_and_value(self, state, action=None):
logits, value = self.forward(state)
dist = Categorical(logits=logits)
if action is None:
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value.squeeze(-1)
class PPO:
"""Proximal Policy Optimization algorithm."""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99,
gae_lambda=0.95, clip_epsilon=0.2, epochs=10,
value_coef=0.5, entropy_coef=0.01):
"""
Args:
state_dim: Dimension of state space
n_actions: Number of discrete actions
lr: Learning rate
gamma: Discount factor
gae_lambda: GAE lambda parameter
clip_epsilon: PPO clipping parameter
epochs: Number of optimization epochs per batch
value_coef: Coefficient for value loss
entropy_coef: Coefficient for entropy bonus
"""
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.epochs = epochs
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.network = PPONetwork(state_dim, n_actions)
self.optimizer = optim.Adam(self.network.parameters(), lr=lr, eps=1e-5)
def collect_rollout(self, env, n_steps):
"""Collect experience from environment."""
states = []
actions = []
rewards = []
dones = []
log_probs = []
values = []
state, _ = env.reset()
for _ in range(n_steps):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
with torch.no_grad():
action, log_prob, _, value = self.network.get_action_and_value(state_tensor)
states.append(state)
actions.append(action.item())
log_probs.append(log_prob.item())
values.append(value.item())
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
rewards.append(reward)
dones.append(done)
if done:
state, _ = env.reset()
else:
state = next_state
# Get value of final state for bootstrapping
with torch.no_grad():
final_state = torch.FloatTensor(state).unsqueeze(0)
_, _, _, final_value = self.network.get_action_and_value(final_state)
final_value = final_value.item()
# Convert to tensors
states = torch.FloatTensor(np.array(states))
actions = torch.LongTensor(actions)
rewards = torch.FloatTensor(rewards)
dones = torch.BoolTensor(dones)
log_probs = torch.FloatTensor(log_probs)
values = torch.FloatTensor(values + [final_value])
return states, actions, rewards, dones, log_probs, values
def compute_advantages(self, rewards, values, dones):
"""Compute GAE advantages."""
T = len(rewards)
advantages = torch.zeros(T)
gae = 0
for t in reversed(range(T)):
if dones[t]:
next_value = 0
gae = 0
else:
next_value = values[t + 1]
delta = rewards[t] + self.gamma * next_value - values[t]
gae = delta + self.gamma * self.gae_lambda * gae
advantages[t] = gae
returns = advantages + values[:-1]
return advantages, returns
def update(self, states, actions, old_log_probs, advantages, returns):
"""Perform PPO update for multiple epochs."""
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
for _ in range(self.epochs):
# Get current policy outputs
_, new_log_probs, entropy, new_values = self.network.get_action_and_value(
states, actions
)
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Clipped policy loss
term1 = ratio * advantages
term2 = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * advantages
policy_loss = -torch.min(term1, term2).mean()
# Value loss
value_loss = nn.functional.mse_loss(new_values, returns)
# Entropy bonus
entropy_loss = -entropy.mean()
# Combined loss
loss = policy_loss + self.value_coef * value_loss + self.entropy_coef * entropy_loss
# Update
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), max_norm=0.5)
self.optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += -entropy_loss.item()
return {
'policy_loss': total_policy_loss / self.epochs,
'value_loss': total_value_loss / self.epochs,
'entropy': total_entropy / self.epochs
}
def train_ppo(env_name='CartPole-v1', total_steps=100000, n_steps=2048):
"""Train PPO agent."""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
agent = PPO(state_dim, n_actions)
# For tracking
episode_rewards = []
current_episode_reward = 0
steps_done = 0
while steps_done < total_steps:
# Collect rollout
states, actions, rewards, dones, log_probs, values = agent.collect_rollout(env, n_steps)
steps_done += n_steps
# Track episode rewards
for r, d in zip(rewards, dones):
current_episode_reward += r.item()
if d:
episode_rewards.append(current_episode_reward)
current_episode_reward = 0
# Compute advantages
advantages, returns = agent.compute_advantages(rewards, values, dones)
# Update policy
metrics = agent.update(states, actions, log_probs, advantages, returns)
# Log progress
if len(episode_rewards) > 0 and steps_done % 10000 == 0:
avg_reward = np.mean(episode_rewards[-100:])
print(f"Steps: {steps_done}, Avg Reward: {avg_reward:.1f}")
return agent, episode_rewards
# Train the agent
agent, rewards = train_ppo('CartPole-v1', total_steps=50000)Training on LunarLander
Implementation
# Train on a more challenging environment
agent, rewards = train_ppo('LunarLander-v2', total_steps=500000, n_steps=2048)
# Plot learning curve
import matplotlib.pyplot as plt
window = 50
smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
plt.figure(figsize=(10, 6))
plt.plot(smoothed)
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('PPO on LunarLander-v2')
plt.axhline(y=200, color='r', linestyle='--', label='Solved threshold')
plt.legend()
plt.show()PPO hyperparameters that typically work well:
- clip_epsilon: 0.1 to 0.2
- epochs: 3 to 10 per batch
- gae_lambda: 0.95
- learning rate: 3e-4 with Adam
- batch size: 2048+ steps for stability
Why PPO Is Popular
PPO has become the default choice for many RL applications because:
- Simple to implement: Just a clipped loss function—no Hessians, no conjugate gradients
- Stable: The clipping naturally prevents catastrophic updates
- Sample efficient: Multiple epochs on the same data (like TRPO)
- Good performance: Competitive with or better than more complex methods
- Robust to hyperparameters: Works well with default settings on many tasks
PPO is used in:
- OpenAI Five (Dota 2)
- ChatGPT training (RLHF)
- Robotics research
- Most RL benchmarks
PPO isn’t always the best choice. For very sample-efficient learning, off-policy methods like SAC might be better. For very simple tasks, simpler methods work fine. But when in doubt, PPO is a solid default.
Summary
Key Takeaways
- Trust regions limit how much the policy can change per update, preventing instability
- TRPO uses a hard KL divergence constraint but is complex to implement
- PPO achieves similar stability with a simple clipped objective
- The probability ratio measures policy change
- Clipping at removes gradient signal when the policy changes too much
- GAE (Generalized Advantage Estimation) balances bias and variance with parameter
- PPO allows multiple epochs on the same data, improving sample efficiency
- PPO is simple, stable, and effective—the go-to algorithm for many applications
You now have a complete understanding of policy gradient methods, from the basic REINFORCE algorithm through the state-of-the-art PPO. These tools power many of the most impressive RL applications today.
In the next chapter, we’ll see these methods applied to real-world challenges: robotics, language models, and more.
Exercises
Conceptual Questions
-
Explain why large policy updates can be harmful. Give a concrete example of what could go wrong.
-
What does the probability ratio measure? What does it mean when ? When ?
-
Why does clipping prevent large policy changes? Trace through what happens to the gradient when the ratio exceeds the clip bounds.
-
Compare TRPO and PPO. What does each do to constrain policy updates? Why is PPO more popular?
Coding Challenges
-
Implement PPO from scratch and train on CartPole. Compare learning curves to A2C from the previous chapter.
-
Experiment with clip values. Try , , , . What happens at extremes? Is there a “best” value?
-
Implement GAE and compare with simple TD advantage estimation. How does affect learning?
Exploration
-
Multiple epochs per batch. PPO typically uses 3-10 epochs of gradient descent on the same batch. Experiment with 1, 5, 10, and 20 epochs. What do you observe? Why might too many epochs be harmful?
-
KL divergence monitoring. Add code to track the average KL divergence between old and new policies during training. How does it correlate with the clip ratio? What happens when KL gets too large?