在 nanoVLM 中从零开始实现 KV Cache

发布于 2025 年 6 月 4 日
在 GitHub 上更新

TL;DR

我们已经在 nanoVLM 仓库(一个使用纯 PyTorch 训练自己的视觉语言模型的小型代码库)中从零开始实现了 KV 缓存。这使我们的生成速度提升了 38%。在这篇博客文章中,我们将介绍 KV 缓存以及我们在实现它时获得的所有经验。所学到的经验是通用的,可以应用于所有自回归语言模型的生成。在一个小型代码库上从零开始实现是一个很好的学习经验,欢迎加入!

bar plot showcasing improvement in generation speed

引言

自回归语言模型通过**一次采样一个 token** 来生成文本。在推理过程中,模型处理给定的输入序列,预测下一个 token,将其附加到序列中,并重复此过程直到满足某个停止条件。

diagram for autoregression

这种逐步生成本质上是顺序的。

  • 为了生成 token ti+1 t_{i+1} ,模型必须考虑从 t0 t_0 ti t_i 的整个序列。在上述示例中,ti+1 t_{i+1} 将是 the,而所有之前的 token t0 t_0 ti t_i 将是 [What, is, in]
  • 尽管 Transformer 内部是并行的,但每个新的预测都需要对所有 Transformer 层进行一次完整的正向传播,这会带来与序列长度呈二次方的内存/计算开销。

这种重复也会导致计算上的**冗余**。在这篇文章中,我们将探讨**KV 缓存**,这是一种缓解这种低效率的优化技术。

目录

重温 Transformer 架构

在深入探讨缓存之前,让我们回顾一下 Transformer 模型中注意力的运作方式。Transformer 语言模型由堆叠层组成,每层包含:

  • 多头自注意力
  • 前馈网络 (MLP)
  • 残差连接和层归一化

为了理解**KV 缓存的帮助之处**,我们重点关注**自注意力**机制,特别是单个注意力头内部。

让我们通过一个简单的 PyTorch 实现来可视化关键计算。

import torch

input_seq_length = 5
dim_model = 10

input_ids_emb = torch.randn(input_seq_length, dim_model)
W_q = torch.randn(dim_model, dim_model)
W_k = torch.randn(dim_model, dim_model)
W_v = torch.randn(dim_model, dim_model)

Q = input_ids_emb @ W_q
K = input_ids_emb @ W_k
V = input_ids_emb @ W_v

自注意力计算

对于 T T 个输入嵌入序列,表示为 XRT×D X \in \mathbb{R}^{T \times D} ,自注意力计算如下:

  • Q=XWQ Q = XW_Q ,其中 WQRD×Dq W_Q \in \mathbb{R}^{D \times D_q}
  • K=XWK K = XW_K ,其中 WKRD×Dk W_K \in \mathbb{R}^{D \times D_k}
  • V=XWV V = XW_V ,其中 WVRD×Dv W_V \in \mathbb{R}^{D \times D_v}
  • 因果掩码 M M 用于防止访问未来 token。

最终输出为

Attention(X;Q,K,V)=softmax(QKMdk)V \text{Attention}(X; Q, K, V) = \text{softmax}\left( \frac{QK^\top \cdot M}{\sqrt{d_k}} \right)V

这是一个使用因果掩码的最小 PyTorch 等效实现。

import torch.nn.functional as F
import math

d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)

# Lower triangular mask to prevent future token access
causal_mask = torch.tril(torch.ones(input_seq_length, input_seq_length))
masked_scores = attention_scores.masked_fill(causal_mask == 0, float('-inf'))

attention_weights = F.softmax(masked_scores, dim=-1)
output = attention_weights @ V

冗余之处

在自回归生成中,模型每次生成一个 token。在每一步中,它都会**为整个序列**重新计算 Q Q K K V V ,即使较早的 token 并未改变。

new_token_emb = torch.randn(1, dim_model)
extended_input = torch.cat([input_ids_emb, new_token_emb], dim=0)

Q_ext = extended_input @ W_q
K_ext = extended_input @ W_k
V_ext = extended_input @ W_v

# (output_ext would be computed using Q_ext, K_ext, V_ext + masking)

为了确认冗余

torch.testing.assert_close(K, K_ext[:input_seq_length]) # test pass
torch.testing.assert_close(V, V_ext[:input_seq_length]) # test pass

这些检查表明,对于除最新 token 之外的所有 token,K K V V 与先前计算的值相同。

Original (5×5):         Extended (6×6):
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■    →         ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
■ ■ ■ ■ ■              ■ ■ ■ ■ ■ □
                       □ □ □ □ □ □
  • **■** = 已计算并重复使用
  • **□** = 不必要地重新计算

大部分注意力计算被不必要地重复。随着序列的增长,这会变得更加昂贵。

KV 缓存如何解决它

为了消除这种低效率,我们使用 **KV 缓存**:

  • 在处理完初始提示后,我们**缓存**每个层计算出的键 (K K ) 和值 (V V )。
  • 在生成过程中,我们**只计算新 token 的** K K **和** V V ,并**将其附加**到缓存中。
  • 我们计算当前 token 的 Q Q ,并将其与**缓存的 K K V V ** 一起使用以获得输出。

这使得生成从全序列重新计算变为轻量级的增量更新。

✅ 在实践中,此缓存是一个逐层字典,包含“key”和“value”,每个形状为 (batch_size, num_heads, seq_len_cached, head_dim)。

这是现代 LLM 如何高效生成长输出的基础。

nanoVLM 中的 KV 缓存:从理论到实践

既然我们已经理解了 KV 缓存背后的理论,接下来让我们看看它在我们的 nanoVLM 仓库中是如何实际实现的。这是一个理想的测试平台,因为它是一个超级简洁且自包含的代码库。

KV 缓存体现在我们模型的三个关键组件中:

  1. 使用和更新 KV 缓存的**注意力块**
  2. 跟踪每层缓存的**语言模型**
  3. 区分**预填充**(使用输入提示的初始传递)和顺序**解码**阶段的**生成循环**

1. 在注意力块中更新 KV 缓存

在 `LanguageModelGroupedAttention` 类中,我们修改了 `forward` 函数,使其接受并更新键和值(`block_kv_cache`)的缓存。

以前,模型在每个生成步骤都会重新计算 K K V V 。现在我们只计算当前 token 的 Knew K_{\text{new}} Vnew V_{\text{new}} ,并将其附加到缓存的值中。

def forward(self, x, cos, sin, attention_mask=None, block_kv_cache=None):
    is_prefill = block_kv_cache is None
    B, T_curr, C = x.size()

    # Project inputs to Q, K, V
    q_curr, k_curr, v_curr = project_current_tokens(x)
    q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)

    if not is_prefill and block_kv_cache['key'] is not None:
        # Append new keys and values to the cache
        k = torch.cat([block_kv_cache['key'], k_rotated], dim=2)
        v = torch.cat([block_kv_cache['value'], v_curr], dim=2)
    else:
        # First pass (prefill) — no cache
        k, v = k_rotated, v_curr

    block_kv_cache = {'key': k, 'value': v}
    return attention_output, block_kv_cache

2. 跨层跟踪缓存

在 `LanguageModel` 类中,我们引入了**逐层缓存跟踪**。`start_pos` 参数有助于模型为新生成的 token 计算正确的**旋转位置编码**。

def forward(self, x, kv_cache=None, start_pos=0):
    T_curr = x.size(1)
    position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device)
    cos, sin = self.rotary_embd(position_ids)

    for i, block in enumerate(self.blocks):
        # Pass per-layer KV cache
        x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])

    return x, kv_cache
  • `kv_cache`:一个字典列表,每个 transformer 层一个,存储着先前的键和值。
  • `start_pos`:确保旋转嵌入与当前生成索引对齐。

3. 生成循环中的预填充与解码

`VisionLanguageModel` 的 `generate()` 方法发生了最大的架构变化。

我们将**生成分为两个阶段**:

  • **预填充阶段:**编码完整提示并构建初始缓存。
  • **解码阶段:**使用缓存的键/值一次生成一个 token。
PREFILL PHASE (cache construction)
[Prompt: "What is"] → [Transformer] → [Cache: K, V for all layers]

DECODE PHASE (token-by-token)
[Token: "the"] → [Q("the") + cached K/V] → [next token: "?"] → ...

相应的代码如下:

# PREFILL: Process the input prompt, fill the cache
prompt_output, kv_cache_list = self.forward(
    inputs,
    kv_cache=None,
    start_pos=0
)

# DECODE: Generate one token at a time using cached K/V
for i in range(max_new_tokens):
    next_token = sample_from(prompt_output)

    decode_output, kv_cache_list = self.forward(
        next_token,
        kv_cache=kv_cache_list,
        start_pos=current_position  # updated with each step
    )

    prompt_output = decode_output

通过分离这些阶段,我们避免了冗余计算,并显著加快了推理速度,特别是对于长提示。

更改总结

模块 原始行为 新行为
LanguageModelGroupedAttention.forward 每步重新计算 Q Q K K V V 使用并更新 KV 缓存
LanguageModel.forward 没有之前的状态记忆 跟踪逐层 KV 缓存,处理 `start_pos`
VisionLanguageModel.generate 单阶段生成循环 分为**预填充**和**解码**阶段

总结:KV 缓存的重要性

益处 说明
增量增长 缓存每增加一个新 token 就增加一行
位置感知解码 `start_pos` 确保位置编码计算的正确性
效率 将每个 token 的推理时间复杂度从二次方降低到 O(`seq len`)

KV 缓存消除了自回归生成过程中不必要的计算,从而实现了更快、更高效的推理,尤其是在长序列和实时应用中。这是速度与内存之间的权衡,其缺点可能是代码更复杂,并限制了更高级的推理方案,如束搜索等。KV 缓存是加速 LLM 推理的一种流行方法,使得它们可以在消费级硬件上运行,现在你也知道它是如何工作的了!

社区

感谢这篇精彩的文章!我从 nanoVLM 项目中学到了很多。
我不是生成式 AI 方面的专家,但我注意到注意力计算示例似乎缺少缩放 √(d_k)。这是为了简化而故意省略的吗?

d_k = K.shape[-1]
attention_scores = (Q @ K.T) / math.sqrt(d_k)

据我理解,这种缩放可以防止点积变得过大,并控制 softmax 区域。

·
文章作者

这个发现太棒了!

你愿意为博客文章提交一个包含这些更改的 PR 吗?

这是博客文章的源代码:https://github.com/huggingface/blog/blob/main/kv-cache.md

读得真好,我发现预填充和解码的解释非常直观。干得漂亮 👏

这是我用 kv 缓存时注意力机制内部发生的情况的可视化表示。
我想与社区分享 🤗

·
文章作者

太酷了!感谢分享。

注册登录 以评论