The PPO Algorithm
PPO (Proximal Policy Optimization) achieves TRPO-like stability with nothing more than clipping and standard gradient descent. This section explains exactly how the clipped objective works and presents the full algorithm.
The Core Idea: Clipping
PPO uses a clipped surrogate objective:
Where:
- is the probability ratio
- is the advantage estimate
- is the clip range (typically 0.2)
The clipped objective asks: “What’s the minimum of the clipped and unclipped objectives?”
This creates a pessimistic bound. If the policy wants to change too much, the clipped version gives a lower (worse) objective, and the minimum operation picks that.
The gradient for the clipped region is zero, so the policy stops updating once it hits the clip boundary.
Understanding the Clip
Let’s analyze the two cases:
Case 1: Positive advantage ()
The action was better than expected. We want to increase its probability, which means increasing .
- Unclipped: - increases with
- Clipped: - caps at
The minimum:
- When : uses unclipped (gradient exists)
- When : uses clipped (gradient is zero)
Result: We can increase up to , then we stop.
Case 2: Negative advantage ()
The action was worse than expected. We want to decrease its probability, which means decreasing .
- Unclipped: - becomes less negative (better) as decreases
- Clipped: - caps at
The minimum:
- When : uses unclipped (gradient exists)
- When : uses clipped (gradient is zero)
Result: We can decrease down to , then we stop.
Think of it as a one-way gate:
Good actions (): Can increase probability up to 20% more, then stop Bad actions (): Can decrease probability down to 20% less, then stop
The policy can’t run away in either direction. It’s bounded to stay close to the old policy.
Visualizing the Objective
Picture the objective as a function of the ratio :
For positive advantage:
Objective
|
------+------
/ | \______ <- Clipped at r = 1+epsilon
/ |
/ |
|
1-eps 1 1+eps rFor negative advantage:
Objective
|
_________|________
| \
| \
| \ <- Clipped at r = 1-epsilon
|
1-eps 1 1+eps rThe flat regions have zero gradient - the policy stops updating there.
The Full PPO Objective
The complete PPO loss combines three terms:
Where:
Policy loss (clipped surrogate):
Value function loss:
Entropy bonus:
Typical coefficients: ,
The PPO Algorithm
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple
class ActorCritic(nn.Module):
"""Shared actor-critic network for PPO."""
def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 64):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
)
self.actor = nn.Linear(hidden_dim, action_dim)
self.critic = nn.Linear(hidden_dim, 1)
def forward(self, state):
features = self.shared(state)
action_logits = self.actor(features)
value = self.critic(features).squeeze(-1)
return action_logits, value
def get_action(self, state):
"""Sample action and return action, log_prob, value."""
action_logits, value = self.forward(state)
dist = torch.distributions.Categorical(logits=action_logits)
action = dist.sample()
log_prob = dist.log_prob(action)
return action, log_prob, value
def evaluate_actions(self, states, actions):
"""Evaluate log probs and values for given state-action pairs."""
action_logits, values = self.forward(states)
dist = torch.distributions.Categorical(logits=action_logits)
log_probs = dist.log_prob(actions)
entropy = dist.entropy()
return log_probs, values, entropy
class PPO:
"""Proximal Policy Optimization implementation."""
def __init__(
self,
state_dim: int,
action_dim: int,
lr: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_epsilon: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
n_epochs: int = 10,
batch_size: int = 64,
):
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_epsilon = clip_epsilon
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.n_epochs = n_epochs
self.batch_size = batch_size
self.network = ActorCritic(state_dim, action_dim)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
def compute_gae(
self, rewards: List, values: List, dones: List, next_value: float
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute GAE advantages and returns."""
advantages = []
gae = 0
# Append next_value for easier indexing
values = values + [next_value]
for t in reversed(range(len(rewards))):
if dones[t]:
delta = rewards[t] - values[t]
gae = delta
else:
delta = rewards[t] + self.gamma * values[t + 1] - values[t]
gae = delta + self.gamma * self.gae_lambda * gae
advantages.insert(0, gae)
advantages = torch.tensor(advantages, dtype=torch.float32)
returns = advantages + torch.tensor(values[:-1], dtype=torch.float32)
return advantages, returns
def update(
self,
states: torch.Tensor,
actions: torch.Tensor,
old_log_probs: torch.Tensor,
returns: torch.Tensor,
advantages: torch.Tensor,
) -> Dict[str, float]:
"""Perform PPO update with multiple epochs."""
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Store losses for logging
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
n_updates = 0
for _ in range(self.n_epochs):
# Create mini-batches
indices = np.random.permutation(len(states))
for start in range(0, len(states), self.batch_size):
end = start + self.batch_size
batch_indices = indices[start:end]
batch_states = states[batch_indices]
batch_actions = actions[batch_indices]
batch_old_log_probs = old_log_probs[batch_indices]
batch_returns = returns[batch_indices]
batch_advantages = advantages[batch_indices]
# Get current policy outputs
new_log_probs, values, entropy = self.network.evaluate_actions(
batch_states, batch_actions
)
# Compute probability ratio
ratio = torch.exp(new_log_probs - batch_old_log_probs)
# Clipped surrogate objective
surr1 = ratio * batch_advantages
surr2 = (
torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
* batch_advantages
)
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = F.mse_loss(values, batch_returns)
# Entropy bonus
entropy_loss = -entropy.mean()
# Total loss
loss = (
policy_loss
+ self.value_coef * value_loss
+ self.entropy_coef * entropy_loss
)
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(
self.network.parameters(), self.max_grad_norm
)
self.optimizer.step()
# Accumulate for logging
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.mean().item()
n_updates += 1
return {
"policy_loss": total_policy_loss / n_updates,
"value_loss": total_value_loss / n_updates,
"entropy": total_entropy / n_updates,
}Multiple Epochs: Reusing Experience
A key feature of PPO is running multiple optimization epochs on the same batch of experience.
In vanilla policy gradient (REINFORCE), each batch is used once then discarded. PPO can reuse each batch for 3-10 epochs because the clipping prevents the policy from drifting too far.
This dramatically improves sample efficiency - we extract more learning from each environment interaction.
Too many epochs can cause problems:
- Overfitting to the batch: The policy becomes too specialized to this particular batch
- KL divergence grows: Eventually the old data becomes stale
- Clipping triggers everywhere: All ratios hit the clip boundary, gradients go to zero
If you see policy loss going to zero with no improvement, you might be using too many epochs.
Practical epoch guidelines:
- Start with 3-5 epochs for simple environments
- Use 10 epochs for more complex problems
- Monitor the ratio - if most samples are getting clipped, reduce epochs
- Some implementations add early stopping when KL gets too large
The Training Loop
import gymnasium as gym
def train_ppo(
env_name: str = "CartPole-v1",
total_timesteps: int = 100000,
rollout_length: int = 2048,
**kwargs
):
"""Complete PPO training loop."""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPO(state_dim, action_dim, **kwargs)
state, _ = env.reset()
episode_rewards = []
current_episode_reward = 0
for step in range(0, total_timesteps, rollout_length):
# Collect rollout
states, actions, rewards, dones = [], [], [], []
log_probs, values = [], []
for _ in range(rollout_length):
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action, log_prob, value = agent.network.get_action(state_tensor)
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
states.append(state_tensor.squeeze(0))
actions.append(action.item())
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob.item())
values.append(value.item())
current_episode_reward += reward
state = next_state
if done:
episode_rewards.append(current_episode_reward)
current_episode_reward = 0
state, _ = env.reset()
# Compute value of final state for bootstrapping
with torch.no_grad():
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
_, _, next_value = agent.network.get_action(state_tensor)
next_value = next_value.item() if not dones[-1] else 0
# Compute advantages and returns
advantages, returns = agent.compute_gae(rewards, values, dones, next_value)
# Convert to tensors
states_tensor = torch.stack(states)
actions_tensor = torch.tensor(actions, dtype=torch.long)
old_log_probs_tensor = torch.tensor(log_probs, dtype=torch.float32)
# Update policy
losses = agent.update(
states_tensor,
actions_tensor,
old_log_probs_tensor,
returns,
advantages,
)
# Logging
if len(episode_rewards) > 0:
recent_avg = np.mean(episode_rewards[-100:])
print(
f"Step {step + rollout_length}, "
f"Avg Reward: {recent_avg:.2f}, "
f"Policy Loss: {losses['policy_loss']:.4f}, "
f"Entropy: {losses['entropy']:.4f}"
)
env.close()
return agent, episode_rewards
if __name__ == "__main__":
agent, rewards = train_ppo()Comparing Clipped vs. Unclipped
Let’s trace through a concrete example:
Setup: , advantage , old log prob = -1
After 1 epoch: ratio = 1.1 (new policy slightly prefers this action)
- Unclipped objective:
- Clipped objective: (not clipped yet)
- Minimum: 2.2, gradient exists, keep optimizing
After 5 epochs: ratio = 1.25 (new policy strongly prefers this action)
- Unclipped objective:
- Clipped objective: (ratio clipped to 1.2)
- Minimum: 2.4, gradient is zero for the clipped term
At this point, further increasing the ratio doesn’t help - we’ve hit the trust region boundary.
Summary
The PPO algorithm combines:
- Clipped surrogate objective: Prevents destructively large updates
- Multiple epochs: Extracts more learning from each batch
- GAE advantages: Low-variance gradient estimates
- Standard gradient descent: No second-order methods needed
The result is an algorithm that’s simple to implement, stable to train, and effective across a wide range of problems.