A Trick for Backpropagation of Linear Transformations
Linear transformations such as sums, matrix products, dot products, Hadamard products, and many more can often be represented using an einsum (short for Einstein summation).
This post explains a simple trick to backpropagate through any einsum, regardless of what operations it represents.
Example Einsum
For example, an einsum for matrix multiplication can be written like so:
import numpy as np A = np.arange(2 * 3).reshape(2, 3) # A = [ # [0, 1, 2], # [3, 4, 5] # ] B = np.arange(3 * 4).reshape(3, 4) # B = [ # [0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10, 11] # ] # C = A @ B (matrix multiplication) # calculate with einsum # A uses i and j, B uses j and k, and C uses i and k C_einsum = np.einsum('ij,jk->ik', A, B) # C_einsum = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ] # calculate with for loops C_forloop = np.zeros((2, 4)) # order of the loops doesn't matter # you can swap the order of the loops, # just like you can swap the order of some double integrals for i in range(2): for j in range(3): for k in range(4): # A uses i and j, B uses j and k, and C uses i and k, # just like before with the einsum C_forloop[i, k] += A[i, j] * B[j, k] # C_forloop = [ # [20, 23, 26, 29], # [56, 68, 80, 92] # ]
Backpropagating Through Einsum
Here comes the fun part. Backpropagating through an einsum is easy with a simple trick.
Let $L$ be the loss function, and assume we know $\frac{\partial L}{\partial C}$ and need to compute $\frac{\partial L}{\partial A}$.
We can just swap what we did for $C$ with what we did for $A$, keeping the letters the same for each tensor. The code should explain it best.
# forward pass # A uses i and j, B uses j and k, and C uses i and k C = np.einsum('ij,jk->ik', A, B) # backward pass # dL_dC is computed somewhere in the backward pass # A uses i and j, B uses j and k, and C uses i and k # again, just like before with the original einsum # we swapped the letters for C with the letters for A # we used dL_dC instead of A for the first parameter # we computed dL_dA instead of C for the output dL_dA = np.einsum('ik,jk->ij', dL_dC, B)
This simple swapping trick makes it very easy to create formulas when backpropagating through einsums.
Verifying the Shape
It has the same shape as $A$ because the output letters we used (i and j) during the backward pass correspond to the input letters for $A$ (i and j) during the forward pass.
Interpreting the Einsum
If you swap around the letters, you can more easily interpret the einsum used for backpropagation, as a kind of matrix multiplication.
# replacing j with k, and k with j # doesn't change the result dL_dA = np.einsum('ij,kj->ik', dL_dC, B)
This corresponds to multiplying $\frac{\partial L}{\partial C}$ by $B^\intercal$. We can tell because it looks like the example einsum for matrix multiplication, except the letters for $B$ are swapped, meaning $B$ is transposed.
Verifying the Values
We can use JAX's automatic differentiation to verify the values.
import jax import jax.numpy as jnp # compute the loss def loss(A, B): C = jnp.einsum('ij,jk->ik', A, B) return jnp.sum(C) # compute the gradient of the loss with respect to A def grad_A(A, B): # argnums=0 corresponds to A return jax.grad(loss, argnums=0)(A, B) A = jnp.arange(2 * 3).reshape(2, 3).astype(jnp.float32) B = jnp.arange(3 * 4).reshape(3, 4).astype(jnp.float32) # autograd output print(grad_A(A, B)) C = jnp.einsum('ij,jk->ik', A, B) # gradient of the loss with respect to C dL_dC = jnp.ones_like(C) # gradient of the loss with respect to A # this is the backpropagation formula we derived dL_dA = jnp.einsum('ik,jk->ij', dL_dC, B) # manually computed output print(dL_dA) assert (dL_dA == grad_A(A, B)).all() print("Success!")
Conclusion
Einsums are a powerful tool for representing and reasoning about linear transformations, and this simple swapping trick makes it very easy to backpropagate through them. I hope you find this informative!