Generalized Advantage Estimation (GAE)
Generalized Advantage Estimation (GAE) is a method for estimating the advantage function in Reinforcement Learning (RL). It reduces the variance of advantage estimates while introducing a tunable bias via an exponentially weighted sum of Temporal Difference (TD) errors. This can lead to more stable policy gradient updates.
A great explanation for GAE can be found in other blog posts (for instance, see this blog), but here we will focus on the core implementation details of GAE.
Let’s see the formulation of GAE and try to break it down into parts. For a time step , the gae is given by
where is the TD error . In simpler terms, this means we start at time and sum the discounted TD errors from onward. Although the mathematical expression goes to infinity, in practice we implement this on finite trajectories.
Backward recurrsion
Directly summing
in a forward manner can be computationally expensive for long trajectories, and handling terminal states can become cumbersome. The key insight from the original GAE paper is that we can rewrite the above summation using a backward recursion:
This means if we have already computed the advantage at time , we can use it to get directly. By iterating backward (from down to ), we only compute TD errors once and avoid repeatedly summing over overlapping segments. This also simplifies handling terminal states because you can “reset” the accumulation at episode boundaries.
Code Example
When we run an environment to sample experiences for an episode (or rollout), we typically collect the following information at each time step:
rewards[t]
(),values[t]
(),next_value
( for the state following the final time step, if not terminal),gamma
(),lam
().
Using these, we can implement GAE as follows:
Code
def gae(rewards, values, next_value, gamma, lambda):
T = len(rewards) # We get the number of time steps in this trajectory
# now we need to find advantage value at each time step
A = [0] * T # initalize it to zero at each instance.
gae = 0 # running value for the gae and it's zero for the terminal state
values = values + [next_value]
# now let's do the backward recurssion
for t in reversed(range(T)):
# let's find the TD error
# \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t)
delta = rewards[t] + gamma * values[t+1] - values[t]
# now the gae at time t
# \hat{A}_t = \delta_t + \gamma \lambda \hat{A}_{t+1}
gae = delta + gamma * lambda * gae
A[t] = gae
return gae
Doing so, we can estimate the advantage value at each instance.
Handling Terminal States
If the episode ended before t == T-1
, you can set next_value = 0
(or the appropriate bootstrap value) to avoid incorrectly carrying over estimated values beyond a terminal state. Another approach is to keep a separate list of dones indicating when an episode has ended, and reset the GAE accumulator if dones[t] == True
.
References :