Einsum is better notation, too

Einsum is All you Need was a deeply influential blog post for me. But it recently occurred to me that einsum isn’t just good for efficiently expressing transformations, the index notation is also a much clearer way to write expressions.

I find math notation quite difficult to unpack. I constantly have to remind myself whether I am looking at a scalar, vector, or tensor. What are the dimensions? Where do I have to transpose to get everything lined up?

The first example in “Einsum is All you Need” does standard matrix mulitplication, then sums each colum. Here it is:

$$ {\color{green}c_j} = \sum_i\sum_k {\color{red}A_{ik}}{\color{blue}B_{kj}} = {\color{red}A_{ik}}{\color{blue}B_{kj}} $$

Or as code:

einsum("ik,kj->j")

The neat feature of this notation is that every value in the equation is a scalar. For the classic example of a mean centered matrix X where each row is a sample we can write $$var(X) = X^TX$$.

For the Einstein summation note that I don’t have to transpose, or say what type var(X) or X are because they must be scalar:

$$ var_{jk} = {\color{red}X_{ij}}{\color{blue}X_{ik}} $$

Code:

np.einsum('ij,ik->jk', x, x)

This also comes up in cross-attention. Given the usual

$$ {\color{green}k_{ik}}, {\color{red}q_{jk}}, {\color{blue}v_{lk}} $$

we pick

$$ w_{ij} = {\color{green}k_{ik}} {\color{red}q_{jk}} $$

And now I can pick values for either input just by switching indexing: $$o_{ik} = w_{ij}{\color{blue}v_{jk}}$$ or $$o_{jk} = w_{ij}{\color{blue}v_{ik}}$$

Code:

w_ij = einsum("ik,jk->ik", k, q)
o_ik = einsum("ij,jk->ik", w_ij, v)