KV 缓存解释:优化 Transformer 推理效率
引言
当 AI 模型生成文本时,它们经常重复许多相同的计算,这会降低速度。键值缓存(Key-Value caching)是一种通过记住前一步骤中的重要信息来加速此过程的技术。模型无需从头开始重新计算所有内容,而是重用已计算的内容,从而使文本生成更快、更高效。
在这篇博客文章中,我们将以易于理解的方式分解 KV 缓存,解释其作用,并展示它如何帮助 AI 模型更快地工作。

先决条件
为了完全理解本文内容,读者应熟悉以下内容:
- Transformer 架构:熟悉注意力机制等组件。
- 自回归建模:了解 GPT 等模型如何生成序列。
- 线性代数基础:矩阵乘法和转置等概念,这些对于理解注意力计算至关重要。
这篇 👉 博客 应该涵盖本文所需的大部分先决条件。
标准推理与 KV 缓存的兴起
当模型生成文本时,它会查看所有先前的标记来预测下一个标记。通常,它会为每个新标记重复相同的计算,这会降低速度。
KV 缓存通过记住先前步骤中的这些计算来解决计算重叠问题,这可以通过在推理过程中存储注意力层的中间状态来实现。
KV 缓存如何工作?
分步过程
- 首次生成:当模型看到第一个输入时,它会计算并将其键和值存储在缓存中。
- 接下来的词:对于每个新词,模型检索存储的键和值,并添加新的键和值,而不是重新开始。
- 高效注意力计算:使用缓存的 和 以及新的 (查询) 来计算输出。
- 更新输入:将新生成的标记添加到输入中,然后 直到生成结束。

过程如下所示:
Token 1: [K1, V1] ➔ Cache: [K1, V1]
Token 2: [K2, V2] ➔ Cache: [K1, K2], [V1, V2]
...
Token n: [Kn, Vn] ➔ Cache: [K1, K2, ..., Kn], [V1, V2, ..., Vn]
KV 缓存 | 标准推理 |
---|---|
上表为了更好的视觉效果,我们使用了 ,请注意,这个数字可能比我们展示的要大得多。
比较:KV 缓存与标准推理
以下是 KV 缓存与常规生成方式的比较:
特性 | 标准推理 | KV 缓存 |
---|---|---|
每词计算量 | 模型为每个词重复相同的计算。 | 模型重用过去的计算以获得更快的结果。 |
内存使用 | 每一步使用更少的内存,但内存随文本长度的增加而增长。 | 使用额外内存存储过去的信息,但保持高效。 |
速度 | 文本越长,速度越慢,因为它重复工作。 | 即使文本更长也能保持快速,避免重复工作。 |
效率 | 计算成本高,响应时间慢。 | 更快更高效,因为模型会记住过去的工作。 |
处理长文本 | 由于重复计算,处理长文本时效率低下。 | 非常适合长文本,因为它会记住过去的步骤。 |
KV 缓存对速度和效率影响巨大,特别是对于长文本。通过保存和重用过去的计算,它避免了每次都重新开始,从而比常规文本生成方式快得多。
实际实现
这是一个在 PyTorch 中实现 KV 缓存的简化示例
# Pseudocode for KV Caching in PyTorch
class KVCache:
def __init__(self):
self.cache = {"key": None, "value": None}
def update(self, key, value):
if self.cache["key"] is None:
self.cache["key"] = key
self.cache["value"] = value
else:
self.cache["key"] = torch.cat([self.cache["key"], key], dim=1)
self.cache["value"] = torch.cat([self.cache["value"], value], dim=1)
def get_cache(self):
return self.cache
使用 transformers 库时,此行为通过 use_cache
参数默认启用,您还可以通过 cache_implementation
参数访问多种缓存方法,这是一个极简代码:
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceTB/SmolLM2-1.7B')
model = AutoModelForCausalLM.from_pretrained('HuggingFaceTB/SmolLM2-1.7B').cuda()
tokens = tokenizer.encode("The red cat was", return_tensors="pt").cuda()
output = model.generate(
tokens, max_new_tokens=300, use_cache = True # by default is set to True
)
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
我们使用 T4 GPU 对上述代码进行带/不带 KV 缓存的基准测试,结果如下:
使用 KV 缓存 | 标准推理 | 加速比 |
---|---|---|
11.7 秒 | 1 分 1 秒 | 快约 5.21 倍 |
结论
KV 缓存是一种简单但功能强大的技术,可帮助 AI 模型更快、更有效地生成文本。通过记住过去的计算而不是重复它们,它减少了预测新词所需的时间和精力。虽然它确实需要额外的内存,但这种方法对于长时间对话特别有用,可确保快速高效的生成。
理解 KV 缓存可以帮助开发人员和 AI 爱好者构建更快、更智能、更可扩展的语言模型,以应用于实际场景。
我衷心感谢 Aritra Roy Gosthipaty 🤗 为本博客文章的开发提供了宝贵的支持、反馈和奉献。