Generalized Advantage Estimation (GAE)
GAE provides a principled way to blend different advantage estimates, giving you smooth control over the bias-variance tradeoff. It’s a key component of PPO and other modern algorithms.
The Bias-Variance Tradeoff
When estimating advantages, we face a fundamental choice:
TD (1-step):
- Uses one real reward, then trusts the value function
- Low variance (only one step of randomness)
- High bias (value function may be wrong)
Monte Carlo:
- Uses all actual rewards until episode end
- High variance (many steps of randomness accumulate)
- Low bias (actual returns, not estimates)
Neither extreme is ideal. GAE smoothly interpolates between them.
N-Step Returns: The Spectrum
Before GAE, let’s understand n-step returns:
Each n gives a different bias-variance point:
| N | Bias | Variance | Description |
|---|---|---|---|
| 1 | High | Low | TD(0) |
| 5 | Medium | Medium | Short-term |
| 20 | Low-ish | High-ish | Long-term |
| Inf | None | High | Monte Carlo |
Think of n as “how many steps to trust reality before falling back to estimates”:
- n=1: “Look one step ahead, then use your prediction”
- n=5: “Look five steps ahead, then use your prediction”
- n=infinity: “Look all the way to the end, don’t use predictions”
More reality = more accurate on average (less bias), but noisier (more variance).
GAE: Blending All Estimates
GAE computes an exponentially-weighted average of n-step advantage estimates:
where is the TD error.
The parameter controls the bias-variance tradeoff:
- : Pure TD (1-step)
- : Pure Monte Carlo
- : Smooth blend
Derivation: GAE can be derived as an exponentially-weighted sum of n-step advantages.
The n-step advantage estimate is:
GAE weights these by :
After simplification:
This elegant form shows that GAE just sums TD errors with exponentially decaying weights.
GAE asks: “How much should I trust what happened l steps in the future?”
- Immediate TD error (): Weight = 1
- Next TD error (): Weight =
- Two steps out (): Weight =
- And so on…
With and :
- l=0: weight 1.0
- l=5: weight 0.77
- l=10: weight 0.60
- l=20: weight 0.36
Recent TD errors matter most; distant ones are heavily discounted.
Computing GAE Efficiently
import torch
import numpy as np
def compute_gae(rewards, values, dones, next_value, gamma=0.99, lam=0.95):
"""
Compute Generalized Advantage Estimation.
Args:
rewards: List or array of rewards [r_0, r_1, ..., r_{T-1}]
values: List or array of value estimates [V(s_0), ..., V(s_{T-1})]
dones: List or array of done flags [d_0, ..., d_{T-1}]
next_value: V(s_T) - value of the state after the last step
gamma: Discount factor
lam: GAE lambda parameter (0 = TD, 1 = MC)
Returns:
advantages: GAE advantages for each timestep
returns: Advantage + value (targets for value function)
"""
T = len(rewards)
advantages = np.zeros(T, dtype=np.float32)
last_gae = 0
# Work backwards from the end
for t in reversed(range(T)):
if dones[t]:
# Episode ended - no bootstrap
next_val = 0
last_gae = 0
elif t == T - 1:
# Last step of rollout - bootstrap with next_value
next_val = next_value
else:
next_val = values[t + 1]
# TD error
delta = rewards[t] + gamma * next_val - values[t]
# GAE recursive formula: A_t = delta_t + gamma * lambda * A_{t+1}
last_gae = delta + gamma * lam * last_gae * (1 - dones[t])
advantages[t] = last_gae
returns = advantages + np.array(values)
return advantages, returns
def compute_gae_pytorch(rewards, values, dones, next_value, gamma=0.99, lam=0.95):
"""
PyTorch version of GAE computation.
"""
T = len(rewards)
advantages = torch.zeros(T)
last_gae = 0
# Append next_value to values for easy indexing
values_extended = torch.cat([values, torch.tensor([next_value])])
for t in reversed(range(T)):
if dones[t]:
next_val = 0
last_gae = 0
else:
next_val = values_extended[t + 1]
delta = rewards[t] + gamma * next_val - values[t]
last_gae = delta + gamma * lam * last_gae
advantages[t] = last_gae
returns = advantages + values
return advantages, returnsThe Lambda Parameter
Lambda () controls the bias-variance tradeoff:
(pure TD):
- Maximum bias (relies entirely on value function)
- Minimum variance (only one step of randomness)
(pure MC):
- Minimum bias (uses actual returns)
- Maximum variance (all future randomness)
(typical):
- Moderate bias
- Moderate variance
- Usually works well in practice
Lambda is like a “trust horizon”:
- : “I only trust the next step, then use my estimate”
- : “I trust actual experience all the way to the end”
- : “I trust actual experience but gradually rely more on estimates for distant future”
Higher lambda = more trust in actual experience = less bias but more variance.
Choosing lambda in practice:
| Situation | Lambda | Reasoning |
|---|---|---|
| Good value function | 0.90-0.95 | Can trust V more |
| Poor value function | 0.97-0.99 | Rely more on actual returns |
| Short episodes | 0.95 | Less variance anyway |
| Long episodes | 0.90 | Control variance |
| Noisy rewards | 0.90-0.95 | Reduce variance |
Default: Start with . This works well for most problems.
GAE in Practice
PPO with GAE
class PPOWithGAE:
"""PPO implementation using GAE for advantage estimation."""
def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99,
lam=0.95, clip_ratio=0.2):
self.gamma = gamma
self.lam = lam
self.clip_ratio = clip_ratio
self.network = ActorCriticNetwork(state_dim, n_actions)
self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)
def collect_rollout(self, env, n_steps):
"""Collect experience for training."""
states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []
state, _ = env.reset()
for _ in range(n_steps):
state_t = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
action_logits, value = self.network(state_t)
dist = torch.distributions.Categorical(logits=action_logits)
action = dist.sample()
log_prob = dist.log_prob(action)
next_state, reward, terminated, truncated, _ = env.step(action.item())
done = terminated or truncated
states.append(state_t.squeeze(0))
actions.append(action.item())
rewards.append(reward)
dones.append(done)
log_probs.append(log_prob)
values.append(value.squeeze())
state = next_state
if done:
state, _ = env.reset()
# Get value of final state for bootstrapping
with torch.no_grad():
_, next_value = self.network(
torch.tensor(state, dtype=torch.float32).unsqueeze(0)
)
next_value = next_value.item()
if dones[-1]:
next_value = 0
return {
'states': torch.stack(states),
'actions': torch.tensor(actions),
'rewards': rewards,
'dones': dones,
'log_probs': torch.stack(log_probs).squeeze(),
'values': torch.stack(values).squeeze(),
'next_value': next_value
}
def compute_advantages(self, rollout):
"""Compute GAE advantages from rollout."""
rewards = rollout['rewards']
values = rollout['values'].detach().numpy()
dones = rollout['dones']
next_value = rollout['next_value']
advantages, returns = compute_gae(
rewards, values, dones, next_value,
gamma=self.gamma, lam=self.lam
)
return torch.tensor(advantages), torch.tensor(returns)Vectorized GAE
For parallel environments, we need vectorized GAE computation:
def compute_gae_vectorized(rewards, values, dones, next_values,
gamma=0.99, lam=0.95):
"""
Compute GAE for multiple parallel environments.
Args:
rewards: Tensor [T, num_envs] of rewards
values: Tensor [T, num_envs] of value estimates
dones: Tensor [T, num_envs] of done flags
next_values: Tensor [num_envs] of bootstrap values
gamma: Discount factor
lam: GAE lambda
Returns:
advantages: Tensor [T, num_envs]
returns: Tensor [T, num_envs]
"""
T, num_envs = rewards.shape
advantages = torch.zeros_like(rewards)
last_gae = torch.zeros(num_envs)
# Append next_values
values_extended = torch.cat([values, next_values.unsqueeze(0)], dim=0)
for t in reversed(range(T)):
# Handle episode boundaries
not_done = 1.0 - dones[t]
# TD error for all environments
delta = rewards[t] + gamma * values_extended[t + 1] * not_done - values[t]
# GAE update
last_gae = delta + gamma * lam * last_gae * not_done
advantages[t] = last_gae
returns = advantages + values
return advantages, returnsComparison: A2C vs GAE
A2C (n-step returns):
- Uses fixed n-step returns
- Hard cutoff at n steps
- Simple but inflexible
GAE:
- Exponentially-weighted blend of all n-step returns
- Smooth decay of influence
- Principled bias-variance control
Example with n=5 A2C vs GAE():
| Step | A2C Weight | GAE Weight |
|---|---|---|
| 1 | 1.0 | 1.0 |
| 2 | 1.0 | 0.95 |
| 3 | 1.0 | 0.90 |
| 4 | 1.0 | 0.86 |
| 5 | 1.0 | 0.81 |
| 6 | 0.0 (bootstrap) | 0.77 |
| 7 | 0.0 | 0.74 |
A2C has a hard cutoff; GAE smoothly decays.
Summary
GAE is a powerful technique for advantage estimation:
- Blends TD and MC: Smooth interpolation via
- Bias-variance control: is TD, is MC
- Efficient computation: Simple backward recursion
- Widely used: Key component of PPO, TRPO, and other algorithms
Default settings that work well:
These values are used in the original PPO paper and work across many domains.
The combination of actor-critic with GAE provides stable, efficient learning. The final piece is controlling policy updates to prevent catastrophic forgetting - that’s where PPO comes in.