Transformers 文档

缓存

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

缓存

想象一下你正在和某人聊天,他们没有记住之前说过的话,而是每次你回应时都必须从头开始。这会非常慢且效率低下,对吗?

你可以将这个类比扩展到 transformer 模型。自回归模型生成可能很慢,因为它一次预测一个 token。每个新的预测都依赖于所有先前的上下文。

为了预测第 1000 个 token,模型需要来自前 999 个 token 的信息。这些信息表示为 token 表示之间的矩阵乘法。

为了预测第 1001 个 token,除了来自第 1000 个 token 的任何信息之外,你还需要来自前 999 个 token 的相同信息。模型必须为每个 token 反复计算大量的矩阵乘法!

键值 (KV) 缓存通过存储从先前处理的 token 的注意力层派生出的 kv 对来消除这种低效率。存储的 kv 对从缓存中检索并重新用于后续 token,避免了重新计算的需要。

缓存只应用于推理。如果在训练期间启用它,可能会导致意外错误。

为了更好地理解缓存的工作原理和原因,让我们仔细看看注意力矩阵的结构。

注意力矩阵

批处理大小为 b,注意力头数为 h,到目前为止的序列长度为 T,每个注意力头的维度为 d_head缩放点积注意力计算如下:Attention(Q,K,V)=softmax(QKdhead×mask)V \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^\top}{\sqrt{d_{\text{head}}}} \times \text{mask} \right) V

查询 (Q)、键 (K) 和值 (V) 矩阵是输入嵌入的投影,形状为 (b, h, T, d_head)

对于因果注意力,掩码阻止模型关注未来的 token。一旦 token 被处理,它的表示相对于未来的 token 就不会改变,这意味着Kpast K_{\text{past}} Vpast V_{\text{past}} 可以缓存并重新用于计算最后一个 token 的表示。Attention(qt,[k1,k2,,kt1cached,kt],[v1,v2,,vt1cached,vt]) \text{Attention}(q_t, [\underbrace{k_1, k_2, \dots, k_{t-1}}_{\text{cached}}, k_{t}], [\underbrace{v_1, v_2, \dots, v_{t-1}}_{\text{cached}}, v_{t}])

在推理时,你只需要最后一个 token 的查询来计算表示xt x_t ,它预测下一个 tokent+1 t+1 。在每个步骤中,新的键和值向量都被存储在缓存中,并附加到过去的键和值中。Kcacheconcat(Kpast,kt),Vcacheconcat(Vpast,vt) K_{\text{cache}} \leftarrow \text{concat}(K_{\text{past}}, k_t), \quad V_{\text{cache}} \leftarrow \text{concat}(V_{\text{past}}, v_t)

注意力在模型的每一层独立计算,并且缓存是逐层进行的。

请参考下表比较缓存如何提高效率。

不使用缓存 使用缓存
对于每个步骤,重新计算所有先前的 KV 对于每个步骤,仅计算当前的 KV
每个步骤的注意力成本与序列长度呈二次关系 每个步骤的注意力成本与序列长度呈线性关系(内存线性增长,但计算/token 保持较低)

缓存类

一个基本的 KV 缓存接口接收当前 token 的键张量和值张量,并返回更新后的 KV 张量。这由模型的 forward 方法内部管理。

new_K, new_V = cache.update(k_t, v_t, layer_idx)
attn_output = attn_layer_idx_fn(q_t, new_K, new_V)

当你使用 Transformers 的 Cache 类时,自注意力模块执行几个关键步骤来整合过去和现在的信息。

  1. 注意力模块将当前 kv 对与缓存中存储的过去 kv 对连接起来。这会创建形状为 (new_tokens_length, past_kv_length + new_tokens_length) 的注意力权重。当前和过去的 kv 对本质上是组合起来计算注意力分数,确保模型了解以前的上下文和当前输入。

  2. forward 方法迭代调用时,注意力掩码的形状与过去和当前 kv 对的组合长度匹配至关重要。注意力掩码的形状应为 (batch_size, past_kv_length + new_tokens_length)。这通常在 generate() 中内部处理,但如果你想使用 Cache 实现自己的生成循环,请记住这一点!注意力掩码应包含过去和当前 token 值。

  3. 同样重要的是要注意 cache_position。如果你想用 forward 方法重用预填充的 Cache,这很重要,因为你必须传递一个有效的 cache_position 值。这表示序列中的输入位置。cache_position 不受填充影响,并且它总是为每个 token 增加一个位置。例如,如果 kv 缓存包含 10 个 token - 无论填充 token 如何 - 下一个 token 的缓存位置应该是 torch.tensor([10])

缓存存储实现

键值对的实际存储在不同的缓存实现之间有所不同。例如,考虑 DynamicCache

DynamicCache 中,键值对作为两个张量列表存储。列表中的每个张量都具有形状 [batch_size, num_heads, seq_len, head_dim]

  • key_cache:一个张量列表,每层一个。
  • value_cache:一个张量列表,每层一个。

当处理新 token 时

  1. 对于每一层,新的键和值状态与现有缓存连接。
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  1. 随着更多 token 的处理,缓存会动态增长。序列长度维度(seq_len)随着每个新 token 的增加而增加。

  2. 缓存通过 self._seen_tokens 维护已看到的 token 计数。当第一层处理新 token 时,此计数会更新。

以下示例演示了如何使用 DynamicCache 创建生成循环。如前所述,注意力掩码是过去和当前 token 值的连接,并且为下一个 token 的缓存位置添加 1

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to("cuda:0")

generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device="cuda:0")
max_new_tokens = 10

for _ in range(max_new_tokens):
    outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
    # Greedily sample one next token
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)
    # Prepare inputs for the next generation step by leaving unprocessed tokens, in our case we have only one new token
    # and expanding attn mask for the new token, as explained above
    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
    cache_position = cache_position[-1:] + 1 # add one more position for the next token

print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])
"[INST] Hello, what's your name. [/INST]  Hello! My name is LLaMA,"

传统缓存格式

Cache 类之前,缓存通常以张量元组的元组形式存储。这种格式是动态的,因为它会随着文本的生成而增长,类似于 DynamicCache

传统格式本质上是相同的数据结构,但组织方式不同。

  • 它是一个元组的元组,其中每个内部元组包含一层的键和值张量。
  • 张量具有相同的形状 [batch_size, num_heads, seq_len, head_dim]
  • 这种格式灵活性较低,不支持量化或卸载等功能。

如果你的项目依赖于此传统格式,你可以使用 from_legacy_cache()DynamicCache.to_legacy_cache() 函数在 DynamicCache 和元组的元组之间进行转换。如果你有用于以特定格式操作缓存的自定义逻辑,这将很有帮助。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

# `return_dict_in_generate=True` is required to return the cache and `return_legacy_cache` forces the returned cache
# in the legacy format
generation_outputs = model.generate(**inputs, return_dict_in_generate=True, return_legacy_cache=True, max_new_tokens=5)

cache = DynamicCache.from_legacy_cache(generation_outputs.past_key_values)
legacy_format_cache = cache.to_legacy_cache()
< > 在 GitHub 上更新