Transformers 文档
缓存
并获得增强的文档体验
开始使用
缓存
想象一下你正在和某人交谈,但对方每次回应时,都要从头开始记住你们之前的对话。这会很慢而且效率低下,对吧?
你可以将这个类比延伸到 Transformer 模型。自回归模型生成可能很慢,因为它一次只做一个预测。每一次新预测都依赖于所有先前的上下文。
为了预测第 1000 个 token,模型需要来自先前 999 个 token 的信息。这些信息以 token 表示的矩阵乘法来表示。
为了预测第 1001 个 token,你需要来自先前 999 个 token 的相同信息,以及来自第 1000 个 token 的任何信息。这是模型需要为每个 token 一遍又一遍计算的大量矩阵乘法!
键值(KV)缓存通过存储先前处理过的 token 的注意力层派生的 kv 对来消除这种低效率。存储的 kv 对从缓存中检索并用于后续 token,从而避免了重新计算的需要。
缓存仅应用于推理。如果在训练期间启用缓存,可能会导致意外错误。
为了更好地理解缓存是如何以及为什么工作的,让我们仔细看看注意力矩阵的结构。
注意力矩阵
缩放点积注意力的计算方式如下,其中批次大小为 b,注意力头数量为 h,当前序列长度为 T,每个注意力头的维度为 d_head。
查询 (Q)、键 (K) 和值 (V) 矩阵是从形状为 (b, h, T, d_head) 的输入嵌入中投影出来的。
对于因果注意力,掩码会阻止模型关注未来的 token。一旦一个 token 被处理,它相对于未来 token 的表示就不会改变,这意味着和可以被缓存并重用,以计算最后一个 token 的表示。
在推理时,您只需要最后一个 token 的查询来计算表示来预测下一个 token $ t+1 $。每一步,新的键和值向量都会被存储在缓存中,并追加到过去的键和值中。
注意力在模型的每一层中独立计算,缓存也是按层进行的。
请参阅下表,了解缓存如何提高效率。
| 无缓存 | 有缓存 |
|---|---|
每一步,重新计算所有先前的 K 和 V | 每一步,仅计算当前的 K 和 V |
| 每步的注意力成本与序列长度呈二次方关系 | 每步的注意力成本与序列长度呈线性关系(内存呈线性增长,但每 token 的计算量保持较低) |
缓存类
基本的 KV 缓存接口接收当前 token 的键和值张量,并返回更新后的 K 和 V 张量。这由模型内部的 forward 方法管理。
new_K, new_V = cache.update(k_t, v_t, layer_idx) attn_output = attn_layer_idx_fn(q_t, new_K, new_V)
当您使用 Transformers 的 Cache 类时,自注意力模块会执行几个关键步骤来集成过去和现在的信息。
注意力模块将当前的 kv 对与缓存中存储的 past kv 对连接起来。这会创建形状为
(new_tokens_length, past_kv_length + new_tokens_length)的注意力权重。当前和过去的 kv 对基本上被组合起来计算注意力分数,从而确保模型了解先前的上下文和当前的输入。当
forward方法被迭代调用时,注意力掩码的形状必须与 past 和 current kv 对的组合长度匹配至关重要。注意力掩码应具有形状(batch_size, past_kv_length + new_tokens_length)。这通常在 generate() 中内部处理,但如果您想使用 Cache 实现自己的生成循环,请记住这一点!注意力掩码应包含 past 和 current token 的值。了解
cache_position也很重要。如果您想使用forward方法重新使用预填充的 Cache,这一点很重要,因为您必须传递一个有效的cache_position值。这表示序列中的输入位置。cache_position不受填充的影响,并且对于每个 token 总是额外增加一个位置。例如,如果一个 kv 缓存包含 10 个 token——无论是否有填充 token——那么下一个 token 的缓存位置应该是torch.tensor([10])。
缓存存储实现
缓存被结构化为层列表,其中每层包含一个键和一个值缓存。键和值缓存是形状为 [batch_size, num_heads, seq_len, head_dim] 的张量。
层可以有不同的类型(例如 DynamicLayer、StaticLayer、StaticSlidingWindowLayer),这主要改变了序列长度的处理方式以及缓存的更新方式。
最简单的是 DynamicLayer,它随着处理更多的 token 而增长。序列长度维度 (seq_len) 随着每个新 token 的增加而增加。
cache.layers[idx].keys = torch.cat([cache.layers[idx].keys, key_states], dim=-2)
cache.layers[idx].values = torch.cat([cache.layers[idx].values, value_states], dim=-2)其他层类型,如 StaticLayer 和 StaticSlidingWindowLayer,具有固定的序列长度,该长度在创建缓存时设置。这使它们与 torch.compile 兼容。对于 StaticSlidingWindowLayer,当添加新 token 时,现有 token 会被移出缓存。
下面的示例演示了如何使用 DynamicCache 创建一个生成循环。如前所述,注意力掩码是 past 和 current token 值的连接,并且 1 被添加到下一个 token 的缓存位置。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator
device = Accelerator().device
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache(config=model.config)
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
# Greedily sample one next token
next_token_ids = outputs.logits[:, -1:].argmax(-1)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
# Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
# and expanding attn mask for the new token, as explained above
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
cache_position = cache_position[-1:] + 1 # add one more position for the next token
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA,"缓存位置
缓存位置跟踪在哪里将新 token 插入到注意力缓存中。它表示上下文中每个 token 的绝对位置,与填充或批次结构无关。假设您已经缓存了 N 个 token,现在正在处理 K 个新 token。新 token 的缓存位置将从 N 到 N + K - 1。换句话说,您正在处理位置为 - [N, N + 1, N + 2, ..., N + K - 1] 的 token。
缓存位置在内部用于两个目的:
- 选择要处理的新 token 并确保只有尚未缓存的 token 被传递给模型的
forward方法。 - 将键/值对存储在缓存的正确位置。这对于固定大小的缓存尤其重要,它预先分配了特定的缓存长度。
生成循环通常会处理缓存位置,但如果您正在编写自定义生成方法,请确保缓存位置准确,因为它们用于将键/值状态写入固定槽或从中读取。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from accelerate import Accelerator
device = Accelerator().device
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
messages = [{"role": "user", "content": "You are a helpful assistant."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=10)