Generalized Advantage Estimation (GAE)

Generalized Advantage Estimation (GAE) is a method for estimating the advantage function in Reinforcement Learning (RL). It reduces the variance of advantage estimates while introducing a tunable bias via an exponentially weighted sum of Temporal Difference (TD) errors. This can lead to more stable policy gradient updates.

A great explanation for GAE can be found in other blog posts (for instance, see this blog), but here we will focus on the core implementation details of GAE.

Let’s see the formulation of GAE and try to break it down into parts. For a time step , the gae is given by

where is the TD error . In simpler terms, this means we start at time and sum the discounted TD errors from onward. Although the mathematical expression goes to infinity, in practice we implement this on finite trajectories.

Backward recurrsion

Directly summing

in a forward manner can be computationally expensive for long trajectories, and handling terminal states can become cumbersome. The key insight from the original GAE paper is that we can rewrite the above summation using a backward recursion:

This means if we have already computed the advantage at time , we can use it to get directly. By iterating backward (from down to ), we only compute TD errors once and avoid repeatedly summing over overlapping segments. This also simplifies handling terminal states because you can “reset” the accumulation at episode boundaries.

Code Example

When we run an environment to sample experiences for an episode (or rollout), we typically collect the following information at each time step:

  • rewards[t] (),
  • values[t] (),
  • next_value ( for the state following the final time step, if not terminal),
  • gamma (),
  • lam ().

Using these, we can implement GAE as follows:

Code

 
def gae(rewards, values, next_value, gamma, lambda):
 
    T = len(rewards) # We get the number of time steps in this trajectory
    # now we need to find advantage value at each time step 
    A = [0] * T # initalize it to zero at each instance. 
 
    gae = 0 # running value for the gae and it's zero for the terminal state
 
    values = values + [next_value]
    # now let's do the backward recurssion
    for t in reversed(range(T)): 
 
        # let's find the TD error 
        # \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
        delta = rewards[t] + gamma * values[t+1] - values[t]
 
        # now the gae at time t 
        # \hat{A}_t = \delta_t + \gamma \lambda \hat{A}_{t+1}
        gae = delta + gamma * lambda * gae 
        A[t] = gae 
    
    return gae 
 
Doing so, we can estimate the advantage value at each instance. 
 

Handling Terminal States

If the episode ended before t == T-1, you can set next_value = 0 (or the appropriate bootstrap value) to avoid incorrectly carrying over estimated values beyond a terminal state. Another approach is to keep a separate list of dones indicating when an episode has ended, and reset the GAE accumulator if dones[t] == True.

References :

  1. Blog: Notes on the Generalized Advantage Estimation Paper
  2. “High-Dimensional Continuous Control Using Generalized Advantage Estimation”