Policy Iteration
Policy iteration is a ping-pong algorithm. It alternates between two steps: evaluate the current policy to get , then improve the policy by acting greedily. Repeat until the policy stops changing. When it stops, you have found the optimal policy.
The Algorithm
An algorithm that alternates between:
- Policy Evaluation: Compute for the current policy
- Policy Improvement: Construct a new policy that is greedy with respect to
Repeat until the policy no longer changes. The final policy is optimal.
Think of it as a feedback loop:
- Evaluation answers: “How good is this policy?”
- Improvement answers: “What policy would be best given these values?”
Each round, evaluation produces new values and improvement produces a new policy. They chase each other until both stabilize at the optimum.
The remarkable thing is that this process terminates, and when it does, we have found and .
Pseudocode
Algorithm: Policy Iteration
Input: MDP
Output: Optimal policy , optimal value function
-
Initialize arbitrarily for all
-
Repeat:
- Policy Evaluation:
- Compute using iterative policy evaluation
- Policy Improvement:
- For each state :
- If :
- Until is True
- Policy Evaluation:
-
Return ,
Why It Works
Policy iteration converges because of two properties:
- Monotonic Improvement: Each policy is at least as good as the previous one (policy improvement theorem)
- Finite Policies: There are only finitely many deterministic policies for a finite MDP
Since we only improve and never get worse, and there are only finitely many policies, we must eventually stop improving. That can only happen at the optimal policy.
More formally, let be the set of all deterministic policies. For a finite MDP with states and actions:
This is finite. Each policy iteration step either:
- Produces a strictly better policy (different from before)
- Produces the same policy (we have converged)
We never revisit a policy because improvement is monotonic. Therefore, we must converge in at most iterations.
In practice, convergence happens much faster, often in just a few iterations.
Complete Implementation
def policy_evaluation(mdp, policy, gamma=0.99, theta=1e-8):
"""
Evaluate a policy using iterative policy evaluation.
Args:
mdp: MDP with states, actions(s), transitions(s, a)
policy: dict mapping state -> action
gamma: discount factor
theta: convergence threshold
Returns:
V: dict mapping state -> value
"""
V = {s: 0.0 for s in mdp.states}
while True:
delta = 0
for s in mdp.states:
if hasattr(mdp, 'terminal_states') and s in mdp.terminal_states:
continue
old_value = V[s]
action = policy[s]
# Bellman backup for deterministic policy
new_value = 0.0
for s_next, prob, reward in mdp.transitions(s, action):
new_value += prob * (reward + gamma * V[s_next])
V[s] = new_value
delta = max(delta, abs(old_value - new_value))
if delta < theta:
break
return V
def policy_iteration(mdp, gamma=0.99, theta=1e-8):
"""
Find optimal policy using policy iteration.
Args:
mdp: MDP object with states, actions(s), transitions(s, a)
gamma: discount factor
theta: convergence threshold for policy evaluation
Returns:
policy: optimal policy (dict: state -> action)
V: optimal value function (dict: state -> value)
history: list of dicts with iteration statistics
"""
import random
# Initialize with arbitrary policy
policy = {}
for s in mdp.states:
if hasattr(mdp, 'terminal_states') and s in mdp.terminal_states:
policy[s] = None
else:
policy[s] = random.choice(list(mdp.actions(s)))
history = []
iteration = 0
while True:
iteration += 1
# Step 1: Policy Evaluation
V = policy_evaluation(mdp, policy, gamma, theta)
# Step 2: Policy Improvement
policy_stable = True
changes = 0
for s in mdp.states:
if hasattr(mdp, 'terminal_states') and s in mdp.terminal_states:
continue
old_action = policy[s]
# Find greedy action
best_action = None
best_value = float('-inf')
for a in mdp.actions(s):
action_value = 0.0
for s_next, prob, reward in mdp.transitions(s, a):
action_value += prob * (reward + gamma * V[s_next])
if action_value > best_value:
best_value = action_value
best_action = a
policy[s] = best_action
if old_action != best_action:
policy_stable = False
changes += 1
# Record history
history.append({
'iteration': iteration,
'policy_changes': changes,
'max_value': max(V.values()),
'min_value': min(V.values()),
})
print(f"Iteration {iteration}: {changes} policy changes")
if policy_stable:
print(f"Policy iteration converged after {iteration} iterations")
break
return policy, V, historyA Worked Example
Consider a 4x4 grid. The agent starts anywhere and wants to reach the bottom-right corner (the goal). Each step costs -1. The discount factor is (no discounting for this episodic task).
. . . .
. . . .
. . . .
. . . G G = Goal (reward 0)Iteration 0: Initialize with random policy
Suppose we start with a policy that always goes “down”:
v v v v
v v v v
v v v v
v v v GIteration 1: Evaluate, then Improve
After evaluation, we find:
- States in the bottom row (except goal) have value -1 (one step to goal)
- States in the third row have value -2 (two steps)
- And so on…
Values under “always down” policy:
-3 -4 -5 -6
-2 -3 -4 -5
-1 -2 -3 -4
0 -1 -2 -3 Wait, this is wrong for column 0!Actually, the leftmost column cannot reach the goal by going down forever. Let us assume walls bounce you back. After proper evaluation:
-6 -5 -4 -3
-5 -4 -3 -2
-4 -3 -2 -1
-3 -2 -1 0Now we improve. For each state, we check which action is best:
- From (0,0): Right leads to -5, Down leads to -5. Tie! Pick either.
- From (0,3): Down leads to -2. That is best.
- From (3,0): Right leads to -2. That is best.
The improved policy becomes diagonal arrows pointing toward the goal.
Iteration 2: Evaluate and Improve again
The new policy is already optimal (shortest paths to goal). After evaluation and improvement, no changes occur.
Result: Converged in 2 iterations!
Optimal policy:
> > > v
> > > v
> > > v
> > > G
Optimal values:
-3 -2 -1 0 (assuming bottom-right is (3,3))
-2 ...
...Convergence Speed
Policy iteration typically converges very quickly, often in just 2-5 iterations for many practical MDPs. This is surprising given that there could be possible policies.
Why so fast? The policy improvement step makes large jumps in policy space. Unlike gradient descent which takes small steps, policy improvement switches to the globally greedy action at every state simultaneously. This aggressive improvement leads to rapid convergence.
| MDP Size | States | Actions | Typical Iterations |
|---|---|---|---|
| Small GridWorld | 16 | 4 | 2-3 |
| Medium GridWorld | 100 | 4 | 3-5 |
| Large GridWorld | 1000 | 4 | 4-7 |
| Complex stochastic MDP | 500 | 10 | 5-10 |
The iteration count grows slowly with MDP size. The bulk of computation is in policy evaluation, not in the number of improvement steps.
Computational Cost
The cost of policy iteration comes from two components:
Per policy evaluation: where is the number of evaluation sweeps (depends on and )
Per policy improvement: for one pass over all states
Total: If we need policy iterations, each requiring evaluation sweeps:
Since is typically small (2-10) but can be large (hundreds to thousands for high ), the evaluation cost dominates.
Variations
Modified Policy Iteration
Modified Policy Iteration is a practical variant that does not run policy evaluation to full convergence. Instead, it runs only evaluation sweeps before improving.
When (one sweep), this becomes very similar to value iteration. When (full convergence), this is standard policy iteration.
In practice, to often works well, balancing evaluation accuracy with computation time.
def modified_policy_iteration(mdp, gamma=0.99, k=10, theta=1e-6):
"""
Modified policy iteration with limited evaluation sweeps.
Args:
mdp: MDP object
gamma: discount factor
k: number of evaluation sweeps per iteration
theta: overall convergence threshold
"""
import random
# Initialize
policy = {s: random.choice(list(mdp.actions(s)))
for s in mdp.states
if not (hasattr(mdp, 'terminal_states') and s in mdp.terminal_states)}
V = {s: 0.0 for s in mdp.states}
iteration = 0
while True:
iteration += 1
# Limited policy evaluation: only k sweeps
for _ in range(k):
for s in mdp.states:
if hasattr(mdp, 'terminal_states') and s in mdp.terminal_states:
continue
action = policy.get(s)
if action is None:
continue
new_value = 0.0
for s_next, prob, reward in mdp.transitions(s, action):
new_value += prob * (reward + gamma * V[s_next])
V[s] = new_value
# Policy improvement
policy_stable = True
max_delta = 0
for s in mdp.states:
if hasattr(mdp, 'terminal_states') and s in mdp.terminal_states:
continue
old_action = policy.get(s)
old_value = V[s]
# Find best action and its value
best_action = None
best_value = float('-inf')
for a in mdp.actions(s):
action_value = 0.0
for s_next, prob, reward in mdp.transitions(s, a):
action_value += prob * (reward + gamma * V[s_next])
if action_value > best_value:
best_value = action_value
best_action = a
policy[s] = best_action
V[s] = best_value # Update value to greedy value
max_delta = max(max_delta, abs(old_value - best_value))
if old_action != best_action:
policy_stable = False
if policy_stable and max_delta < theta:
break
return policy, VAsynchronous Policy Iteration
In asynchronous variants, we do not wait to update all states before improving. Instead, we can:
- Update states in any order
- Improve the policy for a state as soon as its value is updated
- Prioritize states that are likely to change the most
These variants can converge faster in practice, especially for large MDPs with sparse reward structures.
Common Mistakes
Mistake 1: Forgetting to handle terminal states
Terminal states have value 0 (or terminal reward) and no policy. Always check for terminal states:
if s in terminal_states:
V[s] = 0
policy[s] = None
continueMistake 2: Using stochastic policies incorrectly
Standard policy iteration works with deterministic policies. If your policy maps state to action probabilities, the evaluation step must account for this:
# For stochastic policy
new_value = sum(
policy[s][a] * sum(p * (r + gamma * V[s_next]) for s_next, p, r in transitions(s, a))
for a in actions(s)
)Mistake 3: Not detecting convergence properly
Check if ANY action changed, not just the final policy values:
# Correct
if old_action != best_action:
policy_stable = False
# Wrong (checking values can miss ties)
if abs(old_value - best_value) > epsilon:
policy_stable = FalsePolicy Iteration vs. Other Methods
| Aspect | Policy Iteration | Value Iteration | Q-Learning |
|---|---|---|---|
| Model required? | Yes | Yes | No |
| Iterations | Few (2-10) | Many (100s) | Many (1000s+) |
| Per-iteration cost | High (full evaluation) | Low (one backup) | Very low (one sample) |
| Convergence | Exact | Exact | Approximate |
| Memory | states | states | states actions |
Summary
Key Takeaways:
- Policy iteration alternates between evaluation and improvement
- Convergence is guaranteed and typically fast (2-10 iterations)
- Evaluation cost dominates, especially for high
- Modified policy iteration trades evaluation accuracy for speed
- This is the gold standard for small MDPs with known models
Policy iteration gives us exact optimal policies when we know the MDP. But what if we want a simpler algorithm that combines evaluation and improvement in one step? That is value iteration, which we cover next.