Temporal Difference Learning • Part 2 of 3
📝Draft

The SARSA Algorithm

State-Action-Reward-State-Action learning

The SARSA Algorithm

SARSA is the first TD control algorithm we’ll study in depth. It’s elegant, intuitive, and the perfect stepping stone to Q-learning. The name itself tells you everything you need to know about how it works: State, Action, Reward, State, Action.

The SARSA Update Rule

📖SARSA Update

The SARSA update modifies the Q-value for a state-action pair based on the observed transition and the next action chosen by the policy:

Q(St,At)Q(St,At)+α[Rt+1+γQ(St+1,At+1)Q(St,At)]Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha \left[ R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t) \right]

Let’s break this down:

  1. We’re in state StS_t and take action AtA_t
  2. The environment gives us reward Rt+1R_{t+1} and moves us to state St+1S_{t+1}
  3. We choose our next action At+1A_{t+1} from our policy
  4. Now we have all five elements: (St,At,Rt+1,St+1,At+1)(S_t, A_t, R_{t+1}, S_{t+1}, A_{t+1}) — SARSA!
  5. We update Q(St,At)Q(S_t, A_t) toward the target Rt+1+γQ(St+1,At+1)R_{t+1} + \gamma Q(S_{t+1}, A_{t+1})

The key: we need to know At+1A_{t+1} before we update. This action comes from our current policy (typically epsilon-greedy).

Why Does SARSA Need A’?

This might seem like a minor detail, but it’s crucial. SARSA uses Q(St+1,At+1)Q(S_{t+1}, A_{t+1}) — the value of the action we will actually take next.

Compare to what we could do instead:

  • Use maxaQ(St+1,a)\max_a Q(S_{t+1}, a) — the value of the best action (this is Q-learning)
  • Use aπ(aSt+1)Q(St+1,a)\sum_a \pi(a|S_{t+1}) Q(S_{t+1}, a) — the expected value over the policy (this is Expected SARSA)

SARSA uses the actual next action. This means:

  • If we explore randomly sometimes, SARSA accounts for that
  • The Q-values reflect what we actually do, not what we wish we would do
  • SARSA learns about the policy it follows—that’s what makes it on-policy

The Complete Algorithm

Here’s the full SARSA algorithm:

Algorithm: SARSA (On-Policy TD Control)
────────────────────────────────────────
Parameters: step size α ∈ (0, 1], discount γ, exploration ε

Initialize Q(s, a) arbitrarily for all s, a
Initialize Q(terminal, ·) = 0

Loop for each episode:
    Initialize S
    Choose A from S using policy derived from Q (e.g., ε-greedy)

    Loop for each step of episode:
        Take action A, observe R, S'
        Choose A' from S' using policy derived from Q (e.g., ε-greedy)
        Q(S, A) ← Q(S, A) + α[R + γQ(S', A') - Q(S, A)]
        S ← S'
        A ← A'
    until S is terminal
ℹ️Note

Notice the structure: we choose AA' before the update, then shift: SSS \leftarrow S' and AAA \leftarrow A'. This means we always have the action ready for the next step.

Mathematical Details

TD Error for SARSA:

δt=Rt+1+γQ(St+1,At+1)Q(St,At)\delta_t = R_{t+1} + \gamma Q(S_{t+1}, A_{t+1}) - Q(S_t, A_t)

This measures the “surprise”—how much better or worse the transition was compared to our current estimate.

The update in terms of TD error:

Q(St,At)Q(St,At)+αδtQ(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha \cdot \delta_t

Epsilon-Greedy Policy

SARSA needs a policy to select actions. The standard choice is epsilon-greedy:

Mathematical Details

The epsilon-greedy policy selects actions as follows:

  • With probability 1ε+εn1 - \varepsilon + \frac{\varepsilon}{n}: select the greedy action argmaxaQ(s,a)\arg\max_a Q(s,a)
  • With probability εn\frac{\varepsilon}{n}: select each other action (where nn is the number of actions)

Or more simply:

  • With probability ε\varepsilon: random action
  • With probability 1ε1-\varepsilon: greedy action
💡Tip

Choosing epsilon:

  • Too high (e.g., 0.5): Too much exploration, slow to converge
  • Too low (e.g., 0.01): May miss better actions, get stuck in local optima
  • Good starting point: ε=0.1\varepsilon = 0.1

Some practitioners decay ε\varepsilon over time, starting high and reducing it as learning progresses.

A Worked Example

📌Example

SARSA in GridWorld

Consider a 3x3 grid. The agent starts at (0,0)(0,0) and wants to reach the goal at (2,2)(2,2). Actions: UP, DOWN, LEFT, RIGHT. Reward: 1-1 per step, +10+10 at goal.

Initial Q-values: All zeros.

Episode 1, Step 1:

  • State S=(0,0)S = (0,0)
  • Choose action A=A = RIGHT (random exploration)
  • Observe: R=1R = -1, S=(0,1)S' = (0,1)
  • Choose next action A=A' = DOWN (epsilon-greedy, random this time)
  • SARSA update with α=0.1\alpha = 0.1, γ=0.9\gamma = 0.9:

Q((0,0),RIGHT)0+0.1×[1+0.9×00]=0.1Q((0,0), RIGHT) \leftarrow 0 + 0.1 \times [-1 + 0.9 \times 0 - 0] = -0.1

Episode 1, Step 2:

  • Now S=(0,1)S = (0,1), A=A = DOWN (from before)
  • Observe: R=1R = -1, S=(1,1)S' = (1,1)
  • Choose A=A' = RIGHT
  • Update:

Q((0,1),DOWN)0+0.1×[1+0.9×00]=0.1Q((0,1), DOWN) \leftarrow 0 + 0.1 \times [-1 + 0.9 \times 0 - 0] = -0.1

And so on. As the agent reaches the goal and gets the +10 reward, that positive value propagates backward through the Q-table.

Complete Implementation

</>Implementation
import numpy as np
from collections import defaultdict

class SARSAgent:
    """SARSA agent for discrete state-action spaces."""

    def __init__(self, n_actions, alpha=0.1, gamma=0.99, epsilon=0.1):
        """
        Initialize SARSA agent.

        Args:
            n_actions: Number of available actions
            alpha: Learning rate
            gamma: Discount factor
            epsilon: Exploration rate for epsilon-greedy
        """
        self.n_actions = n_actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

        # Q-table: maps state to array of action values
        self.Q = defaultdict(lambda: np.zeros(n_actions))

    def select_action(self, state):
        """Select action using epsilon-greedy policy."""
        if np.random.random() < self.epsilon:
            return np.random.randint(self.n_actions)
        else:
            return np.argmax(self.Q[state])

    def update(self, state, action, reward, next_state, next_action, done):
        """
        Perform SARSA update.

        Args:
            state: Current state
            action: Action taken
            reward: Reward received
            next_state: Next state
            next_action: Next action (from policy)
            done: Whether episode ended
        """
        if done:
            td_target = reward
        else:
            td_target = reward + self.gamma * self.Q[next_state][next_action]

        td_error = td_target - self.Q[state][action]
        self.Q[state][action] += self.alpha * td_error

        return td_error


def train_sarsa_episode(env, agent):
    """
    Train one episode with SARSA.

    Returns:
        total_reward: Sum of rewards in the episode
        steps: Number of steps taken
    """
    state = env.reset()
    action = agent.select_action(state)

    total_reward = 0
    steps = 0
    done = False

    while not done:
        # Take action, observe result
        next_state, reward, done, _ = env.step(action)
        total_reward += reward
        steps += 1

        # Choose next action (needed for SARSA update)
        next_action = agent.select_action(next_state)

        # SARSA update
        agent.update(state, action, reward, next_state, next_action, done)

        # Move to next state-action pair
        state = next_state
        action = next_action

    return total_reward, steps


def train_sarsa(env, agent, num_episodes=1000):
    """Train SARSA agent for multiple episodes."""
    rewards = []
    steps_list = []

    for episode in range(num_episodes):
        total_reward, steps = train_sarsa_episode(env, agent)
        rewards.append(total_reward)
        steps_list.append(steps)

        if episode % 100 == 0:
            avg_reward = np.mean(rewards[-100:])
            print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}")

    return rewards, steps_list

Example: GridWorld Environment

</>Implementation
class SimpleGridWorld:
    """A simple GridWorld for testing SARSA."""

    def __init__(self, size=4):
        self.size = size
        self.goal = (size - 1, size - 1)
        self.state = None

    def reset(self):
        """Start at top-left corner."""
        self.state = (0, 0)
        return self.state

    def step(self, action):
        """
        Take an action.
        Actions: 0=UP, 1=DOWN, 2=LEFT, 3=RIGHT
        """
        row, col = self.state

        if action == 0:    # UP
            row = max(0, row - 1)
        elif action == 1:  # DOWN
            row = min(self.size - 1, row + 1)
        elif action == 2:  # LEFT
            col = max(0, col - 1)
        elif action == 3:  # RIGHT
            col = min(self.size - 1, col + 1)

        self.state = (row, col)

        # Check if reached goal
        if self.state == self.goal:
            return self.state, 10, True, {}
        else:
            return self.state, -1, False, {}


# Train an agent
env = SimpleGridWorld(size=4)
agent = SARSAgent(n_actions=4, alpha=0.1, gamma=0.99, epsilon=0.1)
rewards, steps = train_sarsa(env, agent, num_episodes=500)

Visualizing Q-Values

</>Implementation
import matplotlib.pyplot as plt

def plot_q_values(agent, grid_size):
    """Visualize Q-values as a grid with arrows."""
    fig, ax = plt.subplots(figsize=(8, 8))

    # Action directions for arrows
    arrow_dirs = {
        0: (0, 0.3),   # UP
        1: (0, -0.3),  # DOWN
        2: (-0.3, 0),  # LEFT
        3: (0.3, 0)    # RIGHT
    }

    for row in range(grid_size):
        for col in range(grid_size):
            state = (row, col)
            q_values = agent.Q[state]
            best_action = np.argmax(q_values)

            # Draw cell value (max Q)
            ax.text(col, grid_size - 1 - row, f'{np.max(q_values):.1f}',
                   ha='center', va='center', fontsize=10)

            # Draw arrow for best action
            dx, dy = arrow_dirs[best_action]
            ax.arrow(col, grid_size - 1 - row, dx * 0.6, dy * 0.6,
                    head_width=0.1, head_length=0.05, fc='blue', ec='blue')

    ax.set_xlim(-0.5, grid_size - 0.5)
    ax.set_ylim(-0.5, grid_size - 0.5)
    ax.set_xticks(range(grid_size))
    ax.set_yticks(range(grid_size))
    ax.grid(True)
    ax.set_title('SARSA Q-Values and Policy')
    plt.show()

Expected SARSA: A Variance Reduction

One issue with SARSA is that the update depends on which specific action AA' we happen to sample. This introduces variance—the same (S,A,R,S)(S, A, R, S') could lead to different updates depending on our random action choice.

Expected SARSA fixes this by using the expected Q-value over the policy:

Q(St,At)Q(St,At)+α[Rt+1+γaπ(aSt+1)Q(St+1,a)Q(St,At)]Q(S_t, A_t) \leftarrow Q(S_t, A_t) + \alpha \left[ R_{t+1} + \gamma \sum_a \pi(a|S_{t+1}) Q(S_{t+1}, a) - Q(S_t, A_t) \right]

Instead of sampling one action, we average over all actions weighted by their probability. This reduces variance while keeping the same expected behavior.

</>Implementation
def expected_sarsa_update(Q, state, action, reward, next_state, done,
                          alpha=0.1, gamma=0.99, epsilon=0.1, n_actions=4):
    """Expected SARSA update - uses expected value instead of sampled action."""
    if done:
        td_target = reward
    else:
        # Compute expected Q-value under epsilon-greedy policy
        q_next = Q[next_state]
        greedy_action = np.argmax(q_next)

        # Expected value: epsilon/n * sum(all) + (1-epsilon) * max
        expected_q = epsilon / n_actions * np.sum(q_next) + \
                     (1 - epsilon) * q_next[greedy_action]

        td_target = reward + gamma * expected_q

    td_error = td_target - Q[state][action]
    Q[state][action] += alpha * td_error

    return td_error

Summary

SARSA is a complete TD control algorithm:

  1. Initialize Q-values (typically to zero)
  2. For each episode:
    • Choose initial action using epsilon-greedy
    • For each step:
      • Take action, observe reward and next state
      • Choose next action using epsilon-greedy
      • Update Q using the SARSA rule
      • Move to next state-action pair

The key characteristic: SARSA uses the actual next action from the policy, making it an on-policy method. This means the learned Q-values reflect the behavior of the exploratory policy, not the optimal policy.

In the next section, we’ll explore what this “on-policy” property means in practice, and why it makes SARSA behave differently from Q-learning.