Flash Attention

前情提要
GPU 存储分为芯片内和芯片外,芯片内的 SRAM 用于储存需要计算的临时数据,显存 HBM 在芯片外:
- HBM:位于 GPU 芯片外,就是我们所说的显存,类似于 CPU 的 DRAM,储存模型训练和推理时的参数,容量大,例如 A100 一般为 40G 或 80G。
- SRAM:位于 GPU 芯片上,仅用于存储 CUDA Kernel 计算时所需的临时数据,容量极限一般在 20MB
- CUDA Kernel:GPU 上执行并行的计算函数,是实现并行计算任务的基本单元
原始的 Attention 计算是 ,GPU 需要以下步骤,总计 6 次通信:
- 从 HBM 中加载 、 到 SRAM
- Kernel 计算出
- 将 写会 HBM
- 将 加载到 SRAM
- 计算
- 将 写回 HBM
- 将 、 加载到 SRAM
- 计算
- 将 写回 HBM
但我这里就很奇怪了,既然 Kernel 能计算出 ,那不就代表 SRAM 能存下整个 矩阵了吗?那何必在 SRAM 和 HBM 里面来回移动呢?实际上,在传统 kernel 里 GPU 并不是把 整个 Q、K 都读进 SRAM 再算出 的,而是 分 tile读入并计算。我们的矩阵会被分为一个个 tile 如下图:
+----+----+----+----+
|t11 |t12 |t13 |... |
+----+----+----+----+
|t21 |t22 |t23 |... |
+----+----+----+----+
|t31 |t32 |t33 |... |
+----+----+----+----+
然后 kernel 按照下面的方式,一次次计算一个 tile 并返回,最后得到完整的计算结果,实际上就是分块矩阵的思想:
for Qi in Q_tiles:
load Qi from HBM
for Kj in K_tiles:
load Kj from HBM
Sij = Qi @ Kj^T
store Sij to HBM
所以实际上 SRAM 和 HBM 的通信次数是 ,我们说的 6 次是在 Matrix Level。
优化思路
假设不考虑 softmax 的过程,我们计算 ,这样我们只需要两次 HBM 和 SRAM 的通信了,把 、、 分块从 HBM 读入 SRAM,计算之后再把这一小块 从 SRAM 写会 HBM,最后就能得到计算结果了。但是问题就出在 softmax 身上,由于 softmax 每次需要一整行数据,但是分块后 只有一小块,并不是一整行,如下图。
前向传播
float16 支持的范围是 ±65504,意味着当 x>11时候, 将超过有效范围出现溢出,这就引出了 safe softmax 的概念(PS:在 CS336 手写 softmax 时候我们就实现过)。
每个数字减去最大值再求 softmax 不会改变最终结果,所以在实际使用时都用 safe softmax。
假设 的第一行为 ,我们把矩阵 和 分块,从 HBM 读入数据到 SRAM 分别得到了第一行的一部分 和 ,此时需要计算第一行的 softmax 值 :
- 计算每一块的最大值: 和
- 计算每一块的分子: 和
- 计算每一块的分母: 和
- 合并最大值:
- 计算全局分母:
- 计算最终 softmax 结果: 和
但是我们计算 的时候如何知道整个序列的 和 呢?我们进一步看一下论文是怎么写的:
我认为 Flash Attention 里面很巧妙的一点在于,它忽略了 和 直接计算 。刚刚我们认为存在问题是因为站在了 矩阵或者 矩阵的视角上,它的形状是 ,我们总想着怎么把它在列方向拆分。而 的形状是 ,这样我们就可以让分块矩阵 、、 在 的每一行上原地更新。
这里我举一个例子就能完全理解:将 求平均值然后和 做乘法。假设我们有行向量:
和列向量:
我们的目标是计算:
由于空间不足不能一次性读入 矩阵,所以我们只能一个一个获取 和 :
- 先看 ,此时分母为 ,我们计算
- 加上 ,新分母为 ,把旧 乘回旧分母,加上新项,再除新分母
- 加上 ,新分母为 ,我们同样操作
这就是 在线归一化 的思想。在 Flash Attention 中,我们除了需要重新计算全局分母,还需要在最大值更新时更新分子,这就是 Flash Attention 的全部思想。
for KV_block in KV: # 外层
load K_block, V_block -> shared memory
for Q_block in Q: # 内层
load Q_block
compute Q_block @ K_block^T
online softmax update
accumulate with V_block
假如我们有 n 个 Q block 和 m 个 KV block(通常情况下 n 都是大等于 m 的,比如 Inference 的情况)。先加载 KV block,它就可以和所有的 Q block 比较,总共需要 load 一次全部 KV block 和 m 次全部 Q block。假如我们先加载 Q block,那么总共需要 load 一次全部 Q block 和 n 次 KV block,明显这种 HBM 访问次数更多。
反向传播
反向传播的梯度计算太复杂了,这里就不具体推到了。它的核心在于,虽然前向计算中省略了 和 矩阵的计算,缺少了激活值,但是我们在 HBM 里面也存了 和 可以帮助我们在反向传播中很快的 recompute,性能不会差非常多。