通过动态推测实现更快的辅助生成

发布于 2024 年 10 月 8 日
在 GitHub 上更新

⭐ 在这篇博文中,我们将探讨动态推测性解码——一种由 Intel 实验室和 Hugging Face 开发的新型方法,根据任务的不同,它可以将文本生成速度提高多达 2.7 倍。从 Transformers🤗 4.45.0 版本开始,此方法已成为辅助生成 的默认操作模式 ⭐

推测性解码

推测性解码是一种流行的技术,用于在保持准确性的同时加速大型语言模型的推理。如下图所示,推测性解码通过将生成过程分为两个阶段来工作。在第一阶段,一个快速但准确性较低的草稿模型(又称助手)自回归地生成一个令牌序列。在第二阶段,一个大型但准确性更高的目标模型对生成的草稿令牌进行并行验证。这个过程允许目标模型在单次前向传递中生成多个令牌,从而加速自回归解码。推测性解码的成功很大程度上取决于推测前瞻(SL),即草稿模型在每次迭代中生成的令牌数量。在实践中,SL 是一个静态值或基于启发式,两者都不是在推理过程中榨取最大性能的最佳选择。


推测性解码迭代。

动态推测性解码

Transformers🤗 提供了两种不同的方法来确定在推理过程中调整草稿(助手)令牌数量的调度。直接的方法,基于 Leviathan 等人的工作,使用静态的推测前瞻值,并在每次推测迭代中生成恒定数量的候选令牌。另一种方法是基于启发式的方法,它根据当前迭代的接受率调整下一次迭代的候选令牌数量。如果所有推测性令牌都正确,候选令牌的数量就会增加;否则,就会减少。

我们预计,一种增强的优化策略来管理生成的草稿令牌数量可以进一步减少延迟。为了验证这一论点,我们使用一个预言机来确定每次推测迭代的最佳推测前瞻值。预言机利用草稿模型自回归地生成令牌,直到草稿模型和目标模型的预测令牌之间出现差异。这个过程在每次推测迭代中重复进行,最终确定每次迭代接受的最佳(最大)草稿令牌数量。草稿/目标令牌不匹配是通过 Leviathan 等人引入的零温度拒绝采样算法识别的。这个预言机通过在每一步生成最大数量的有效草稿令牌并最小化对草稿模型和目标模型的调用次数,从而充分发挥了推测性解码的潜力。

下图中左侧的图表展示了来自 MBPP 数据集的一个代码生成示例中,在推测迭代中预言机和静态推测前瞻值的情况。观察到预言机推测前瞻值(橙色条)存在高度可变性。静态推测前瞻(蓝色条)中,生成的草稿令牌数量固定为 5,执行了 38 次目标前向传递和 192 次草稿前向传递,而预言机推测前瞻仅执行了 27 次目标前向传递和 129 次草稿前向传递——显著减少。右侧的图表显示了整个 Alpaca 数据集的预言机和静态推测前瞻。


一个 MBPP 示例上的预言机和静态推测前瞻 (SL) 值。


整个 Alpaca 数据集的平均预言机推测前瞻。

两张图都显示了预言机推测前瞻值存在显著差异,表明静态推测前瞻可能不是最优的。

为了更接近 Oracle 并获得额外的加速,我们开发了一种简单的方法,可以在每次迭代中动态调整推测前瞻值。在生成每个草稿令牌后,我们根据助手模型对其预测的置信度(通过 logits 的 softmax 估计)来决定草稿模型是继续生成下一个令牌还是切换到目标模型进行验证。如果助手模型对当前令牌预测的置信度低于预定义的阈值(称为 assistant_confidence_threshold),它将停止该迭代的令牌生成过程,即使尚未达到最大推测令牌数量 num_assistant_tokens。一旦停止,当前迭代中生成的草稿令牌将发送到目标模型进行验证。

基准测试

我们在一系列任务和模型配对上对动态方法和启发式方法进行了基准测试。动态方法在所有测试中都表现出更好的性能。值得注意的是,使用动态方法并将 Llama3.2-1B 作为 Llama3.1-8B 的助手时,我们观察到加速高达 1.52 倍,而启发式方法在相同的设置下没有显示出显著加速。另一个观察结果是,使用启发式方法时,codegen-6B-mono 会导致减速,而动态方法则显示出加速。

目标模型 草稿(助手)模型 任务 加速 - 启发式 加速 - 动态
facebook/opt-6.7b facebook/opt-125m 摘要 1.82倍 2.71倍
facebook/opt-6.7b facebook/opt-125m 开放式生成 1.23倍 1.59倍
Salesforce/codegen-6B-mono Salesforce/codegen-350M-mono 代码生成(Python) 0.89倍 1.09倍
google/flan-t5-xl google/flan-t5-small 摘要 1.18倍 1.31倍
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B 摘要 1.00倍 1.52倍
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B 开放式生成 1.00倍 1.18倍
meta-llama/Llama-3.1-8B meta-llama/Llama-3.2-1B 代码生成(Python) 1.09倍 1.15倍

代码

动态推测已集成到 Hugging Face Transformers 库的 4.45.0 版本中,现在作为辅助解码的默认操作模式。要使用动态推测进行辅助生成,无需更改代码——只需像往常一样执行代码即可

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

prompt = "Alice and Bob"
checkpoint = "EleutherAI/pythia-1.4b-deduped"
assistant_checkpoint = "EleutherAI/pythia-160m-deduped"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
assistant_model = AutoModelForCausalLM.from_pretrained(assistant_checkpoint).to(device)

outputs = model.generate(**inputs, assistant_model=assistant_model)

默认的动态推测前瞻参数反映了最佳值,但可以通过以下代码进行调整,以提高特定模型对或数据集的性能

# confidence threshold
assistant_model.generation_config.assistant_confidence_threshold=0.4

# 'constant' means that num_assistant_tokens stays unchanged during generation
assistant_model.generation_config.num_assistant_tokens_schedule='constant'

# the maximum number of tokens generated by the assistant model.
# after 20 tokens the draft halts even if the confidence is above the threshold
assistant_model.generation_config.num_assistant_tokens=20

要恢复到启发式常数(如 Leviathan 等人的工作)方法,只需将 num_assistant_tokens_schedule 分别设置为 'heuristic''constant',并按如下所示设置 assistant_confidence_threshold=0num_assistant_tokens=5

# Use 'heuristic' or 'constant' or 'dynamic'
assistant_model.generation_config.num_assistant_tokens_schedule='heuristic'
assistant_model.generation_config.assistant_confidence_threshold=0
assistant_model.generation_config.num_assistant_tokens=5

接下来是什么?

我们引入了一种更快的辅助生成策略,称为动态推测性解码,它优于基于启发式的方法以及生成固定数量候选令牌的方法。

在即将发布的博客文章中,我们将展示一种新的辅助生成方法:将任何目标模型与任何辅助模型结合!这将为加速 Hugging Face Hub 上无数没有足够小的辅助变体的模型打开大门。例如,Phi 3Gemma 2CodeLlama 等等都将有资格进行推测性解码。敬请期待!

参考文献

引用

@article{mamou2024accelerating,
  title={Accelerating Speculative Decoding using Dynamic Speculation Length},
  author={Mamou, Jonathan and Pereg, Oren and Korat, Daniel and Berchansky, Moshe and Timor, Nadav and Wasserblat, Moshe and Schwartz, Roy},
  journal={arXiv preprint arXiv:2405.04304},
  year={2024}
}

社区

你好,
感谢您撰写如此有趣的文章。

我在 Llama-3.1-8B-Instruct/Llama-3.2.1B-Instruct 和 Qwen-2.5-7B-Instruct/Qwen-2.5-0.5B-Instruct 上重现实验时遇到一个问题。

在 A100 40GB GPU 上使用默认设置重现实验时,我发现 Disco 的生成速度比自回归解码慢。

您对此有什么想法吗?

非常感谢。

文章作者

@Minhnt27 👋

您能分享一个最小复现器吗?它将极大地帮助我们找出您遇到的具体问题 🤗

注册登录 评论