I was looking at the FlashAttention-2 paper recently. It’s about optimizing the forward and backward pass through the attention layer - this being the bottleneck to scaling transformers for longer sequence lengths. As a start, I implemented the forward and backward passes of the standard (unoptimized) attention module in PyTorch but with einsum
and einops
instead of view
, reshape
, transpose
, permute
, and matmul
which I habitually use.
Einsum and Einops
I found einsum
and einops
to be very useful and elegant for working on tensors - a core requirement of deep learning. One ends up writing more readable and less error-prone PyTorch code with einsum
(based on Einstein summation) and einops
(general tensor manipulation using Einstein notation). Moreover, einsum
+ einops
can lead to faster and more memory-efficient implementations, especially by potentially fusing operations; however, I need to investigate the efficient part more.
There are excellent tutorials and introductions that will get you started with einsum
/einops
in no time - 1, 2, 3
As a motivating example, consider a vector A(3)
and another matrix B(3, 4)
. We want to multiply each element of vector A
with the corresponding row of matrix B
(element-wise), and then sum the results along each row to get a final vector. In PyTorch, we can do
* B).sum(axis=1)) ((A[:, torch.newaxis]
where \(*\) denotes element-wise product. A bit tedious, isn’t it? With einsum
, we simply do:
'i,ij -> i', A, B) torch.einsum(
Einsum
is a little limited in its functionality. That is why, we also use einops
which can be thought of as an extension of einsum
. Whereas einsum
mostly permits reduce-sum type of operations, einops
is very convenient for adding dimensions (repeat
), performing general reduce
(whether max, sum, mean), and providing a new view of the tensor (rearrange
). Taken together, one can do almost all tensor computations using einsum
/einops
.
Implementing Attention Forward Pass using Einsum/Einops
Consider the implementation of standard single-head attention from here:
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
= query.size(-1)
d_k = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
scores if mask is not None:
= scores.masked_fill(mask == 0, -1e9)
scores = scores.softmax(dim=-1)
p_attn if dropout is not None:
= dropout(p_attn)
p_attn return torch.matmul(p_attn, value), p_attn
With einsum
, the same implementation is more concise and it took me less time and effort to implement. Note that our implementation below works for multi-head attention, but for comparison with above reference code, we used a single head (\(h = 1\)). Moreover, we don’t make use of mask
or dropout
in the above reference implementation for comparison.
def forward(self, q, k, v):
# I'm using notation from flash-attention2 paper
= torch.einsum('bihd,bjhd -> bhij', q, k) / self.scale
S self.P = F.softmax(S, dim = -1)
= torch.einsum('bhin,bnhj -> bihj', self.P, v)
O return O
To test, we simply pass random q, k, v
of appropriate shapes to both implementations and compare the output.
Backprop through Softmax
To implement our own backward pass of the attention module (so that we can later optimize it), one can use the equations in Section 2.2 of the FlashAttention-2 paper. It’s straightforward matrix calculus. One confusion may be symbols are in shorthand format - hence \(\mathbf dO\) stands for \(\frac{\partial L}{\partial O}\) - that is, the partial derivative of loss function \(L\) with respect to \(\mathbf O\), and it has the same shape as \(\mathbf O\). Same for \(\mathbf dV, dP, dS, dQ, dK\).
The derivation of softmax Jacobian may need some explanation (refer Section 2.2 of paper). Given \({\mathbf S} \in {\mathbb R}^{N \times N}\), the attention matrix is computed as:
\({\mathbf P} = softmax({\mathbf S}) \in {\mathbb R}^{N \times N}\)
where the softmax is applied row-wise to \(\mathbf S\). Given \({\mathbf dP} \in {\mathbb R}^{N \times N}\), we want to derive \({\mathbf dS} \in {\mathbb R}^{N \times N}\). First, let’s consider one row of \(\mathbf S\) and \(\mathbf P\) and denote them by \(\mathbf s\) and \(\mathbf p\) respectively. The corresponding rows of \(\mathbf dS\) and \(\mathbf dP\) are denoted by \(\mathbf ds\) and \(\mathbf dp\). By chain rule, we have
\({\mathbf ds} = J^T \cdot {\mathbf dp}\)
where \(J\) is the Jacobian matrix of \(\frac{\partial \mathbf p}{\partial \mathbf s}\). Let \({\mathbf p} = [p_1, p_2, \dots , p_N]\). \(p_i\) is computed as \(\frac{\exp(s_i)}{\sum_j {\exp(s_j)}}\). The entries of the Jacobian matrix are given by
\(J_{ii} = \frac{\exp(s_i)}{\sum_j {\exp(s_j)}} - \frac{\exp(2s_i)}{(\sum_j \exp(s_j))^2} = p_i - p_i^2\)
and for non-diagonal elements of \(J\), we have
\(J_{ij} = - \frac{\exp(s_i + s_j)}{(\sum_j \exp(s_j))^2} = - p_i \cdot p_j\)
The Jacobian can be written more concisely as
\(J = {\text diag}({\mathbf p}) - {\mathbf p}{\mathbf p}^T\)
where diag() is the matrix formed by placing elements of \({\mathbf p}\) on the diagonal; the second term is the outer product of \(\mathbf p\). In our case, \(J\) is symmetric, hence
\({\mathbf ds} = ({\text diag}({\mathbf p}) - {\mathbf p}{\mathbf p}^T) \cdot {\mathbf dp}\)
To generalize from rows of \({\mathbf P}, {\mathbf S}\) to full matrices, we’ll manipulate the above expression a bit:
\({\mathbf ds} = ({\text diag}({\mathbf p}) \cdot {\mathbf dp}- {\mathbf p}{\mathbf p}^T \cdot {\mathbf dp})\)
The above can be written in terms of the elementwise-product (\(*\)):
\({\mathbf ds} = {\mathbf p} * {\mathbf dp}- {\mathbf p} * ({\mathbf p}^T \cdot {\mathbf dp})\)
The dot product \(({\mathbf p}^T \cdot {\mathbf dp})\) is a scalar. Assuming broadcasting, we can write
\({\mathbf ds} = {\mathbf p} * ({\mathbf dp}- ({\mathbf p}^T \cdot {\mathbf dp}))\)
The above formulation was for single rows of \(\mathbf S, P, dS, dP\), but viewed as column vectors. Generalizing to full matrices, we derive
\({\mathbf dS} = {\mathbf P} * ({\mathbf dP} - \operatorname{dot\_prod})\)
where \(\operatorname{dot\_prod} = \sum_k P[i, k] * dP[i, k]\) is the dot product in the last dimension between matrices \(\mathbf P\) and \(\mathbf dP\). The above works because \(\operatorname{dot\_prod}\) will get broadcasted to dimension of \(\mathbf P\); for that, we need to insert an extra dimension at the end of \(\operatorname{dot\_prod}\), else PyTorch will complain about dimension mismatch. One can check the correctness of our own backward()
implementation by comparing output with that given by autograd
. The full implementation is available in this notebook.
In summary, we implemented forward and backward passes of standard attention module as an exercise in using einsum
and einops
, and also because it provides us a starting point for hardware optimizations to speed up attention for long context lengths. Along the way, we saw that tensor formulation of backprop through softmax
needed some matrix manipulations.
Backpropagation is amazing since it lets us reuse computation, resulting in gradient computation expense comparable to forward pass (within a constant factor). Autograd is great since it speeds up implementation by freeing us from the need to write gradient code. However, notice how autograd required saving intermediate computational states (in our example, we saved the \(\bf P\) matrix; see notebook). That’s a load on limited GPU memory.
A fun exercise: why do you think topological sort is needed during backprop?
Thanks to Eric for introduction to new concepts.