由于主包最近准备了一场 Coding面试,初步领略了下Flash
Attention的神奇,其中的关键操作便是对于Softmax的优化。面试结束之后主包也是决定记录一下这个操作,于是有了这篇博客。
由于是准备的Coding面,这里也会附上代码的😁 (所以也可以叫
手撕online softmax?)
为什么需要Online Softmax
这个首先要从Flash
Attention说起,这里简单介绍,原始Transformer中,会存在一个Score矩阵,其维度为
N ⋅ N,非常巨大。这导致其在GPU上没法分块计算。由此我们想要分块计算Q,K,V矩阵,Flash
Attention便完成了这个任务,有效减小了GPU的IO操作数量,从而实现了加速。
想法很简单,但是实现起来存在一些问题,其中的关键便是Softmax,我们从头开始说起
Softmax的“进化”
标准的Softmax
我们都知道Softmax的公式如下:
$$
Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}}
$$
这个的手撕代码很简单,我们对矩阵做行Softmax:
1234567891011import torchimport torch.nn.functional as Fx = torch.randn(2,5)row_sum = torch.exp(x).sum(dim=-1,keepdim=True)ours_out = torch.exp(x) / row_sumprint("======= ours softmax =======")print(ours_out)print("======= Standard softmax =======")print(F.softmax(x, dim=-1))
输出如下:
123456======= ours softmax =======tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675], [0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])======= Standard softmax =======tensor([[0.4979, 0.0880, 0.0198, 0.0267, 0.3675], [0.1271, 0.1828, 0.3216, 0.1422, 0.2263]])
看上去非常不错,但是当我们希望使用fp16精度或者输入的数据大一些,由于指数的存在,会很容易的溢出,比如我们仅仅将x的输入扩大100倍(x = torch.randn(2,5)*100),
输出就会变为
123======= ours softmax =======tensor([[nan, 0., 0., 0., 0.], [0., 0., nan, 0., 0.]])
结果变得不稳定,同时出现数据溢出。由此,便提出了
Safe_softmax
Safe_softmax
为了解决上述问题,我们可以很简单的对输入做一个平移。具体而言我们可以让每个x减去其所在行的最大值,核心可以描述为如下公式:
$$
Softmax(x_i) = \frac{e^{x_i}}{\sum_{j}^{N} e^{x_j}} =
\frac{e^{x_i-x_{max}}}{\sum_{j}^{N} e^{x_j-x_{max}}}
$$
证明比较简单,我们用上一个标准的softmax不能跑的情况测试一次
123456789101112import torchimport torch.nn.functional as Fx = torch.randn(2,5)*100x_max,_ = x.max(dim=-1, keepdim=True)x = x - x_maxrow_sum = torch.exp(x).sum(dim=-1,keepdim=True)ours_out = torch.exp(x) / row_sumprint("======= ours softmax =======")print(ours_out)print("======= Standard softmax =======")print(F.softmax(x, dim=-1))
输出:
123456======= ours softmax =======tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00], [5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])======= Standard softmax =======tensor([[0.0000e+00, 0.0000e+00, 7.4508e-32, 9.3580e-40, 1.0000e+00], [5.3247e-28, 0.0000e+00, 1.0000e+00, 2.9268e-07, 0.0000e+00]])
目前我们简单解决了Softmax的溢出问题,但是我们会发现其计算过程中,依赖全局的和。这也限制了我们希望能够局部分块处理Attention的计算。由此提出了Online Softmax。
Online Softmax
Online Softmax将原来的计算过程,通过两个全局变量转变为了流式计算过程。
具体而言,我们假设输入X可以分为两个部分,即 X = (X1,X2)
。我们依次处理两个部分。
首先针对X1,
我们通过如下公式计算,其行和最大值 M1,以及Row_Sum:L1,
$$
L_1 = \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) , 即局部exp之和
$$
由此,基于上述两个全局变量我们能够计算出这个局部的(即X1)softmax值,同理我们可以得到M2, L2
但是想要基于全局准确的Softmax结果还要进一步计算,具体可以做如下推导:
$$
\begin{aligned}
M &= Max(M_1, M_2) \\
L &= \sum_{j=1}^{X.size(-1)} exp(X^j - M) \\
&= \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M) +
\sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M) \\
&= exp(M_1 - M) * \sum_{j=1}^{X_1.size(-1)} exp(X_1^j - M_1) +
exp(M_2 - M) * \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M_2) \\
& = L_1 * exp(M_1 - M) + L_2 * exp(M_2 - M)
\end{aligned}
$$
由此我们更具局部的信息计算出来全局的两个信息,当然在实际输入顺序中,L2
可以不用计算,可以计算得到M之后,直接计算局部和, 即:
$$
L = L_1 * exp(M_1 - M) + \sum_{j=1}^{X_2.size(-1)} exp(X_2^j - M)
$$
得到两个的全局信息之后,在遍历一遍即可得到最后的Softmax值,具体的代码实现如下:
12345678910111213141516171819202122232425import torchimport torch.nn.functional as Fx = torch.randn(2,16)out = torch.zeros_like(x)x_blocks = torch.split(x, 4, dim=1)out_blocks = list(torch.split(out, 4, dim=1))m = torch.ones(x.size(0)).unsqueeze(-1)*-1e9l = torch.zeros(x.size(0)).unsqueeze(-1)for xi in x_blocks: mi,_ = xi.max(dim=-1, keepdim=True) m_new = torch.maximum(mi,m) l = l*torch.exp(m - m_new) + torch.exp(xi-m_new).sum(dim=-1, keepdim=True) m = m_newfor i in range(len(x_blocks)): out_blocks[i] = torch.exp(x_blocks[i] - m) / l out = torch.cat(out_blocks, dim=1)print("======= ours softmax =======")print(out)print("======= Standard softmax =======")print(F.softmax(x, dim=-1))
输出:
12345678910======= ours softmax =======tensor([[0.0072, 0.1026, 0.0273, 0.0057, 0.2055, 0.1391, 0.0090, 0.0694, 0.0036, 0.0191, 0.0087, 0.0462, 0.0069, 0.0355, 0.2882, 0.0259], [0.1437, 0.0477, 0.0249, 0.0123, 0.0210, 0.0544, 0.0244, 0.0082, 0.2192, 0.1388, 0.0311, 0.1340, 0.0217, 0.0071, 0.0646, 0.0468]])======= Standard softmax =======tensor([[0.0072, 0.1026, 0.0273, 0.0057, 0.2055, 0.1391, 0.0090, 0.0694, 0.0036, 0.0191, 0.0087, 0.0462, 0.0069, 0.0355, 0.2882, 0.0259], [0.1437, 0.0477, 0.0249, 0.0123, 0.0210, 0.0544, 0.0244, 0.0082, 0.2192, 0.1388, 0.0311, 0.1340, 0.0217, 0.0071, 0.0646, 0.0468]])
最后用一张图片来总结一下online
softmax的思路,我觉得挺贴切的。即可以将原本三次扫描的Safe
Softmax转变为两次扫描的 Online Softmax
示意图
当然在实际的Flash Attention中,可以进一步简化为one
pass的扫描,这就下回分解了😋
参考
【手撕online
softmax】Flash Attention前传,一撕一个不吱声