在 nanoVLM 中从零开始实现 KV Cache
TL;DR
我们已经在 nanoVLM 仓库(一个使用纯 PyTorch 训练自己的视觉语言模型的小型代码库)中从零开始实现了 KV 缓存。这使我们的生成速度提升了 38%。在这篇博客文章中,我们将介绍 KV 缓存以及我们在实现它时获得的所有经验。所学到的经验是通用的,可以应用于所有自回归语言模型的生成。在一个小型代码库上从零开始实现是一个很好的学习经验,欢迎加入!
引言
自回归语言模型通过**一次采样一个 token** 来生成文本。在推理过程中,模型处理给定的输入序列,预测下一个 token,将其附加到序列中,并重复此过程直到满足某个停止条件。
这种逐步生成本质上是顺序的。
- 为了生成 token ,模型必须考虑从 到 的整个序列。在上述示例中, 将是
the
,而所有之前的 token 到 将是[What, is, in]
。 - 尽管 Transformer 内部是并行的,但每个新的预测都需要对所有 Transformer 层进行一次完整的正向传播,这会带来与序列长度呈二次方的内存/计算开销。
这种重复也会导致计算上的**冗余**。在这篇文章中,我们将探讨**KV 缓存**,这是一种缓解这种低效率的优化技术。
目录
重温 Transformer 架构
在深入探讨缓存之前,让我们回顾一下 Transformer 模型中注意力的运作方式。Transformer 语言模型由堆叠层组成,每层包含:
- 多头自注意力
- 前馈网络 (MLP)
- 残差连接和层归一化
为了理解**KV 缓存的帮助之处**,我们重点关注**自注意力**机制,特别是单个注意力头内部。
让我们通过一个简单的 PyTorch 实现来可视化关键计算。
import torch
input_seq_length = 5
dim_model = 10
input_ids_emb = torch.randn(input_seq_length, dim_model)
W_q = torch.randn(dim_model, dim_model)
W_k = torch.randn(dim_model, dim_model)
W_v = torch.randn(dim_model, dim_model)
Q = input_ids_emb @ W_q
K = input_ids_emb @ W_k
V = input_ids_emb @ W_v
自注意力计算
对于 个输入嵌入序列,表示为 ,自注意力计算如下:
- ,其中
- ,其中
- ,其中
- 因果掩码 用于防止访问未来 token。
最终输出为
这是一个使用因果掩码的最小 PyTorch 等效实现。
import torch.nn.functional as F
import math
d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)
# Lower triangular mask to prevent future token access
causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length))
masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf'))
attention_weights = F.softmax(masked_scores, dim=-1)
output = attention_weights @ V
冗余之处
在自回归生成中,模型每次生成一个 token。在每一步中,它都会**为整个序列**重新计算 、 和 ,即使较早的 token 并未改变。
new_token_emb = torch.randn(1, dim_model)
extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0)
Q_ext = extended_input @ W_q
K_ext = extended_input @ W_k
V_ext = extended_input @ W_v
# (output_ext would be computed using Q_ext, K_ext, V_ext + masking)
为了确认冗余
torch.testing.assert_close(K, K_ext[:input_seq_length]) # test pass
torch.testing.assert_close(V, V_ext[:input_seq_length]) # test pass
这些检查表明,对于除最新 token 之外的所有 token, 和 与先前计算的值相同。
Original (5×5): Extended (6×6):
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ → ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ ■ □
□ □ □ □ □ □
- **■** = 已计算并重复使用
- **□** = 不必要地重新计算
大部分注意力计算被不必要地重复。随着序列的增长,这会变得更加昂贵。
KV 缓存如何解决它
为了消除这种低效率,我们使用 **KV 缓存**:
- 在处理完初始提示后,我们**缓存**每个层计算出的键 () 和值 ()。
- 在生成过程中,我们**只计算新 token 的** **和** ,并**将其附加**到缓存中。
- 我们计算当前 token 的 ,并将其与**缓存的 和 ** 一起使用以获得输出。
这使得生成从全序列重新计算变为轻量级的增量更新。
✅ 在实践中,此缓存是一个逐层字典,包含“key”和“value”,每个形状为 (
batch_size
,num_heads
,seq_len_cached
,head_dim
)。
这是现代 LLM 如何高效生成长输出的基础。
nanoVLM 中的 KV 缓存:从理论到实践
既然我们已经理解了 KV 缓存背后的理论,接下来让我们看看它在我们的 nanoVLM 仓库中是如何实际实现的。这是一个理想的测试平台,因为它是一个超级简洁且自包含的代码库。
KV 缓存体现在我们模型的三个关键组件中:
- 使用和更新 KV 缓存的**注意力块**
- 跟踪每层缓存的**语言模型**
- 区分**预填充**(使用输入提示的初始传递)和顺序**解码**阶段的**生成循环**
1. 在注意力块中更新 KV 缓存
在 `LanguageModelGroupedAttention` 类中,我们修改了 `forward` 函数,使其接受并更新键和值(`block_kv_cache`)的缓存。
以前,模型在每个生成步骤都会重新计算 和 。现在我们只计算当前 token 的 和 ,并将其附加到缓存的值中。
def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
is_prefill = block_kv_cache is None
B, T_curr, C = x.size()
# Project inputs to Q, K, V
q_curr, k_curr, v_curr = project_current_tokens(x)
q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)
if not is_prefill and block_kv_cache['key'] is not None:
# Append new keys and values to the cache
k = torch.cat([block_kv_cache['key'], k_rotated], dim=2)
v = torch.cat([block_kv_cache['value'], v_curr], dim=2)
else:
# First pass (prefill) — no cache
k, v = k_rotated, v_curr
block_kv_cache = {'key': k, 'value': v}
return attention_output, block_kv_cache
2. 跨层跟踪缓存
在 `LanguageModel` 类中,我们引入了**逐层缓存跟踪**。`start_pos` 参数有助于模型为新生成的 token 计算正确的**旋转位置编码**。
def forward(self, x, kv_cache=None, start_pos=0):
T_curr = x.size(1)
position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device)
cos, sin = self.rotary_embd(position_ids)
for i, block in enumerate(self.blocks):
# Pass per-layer KV cache
x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])
return x, kv_cache
- `kv_cache`:一个字典列表,每个 transformer 层一个,存储着先前的键和值。
- `start_pos`:确保旋转嵌入与当前生成索引对齐。
3. 生成循环中的预填充与解码
`VisionLanguageModel` 的 `generate()` 方法发生了最大的架构变化。
我们将**生成分为两个阶段**:
- **预填充阶段:**编码完整提示并构建初始缓存。
- **解码阶段:**使用缓存的键/值一次生成一个 token。
PREFILL PHASE (cache construction)
[Prompt: "What is"] → [Transformer] → [Cache: K, V for all layers]
DECODE PHASE (token-by-token)
[Token: "the"] → [Q("the") + cached K/V] → [next token: "?"] → ...
相应的代码如下:
# PREFILL: Process the input prompt, fill the cache
prompt_output, kv_cache_list = self.forward(
inputs,
kv_cache=None,
start_pos=0
)
# DECODE: Generate one token at a time using cached K/V
for i in range(max_new_tokens):
next_token = sample_from(prompt_output)
decode_output, kv_cache_list = self.forward(
next_token,
kv_cache=kv_cache_list,
start_pos=current_position # updated with each step
)
prompt_output = decode_output
通过分离这些阶段,我们避免了冗余计算,并显著加快了推理速度,特别是对于长提示。
更改总结
模块 | 原始行为 | 新行为 |
---|---|---|
LanguageModelGroupedAttention.forward |
每步重新计算 、、 | 使用并更新 KV 缓存 |
LanguageModel.forward |
没有之前的状态记忆 | 跟踪逐层 KV 缓存,处理 `start_pos` |
VisionLanguageModel.generate |
单阶段生成循环 | 分为**预填充**和**解码**阶段 |
总结:KV 缓存的重要性
益处 | 说明 |
---|---|
增量增长 | 缓存每增加一个新 token 就增加一行 |
位置感知解码 | `start_pos` 确保位置编码计算的正确性 |
效率 | 将每个 token 的推理时间复杂度从二次方降低到 O(`seq len`) |
KV 缓存消除了自回归生成过程中不必要的计算,从而实现了更快、更高效的推理,尤其是在长序列和实时应用中。这是速度与内存之间的权衡,其缺点可能是代码更复杂,并限制了更高级的推理方案,如束搜索等。KV 缓存是加速 LLM 推理的一种流行方法,使得它们可以在消费级硬件上运行,现在你也知道它是如何工作的了!