Elegant Tensors: Attention Forward & Backward Passes using PyTorch Einsum/Einops

Author

Pankaj Pansari

Published

April 29, 2025

I was looking at the FlashAttention-2 paper recently. It’s about optimizing the forward and backward pass through the attention layer - this being the bottleneck to scaling transformers for longer sequence lengths. As a start, I implemented the forward and backward passes of the standard (unoptimized) attention module in PyTorch but with einsum and einops instead of view, reshape, transpose, permute, and matmul which I habitually use.

Einsum and Einops

I found einsum and einops to be very useful and elegant for working on tensors - a core requirement of deep learning. One ends up writing more readable and less error-prone PyTorch code with einsum (based on Einstein summation) and einops (general tensor manipulation using Einstein notation). Moreover, einsum + einops can lead to faster and more memory-efficient implementations, especially by potentially fusing operations; however, I need to investigate the efficient part more.

There are excellent tutorials and introductions that will get you started with einsum/einops in no time - 1, 2, 3

As a motivating example, consider a vector A(3) and another matrix B(3, 4). We want to multiply each element of vector A with the corresponding row of matrix B (element-wise), and then sum the results along each row to get a final vector. In PyTorch, we can do

((A[:, torch.newaxis] * B).sum(axis=1))

where \(*\) denotes element-wise product. A bit tedious, isn’t it? With einsum, we simply do:

torch.einsum('i,ij -> i', A, B)

Einsum is a little limited in its functionality. That is why, we also use einops which can be thought of as an extension of einsum. Whereas einsum mostly permits reduce-sum type of operations, einops is very convenient for adding dimensions (repeat), performing general reduce (whether max, sum, mean), and providing a new view of the tensor (rearrange). Taken together, one can do almost all tensor computations using einsum/einops.

Implementing Attention Forward Pass using Einsum/Einops

Consider the implementation of standard single-head attention from here:

def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

With einsum, the same implementation is more concise and it took me less time and effort to implement. Note that our implementation below works for multi-head attention, but for comparison with above reference code, we used a single head (\(h = 1\)). Moreover, we don’t make use of mask or dropout in the above reference implementation for comparison.

def forward(self, q, k, v):
    # I'm using notation from flash-attention2 paper
    S = torch.einsum('bihd,bjhd -> bhij', q, k) / self.scale
    self.P = F.softmax(S, dim = -1)
    O = torch.einsum('bhin,bnhj -> bihj', self.P, v)
    return O

To test, we simply pass random q, k, v of appropriate shapes to both implementations and compare the output.

Backprop through Softmax

To implement our own backward pass of the attention module (so that we can later optimize it), one can use the equations in Section 2.2 of the FlashAttention-2 paper. It’s straightforward matrix calculus. One confusion may be symbols are in shorthand format - hence \(\mathbf dO\) stands for \(\frac{\partial L}{\partial O}\) - that is, the partial derivative of loss function \(L\) with respect to \(\mathbf O\), and it has the same shape as \(\mathbf O\). Same for \(\mathbf dV, dP, dS, dQ, dK\).

The derivation of softmax Jacobian may need some explanation (refer Section 2.2 of paper). Given \({\mathbf S} \in {\mathbb R}^{N \times N}\), the attention matrix is computed as:

\({\mathbf P} = softmax({\mathbf S}) \in {\mathbb R}^{N \times N}\)

where the softmax is applied row-wise to \(\mathbf S\). Given \({\mathbf dP} \in {\mathbb R}^{N \times N}\), we want to derive \({\mathbf dS} \in {\mathbb R}^{N \times N}\). First, let’s consider one row of \(\mathbf S\) and \(\mathbf P\) and denote them by \(\mathbf s\) and \(\mathbf p\) respectively. The corresponding rows of \(\mathbf dS\) and \(\mathbf dP\) are denoted by \(\mathbf ds\) and \(\mathbf dp\). By chain rule, we have

\({\mathbf ds} = J^T \cdot {\mathbf dp}\)

where \(J\) is the Jacobian matrix of \(\frac{\partial \mathbf p}{\partial \mathbf s}\). Let \({\mathbf p} = [p_1, p_2, \dots , p_N]\). \(p_i\) is computed as \(\frac{\exp(s_i)}{\sum_j {\exp(s_j)}}\). The entries of the Jacobian matrix are given by

\(J_{ii} = \frac{\exp(s_i)}{\sum_j {\exp(s_j)}} - \frac{\exp(2s_i)}{(\sum_j \exp(s_j))^2} = p_i - p_i^2\)

and for non-diagonal elements of \(J\), we have

\(J_{ij} = - \frac{\exp(s_i + s_j)}{(\sum_j \exp(s_j))^2} = - p_i \cdot p_j\)

The Jacobian can be written more concisely as

\(J = {\text diag}({\mathbf p}) - {\mathbf p}{\mathbf p}^T\)

where diag() is the matrix formed by placing elements of \({\mathbf p}\) on the diagonal; the second term is the outer product of \(\mathbf p\). In our case, \(J\) is symmetric, hence

\({\mathbf ds} = ({\text diag}({\mathbf p}) - {\mathbf p}{\mathbf p}^T) \cdot {\mathbf dp}\)

To generalize from rows of \({\mathbf P}, {\mathbf S}\) to full matrices, we’ll manipulate the above expression a bit:

\({\mathbf ds} = ({\text diag}({\mathbf p}) \cdot {\mathbf dp}- {\mathbf p}{\mathbf p}^T \cdot {\mathbf dp})\)

The above can be written in terms of the elementwise-product (\(*\)):

\({\mathbf ds} = {\mathbf p} * {\mathbf dp}- {\mathbf p} * ({\mathbf p}^T \cdot {\mathbf dp})\)

The dot product \(({\mathbf p}^T \cdot {\mathbf dp})\) is a scalar. Assuming broadcasting, we can write

\({\mathbf ds} = {\mathbf p} * ({\mathbf dp}- ({\mathbf p}^T \cdot {\mathbf dp}))\)

The above formulation was for single rows of \(\mathbf S, P, dS, dP\), but viewed as column vectors. Generalizing to full matrices, we derive

\({\mathbf dS} = {\mathbf P} * ({\mathbf dP} - \operatorname{dot\_prod})\)

where \(\operatorname{dot\_prod} = \sum_k P[i, k] * dP[i, k]\) is the dot product in the last dimension between matrices \(\mathbf P\) and \(\mathbf dP\). The above works because \(\operatorname{dot\_prod}\) will get broadcasted to dimension of \(\mathbf P\); for that, we need to insert an extra dimension at the end of \(\operatorname{dot\_prod}\), else PyTorch will complain about dimension mismatch. One can check the correctness of our own backward() implementation by comparing output with that given by autograd. The full implementation is available in this notebook.

In summary, we implemented forward and backward passes of standard attention module as an exercise in using einsum and einops, and also because it provides us a starting point for hardware optimizations to speed up attention for long context lengths. Along the way, we saw that tensor formulation of backprop through softmax needed some matrix manipulations.

Note

Backpropagation is amazing since it lets us reuse computation, resulting in gradient computation expense comparable to forward pass (within a constant factor). Autograd is great since it speeds up implementation by freeing us from the need to write gradient code. However, notice how autograd required saving intermediate computational states (in our example, we saved the \(\bf P\) matrix; see notebook). That’s a load on limited GPU memory.

A fun exercise: why do you think topological sort is needed during backprop?


Thanks to Eric for introduction to new concepts.