Policy Gradient Methods • Part 4 of 4
📝Draft

Generalized Advantage Estimation

Balancing bias and variance in advantage estimation

Generalized Advantage Estimation (GAE)

GAE provides a principled way to blend different advantage estimates, giving you smooth control over the bias-variance tradeoff. It’s a key component of PPO and other modern algorithms.

The Bias-Variance Tradeoff

When estimating advantages, we face a fundamental choice:

TD (1-step): A^t=rt+γV(st+1)V(st)\hat{A}_t = r_t + \gamma V(s_{t+1}) - V(s_t)

  • Uses one real reward, then trusts the value function
  • Low variance (only one step of randomness)
  • High bias (value function may be wrong)

Monte Carlo: A^t=GtV(st)\hat{A}_t = G_t - V(s_t)

  • Uses all actual rewards until episode end
  • High variance (many steps of randomness accumulate)
  • Low bias (actual returns, not estimates)

Neither extreme is ideal. GAE smoothly interpolates between them.

N-Step Returns: The Spectrum

Mathematical Details

Before GAE, let’s understand n-step returns:

Gt(1)=rt+γV(st+1)G_t^{(1)} = r_t + \gamma V(s_{t+1}) Gt(2)=rt+γrt+1+γ2V(st+2)G_t^{(2)} = r_t + \gamma r_{t+1} + \gamma^2 V(s_{t+2}) Gt(n)=k=0n1γkrt+k+γnV(st+n)G_t^{(n)} = \sum_{k=0}^{n-1} \gamma^k r_{t+k} + \gamma^n V(s_{t+n}) Gt()=k=0γkrt+k=Gt (Monte Carlo)G_t^{(\infty)} = \sum_{k=0}^{\infty} \gamma^k r_{t+k} = G_t \text{ (Monte Carlo)}

Each n gives a different bias-variance point:

NBiasVarianceDescription
1HighLowTD(0)
5MediumMediumShort-term
20Low-ishHigh-ishLong-term
InfNoneHighMonte Carlo

Think of n as “how many steps to trust reality before falling back to estimates”:

  • n=1: “Look one step ahead, then use your prediction”
  • n=5: “Look five steps ahead, then use your prediction”
  • n=infinity: “Look all the way to the end, don’t use predictions”

More reality = more accurate on average (less bias), but noisier (more variance).

GAE: Blending All Estimates

📖Generalized Advantage Estimation (GAE)

GAE computes an exponentially-weighted average of n-step advantage estimates:

A^tGAE(γ,λ)=l=0(γλ)lδt+l\hat{A}_t^{GAE(\gamma, \lambda)} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}

where δt=rt+γV(st+1)V(st)\delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) is the TD error.

The parameter λ[0,1]\lambda \in [0, 1] controls the bias-variance tradeoff:

  • λ=0\lambda = 0: Pure TD (1-step)
  • λ=1\lambda = 1: Pure Monte Carlo
  • λ(0,1)\lambda \in (0, 1): Smooth blend
Mathematical Details

Derivation: GAE can be derived as an exponentially-weighted sum of n-step advantages.

The n-step advantage estimate is: A^t(n)=l=0n1γlδt+l\hat{A}_t^{(n)} = \sum_{l=0}^{n-1} \gamma^l \delta_{t+l}

GAE weights these by (1λ)λn1(1-\lambda)\lambda^{n-1}: A^tGAE=(1λ)n=1λn1A^t(n)\hat{A}_t^{GAE} = (1-\lambda) \sum_{n=1}^{\infty} \lambda^{n-1} \hat{A}_t^{(n)}

After simplification: A^tGAE=l=0(γλ)lδt+l\hat{A}_t^{GAE} = \sum_{l=0}^{\infty} (\gamma \lambda)^l \delta_{t+l}

This elegant form shows that GAE just sums TD errors with exponentially decaying weights.

GAE asks: “How much should I trust what happened l steps in the future?”

  • Immediate TD error (l=0l=0): Weight = 1
  • Next TD error (l=1l=1): Weight = γλ\gamma \lambda
  • Two steps out (l=2l=2): Weight = (γλ)2(\gamma \lambda)^2
  • And so on…

With γ=0.99\gamma = 0.99 and λ=0.95\lambda = 0.95:

  • l=0: weight 1.0
  • l=5: weight 0.77
  • l=10: weight 0.60
  • l=20: weight 0.36

Recent TD errors matter most; distant ones are heavily discounted.

Computing GAE Efficiently

</>Implementation
import torch
import numpy as np


def compute_gae(rewards, values, dones, next_value, gamma=0.99, lam=0.95):
    """
    Compute Generalized Advantage Estimation.

    Args:
        rewards: List or array of rewards [r_0, r_1, ..., r_{T-1}]
        values: List or array of value estimates [V(s_0), ..., V(s_{T-1})]
        dones: List or array of done flags [d_0, ..., d_{T-1}]
        next_value: V(s_T) - value of the state after the last step
        gamma: Discount factor
        lam: GAE lambda parameter (0 = TD, 1 = MC)

    Returns:
        advantages: GAE advantages for each timestep
        returns: Advantage + value (targets for value function)
    """
    T = len(rewards)
    advantages = np.zeros(T, dtype=np.float32)
    last_gae = 0

    # Work backwards from the end
    for t in reversed(range(T)):
        if dones[t]:
            # Episode ended - no bootstrap
            next_val = 0
            last_gae = 0
        elif t == T - 1:
            # Last step of rollout - bootstrap with next_value
            next_val = next_value
        else:
            next_val = values[t + 1]

        # TD error
        delta = rewards[t] + gamma * next_val - values[t]

        # GAE recursive formula: A_t = delta_t + gamma * lambda * A_{t+1}
        last_gae = delta + gamma * lam * last_gae * (1 - dones[t])
        advantages[t] = last_gae

    returns = advantages + np.array(values)
    return advantages, returns


def compute_gae_pytorch(rewards, values, dones, next_value, gamma=0.99, lam=0.95):
    """
    PyTorch version of GAE computation.
    """
    T = len(rewards)
    advantages = torch.zeros(T)
    last_gae = 0

    # Append next_value to values for easy indexing
    values_extended = torch.cat([values, torch.tensor([next_value])])

    for t in reversed(range(T)):
        if dones[t]:
            next_val = 0
            last_gae = 0
        else:
            next_val = values_extended[t + 1]

        delta = rewards[t] + gamma * next_val - values[t]
        last_gae = delta + gamma * lam * last_gae
        advantages[t] = last_gae

    returns = advantages + values
    return advantages, returns

The Lambda Parameter

Mathematical Details

Lambda (λ\lambda) controls the bias-variance tradeoff:

λ=0\lambda = 0 (pure TD): A^t=δt=rt+γV(st+1)V(st)\hat{A}_t = \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)

  • Maximum bias (relies entirely on value function)
  • Minimum variance (only one step of randomness)

λ=1\lambda = 1 (pure MC): A^t=l=0γlδt+l=GtV(st)\hat{A}_t = \sum_{l=0}^{\infty} \gamma^l \delta_{t+l} = G_t - V(s_t)

  • Minimum bias (uses actual returns)
  • Maximum variance (all future randomness)

λ=0.95\lambda = 0.95 (typical):

  • Moderate bias
  • Moderate variance
  • Usually works well in practice

Lambda is like a “trust horizon”:

  • λ=0\lambda = 0: “I only trust the next step, then use my estimate”
  • λ=1\lambda = 1: “I trust actual experience all the way to the end”
  • λ=0.95\lambda = 0.95: “I trust actual experience but gradually rely more on estimates for distant future”

Higher lambda = more trust in actual experience = less bias but more variance.

💡Tip

Choosing lambda in practice:

SituationLambdaReasoning
Good value function0.90-0.95Can trust V more
Poor value function0.97-0.99Rely more on actual returns
Short episodes0.95Less variance anyway
Long episodes0.90Control variance
Noisy rewards0.90-0.95Reduce variance

Default: Start with λ=0.95\lambda = 0.95. This works well for most problems.

GAE in Practice

PPO with GAE

</>Implementation
class PPOWithGAE:
    """PPO implementation using GAE for advantage estimation."""

    def __init__(self, state_dim, n_actions, lr=3e-4, gamma=0.99,
                 lam=0.95, clip_ratio=0.2):
        self.gamma = gamma
        self.lam = lam
        self.clip_ratio = clip_ratio

        self.network = ActorCriticNetwork(state_dim, n_actions)
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=lr)

    def collect_rollout(self, env, n_steps):
        """Collect experience for training."""
        states, actions, rewards, dones, log_probs, values = [], [], [], [], [], []

        state, _ = env.reset()

        for _ in range(n_steps):
            state_t = torch.tensor(state, dtype=torch.float32).unsqueeze(0)

            with torch.no_grad():
                action_logits, value = self.network(state_t)
                dist = torch.distributions.Categorical(logits=action_logits)
                action = dist.sample()
                log_prob = dist.log_prob(action)

            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            states.append(state_t.squeeze(0))
            actions.append(action.item())
            rewards.append(reward)
            dones.append(done)
            log_probs.append(log_prob)
            values.append(value.squeeze())

            state = next_state
            if done:
                state, _ = env.reset()

        # Get value of final state for bootstrapping
        with torch.no_grad():
            _, next_value = self.network(
                torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            )
            next_value = next_value.item()
            if dones[-1]:
                next_value = 0

        return {
            'states': torch.stack(states),
            'actions': torch.tensor(actions),
            'rewards': rewards,
            'dones': dones,
            'log_probs': torch.stack(log_probs).squeeze(),
            'values': torch.stack(values).squeeze(),
            'next_value': next_value
        }

    def compute_advantages(self, rollout):
        """Compute GAE advantages from rollout."""
        rewards = rollout['rewards']
        values = rollout['values'].detach().numpy()
        dones = rollout['dones']
        next_value = rollout['next_value']

        advantages, returns = compute_gae(
            rewards, values, dones, next_value,
            gamma=self.gamma, lam=self.lam
        )

        return torch.tensor(advantages), torch.tensor(returns)

Vectorized GAE

For parallel environments, we need vectorized GAE computation:

</>Implementation
def compute_gae_vectorized(rewards, values, dones, next_values,
                           gamma=0.99, lam=0.95):
    """
    Compute GAE for multiple parallel environments.

    Args:
        rewards: Tensor [T, num_envs] of rewards
        values: Tensor [T, num_envs] of value estimates
        dones: Tensor [T, num_envs] of done flags
        next_values: Tensor [num_envs] of bootstrap values
        gamma: Discount factor
        lam: GAE lambda

    Returns:
        advantages: Tensor [T, num_envs]
        returns: Tensor [T, num_envs]
    """
    T, num_envs = rewards.shape
    advantages = torch.zeros_like(rewards)
    last_gae = torch.zeros(num_envs)

    # Append next_values
    values_extended = torch.cat([values, next_values.unsqueeze(0)], dim=0)

    for t in reversed(range(T)):
        # Handle episode boundaries
        not_done = 1.0 - dones[t]

        # TD error for all environments
        delta = rewards[t] + gamma * values_extended[t + 1] * not_done - values[t]

        # GAE update
        last_gae = delta + gamma * lam * last_gae * not_done
        advantages[t] = last_gae

    returns = advantages + values
    return advantages, returns

Comparison: A2C vs GAE

A2C (n-step returns):

  • Uses fixed n-step returns
  • Hard cutoff at n steps
  • Simple but inflexible

GAE:

  • Exponentially-weighted blend of all n-step returns
  • Smooth decay of influence
  • Principled bias-variance control

Example with n=5 A2C vs GAE(λ=0.95\lambda=0.95):

StepA2C WeightGAE Weight
11.01.0
21.00.95
31.00.90
41.00.86
51.00.81
60.0 (bootstrap)0.77
70.00.74

A2C has a hard cutoff; GAE smoothly decays.

Summary

GAE is a powerful technique for advantage estimation:

  • Blends TD and MC: Smooth interpolation via λ\lambda
  • Bias-variance control: λ=0\lambda=0 is TD, λ=1\lambda=1 is MC
  • Efficient computation: Simple backward recursion
  • Widely used: Key component of PPO, TRPO, and other algorithms
💡Tip

Default settings that work well:

  • γ=0.99\gamma = 0.99
  • λ=0.95\lambda = 0.95

These values are used in the original PPO paper and work across many domains.

The combination of actor-critic with GAE provides stable, efficient learning. The final piece is controlling policy updates to prevent catastrophic forgetting - that’s where PPO comes in.