In this post, we will derive a backprop like algorithm to compute the gradient of a finite horizon optimal control problem. The technique we use here is well established, known as the method of adjoints. The derivation I am using is based off these excellent notes. $ \newcommand{\abs}[1]{| #1 |} \newcommand{\bigabs}[1]{\left| #1 \right|} \newcommand{\R}{\mathbb{R}} \newcommand{\E}{\mathbb{E}} \renewcommand{\Pr}{\mathbb{P}} \newcommand{\calE}{\mathcal{E}} \newcommand{\calF}{\mathcal{F}} \newcommand{\calD}{\mathcal{D}} \newcommand{\calN}{\mathcal{N}} \newcommand{\calL}{\mathcal{L}} \newcommand{\calM}{\mathcal{M}} \newcommand{\bbP}{\mathbb{P}} \newcommand{\bbQ}{\mathbb{Q}} \newcommand{\ip}[2]{\langle #1, #2 \rangle} \newcommand{\bigip}[2]{\left\langle #1, #2 \right\rangle} \newcommand{\T}{\mathsf{T}} \newcommand{\Tr}{\mathrm{Tr}} \newcommand{\ind}{\mathbf{1}} \newcommand{\calL}{\mathcal{L}} \newcommand{\norm}[1]{\lVert #1 \rVert}$

Consider the problem $$ \begin{align*} \mathop{\mathrm{minimize}}_{u_0, ..., u_{N-1}} \sum_{t=0}^{N-1} c_t(x_t, u_t) + c_N(x_N) \: : \: x_{t+1} = f_t(x_t, u_t) \:, \end{align*} $$ where $x_0$ is given. Suppose that the $c_t$'s and $f_t$'s are differentiable. Note that we can write the entire problem as an unconstrained minimization problem over some differentiable function $g(u_0, ..., u_{N-1})$. The question is, can we efficiently compute $\nabla u_k g$?

Let us first see what happens if we try to compute the gradient directly. For concreteness, let us look at $c_3(x_3, u_3)$, $$ c_3(x_3, u_3) = c_3(f_2(f_1(f_0(x_0, u_0), u_1), u_2), u_3) \:. $$ By application of the chain rule, $\nabla_u c_3$ is $$ \nabla_u c_3 = \begin{bmatrix} D_x c_3 D_x f_2 D_x f_1 D_u f_0 \\ D_x c_3 D_x f_2 D_u f_1 \\ D_x c_3 D_u f_2 \\ 0 \\ \vdots \end{bmatrix} \:. $$ The generalization to $c_t(x_t, u_t)$ is clear here, and therefore to compute $\nabla_u c_t$ we will need to perform $O(t^2)$ operations. Hence to compute $\nabla_u c_t$ for all $t$ we will need $O(\sum_{t=1}^{N} t^2) = O(N^3)$ operations. Note that in the $O(\cdot)$ notation here I am suppressing the dependence on the dimension of $x_t$ and $u_t$, which I am treating as fixed while $N$ grows.

Let us derive a more efficient algorithm based on the method of adjoints. Let $\phi_k(u_0, ..., u_{k-1})$ denote the map such that $\phi_k = x_k$. That is, $\phi_0 = x_0$, $\phi_1(u_0) = f_0(x_0, u_0)$, $\phi_2(u_0, u_1) = f_1(f_0(x_0, u_0), u_1)$, and so on.

For what follows, we will use $\phi_t$ as shorthand for $\phi_t(u_0, ..., u_{t-1})$, $c_t$ as shorthand for $c_t(\phi_t, u_t)$, and $f_t$ as shorthand for $f_t(\phi_t, u_t)$. With this notation, we write $$ \begin{align*} g(u_0, ... u_{N-1}) = \sum_{t=0}^{N-1} c_t(\phi_t, u_t) + c_N(\phi_N) \:. \end{align*} $$ Let $\lambda_k$ be specified as $$ \begin{align*} \lambda_{N-1}^\T &= - D_x c_N \:, \\ \lambda_{k}^\T &= \lambda_{k+1}^\T D_x f_{k+1} - D_x c_{k+1} \:, \:\: 0 \leq k \leq N-2 \:. \end{align*} $$ We form the Lagrangian $$ \begin{align*} \calL(u_0, ..., u_{N-1}) = \sum_{t=0}^{N-1} c_t(\phi_t, u_t) + c_N(\phi_N) + \sum_{t=0}^{N-1} \lambda_t^\T(\phi_{t+1} - f(\phi_t, u_t)) \:. \end{align*} $$ By construction, we have that $g = \calL$, since $\phi_{k+1} = f_k$. We now compute $D_{u_k} \calL$, starting with the base case $D_{u_{N-1}} \calL$. Using the fact that $(D_{u_k} \lambda_k^\T) (\phi_{k+1} - f_k) = 0$, $$ \begin{align*} D_{u_{N-1}} \calL &= D_u c_{N-1} + D_x c_N D_{u_{N-1}} \phi_N + \lambda_{N-1}^\T (D_{u_{N-1}} \phi_{N} - D_u f_{N-1}) \\ &= D_u c_{N-1} + (D_x c_N + \lambda_{N-1}^\T) D_{u_{N-1}} \phi_N - \lambda_{N-1}^\T D_u f_{N-1} \;. \end{align*} $$ Now using the setting $\lambda_{N-1}^\T = - D_x c_N$, we obtain $$ \begin{align*} D_{u_{N-1}} \calL = D_u c_{N-1} - \lambda_{N-1}^\T D_u f_{N-1} \:. \end{align*} $$ We now proceed for $0 \leq k < N-1$ as follows, $$ \begin{align*} D_{u_k} \calL &= D_u c_k + \sum_{t=k+1}^{N} D_x c_t D_{u_k} \phi_t + \lambda_k^\T( D_{u_k} \phi_{k+1} - D_u f_k) + \sum_{t=k+1}^{N-1} \lambda_t^\T( D_{u_k} \phi_{t+1} - D_x f_t D_{u_k} \phi_t) \\ &= D_u c_k - \lambda_k^\T D_u f_k + \sum_{t=k}^{N-2} ( D_x c_{t+1} + \lambda_t^\T - \lambda_{t+1}^\T D_x f_{t+1} ) D_{u_k} \phi_{t+1} + (D_x c_N + \lambda_{N-1}^\T) D_{u_k} \phi_N \:. \end{align*} $$ Recalling the setting of $\lambda_k$'s, we have $$ \begin{align*} D_{u_k} \calL &= D_u c_k - \lambda_k^\T D_u f_k \:. \end{align*} $$ Hence, using $g = \calL$, for all $0 \leq k \leq N-1$, $$ \begin{align*} \nabla_{u_k} g &= \nabla_{u_k} c_k - (D_u f_k)^\T \lambda_k \:. \end{align*} $$

These equations give us an efficient algorithm to compute $\nabla_u g$. First, we do a forward pass, where given inputs $u_0, ..., u_{N-1}$, we compute the associated trajectory $x_0, x_1, ..., x_{N-1}$. Next, we do a backward pass, where we recursively compute the values of the Lagrange multiplies $\lambda_k$. Once we have these values in hand, we can read off the gradient. Notice how the runtime of this algorithm is now $O(N)$ time (compared to $O(N^3)$ before), but we required extra $O(N)$ space. This is of course not a big deal since it takes $O(N)$ space to write down the gradient in the first place.

As an example, let us specialize to the case of LQR, where $c_t(x_t, u_t) = \frac{1}{2} x_t^\T Q_t x_t + \frac{1}{2} u_t^\T R_t u_t$, $c_N(x_N) = \frac{1}{2} x_N^\T Q_N x_N$, and $f_t(x_t, u_t) = A_t x_t + B_t u_t$. The forward pass is simply to set $x_{t+1} = A_t x_t + B_t u_t$ for $t = 0, ..., N-1$. For the backward pass, we set $\lambda_{N-1} = - Q_N x_N$, and then $\lambda_t = A_{t+1}^\T \lambda_{t+1} - Q_{t+1} x_{t+1}$ for $t=N-2, ..., 0$. The gradient $\nabla_u g$ is then $$ \begin{align*} \nabla_u g(u_0, ..., u_{N-1}) = \begin{bmatrix} R u_0 - B_0^\T \lambda_0 \\ R u_1 - B_1^\T \lambda_1 \\ \vdots \\ R u_{N-1} - B_{N-1}^\T \lambda_{N-1} \end{bmatrix} \:. \end{align*} $$