PPO in Practice
PPO is robust, but getting the best results requires understanding its hyperparameters and avoiding common pitfalls. This section provides practical guidance for implementing and tuning PPO.
Key Hyperparameters
Here’s a quick reference for PPO hyperparameters:
| Parameter | Typical Range | Default | Effect |
|---|---|---|---|
| Clip epsilon | 0.1 - 0.3 | 0.2 | Larger = more aggressive updates |
| Learning rate | 1e-4 to 3e-3 | 3e-4 | Policy and value function learning |
| Epochs per batch | 3 - 10 | 4 | More = better sample efficiency, risk of overfitting |
| Mini-batch size | 32 - 2048 | 64 | Larger = more stable gradients |
| Rollout length | 128 - 4096 | 2048 | Steps collected before each update |
| GAE lambda | 0.9 - 0.99 | 0.95 | Bias-variance tradeoff in advantages |
| Gamma | 0.99 - 0.999 | 0.99 | Discount factor |
| Value coefficient | 0.25 - 1.0 | 0.5 | Weight of value loss |
| Entropy coefficient | 0.0 - 0.05 | 0.01 | Exploration bonus |
| Max grad norm | 0.5 - 1.0 | 0.5 | Gradient clipping |
Understanding Each Hyperparameter
Clip Epsilon
The clip range controls how much the policy can change per update.
Smaller epsilon (0.1):
- More conservative updates
- Higher stability
- Slower learning
- Good for sensitive tasks
Larger epsilon (0.3):
- More aggressive updates
- Faster learning
- Higher risk of instability
- Good for simple tasks
Rule of thumb: Start with 0.2. If training is unstable, decrease to 0.1. If learning is too slow, try 0.3.
# Adaptive clipping (optional)
def adaptive_clip_epsilon(episode, initial_eps=0.2, final_eps=0.1, decay_episodes=1000):
"""Linearly decay clip epsilon over training."""
progress = min(episode / decay_episodes, 1.0)
return initial_eps - progress * (initial_eps - final_eps)
# Usage
epsilon = adaptive_clip_epsilon(current_episode)Learning Rate
Learning rate in RL is typically lower than in supervised learning because:
- Gradients are noisier (estimated from samples)
- The objective changes as the policy updates
- Stability matters more than speed
Common schedule: Linear decay from initial value to zero over training.
class LinearSchedule:
"""Linear learning rate decay."""
def __init__(self, initial_lr, final_lr, total_steps):
self.initial_lr = initial_lr
self.final_lr = final_lr
self.total_steps = total_steps
def get_lr(self, step):
progress = min(step / self.total_steps, 1.0)
return self.initial_lr + progress * (self.final_lr - self.initial_lr)
# Usage with PyTorch optimizer
def update_learning_rate(optimizer, new_lr):
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
# In training loop
schedule = LinearSchedule(3e-4, 0, total_timesteps)
current_lr = schedule.get_lr(current_step)
update_learning_rate(optimizer, current_lr)Number of Epochs
More epochs extract more learning from each batch, but too many causes problems:
Too few epochs (1-2):
- Underutilize collected data
- Lower sample efficiency
- Safe but slow
Too many epochs (15+):
- Overfitting to batch
- Policy drifts too far from data distribution
- Clipping triggers everywhere
- Training can stall
Signs you need fewer epochs:
- Clip fraction approaches 1.0 (most samples are clipped)
- Policy loss goes to zero with no improvement
- KL divergence grows very large
def ppo_update_with_early_stopping(
network, optimizer, states, actions, old_log_probs, returns, advantages,
clip_eps=0.2, target_kl=0.01, max_epochs=10
):
"""PPO update with early stopping based on KL divergence."""
for epoch in range(max_epochs):
# ... standard PPO update code ...
# Compute approximate KL divergence
with torch.no_grad():
new_log_probs, _, _ = network.evaluate_actions(states, actions)
approx_kl = (old_log_probs - new_log_probs).mean().item()
if approx_kl > 1.5 * target_kl:
print(f"Early stopping at epoch {epoch} due to KL = {approx_kl:.4f}")
break
return epoch + 1 # Number of epochs actually runAdvantage Normalization
Always normalize advantages! This is one of the most impactful implementation details.
Without normalization:
- Advantages can have large magnitude differences across batches
- Gradient magnitudes vary wildly
- Learning rate that works for one batch may be wrong for another
With normalization:
- Consistent gradient magnitudes
- More stable learning
- Less sensitivity to reward scaling
def normalize_advantages(advantages):
"""Normalize advantages to mean 0, std 1."""
return (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Always apply before computing policy loss
advantages = normalize_advantages(advantages)
policy_loss = -(ratio * advantages).mean()Implementation Tricks
Orthogonal Initialization
Orthogonal initialization with small gain for policy output helps stability:
def init_weights(module, gain=1.0):
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=gain)
if module.bias is not None:
module.bias.data.fill_(0.0)
# Apply to network
network.apply(lambda m: init_weights(m, gain=np.sqrt(2)))
# Policy head with smaller gain
init_weights(network.actor, gain=0.01)
# Value head
init_weights(network.critic, gain=1.0)The small gain (0.01) for the policy head ensures the initial policy is close to uniform, which helps exploration.
Value Function Clipping (Optional)
Some implementations clip the value function update similarly to the policy:
def clipped_value_loss(values, old_values, returns, clip_eps=0.2):
"""Clipped value function loss."""
# Unclipped loss
value_loss_unclipped = (values - returns) ** 2
# Clipped value
values_clipped = old_values + torch.clamp(
values - old_values, -clip_eps, clip_eps
)
value_loss_clipped = (values_clipped - returns) ** 2
# Take maximum (pessimistic)
return 0.5 * torch.max(value_loss_unclipped, value_loss_clipped).mean()This prevents the value function from changing too drastically. However, the benefit is debated - many implementations skip this.
Observation Normalization
For continuous observation spaces, normalize observations with running statistics:
class RunningMeanStd:
"""Tracks running mean and standard deviation."""
def __init__(self, shape):
self.mean = np.zeros(shape, dtype=np.float64)
self.var = np.ones(shape, dtype=np.float64)
self.count = 1e-4
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = x.shape[0]
delta = batch_mean - self.mean
total_count = self.count + batch_count
self.mean = self.mean + delta * batch_count / total_count
m_a = self.var * self.count
m_b = batch_var * batch_count
M2 = m_a + m_b + delta ** 2 * self.count * batch_count / total_count
self.var = M2 / total_count
self.count = total_count
def normalize(self, x):
return (x - self.mean) / (np.sqrt(self.var) + 1e-8)
# Usage
obs_normalizer = RunningMeanStd(state_dim)
obs_normalizer.update(batch_of_observations)
normalized_obs = obs_normalizer.normalize(observation)This stabilizes learning when observations have different scales across dimensions.
Reward Scaling
Reward scaling affects learning dynamics. Common approaches:
Reward normalization (divide by running std):
reward_normalizer = RunningMeanStd(shape=())
reward_normalizer.update(rewards)
scaled_rewards = rewards / (np.sqrt(reward_normalizer.var) + 1e-8)Reward clipping (prevent extreme values):
clipped_rewards = np.clip(rewards, -10, 10)Neither is always correct. If rewards are naturally bounded (e.g., 0-1), don’t scale. If they can be arbitrarily large, scaling helps.
Common Pitfalls
Pitfall 1: Forgetting to detach old values
When computing advantages, don’t backprop through the old value estimates:
# Wrong - will backprop through old values
advantages = returns - values
# Correct - detach old values
with torch.no_grad():
old_values = network.get_value(states)
advantages = returns - old_values # or: returns - values.detach()Pitfall 2: Computing ratio incorrectly
The ratio must use the same action:
# Wrong - samples new action
new_action = policy.sample(state)
new_log_prob = policy.log_prob(state, new_action)
ratio = torch.exp(new_log_prob - old_log_prob) # Comparing different actions!
# Correct - evaluates the same action
new_log_prob = policy.log_prob(state, old_action)
ratio = torch.exp(new_log_prob - old_log_prob)Pitfall 3: Wrong loss sign
PPO maximizes the objective, but PyTorch minimizes loss:
# Wrong - minimizes surrogate (bad!)
loss = torch.min(surr1, surr2).mean()
# Correct - minimizes negative surrogate (= maximizes surrogate)
loss = -torch.min(surr1, surr2).mean()Debugging PPO
Metrics to monitor:
- Episode reward: Should increase over time
- Policy loss: Should fluctuate around zero (not monotonically decrease!)
- Value loss: Should decrease over time
- Entropy: Should slowly decrease as policy becomes more deterministic
- Clip fraction: Fraction of samples where ratio is clipped (should be 0.1-0.3)
- Approx KL: KL between old and new policy (should be around target_kl)
Signs of problems:
- Clip fraction near 1.0: Too many epochs or clip epsilon too small
- Entropy dropping to 0: Policy collapsed, increase entropy coefficient
- Value loss increasing: Learning rate too high or returns computed wrong
- Reward stuck: Exploration problem, increase entropy coefficient
When to Use PPO
Good use cases for PPO:
- General-purpose algorithm when you don’t know what will work
- Continuous action spaces (robotics, control)
- Discrete action spaces (games, decision-making)
- Large-scale training with parallelism
- When you need stable, predictable training
Consider alternatives when:
- Maximum sample efficiency needed (try SAC, TD3 for continuous)
- Discrete actions with large replay possible (try DQN variants)
- Model of environment available (try model-based methods)
- Very sparse rewards (try curiosity-driven methods)
Reference Implementation
Here’s a complete, well-tested PPO implementation with all the tricks:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Dict
class PPOAgent:
"""Complete PPO implementation with best practices."""
def __init__(
self,
state_dim: int,
action_dim: int,
continuous: bool = False,
lr: float = 3e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_eps: float = 0.2,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: float = 0.5,
epochs: int = 10,
mini_batch_size: int = 64,
target_kl: float = 0.03,
):
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_eps = clip_eps
self.vf_coef = vf_coef
self.ent_coef = ent_coef
self.max_grad_norm = max_grad_norm
self.epochs = epochs
self.mini_batch_size = mini_batch_size
self.target_kl = target_kl
self.continuous = continuous
# Network
self.network = self._build_network(state_dim, action_dim, continuous)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr, eps=1e-5)
# Running stats for observation normalization
self.obs_rms = RunningMeanStd(shape=(state_dim,))
def _build_network(self, state_dim, action_dim, continuous):
"""Build actor-critic network with orthogonal init."""
network = ActorCriticNet(state_dim, action_dim, continuous)
# Orthogonal initialization
for module in network.modules():
if isinstance(module, nn.Linear):
nn.init.orthogonal_(module.weight, gain=np.sqrt(2))
module.bias.data.zero_()
# Small init for policy output
nn.init.orthogonal_(network.actor[-1].weight, gain=0.01)
return network
def get_action(self, state: np.ndarray) -> Tuple[np.ndarray, float, float]:
"""Get action, log_prob, and value for a state."""
# Normalize observation
state = self.obs_rms.normalize(state)
state_t = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
if self.continuous:
mean, std, value = self.network(state_t)
dist = torch.distributions.Normal(mean, std)
action = dist.sample()
log_prob = dist.log_prob(action).sum(-1)
action = action.squeeze(0).numpy()
else:
logits, value = self.network(state_t)
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
action = action.item()
return action, log_prob.item(), value.item()
def update(self, rollout: Dict) -> Dict[str, float]:
"""Perform PPO update on collected rollout."""
# Unpack rollout
states = torch.tensor(rollout['states'], dtype=torch.float32)
actions = torch.tensor(rollout['actions'])
old_log_probs = torch.tensor(rollout['log_probs'], dtype=torch.float32)
returns = torch.tensor(rollout['returns'], dtype=torch.float32)
advantages = torch.tensor(rollout['advantages'], dtype=torch.float32)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Update observation statistics
self.obs_rms.update(rollout['states'])
# Training stats
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
total_kl = 0
n_updates = 0
epochs_run = 0
for epoch in range(self.epochs):
# Shuffle and create mini-batches
indices = np.random.permutation(len(states))
for start in range(0, len(states), self.mini_batch_size):
end = start + self.mini_batch_size
mb_idx = indices[start:end]
mb_states = states[mb_idx]
mb_actions = actions[mb_idx]
mb_old_log_probs = old_log_probs[mb_idx]
mb_returns = returns[mb_idx]
mb_advantages = advantages[mb_idx]
# Forward pass
if self.continuous:
mean, std, values = self.network(mb_states)
dist = torch.distributions.Normal(mean, std)
new_log_probs = dist.log_prob(mb_actions).sum(-1)
entropy = dist.entropy().sum(-1).mean()
else:
logits, values = self.network(mb_states)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(mb_actions)
entropy = dist.entropy().mean()
# Compute ratio
ratio = torch.exp(new_log_probs - mb_old_log_probs)
# Clipped surrogate objective
surr1 = ratio * mb_advantages
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * mb_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# Value loss
value_loss = F.mse_loss(values.squeeze(), mb_returns)
# Total loss
loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * entropy
# Gradient step
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm)
self.optimizer.step()
# Stats
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.item()
n_updates += 1
# Check KL for early stopping
with torch.no_grad():
if self.continuous:
mean, std, _ = self.network(states)
dist = torch.distributions.Normal(mean, std)
new_log_probs = dist.log_prob(actions).sum(-1)
else:
logits, _ = self.network(states)
dist = torch.distributions.Categorical(logits=logits)
new_log_probs = dist.log_prob(actions)
kl = (old_log_probs - new_log_probs).mean().item()
total_kl = kl
epochs_run = epoch + 1
if kl > 1.5 * self.target_kl:
break
return {
'policy_loss': total_policy_loss / n_updates,
'value_loss': total_value_loss / n_updates,
'entropy': total_entropy / n_updates,
'kl': total_kl,
'epochs': epochs_run,
}Summary
PPO in practice requires attention to:
- Hyperparameters: Start with defaults, tune if needed
- Normalization: Advantages must be normalized; observations often should be
- Implementation details: Correct detaching, loss signs, initialization
- Monitoring: Track multiple metrics to catch problems early
With these best practices, PPO provides stable, reliable training across a wide range of problems.