利用自推测解码加速文本生成
自推测解码,由LayerSkip: Enabling Early Exit Inference and Self-Speculative Decoding提出,是一种新颖的文本生成方法。它结合了推测解码和大型语言模型(LLM)的提前退出优势。该方法通过使用**同一模型**的早期层来起草标记,并使用后续层进行验证,从而实现高效生成。
这项技术不仅加速了文本生成,还在内存和计算延迟方面取得了显著的节省。为了实现端到端加速,早期层的输出需要足够接近最后一层。这通过一种训练方案来实现,如论文所述,该方案可在预训练期间应用,也可在特定领域进行微调时应用。自推测解码对于实际应用特别高效,它使得模型能够在较小的 GPU 上部署,并降低了**大规模推理**所需的总体硬件占用。
在这篇博客文章中,我们将探讨自推测解码的概念、其实现以及使用 🤗 transformers 库的实际应用。您将了解其技术基础,包括**提前退出层**、**反嵌入**和**训练修改**。为了将这些概念付诸实践,我们提供了代码示例、与传统推测解码的基准比较以及性能权衡的见解。
请直接查看以下 Hugging Face 资源,以了解更多有关该方法的信息并亲身体验:
推测解码与自推测解码
LayerSkip 推理在
facebook/layerskip-llama2-7B
(使用 LayerSkip 方案持续预训练的 Llama2 7B)上的演示。
传统的推测解码使用**两个**模型:一个较小的模型(草稿模型)生成一系列草稿标记,一个较大的模型(验证模型)验证草稿的准确性。较小的模型执行大部分生成工作,而较大的模型则进行结果修正。这加快了文本生成速度,因为较大的模型可以一次性验证完整的序列,而不是一次生成一个草稿。
在自推测解码中,作者在此概念的基础上,使用大型模型的早期层来生成草稿标记,然后由模型的深层进行验证。这种推测解码的“自我”方面需要特定的训练,使得模型能够同时执行起草和验证。反过来,这与传统的推测解码相比,提高了速度并降低了计算成本。
使用 transformers
为了在 🤗 transformers 库中启用提前退出自推测解码,我们只需要在 generate()
函数中添加 assistant_early_exit
参数。
这是一个展示该功能的简单代码片段。
pip install transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
early_exit_layer = 4
prompt = "Alice and Bob"
checkpoint = "facebook/layerskip-llama2-7B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
model = AutoModelForCausalLM.from_pretrained(checkpoint).to("cuda")
outputs = model.generate(**inputs, assistant_early_exit=early_exit_layer)
注意: 虽然
assistant_early_exit
参数理论上可以为任何仅解码器 Transformer 启用提前退出自推测解码,但中间层的 logits 无法进行**反嵌入**(通过 LM Head 解码的过程,稍后在博客文章中描述),除非模型经过专门训练。此外,只有当检查点经过特殊训练以提高早期层准确性时,您才能**获得加速**。LayerSkip 论文提出了一种训练方法来实现这一点(即,应用提前退出损失,并逐步增加层 dropout 率)。此处提供了一系列使用 LayerSkip 训练方法持续预训练的 Llama2、Llama3 和 Code Llama 检查点。
基准测试
我们进行了一系列广泛的基准测试,以衡量 LayerSkip 的自推测解码相对于各种模型上的自回归解码的加速效果。我们还比较了自推测解码(基于提前退出)与标准推测解码技术。要重现结果,您可以在此处找到代码,并在此电子表格中找到运行每个实验的命令。所有实验均在单个 80GB A100 GPU 上运行,Llama2 70B 实验除外,它们在包含 8 个 A100 GPU 的节点上运行。
Llama3.2 1B
模型变体 | 层数 | 辅助模型 | 辅助层 | 任务 | 总层数 | FLOPs/输入 (G) | 时间/输入 (s) | FLOPs/输出 (G) | 时间/输出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
facebook/layerskip-llama3.2-1B | 1 | 提前退出 @ 第 4 层 | 摘要 | 1 | 1195.28 | 9.96 | 2147.7 | 17.9 | 1.80 |
Llama3 8B
模型变体 | 层数 | 辅助模型 | 辅助层 | 任务 | 总层数 | FLOPs/输入 (G) | 时间/输入 (s) | FLOPs/输出 (G) | 时间/输出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Meta-Llama-3-8B | 8 | meta-llama/Llama-3.2-1B | 1 | 摘要 | 9 | 1872.46 | 19.04 | 2859.35 | 29.08 | 1.53 |
meta-llama/Meta-Llama-3-8B | 8 | meta-llama/Llama-3.2-3B | 3 | 摘要 | 11 | 2814.82 | 28.63 | 2825.36 | 28.73 | 1.00 |
facebook/layerskip-llama3-8B | 8 | 提前退出 @ 第 4 层 | 摘要 | 8 | 1949.02 | 15.75 | 3571.81 | 28.87 | 1.83 |
Llama2 70B
模型变体 | 层数 | 辅助模型 | 辅助层 | 任务 | 总层数 | FLOPs/输入 (G) | 时间/输入 (s) | FLOPs/输出 (G) | 时间/输出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-70b-hf | 70 | meta-llama/Llama-2-13b-hf | 13 | 摘要 | 83 | 5036.54 | 46.3 | 12289.01 | 112.97 | 2.44 |
meta-llama/Llama-2-70b-hf | 70 | meta-llama/Llama-2-7b-hf | 7 | 摘要 | 77 | 4357.55 | 40.06 | 12324.19 | 113.3 | 2.83 |
meta-llama/Llama-2-70b-hf | 70 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 71 | 4356.21 | 40.05 | 12363.22 | 113.66 | 2.84 |
facebook/layerskip-llama2-70B | 70 | 提前退出 @ 第 10 层 | 摘要 | 70 | 6012.04 | 54.96 | 1283.34 | 113.2 | 2.06 |
Llama2 13B
模型变体 | 层数 | 辅助模型 | 辅助层 | 任务 | 总层数 | FLOPs/输入 (G) | 时间/输入 (s) | FLOPs/输出 (G) | 时间/输出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-13b-hf | 13 | meta-llama/Llama-2-7b-hf | 7 | 摘要 | 20 | 3557.07 | 27.79 | 4088.48 | 31.94 | 1.15 |
meta-llama/Llama-2-13b-hf | 13 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 14 | 2901.92 | 22.67 | 4190.42 | 32.74 | 1.44 |
meta-llama/Llama-2-13b-hf | 13 | apple/OpenELM-270M | 0.27 | 摘要 | 13.27 | 2883.33 | 22.53 | 4521.12 | 35.32 | 1.57 |
meta-llama/Llama-2-13b-hf | 13 | apple/OpenELM-450M | 0.45 | 摘要 | 13.45 | 3267.69 | 25.53 | 4321.75 | 33.76 | 1.32 |
facebook/layerskip-llama2-13B | 13 | 提前退出 @ 第 4 层 | 摘要 | 13 | 4238.45 | 33.11 | 4217.78 | 32.95 | 0.995 | |
facebook/layerskip-llama2-13B | 13 | 提前退出 @ 第 8 层 | 摘要 | 13 | 2459.61 | 19.22 | 4294.98 | 33.55 | 1.746 |
Llama2 7B
模型变体 | 层数 | 辅助模型 | 辅助层 | 任务 | 总层数 | FLOPs/输入 (G) | 时间/输入 (s) | FLOPs/输出 (G) | 时间/输出 (s) | 效率 |
---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-2-7b-hf | 7 | TinyLlama/TinyLlama_v1.1 | 1 | 摘要 | 8 | 2771.54 | 21.65 | 3368.48 | 26.32 | 1.22 |
meta-llama/Llama-2-7b-hf | 7 | apple/OpenELM-270M | 0.27 | 摘要 | 7.27 | 2607.82 | 20.37 | 4221.14 | 32.98 | 1.62 |
meta-llama/Llama-2-7b-hf | 7 | apple/OpenELM-450M | 0.45 | 摘要 | 7.45 | 3324.68 | 25.97 | 4178.66 | 32.65 | 1.26 |
facebook/layerskip-llama2-7B | 7 | 提前退出 @ 第 4 层 | 摘要 | 7 | 2548.4 | 19.91 | 3306.73 | 25.83 | 1.297 |
我们可以从结果中得出以下观察:
- 如“总参数数量”列所示,自推测解码消耗的内存更少,因为它不需要单独的草稿模型,并且草稿阶段层的权重被重用。
- 对于除 Llama2 70B 之外的所有模型大小和生成,提前退出自推测解码都比常规的双模型推测解码更快。
Llama2 70B 上自推测解码速度提升相对有限的原因可能有多种,例如 LayerSkip Llama2 70B 检查点持续预训练的 token 数量较少(Llama2 70B 为 328M token,而 Llama2 7B 为 52B token)。但这仍然是未来研究需要深入探索的改进领域。尽管如此,70B 模型的自推测解码仍显著快于自回归解码。
提前退出与反嵌入
自推测解码的一项关键技术是提前退出,即生成过程可以在预设层停止。为了实现这一点,我们通过将这些层的 logits 投射到语言模型(LM)头部来预测下一个 token,从而**反嵌入**它们。这使得模型能够跳过后续层并提高推理时间。
反嵌入可以在任何 Transformer 层进行,将提前退出转变为一种高效的 token 预测机制。一个自然而然的问题是:LM 头如何适应反嵌入早期层的 logits,因为它最初是经过训练只与最终层一起工作的?这就是训练修改发挥作用的地方。
训练修改:层 dropout 和 提前退出损失
在训练阶段,我们引入了**层 dropout**,它允许模型在训练期间跳过某些层。dropout 率在深层中逐渐增加,使模型对后续层的依赖性降低,同时增强模型的泛化能力并加速训练。
除了层 dropout,还应用了**提前退出损失**,以确保 LM 头部学习反嵌入不同层。用于训练具有提前退出功能的模型的总损失函数由每次退出(中间层)的标准化损失之和给出。该技术通过在所有层之间分配学习任务来实现高效训练。
自起草与自验证
训练完成后,我们可以在推理期间应用自推测解码。该过程始于**自起草**,其中通过从某个中间层提前退出生成标记。推测标记的数量定义了在此阶段生成的草稿标记数量,而我们退出的层定义了草稿阶段的大小和准确性。这两个参数都可以在推理时根据速度和草稿阶段准确性之间的权衡进行指定。
下一阶段是**自验证**,其中使用完整模型来验证草稿标记。验证模型重用了草稿模型的部分缓存。如果草稿标记与验证标记一致,它们就会被添加到最终输出中,从而更好地利用我们系统的内存带宽,因为使用完整模型生成一系列标记比验证草稿的成本要高得多,只要有几个标记匹配。
在自验证阶段,由于早期层的结果在起草阶段已缓存,因此仅计算剩余层进行验证。
优化:共享权重、共享 KV 缓存和共享计算
自推测解码显著受益于缓存重用,特别是**KV 缓存**,它存储在起草阶段计算的键值对。此缓存允许模型跳过冗余计算,因为起草和验证阶段都使用相同的早期层。此外,**退出查询缓存**存储退出层中的查询向量,允许验证从起草阶段无缝继续。
与传统的双模型推测解码相比,提前退出自推测解码可以从以下节省中受益:
- 共享权重:重用前 层的权重,用于起草和验证。
- 共享 KV 缓存:重用前 层的键值对,用于起草和验证。
- 共享计算:通过使用**退出查询缓存**(只保存退出层 的查询向量),重用前 层的计算,从而使验证过程无需计算从 到 的层。
KV 和退出查询缓存的结合,称为**KVQ 缓存**,减少了内存开销并改善了推理延迟。
到目前为止,🤗 transformers 库已在此拉取请求中实现了第一项优化(共享权重)。随着使用此方法的模型数量增加,我们将考虑其他优化。如果您感兴趣,请随时提交 PR!
我们能多早退出?
草稿阶段的提前退出层是一个超参数,我们可以在推理过程中进行调整或修改。
- 我们越早退出,草稿 token 的生成速度越快,但准确性越低。
- 我们越晚退出,草稿 token 的准确性越高,但生成速度越慢。
我们编写了一个脚本,以在不同的提前退出层上测量 A100 GPU 上的每秒 token 数。在下表中,我们绘制了不同 Llama 模型(包括 LayerSkip 和基线检查点)的每秒 token 数与提前退出层之间的关系(您可以此处查看完整的日志)。
Llama3.2 1B
Llama3 8B
Code Llama3 34B
Code Llama3 7B
Llama2 70B
Llama2 13B
Llama2 7B
我们可以观察到以下几点:
- 对于未经过 LayerSkip 训练方案预训练或持续预训练的基线检查点,提前退出自推测解码比自回归解码慢。这是因为在大多数 LLM 的训练过程中,早期层没有被激励去学习预测输出,因此使用早期层生成标记的接受率会非常低。
- 另一方面,对于使用 LayerSkip 训练持续预训练的 Llama 检查点,提前退出自推测解码在至少部分层中比自回归解码具有更高的加速效果。
- 对于大多数模型,除了 Llama3.2 1B,我们注意到在遍历层时有一个规律:加速效果在前几层开始较低,逐渐增加到最佳点,然后再次下降。
- 提前退出层的最佳点是我们在高预测准确性和低标记生成开销之间达到最佳权衡的时候。这个最佳点取决于每个模型,也可能取决于提示或提示的领域。
这些观察结果为进一步的实验和探索提供了有趣的机会。我们鼓励读者在这些想法的基础上进行构建,测试变体,并追求自己的研究。这些努力可以带来宝贵的见解,并为该领域做出有意义的贡献。
结论
LayerSkip 利用提前退出、层 dropout 和缓存重用之间的协同作用,创建了一个快速高效的文本生成管道。通过训练模型以反嵌入不同层的输出,并利用缓存优化验证过程,该方法在速度和准确性之间取得了平衡。因此,它显著缩短了大型语言模型的推理时间,同时保持了高质量的输出。由于只使用一个模型作为草稿和验证模型,它还减少了与传统推测解码技术相比的内存消耗。
自推测是一个令人兴奋的领域,同一个 LLM 既可以生成草稿标记,又可以自行修正。其他自推测方法包括:
- 起草与验证 (Draft & Verify):草稿阶段涉及跳过预定的注意力层和前馈层。
- MagicDec:草稿阶段使用 KV 缓存的一个子集,这对于长上下文输入很有用。
- 雅可比解码 (Jacobi Decoding) 和 前瞻解码 (Lookahead Decoding):草稿阶段是一系列“猜测标记”,这些标记可以是随机的,也可以是从 N-gram 查找表中获得的。