The SARSA Algorithm
SARSA is the first TD control algorithm we’ll study in depth. It’s elegant, intuitive, and the perfect stepping stone to Q-learning. The name itself tells you everything you need to know about how it works: State, Action, Reward, State, Action.
The SARSA Update Rule
The SARSA update modifies the Q-value for a state-action pair based on the observed transition and the next action chosen by the policy:
Let’s break this down:
- We’re in state and take action
- The environment gives us reward and moves us to state
- We choose our next action from our policy
- Now we have all five elements: — SARSA!
- We update toward the target
The key: we need to know before we update. This action comes from our current policy (typically epsilon-greedy).
Why Does SARSA Need A’?
This might seem like a minor detail, but it’s crucial. SARSA uses — the value of the action we will actually take next.
Compare to what we could do instead:
- Use — the value of the best action (this is Q-learning)
- Use — the expected value over the policy (this is Expected SARSA)
SARSA uses the actual next action. This means:
- If we explore randomly sometimes, SARSA accounts for that
- The Q-values reflect what we actually do, not what we wish we would do
- SARSA learns about the policy it follows—that’s what makes it on-policy
The Complete Algorithm
Here’s the full SARSA algorithm:
Algorithm: SARSA (On-Policy TD Control)
────────────────────────────────────────
Parameters: step size α ∈ (0, 1], discount γ, exploration ε
Initialize Q(s, a) arbitrarily for all s, a
Initialize Q(terminal, ·) = 0
Loop for each episode:
Initialize S
Choose A from S using policy derived from Q (e.g., ε-greedy)
Loop for each step of episode:
Take action A, observe R, S'
Choose A' from S' using policy derived from Q (e.g., ε-greedy)
Q(S, A) ← Q(S, A) + α[R + γQ(S', A') - Q(S, A)]
S ← S'
A ← A'
until S is terminal
Notice the structure: we choose before the update, then shift: and . This means we always have the action ready for the next step.
TD Error for SARSA:
This measures the “surprise”—how much better or worse the transition was compared to our current estimate.
The update in terms of TD error:
Epsilon-Greedy Policy
SARSA needs a policy to select actions. The standard choice is epsilon-greedy:
The epsilon-greedy policy selects actions as follows:
- With probability : select the greedy action
- With probability : select each other action (where is the number of actions)
Or more simply:
- With probability : random action
- With probability : greedy action
Choosing epsilon:
- Too high (e.g., 0.5): Too much exploration, slow to converge
- Too low (e.g., 0.01): May miss better actions, get stuck in local optima
- Good starting point:
Some practitioners decay over time, starting high and reducing it as learning progresses.
A Worked Example
SARSA in GridWorld
Consider a 3x3 grid. The agent starts at and wants to reach the goal at . Actions: UP, DOWN, LEFT, RIGHT. Reward: per step, at goal.
Initial Q-values: All zeros.
Episode 1, Step 1:
- State
- Choose action RIGHT (random exploration)
- Observe: ,
- Choose next action DOWN (epsilon-greedy, random this time)
- SARSA update with , :
Episode 1, Step 2:
- Now , DOWN (from before)
- Observe: ,
- Choose RIGHT
- Update:
And so on. As the agent reaches the goal and gets the +10 reward, that positive value propagates backward through the Q-table.
Complete Implementation
import numpy as np
from collections import defaultdict
class SARSAgent:
"""SARSA agent for discrete state-action spaces."""
def __init__(self, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
"""
Initialize SARSA agent.
Args:
n_actions: Number of available actions
alpha: Learning rate
gamma: Discount factor
epsilon: Exploration rate for epsilon-greedy
"""
self.n_actions = n_actions
self.alpha = alpha
self.gamma = gamma
self.epsilon = epsilon
# Q-table: maps state to array of action values
self.Q = defaultdict(lambda: np.zeros(n_actions))
def select_action(self, state):
"""Select action using epsilon-greedy policy."""
if np.random.random() < self.epsilon:
return np.random.randint(self.n_actions)
else:
return np.argmax(self.Q[state])
def update(self, state, action, reward, next_state, next_action, done):
"""
Perform SARSA update.
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
next_action: Next action (from policy)
done: Whether episode ended
"""
if done:
td_target = reward
else:
td_target = reward + self.gamma * self.Q[next_state][next_action]
td_error = td_target - self.Q[state][action]
self.Q[state][action] += self.alpha * td_error
return td_error
def train_sarsa_episode(env, agent):
"""
Train one episode with SARSA.
Returns:
total_reward: Sum of rewards in the episode
steps: Number of steps taken
"""
state = env.reset()
action = agent.select_action(state)
total_reward = 0
steps = 0
done = False
while not done:
# Take action, observe result
next_state, reward, done, _ = env.step(action)
total_reward += reward
steps += 1
# Choose next action (needed for SARSA update)
next_action = agent.select_action(next_state)
# SARSA update
agent.update(state, action, reward, next_state, next_action, done)
# Move to next state-action pair
state = next_state
action = next_action
return total_reward, steps
def train_sarsa(env, agent, num_episodes=1000):
"""Train SARSA agent for multiple episodes."""
rewards = []
steps_list = []
for episode in range(num_episodes):
total_reward, steps = train_sarsa_episode(env, agent)
rewards.append(total_reward)
steps_list.append(steps)
if episode % 100 == 0:
avg_reward = np.mean(rewards[-100:])
print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}")
return rewards, steps_listExample: GridWorld Environment
class SimpleGridWorld:
"""A simple GridWorld for testing SARSA."""
def __init__(self, size=4):
self.size = size
self.goal = (size - 1, size - 1)
self.state = None
def reset(self):
"""Start at top-left corner."""
self.state = (0, 0)
return self.state
def step(self, action):
"""
Take an action.
Actions: 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT
"""
row, col = self.state
if action == 0: # UP
row = max(0, row - 1)
elif action == 1: # DOWN
row = min(self.size - 1, row + 1)
elif action == 2: # LEFT
col = max(0, col - 1)
elif action == 3: # RIGHT
col = min(self.size - 1, col + 1)
self.state = (row, col)
# Check if reached goal
if self.state == self.goal:
return self.state, 10, True, {}
else:
return self.state, -1, False, {}
# Train an agent
env = SimpleGridWorld(size=4)
agent = SARSAgent(n_actions=4, alpha=0.1, gamma=0.99, epsilon=0.1)
rewards, steps = train_sarsa(env, agent, num_episodes=500)Visualizing Q-Values
import matplotlib.pyplot as plt
def plot_q_values(agent, grid_size):
"""Visualize Q-values as a grid with arrows."""
fig, ax = plt.subplots(figsize=(8, 8))
# Action directions for arrows
arrow_dirs = {
0: (0, 0.3), # UP
1: (0, -0.3), # DOWN
2: (-0.3, 0), # LEFT
3: (0.3, 0) # RIGHT
}
for row in range(grid_size):
for col in range(grid_size):
state = (row, col)
q_values = agent.Q[state]
best_action = np.argmax(q_values)
# Draw cell value (max Q)
ax.text(col, grid_size - 1 - row, f'{np.max(q_values):.1f}',
ha='center', va='center', fontsize=10)
# Draw arrow for best action
dx, dy = arrow_dirs[best_action]
ax.arrow(col, grid_size - 1 - row, dx * 0.6, dy * 0.6,
head_width=0.1, head_length=0.05, fc='blue', ec='blue')
ax.set_xlim(-0.5, grid_size - 0.5)
ax.set_ylim(-0.5, grid_size - 0.5)
ax.set_xticks(range(grid_size))
ax.set_yticks(range(grid_size))
ax.grid(True)
ax.set_title('SARSA Q-Values and Policy')
plt.show()Expected SARSA: A Variance Reduction
One issue with SARSA is that the update depends on which specific action we happen to sample. This introduces variance—the same could lead to different updates depending on our random action choice.
Expected SARSA fixes this by using the expected Q-value over the policy:
Instead of sampling one action, we average over all actions weighted by their probability. This reduces variance while keeping the same expected behavior.
def expected_sarsa_update(Q, state, action, reward, next_state, done,
alpha=0.1, gamma=0.99, epsilon=0.1, n_actions=4):
"""Expected SARSA update - uses expected value instead of sampled action."""
if done:
td_target = reward
else:
# Compute expected Q-value under epsilon-greedy policy
q_next = Q[next_state]
greedy_action = np.argmax(q_next)
# Expected value: epsilon/n * sum(all) + (1-epsilon) * max
expected_q = epsilon / n_actions * np.sum(q_next) + \
(1 - epsilon) * q_next[greedy_action]
td_target = reward + gamma * expected_q
td_error = td_target - Q[state][action]
Q[state][action] += alpha * td_error
return td_errorSummary
SARSA is a complete TD control algorithm:
- Initialize Q-values (typically to zero)
- For each episode:
- Choose initial action using epsilon-greedy
- For each step:
- Take action, observe reward and next state
- Choose next action using epsilon-greedy
- Update Q using the SARSA rule
- Move to next state-action pair
The key characteristic: SARSA uses the actual next action from the policy, making it an on-policy method. This means the learned Q-values reflect the behavior of the exploratory policy, not the optimal policy.
In the next section, we’ll explore what this “on-policy” property means in practice, and why it makes SARSA behave differently from Q-learning.