Transformers 文档
缓存
并获得增强的文档体验
开始使用
缓存
想象一下你正在和某人聊天,他们没有记住之前说过的话,而是每次你回应时都必须从头开始。这会非常慢且效率低下,对吗?
你可以将这个类比扩展到 transformer 模型。自回归模型生成可能很慢,因为它一次预测一个 token。每个新的预测都依赖于所有先前的上下文。
为了预测第 1000 个 token,模型需要来自前 999 个 token 的信息。这些信息表示为 token 表示之间的矩阵乘法。
为了预测第 1001 个 token,除了来自第 1000 个 token 的任何信息之外,你还需要来自前 999 个 token 的相同信息。模型必须为每个 token 反复计算大量的矩阵乘法!
键值 (KV) 缓存通过存储从先前处理的 token 的注意力层派生出的 kv 对来消除这种低效率。存储的 kv 对从缓存中检索并重新用于后续 token,避免了重新计算的需要。
缓存只应用于推理。如果在训练期间启用它,可能会导致意外错误。
为了更好地理解缓存的工作原理和原因,让我们仔细看看注意力矩阵的结构。
注意力矩阵
批处理大小为 b
,注意力头数为 h
,到目前为止的序列长度为 T
,每个注意力头的维度为 d_head
的缩放点积注意力计算如下:
查询 (Q
)、键 (K
) 和值 (V
) 矩阵是输入嵌入的投影,形状为 (b, h, T, d_head)
。
对于因果注意力,掩码阻止模型关注未来的 token。一旦 token 被处理,它的表示相对于未来的 token 就不会改变,这意味着和可以缓存并重新用于计算最后一个 token 的表示。
在推理时,你只需要最后一个 token 的查询来计算表示,它预测下一个 token。在每个步骤中,新的键和值向量都被存储在缓存中,并附加到过去的键和值中。
注意力在模型的每一层独立计算,并且缓存是逐层进行的。
请参考下表比较缓存如何提高效率。
不使用缓存 | 使用缓存 |
---|---|
对于每个步骤,重新计算所有先前的 K 和 V | 对于每个步骤,仅计算当前的 K 和 V |
每个步骤的注意力成本与序列长度呈二次关系 | 每个步骤的注意力成本与序列长度呈线性关系(内存线性增长,但计算/token 保持较低) |
缓存类
一个基本的 KV 缓存接口接收当前 token 的键张量和值张量,并返回更新后的 K
和 V
张量。这由模型的 forward
方法内部管理。
new_K, new_V = cache.update(k_t, v_t, layer_idx) attn_output = attn_layer_idx_fn(q_t, new_K, new_V)
当你使用 Transformers 的 Cache 类时,自注意力模块执行几个关键步骤来整合过去和现在的信息。
注意力模块将当前 kv 对与缓存中存储的过去 kv 对连接起来。这会创建形状为
(new_tokens_length, past_kv_length + new_tokens_length)
的注意力权重。当前和过去的 kv 对本质上是组合起来计算注意力分数,确保模型了解以前的上下文和当前输入。当
forward
方法迭代调用时,注意力掩码的形状与过去和当前 kv 对的组合长度匹配至关重要。注意力掩码的形状应为(batch_size, past_kv_length + new_tokens_length)
。这通常在 generate() 中内部处理,但如果你想使用 Cache 实现自己的生成循环,请记住这一点!注意力掩码应包含过去和当前 token 值。同样重要的是要注意
cache_position
。如果你想用forward
方法重用预填充的 Cache,这很重要,因为你必须传递一个有效的cache_position
值。这表示序列中的输入位置。cache_position
不受填充影响,并且它总是为每个 token 增加一个位置。例如,如果 kv 缓存包含 10 个 token - 无论填充 token 如何 - 下一个 token 的缓存位置应该是torch.tensor([10])
。
缓存存储实现
键值对的实际存储在不同的缓存实现之间有所不同。例如,考虑 DynamicCache。
在 DynamicCache 中,键值对作为两个张量列表存储。列表中的每个张量都具有形状 [batch_size, num_heads, seq_len, head_dim]
。
key_cache
:一个张量列表,每层一个。value_cache
:一个张量列表,每层一个。
当处理新 token 时
- 对于每一层,新的键和值状态与现有缓存连接。
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
随着更多 token 的处理,缓存会动态增长。序列长度维度(
seq_len
)随着每个新 token 的增加而增加。缓存通过
self._seen_tokens
维护已看到的 token 计数。当第一层处理新 token 时,此计数会更新。
以下示例演示了如何使用 DynamicCache 创建生成循环。如前所述,注意力掩码是过去和当前 token 值的连接,并且为下一个 token 的缓存位置添加 1
。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)
past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")
generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10
for _ in range(max_new_tokens):
outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
# Greedily sample one next token
next_token_ids = outputs.logits[:, -1:].argmax(-1)
generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
# Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
# and expanding attn mask for the new token, as explained above
attention_mask = inputs["attention_mask"]
attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
cache_position = cache_position[-1:] + 1 # add one more position for the next token
print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST] Hello! My name is LLaMA,"
传统缓存格式
在 Cache 类之前,缓存通常以张量元组的元组形式存储。这种格式是动态的,因为它会随着文本的生成而增长,类似于 DynamicCache。
传统格式本质上是相同的数据结构,但组织方式不同。
- 它是一个元组的元组,其中每个内部元组包含一层的键和值张量。
- 张量具有相同的形状
[batch_size, num_heads, seq_len, head_dim]
。 - 这种格式灵活性较低,不支持量化或卸载等功能。
如果你的项目依赖于此传统格式,你可以使用 from_legacy_cache() 和 DynamicCache.to_legacy_cache() 函数在 DynamicCache 和元组的元组之间进行转换。如果你有用于以特定格式操作缓存的自定义逻辑,这将很有帮助。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)
# `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache
# in the legacy format
generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)
cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()