KV 缓存解释:优化 Transformer 推理效率

社区文章 发布于 2025 年 1 月 30 日

引言

当 AI 模型生成文本时,它们经常重复许多相同的计算,这会降低速度。键值缓存(Key-Value caching)是一种通过记住前一步骤中的重要信息来加速此过程的技术。模型无需从头开始重新计算所有内容,而是重用已计算的内容,从而使文本生成更快、更高效。

在这篇博客文章中,我们将以易于理解的方式分解 KV 缓存,解释其作用,并展示它如何帮助 AI 模型更快地工作。

先决条件

为了完全理解本文内容,读者应熟悉以下内容:

  1. Transformer 架构:熟悉注意力机制等组件。
  2. 自回归建模:了解 GPT 等模型如何生成序列。
  3. 线性代数基础:矩阵乘法和转置等概念,这些对于理解注意力计算至关重要。

这篇 👉 博客 应该涵盖本文所需的大部分先决条件。

点击此处查看一些最重要的要点。
  • 注意力权重形状为 [batch,h,Seqlen,Seqlen] [\text{batch}, h, \mathrm{Seq}_{\mathrm{len}}, \mathrm{Seq}_{\mathrm{len}}]
  • 带掩码多头注意力允许每个标记由其自身和所有先前的标记表示。
  • 为了生成新标记,模型需要查看所有先前的标记及其由其前置标记表示的内容。

标准推理与 KV 缓存的兴起

当模型生成文本时,它会查看所有先前的标记来预测下一个标记。通常,它会为每个新标记重复相同的计算,这会降低速度。

KV 缓存通过记住先前步骤中的这些计算来解决计算重叠问题,这可以通过在推理过程中存储注意力层的中间状态来实现。

KV 缓存如何工作?

分步过程

  1. 首次生成:当模型看到第一个输入时,它会计算并将其键和值存储在缓存中。 \Downarrow
  2. 接下来的词:对于每个新词,模型检索存储的键和值,并添加新的键和值,而不是重新开始。
  3. 高效注意力计算:使用缓存的 KKVV 以及新的 QQ (查询) 来计算输出。
  4. 更新输入:将新生成的标记添加到输入中,然后 返回  2 \texttt{go back to step 2} 直到生成结束。

过程如下所示:

Token 1: [K1, V1] ➔ Cache: [K1, V1]
Token 2: [K2, V2] ➔ Cache: [K1, K2], [V1, V2]
...
Token n: [Kn, Vn] ➔ Cache: [K1, K2, ..., Kn], [V1, V2, ..., Vn]
KV 缓存 标准推理

上表为了更好的视觉效果,我们使用了 dk=5d_k = 5 ,请注意,这个数字可能比我们展示的要大得多。

比较:KV 缓存与标准推理

以下是 KV 缓存与常规生成方式的比较:

特性 标准推理 KV 缓存
每词计算量 模型为每个词重复相同的计算。 模型重用过去的计算以获得更快的结果。
内存使用 每一步使用更少的内存,但内存随文本长度的增加而增长。 使用额外内存存储过去的信息,但保持高效。
速度 文本越长,速度越慢,因为它重复工作。 即使文本更长也能保持快速,避免重复工作。
效率 计算成本高,响应时间慢。 更快更高效,因为模型会记住过去的工作。
处理长文本 由于重复计算,处理长文本时效率低下。 非常适合长文本,因为它会记住过去的步骤。

KV 缓存对速度效率影响巨大,特别是对于长文本。通过保存和重用过去的计算,它避免了每次都重新开始,从而比常规文本生成方式快得多。

实际实现

这是一个在 PyTorch 中实现 KV 缓存的简化示例

# Pseudocode for KV Caching in PyTorch
class KVCache:
    def __init__(self):
        self.cache = {"key": None, "value": None}

    def update(self, key, value):
        if self.cache["key"] is None:
            self.cache["key"] = key
            self.cache["value"] = value
        else:
            self.cache["key"] = torch.cat([self.cache["key"], key], dim=1)
            self.cache["value"] = torch.cat([self.cache["value"], value], dim=1)

    def get_cache(self):
        return self.cache

使用 transformers 库时,此行为通过 use_cache 参数默认启用,您还可以通过 cache_implementation 参数访问多种缓存方法,这是一个极简代码:

from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-1.7B')
model = AutoModelForCausalLM.from_pretrained('HuggingFaceTB/SmolLM2-1.7B').cuda()

tokens = tokenizer.encode("The red cat was", return_tensors="pt").cuda()
output = model.generate(
    tokens, max_new_tokens=300, use_cache = True # by default is set to True
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]

我们使用 T4 GPU 对上述代码进行带/不带 KV 缓存的基准测试,结果如下:

使用 KV 缓存 标准推理 加速比
11.7 秒 1 分 1 秒 快约 5.21 倍

结论

KV 缓存是一种简单但功能强大的技术,可帮助 AI 模型更快、更有效地生成文本。通过记住过去的计算而不是重复它们,它减少了预测新词所需的时间和精力。虽然它确实需要额外的内存,但这种方法对于长时间对话特别有用,可确保快速高效的生成。

理解 KV 缓存可以帮助开发人员和 AI 爱好者构建更快、更智能、更可扩展的语言模型,以应用于实际场景。

我衷心感谢 Aritra Roy Gosthipaty 🤗 为本博客文章的开发提供了宝贵的支持、反馈和奉献。

参考文献和扩展阅读

  1. Transformer KV 缓存解释
  2. Transformer 键值缓存解释
  3. 掌握 LLM 技术:推理优化
  4. Hugging Face 文档 - Transformer 中的 KV 缓存

社区

这可以用于图像生成模型吗?(我不是程序员:- 或 AI 专家)

·
此评论已隐藏(标记为垃圾邮件)

感谢 Linux 工具 的解释

很好的参考资料,谢谢分享。

·
文章作者

非常感谢你的赞美 (≧∇≦)ノ✨

注册登录 评论