在微调过程中使用无填充 Transformer 层节省内存
对于长序列训练,注意力计算可能成为内存瓶颈,因为朴素实现需要 内存,其中 是序列长度。然而,最近提出了 FlashAttention [1,2],它优化了 IO 并使用在线 softmax [3] 来减少 GPU 内存(通常是数据中心 GPU 的 HBM)和 GPU 缓存的数据移动 [4]。FlashAttention 算法还将注意力计算的内存需求从 减少到 ,即从二次方减少到线性。
目前大多数库(包括 HuggingFace transformers)中 FlashAttention [1,2] 的训练集成是基于对这些库的非侵入性修改,大多数实现只是用 FlashAttention [1,2] 替换了朴素注意力模块。尽管这很容易实现,但在批次中使用可变长度序列时(即批次中存在填充时)它的性能会次优。除了注意力之外的所有操作都独立地应用于 Transformer 中的每个 token 位置。由于 FlashAttention [1,2] 完全避免了对填充 token 的任何计算/内存需求,因此在使用 FlashAttention 时,可以从 Transformer 模型中删除所有多余的计算和填充 token 所需的内存,从而本质上创建一个无填充 Transformer 模型。原始的 FlashAttention 训练代码库中也执行了此操作。需要注意的是,这是模型的精确实现,没有近似值。
HuggingFace TGI 中也进行了类似的优化以提高推理效率。需要注意的是,在不需要批次填充的情况下,例如批次中所有示例的长度相等,或者在使用密集打包示例时(如预训练模型的情况),这不会成为问题。
在本博客中,我们给出了朴素注意力、带有填充 Transformer 块的 FlashAttention(HuggingFace transformers 库中的当前实现)和无填充 Transformer 块的理论内存消耗。
假设一个形状为 的嵌入输入批次作为 Transformer 层的输入,其中 、 和 分别表示批次大小、批次中第 个示例的未填充序列长度以及 Transformer 模型的隐藏大小。为了训练模型,每个 Transformer 层都需要缓存每个操作(在前向传播中计算)的激活以进行反向传播。为了简化,我们假设训练使用 16 位精度(张量中每个值 2 字节)。我们还假设多头注意力 [6] 具有 个注意力头。尽管相同的思想也适用于多查询注意力 [7] 和分组查询注意力 [8]。
朴素注意力
输入层归一化接收形状为 的输入,该输入需要为反向传播缓存。均值和方差也需要缓存,它们的形状都是 。由于 ,我们可以忽略均值和方差的 元素。此操作的总激活内存为
QKV 投影矩阵的输入(在 Q、K 和 V 投影之间共享)需要缓存。它也有 个元素,占用 字节。Q、K 和 V 投影的每个输出也需要缓存,每个输出都包含 个元素,每个占用 字节。此操作的总激活内存为
softmax 的输出,它具有 个元素,也需要缓存。此操作的总激活内存为
注意力 softmax 有一个 dropout,需要保存一个包含 个元素的掩码。每个元素占用一个字节,因为 PyTorch 不允许位张量。这可能是因为 GPU 通常是字节寻址的,从而简化了实现。此操作的总内存为
softmax dropout 的输出包含 个元素,也需要缓存。此操作的总激活内存为
我们缓存上述乘法的输出,它是投影矩阵的输入。它包含 个元素。此操作的总激活内存为
只需要缓存 dropout 掩码。此操作的总内存为
与之前的层归一化相同。内存需求为
我们在这里假设前馈隐藏维度为 ,这对于标准 Transformer 来说是典型的。每个线性层的输入和 GELU 激活函数的输入都需要缓存。它们分别占用 、 字节和 字节。MLP 块所需的内存为
所需的内存与上面的第 (8) 点相同,即
将这些相加,每层的总激活内存由下式给出:
带有 FlashAttention 的 Transformer 层
FlashAttention [1,2] 已集成到 HuggingFace transformers API 中。在撰写本博客时,当前实现会在 FlashAttention 核执行之前执行一次去填充操作。此操作会将形状为 的输入 Q、K、V 转换为形状为 的输入(其中批次中的每个示例都被一个接一个地连接起来,形成一个二维张量),并启动 FlashAttention 核。注意力计算后,输出再次填充回 形状。
FlashAttention [1,2] 避免了在内存中具体化 二次矩阵,并使用在线 softmax [3],从而无需在点 (3) 缓存激活。相反,我们只需要具体化输出矩阵,其形状为 ,2 个 softmax 统计数据都具有相同的形状 以及用于 dropout 的随机数生成器状态,这里我们忽略了这一点。有关算法详情,请参阅 FlashAttention [1,2] 论文。我们还需要缓存用于填充和取消填充的布尔注意力掩码。尽管在计算中我们忽略了它,因为它对每个层都相同,并且可以为整个 Transformer 模型缓存一次,而不需要在每个层都缓存。因此,注意力所需的内存变为
因此,FlashAttention [1,2] 的每层总激活内存如下:
无填充 Transformer 层
由于 Transformer 层中的所有操作(注意力除外)对于每个 token 位置都相同,我们可以避免填充和去填充操作,从而进一步减少 Transformer 层所需的激活内存,这需要对 HuggingFace transformers 实现进行少量修改。在此 Transformer 实现中,根本没有浪费填充 token 位置的内存!在这种情况下,整个 Transformer 模型的输入形状为 。在这种情况下,内存由以下公式给出:
需要注意的是,当没有填充时,即当 时,。这种优化类似于运行带有嵌套张量的 Transformer 模型。尽管已经付出了巨大的努力来通过按上下文长度对示例进行分桶等方法来解决此问题,但这些方法会导致模型性能下降,尤其是在微调期间。
使用无填充 Transformer 层的动机
现在,我们分析了 3 种 Transformer 层实现中的内存消耗。我们假设我们有一个序列长度遵循离散均匀分布的数据集,即 ,其中 是表示批次中第 是数据集和模型最大序列长度。我们以 个示例批次采样,序列长度为 。我们计算在离散均匀分布下,、 和 。为此,我们考虑另一个随机变量 。可以导出 的累积分布函数为: 现在,利用批次中的示例是独立同分布的这一事实,我们有 ,因此我们有 的概率质量函数为: 我们可以使用计算方法或 Faulhaber 公式 [9] 以及上述推导结果来计算这 3 种方法的内存使用期望。下表报告了使用 20B 参数模型的方程推导出的理论内存消耗。我们发现使用无填充版本的 Transformer 层节省了 的激活内存,并且还节省了大量冗余的 FLOPs。本博客未对 FLOPs 进行分析,但它们很容易推导。 个样本的序列长度的随机变量,
序列长度 | 朴素注意力 (Naive Attention) | Flash Attention | 无填充 Transformer |
---|---|---|---|
512 | 1.085 GB | 0.721 GB | 0.411 GB |
1024 | 2.919 GB | 1.441 GB | 0.821 GB |
2048 | 8.837 GB | 2.882 GB | 1.642 GB |
4096 | 29.674 GB | 5.763 GB | 3.283 GB |
8192 | 107.347 GB | 11.524 GB | 6.566 GB |
16384 | 406.693 GB | 23.048 GB | 13.132 GB |
32768 | 1581.386 GB | 46.096 GB | 26.263 GB |
表:在不同上下文长度下,20B 参数模型在不同注意力实现下的每 Transformer 层内存使用情况,其中上下文长度为 ,隐藏层大小为 ,FFN 隐藏层大小为 ,注意力头数为 。
结论
在本博客中,我们提出了一种在使用 FlashAttention 微调 Transformer 模型期间完全避免填充 token 的计算和内存需求的方法。我们的修改可以轻松集成到 HuggingFace transformers 生态系统中进行微调。我们还在本博客中推导了相同理论内存消耗的方程。该方法不涉及编写任何低级设备代码。我们使用的唯一非原生 PyTorch 代码是 FlashAttention,它已经可用。
参考文献
- Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." Advances in Neural Information Processing Systems 35 (2022): 16344-16359.
- Dao, Tri. "Flashattention-2: Faster attention with better parallelism and work partitioning." arXiv preprint arXiv:2307.08691 (2023).
- Milakov, Maxim, and Natalia Gimelshein. "Online normalizer calculation for softmax." arXiv preprint arXiv:1805.02867 (2018).
- Ivanov, Andrei, et al. "Data movement is all you need: A case study on optimizing transformers." Proceedings of Machine Learning and Systems 3 (2021): 711-732.
- Korthikanti, Vijay Anand, et al. "Reducing activation recomputation in large transformer models." Proceedings of Machine Learning and Systems 5 (2023).
- Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
- Shazeer, Noam. "Fast transformer decoding: One write-head is all you need." arXiv preprint arXiv:1911.02150 (2019).
- Ainslie, Joshua, et al. "Gqa: Training generalized multi-query transformer models from multi-head checkpoints." arXiv preprint arXiv:2305.13245 (2023).
- Knuth, Donald E. "Johann Faulhaber and sums of powers." Mathematics of Computation 61.203 (1993): 277-294.