A Trick for Backpropagation of Linear Transformations

Linear transformations such as sums, matrix products, dot products, Hadamard products, and many more can often be represented using an einsum (short for Einstein summation).
This post explains a simple trick to backpropagate through any einsum, regardless of what operations it represents.

Example Einsum

For example, an einsum for matrix multiplication can be written like so:
import numpy as np

A = np.arange(2 * 3).reshape(2, 3)
# A = [
# [0, 1, 2],
# [3, 4, 5]
# ]
B = np.arange(3 * 4).reshape(3, 4)
# B = [
# [0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10, 11]
# ]

# C = A @ B (matrix multiplication)
# calculate with einsum
# A uses i and j, B uses j and k, and C uses i and k
C_einsum = np.einsum('ij,jk->ik', A, B)
# C_einsum = [
# [20, 23, 26, 29],
# [56, 68, 80, 92]
# ]

# calculate with for loops
C_forloop = np.zeros((2, 4))
# order of the loops doesn't matter
# you can swap the order of the loops,
# just like you can swap the order of some double integrals
for i in range(2):
    for j in range(3):
        for k in range(4):
            # A uses i and j, B uses j and k, and C uses i and k,
            # just like before with the einsum
            C_forloop[i, k] += A[i, j] * B[j, k]
# C_forloop = [
# [20, 23, 26, 29],
# [56, 68, 80, 92]
# ]

Backpropagating Through Einsum

Here comes the fun part. Backpropagating through an einsum is easy with a simple trick.
Let $L$ be the loss function, and assume we know $\frac{\partial L}{\partial C}$ and need to compute $\frac{\partial L}{\partial A}$.
We can just swap what we did for $C$ with what we did for $A$, keeping the letters the same for each tensor. The code should explain it best.
# forward pass

# A uses i and j, B uses j and k, and C uses i and k
C = np.einsum('ij,jk->ik', A, B)

# backward pass

# dL_dC is computed somewhere in the backward pass

# A uses i and j, B uses j and k, and C uses i and k
# again, just like before with the original einsum
# we swapped the letters for C with the letters for A
# we used dL_dC instead of A for the first parameter
# we computed dL_dA instead of C for the output
dL_dA = np.einsum('ik,jk->ij', dL_dC, B)
This simple swapping trick makes it very easy to create formulas when backpropagating through einsums.

Verifying the Shape

It has the same shape as $A$ because the output letters we used (i and j) during the backward pass correspond to the input letters for $A$ (i and j) during the forward pass.

Interpreting the Einsum

If you swap around the letters, you can more easily interpret the einsum used for backpropagation, as a kind of matrix multiplication.
# replacing j with k, and k with j
# doesn't change the result
dL_dA = np.einsum('ij,kj->ik', dL_dC, B)
This corresponds to multiplying $\frac{\partial L}{\partial C}$ by $B^\intercal$. We can tell because it looks like the example einsum for matrix multiplication, except the letters for $B$ are swapped, meaning $B$ is transposed.

Verifying the Values

We can use JAX's automatic differentiation to verify the values.
import jax
import jax.numpy as jnp

# compute the loss
def loss(A, B):
    C = jnp.einsum('ij,jk->ik', A, B)
    return jnp.sum(C)

# compute the gradient of the loss with respect to A
def grad_A(A, B):
    # argnums=0 corresponds to A
    return jax.grad(loss, argnums=0)(A, B)

A = jnp.arange(2 * 3).reshape(2, 3).astype(jnp.float32)
B = jnp.arange(3 * 4).reshape(3, 4).astype(jnp.float32)

# autograd output
print(grad_A(A, B))

C = jnp.einsum('ij,jk->ik', A, B)
# gradient of the loss with respect to C
dL_dC = jnp.ones_like(C)
# gradient of the loss with respect to A
# this is the backpropagation formula we derived
dL_dA = jnp.einsum('ik,jk->ij', dL_dC, B)

# manually computed output
print(dL_dA)

assert (dL_dA == grad_A(A, B)).all()

print("Success!")

Conclusion

Einsums are a powerful tool for representing and reasoning about linear transformations, and this simple swapping trick makes it very easy to backpropagate through them. I hope you find this informative!