Transformers 文档

缓存

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

缓存

想象一下你正在和某人交谈,但对方每次回应时,都要从头开始记住你们之前的对话。这会很慢而且效率低下,对吧?

你可以将这个类比延伸到 Transformer 模型。自回归模型生成可能很慢,因为它一次只做一个预测。每一次新预测都依赖于所有先前的上下文。

为了预测第 1000 个 token,模型需要来自先前 999 个 token 的信息。这些信息以 token 表示的矩阵乘法来表示。

为了预测第 1001 个 token,你需要来自先前 999 个 token 的相同信息,以及来自第 1000 个 token 的任何信息。这是模型需要为每个 token 一遍又一遍计算的大量矩阵乘法!

键值(KV)缓存通过存储先前处理过的 token 的注意力层派生的 kv 对来消除这种低效率。存储的 kv 对从缓存中检索并用于后续 token,从而避免了重新计算的需要。

缓存仅应用于推理。如果在训练期间启用缓存,可能会导致意外错误。

为了更好地理解缓存是如何以及为什么工作的,让我们仔细看看注意力矩阵的结构。

注意力矩阵

缩放点积注意力的计算方式如下,其中批次大小为 b,注意力头数量为 h,当前序列长度为 T,每个注意力头的维度为 d_headAttention(Q,K,V)=softmax(QKdhead×mask)V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V

查询 (Q)、键 (K) 和值 (V) 矩阵是从形状为 (b, h, T, d_head) 的输入嵌入中投影出来的。

对于因果注意力,掩码会阻止模型关注未来的 token。一旦一个 token 被处理,它相对于未来 token 的表示就不会改变,这意味着Kpast K_{\text{past}} Vpast V_{\text{past}} 可以被缓存并重用,以计算最后一个 token 的表示。Attention(qt,[k1,k2,,kt1cached,kt],[v1,v2,,vt1cached,vt]) \text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])

在推理时,您只需要最后一个 token 的查询来计算表示xt x_t 来预测下一个 token $ t+1 $。每一步,新的键和值向量都会被存储在缓存中,并追加到过去的键和值中。Kcacheconcat(Kpast,kt),Vcacheconcat(Vpast,vt) K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)

注意力在模型的每一层中独立计算,缓存也是按层进行的。

请参阅下表,了解缓存如何提高效率。

无缓存 有缓存
每一步,重新计算所有先前的 KV 每一步,仅计算当前的 KV
每步的注意力成本与序列长度呈二次方关系 每步的注意力成本与序列长度呈线性关系(内存呈线性增长,但每 token 的计算量保持较低)

缓存类

基本的 KV 缓存接口接收当前 token 的键和值张量,并返回更新后的 KV 张量。这由模型内部的 forward 方法管理。

new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)

当您使用 Transformers 的 Cache 类时,自注意力模块会执行几个关键步骤来集成过去和现在的信息。

  1. 注意力模块将当前的 kv 对与缓存中存储的 past kv 对连接起来。这会创建形状为 (new_tokens_length, past_kv_length + new_tokens_length) 的注意力权重。当前和过去的 kv 对基本上被组合起来计算注意力分数,从而确保模型了解先前的上下文和当前的输入。

  2. forward 方法被迭代调用时,注意力掩码的形状必须与 past 和 current kv 对的组合长度匹配至关重要。注意力掩码应具有形状 (batch_size, past_kv_length + new_tokens_length)。这通常在 generate() 中内部处理,但如果您想使用 Cache 实现自己的生成循环,请记住这一点!注意力掩码应包含 past 和 current token 的值。

  3. 了解 cache_position也很重要。如果您想使用 forward 方法重新使用预填充的 Cache,这一点很重要,因为您必须传递一个有效的 cache_position 值。这表示序列中的输入位置。cache_position不受填充的影响,并且对于每个 token 总是额外增加一个位置。例如,如果一个 kv 缓存包含 10 个 token——无论是否有填充 token——那么下一个 token 的缓存位置应该是 torch.tensor([10])

缓存存储实现

缓存被结构化为层列表,其中每层包含一个键和一个值缓存。键和值缓存是形状为 [batch_size, num_heads, seq_len, head_dim] 的张量。

层可以有不同的类型(例如 DynamicLayerStaticLayerStaticSlidingWindowLayer),这主要改变了序列长度的处理方式以及缓存的更新方式。

最简单的是 DynamicLayer,它随着处理更多的 token 而增长。序列长度维度 (seq_len) 随着每个新 token 的增加而增加。

cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)

其他层类型,如 StaticLayerStaticSlidingWindowLayer,具有固定的序列长度,该长度在创建缓存时设置。这使它们与 torch.compile 兼容。对于 StaticSlidingWindowLayer,当添加新 token 时,现有 token 会被移出缓存。

下面的示例演示了如何使用 DynamicCache 创建一个生成循环。如前所述,注意力掩码是 past 和 current token 值的连接,并且 1 被添加到下一个 token 的缓存位置。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator

device = Accelerator().device

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)

generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
    # Greedily sample one next token
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
    # Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
    # and expanding attn mask for the new token, as explained above
    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
    cache_position = cache_position[-1:] + 1 # add one more position for the next token

print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST]  Hello! My name is LLaMA,"

缓存位置

缓存位置跟踪在哪里将新 token 插入到注意力缓存中。它表示上下文中每个 token 的绝对位置,与填充或批次结构无关。假设您已经缓存了 N 个 token,现在正在处理 K 个新 token。新 token 的缓存位置将从 NN + K - 1。换句话说,您正在处理位置为 - [N, N + 1, N + 2, ..., N + K - 1] 的 token。

缓存位置在内部用于两个目的:

  1. 选择要处理的新 token 并确保只有尚未缓存的 token 被传递给模型的 forward 方法。
  2. 将键/值对存储在缓存的正确位置。这对于固定大小的缓存尤其重要,它预先分配了特定的缓存长度。

生成循环通常会处理缓存位置,但如果您正在编写自定义生成方法,请确保缓存位置准确,因为它们用于将键/值状态写入固定槽或从中读取。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator

device = Accelerator().device

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)

messages = [{"role": "user", "content": "You are a helpful assistant."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10)
在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.