Baselines and Variance Reduction
We’ve seen that REINFORCE suffers from high variance. The solution is elegantly simple: subtract a baseline from the returns. This dramatically reduces variance without biasing the gradient - a free lunch in the world of optimization.
The Baseline Trick
A baseline is a function we subtract from the return in the policy gradient:
Key properties:
- The baseline can depend on the state , but NOT on the action
- Subtracting a baseline does not change the expected gradient
- A good baseline dramatically reduces variance
The baseline provides context. Instead of asking “was this return good?” we ask “was this return better than expected?”
Imagine two scenarios:
- Without baseline: You get return 100. Is that good? No idea without context.
- With baseline: You get return 100, baseline is 80. You did 20 better than expected - definitely reinforce those actions.
The baseline tells you what to expect, so you can judge whether the actual outcome was above or below average.
Why Baselines Don’t Bias the Gradient
The key insight is that for any function that depends only on state:
Proof:
Since doesn’t depend on , we can factor it out:
Now we use the fact that the expected score is zero:
Therefore, subtracting changes variance but not expected value - the gradient remains unbiased!
The baseline must NOT depend on the action . If it did, the proof above would break - we couldn’t factor out of the expectation over actions.
This is why we use as a baseline, not .
Common Baseline Choices
1. Constant Baseline
The simplest baseline: subtract the average return across all episodes.
This centers returns around zero, which helps prevent all gradients from pointing the same direction.
def reinforce_with_constant_baseline(policy, optimizer, episodes):
"""REINFORCE with constant baseline (average return)."""
# Collect all returns
all_returns = []
for states, actions, rewards in episodes:
returns = compute_returns(rewards)
all_returns.extend(returns.tolist())
# Compute baseline
baseline = np.mean(all_returns)
# Compute loss with baseline
total_loss = 0
for states, actions, rewards in episodes:
returns = compute_returns(rewards)
advantages = returns - baseline # Subtract baseline
states_tensor = torch.stack(states)
actions_tensor = torch.tensor(actions)
log_probs = policy.log_prob(states_tensor, actions_tensor)
total_loss += -(log_probs * advantages).sum()
# Update
optimizer.zero_grad()
(total_loss / len(episodes)).backward()
optimizer.step()2. Running Average Baseline
Track the average return over time using an exponential moving average:
This adapts to the current policy’s performance level.
class RunningBaseline:
"""Exponential moving average of returns."""
def __init__(self, alpha=0.1):
self.alpha = alpha
self.value = 0.0
self.initialized = False
def update(self, returns):
"""Update baseline with new returns."""
mean_return = returns.mean().item()
if not self.initialized:
self.value = mean_return
self.initialized = True
else:
self.value = self.alpha * mean_return + (1 - self.alpha) * self.value
def __call__(self, states):
"""Return baseline value (constant for all states)."""
return self.value3. Learned State-Value Baseline (Optimal)
The theoretically optimal baseline is - the expected return from state under the current policy. This tells you exactly what to expect from each state.
If you get return 100 from a state that usually gives 80, you did better than expected. If the state usually gives 120, you did worse.
With as baseline, the policy gradient becomes:
The term is an estimate of the advantage - how much better the actual action was compared to the average action.
This is the bridge to actor-critic methods!
import torch
import torch.nn as nn
import torch.nn.functional as F
class ValueNetwork(nn.Module):
"""Neural network for state-value function."""
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
return self.network(state).squeeze(-1)
class REINFORCEWithBaseline:
"""REINFORCE with learned value function baseline."""
def __init__(self, state_dim, n_actions, lr_policy=1e-3,
lr_value=1e-3, gamma=0.99):
self.gamma = gamma
# Policy network
self.policy = PolicyNetwork(state_dim, n_actions)
self.policy_optimizer = torch.optim.Adam(
self.policy.parameters(), lr=lr_policy
)
# Value network (baseline)
self.value = ValueNetwork(state_dim)
self.value_optimizer = torch.optim.Adam(
self.value.parameters(), lr=lr_value
)
def compute_returns(self, rewards):
"""Compute discounted returns."""
returns = []
G = 0
for r in reversed(rewards):
G = r + self.gamma * G
returns.insert(0, G)
return torch.tensor(returns, dtype=torch.float32)
def update(self, states, actions, rewards):
"""Update policy and value function."""
returns = self.compute_returns(rewards)
states_tensor = torch.stack(states)
actions_tensor = torch.tensor(actions, dtype=torch.long)
# Compute value predictions (baseline)
values = self.value(states_tensor)
# Compute advantages: G - V(s)
advantages = returns - values.detach()
# Normalize advantages for stability
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Policy loss: -log_prob * advantage
log_probs = self.policy.log_prob(states_tensor, actions_tensor)
policy_loss = -(log_probs * advantages).sum()
# Value loss: MSE between V(s) and actual returns
value_loss = F.mse_loss(values, returns)
# Update policy
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# Update value function
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
return policy_loss.item(), value_loss.item()How Much Does the Baseline Help?
The variance reduction can be dramatic. Consider a task where returns range from 0 to 100:
Without baseline: Gradients are scaled by 0-100. High variance.
With good baseline: If , then ranges from -50 to +50. The magnitude is the same, but now positive and negative values cancel out in expectation, reducing variance.
More importantly, the baseline provides relative information. Actions that do better than expected get reinforced; actions that do worse get suppressed.
The variance of the policy gradient estimate with baseline is:
The optimal baseline that minimizes variance is approximately:
In practice, is close to optimal and much easier to estimate.
def compare_variance_reduction(policy, env, n_episodes=50, gamma=0.99):
"""Compare gradient variance with and without baseline."""
gradients_no_baseline = []
gradients_with_baseline = []
all_returns = []
# Collect episodes
episodes = []
for _ in range(n_episodes):
states, actions, rewards = sample_episode(policy, env)
returns = compute_returns(rewards, gamma)
all_returns.extend(returns.tolist())
episodes.append((states, actions, rewards, returns))
baseline = np.mean(all_returns)
# Compute gradients
for states, actions, rewards, returns in episodes:
states_tensor = torch.stack(states)
actions_tensor = torch.tensor(actions)
# Without baseline
policy.zero_grad()
log_probs = policy.log_prob(states_tensor, actions_tensor)
loss_no_baseline = -(log_probs * returns).sum()
loss_no_baseline.backward()
grad_no_baseline = get_flat_gradient(policy)
gradients_no_baseline.append(grad_no_baseline)
# With baseline
policy.zero_grad()
advantages = returns - baseline
loss_with_baseline = -(log_probs.detach() * advantages).sum()
# Recompute for fresh gradients
log_probs = policy.log_prob(states_tensor, actions_tensor)
loss_with_baseline = -(log_probs * advantages).sum()
loss_with_baseline.backward()
grad_with_baseline = get_flat_gradient(policy)
gradients_with_baseline.append(grad_with_baseline)
# Compute variance
gradients_no_baseline = torch.stack(gradients_no_baseline)
gradients_with_baseline = torch.stack(gradients_with_baseline)
var_no_baseline = gradients_no_baseline.var(dim=0).mean().item()
var_with_baseline = gradients_with_baseline.var(dim=0).mean().item()
print(f"Variance without baseline: {var_no_baseline:.4f}")
print(f"Variance with baseline: {var_with_baseline:.4f}")
print(f"Variance reduction: {(1 - var_with_baseline/var_no_baseline)*100:.1f}%")
def get_flat_gradient(model):
"""Get flattened gradient vector."""
grads = []
for param in model.parameters():
if param.grad is not None:
grads.append(param.grad.clone().flatten())
return torch.cat(grads)The Advantage Function
The advantage function measures how much better action is compared to the average action in state :
Properties:
- : Action is better than average
- : Action is worse than average
- : Advantages average to zero
The advantage is the natural quantity to multiply the gradient by:
- Positive advantage: “This action was better than expected, reinforce it”
- Negative advantage: “This action was worse than expected, suppress it”
Using in REINFORCE estimates the advantage:
This connects baselines directly to the advantage function, setting the stage for actor-critic methods.
Complete REINFORCE with Baseline
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gymnasium as gym
class PolicyNetwork(nn.Module):
"""Policy network for discrete actions."""
def __init__(self, state_dim, n_actions, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_actions)
)
def forward(self, state):
logits = self.network(state)
return F.softmax(logits, dim=-1)
def sample(self, state):
probs = self.forward(state)
dist = torch.distributions.Categorical(probs)
return dist.sample().item()
def log_prob(self, states, actions):
probs = self.forward(states)
dist = torch.distributions.Categorical(probs)
return dist.log_prob(actions)
class ValueNetwork(nn.Module):
"""Value network for state-value function."""
def __init__(self, state_dim, hidden_dim=128):
super().__init__()
self.network = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)
def forward(self, state):
return self.network(state).squeeze(-1)
def train_reinforce_baseline(env_name='CartPole-v1', episodes=500,
gamma=0.99, lr_policy=1e-3, lr_value=1e-2):
"""Train REINFORCE with learned baseline."""
env = gym.make(env_name)
state_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
policy = PolicyNetwork(state_dim, n_actions)
value = ValueNetwork(state_dim)
policy_optimizer = torch.optim.Adam(policy.parameters(), lr=lr_policy)
value_optimizer = torch.optim.Adam(value.parameters(), lr=lr_value)
episode_rewards = []
for ep in range(episodes):
# Collect episode
states, actions, rewards = [], [], []
state, _ = env.reset()
done = False
while not done:
state_tensor = torch.tensor(state, dtype=torch.float32)
states.append(state_tensor)
action = policy.sample(state_tensor.unsqueeze(0))
actions.append(action)
state, reward, terminated, truncated, _ = env.step(action)
done = terminated or truncated
rewards.append(reward)
# Compute returns
returns = []
G = 0
for r in reversed(rewards):
G = r + gamma * G
returns.insert(0, G)
returns = torch.tensor(returns, dtype=torch.float32)
states_tensor = torch.stack(states)
actions_tensor = torch.tensor(actions, dtype=torch.long)
# Compute baseline (value function)
values = value(states_tensor)
# Compute advantages
advantages = returns - values.detach()
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Update policy
log_probs = policy.log_prob(states_tensor, actions_tensor)
policy_loss = -(log_probs * advantages).sum()
policy_optimizer.zero_grad()
policy_loss.backward()
policy_optimizer.step()
# Update value function
value_loss = F.mse_loss(values, returns)
value_optimizer.zero_grad()
value_loss.backward()
value_optimizer.step()
# Track progress
total_reward = sum(rewards)
episode_rewards.append(total_reward)
if (ep + 1) % 100 == 0:
avg_reward = np.mean(episode_rewards[-100:])
print(f"Episode {ep + 1}, Avg Reward: {avg_reward:.2f}")
env.close()
return policy, value, episode_rewards
if __name__ == "__main__":
policy, value, rewards = train_reinforce_baseline()Summary
Baselines are a powerful variance reduction technique:
- No bias: Subtracting a state-dependent baseline doesn’t change the expected gradient
- Lower variance: Centering returns around expected values reduces gradient magnitude variation
- Better credit assignment: Actions are judged relative to expectations, not absolute returns
- Bridge to actor-critic: Using as baseline estimates advantages, connecting to actor-critic methods
Baselines are so important that they’re present in virtually every modern policy gradient algorithm. PPO, A2C, and other methods all use advantage estimation, which is just a sophisticated form of baseline subtraction.
The natural next step: if we’re already learning for the baseline, why not use it more directly? That’s the idea behind actor-critic methods, which we’ll explore in the next chapter.