Transformers 文档

缓存策略

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

缓存策略

键值(KV)向量用于计算注意力分数。对于自回归模型,由于模型一次预测一个 token,因此 KV 分数会在每次计算时都进行计算。每次预测都依赖于前一个 token,这意味着模型每次都会执行相同的计算。

KV 缓存 会存储这些计算结果,以便在不重新计算的情况下重复使用。高效的缓存对于优化模型性能至关重要,因为它可以减少计算时间并提高响应速率。有关缓存工作原理的更详细说明,请参阅缓存文档。

Transformers 提供了多种 Cache 类,用于实现不同的缓存机制。其中一些 Cache 类经过优化,可以节省内存,而另一些则旨在最大限度地提高生成速度。请参阅下表,比较不同缓存类型,并帮助您为用例选择最佳缓存。

缓存类型 支持滑动层 支持卸载 支持 torch.compile() 预期内存使用量
动态缓存 中型
静态缓存
量化缓存 否   

本指南将向您介绍不同的 Cache 类,并演示如何使用它们进行生成。

默认缓存

DynamicCache 是所有模型的默认缓存类。它允许缓存大小动态增长,以便在生成过程中存储越来越多的键值。

请注意,对于使用滑动窗口注意力(Mistral、Gemma2 等)或分块注意力(Llama4)的模型,缓存将在使用这些注意力类型的层达到最大尺寸(滑动窗口或分块大小)时停止增长。

通过在 generate() 中配置 use_cache=False 来禁用缓存。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", dtype=torch.float16, device_map="auto")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)

也可以在调用并将其传递给模型的 past_key_values 参数之前,先初始化缓存类。这对于更细粒度的控制或更高级的用法(如上下文缓存)非常有用。

在大多数情况下,在 cache_implementation 参数中定义缓存策略更为容易。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", dtype=torch.float16, device_map="auto")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

past_key_values = DynamicCache(config=model.config)
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, past_key_values=past_key_values)

固定大小缓存

默认的 DynamicCache 由于缓存大小不是固定的,因此无法利用大多数即时(JIT)优化。JIT 优化可以让你在内存使用方面付出代价来最大化延迟。以下所有缓存类型都兼容 JIT 优化,例如 torch.compile,以加速生成。

固定大小缓存(StaticCache)会为 kv 对预分配一个特定的最大缓存大小。你可以在不修改的情况下生成多达最大缓存大小的内容。但是,键/值状态具有固定(通常很大)的大小意味着在生成过程中,许多 token 将被掩码,因为它们不应参与注意力计算。因此,这个技巧可以轻松地 compile 解码阶段,但会浪费 token 在注意力计算中。就像所有事物一样,这是一个权衡,如果你以长度大致相同的多个序列生成,它会非常好,但如果你有一个非常长的序列,然后只有短序列(因为固定缓存大小会很大,短序列会浪费很多),它可能不是最佳选择。使用它时请务必理解其影响!

DynamicCache 类似,请注意,对于使用滑动窗口注意力(Mistral、Gemma2 等)或分块注意力(Llama4)的模型,即使指定的最大长度更大,缓存也不会超过使用这些注意力类型的层的滑动窗口/分块大小。

可以通过在 generate() 中配置 cache_implementation="static" 来启用 StaticCache。这还将自动启用对贪婪和采样解码策略的解码阶段进行编译

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="static")
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"

缓存卸载

KV 缓存可能占用大量内存,并成为长上下文生成的瓶颈。内存高效缓存侧重于牺牲速度以换取更少的内存使用。这对于大型语言模型(LLM)以及硬件内存受限的情况尤其重要。

卸载缓存可以通过将除一层外的模型层的 KV 缓存移动到 CPU 来节省 GPU 内存。在模型通过各层进行 forward 迭代期间,只有当前层的缓存保留在 GPU 上。它将异步预取下一层的缓存,并在注意力计算完成后将当前层的缓存发送回 CPU。

如果您使用的 GPU 较小且遇到内存不足(OOM)错误,则可以考虑卸载。

与完整的设备内缓存相比,生成吞吐量可能会略有下降,具体取决于您的模型和生成选项(上下文大小、生成的 token 数量、beam 数量等)。这是因为来回移动键/值状态需要一些工作。

卸载功能可用于 DynamicCacheStaticCache。您可以通过在 GenerationConfiggenerate() 中配置 cache_implementation="offloaded"(动态版本)或 cache_implementation="offloaded_static"(静态版本)来启用它。此外,您还可以先实例化自己的 DynamicCacheStaticCache,并设置 offloading=True 选项,然后将其传递给 generate 或模型的 forward(例如,对于动态缓存,past_key_values=DynamicCache(config=model.config, offloading=True))。

请注意,上面提到的两个 Cache 类在直接实例化时还有一个附加选项 offload_only_non_sliding。这个附加参数决定了使用滑动窗口/分块注意力的层(如果有)是否也会被卸载。由于这些层通常很短,最好避免卸载它们,因为卸载可能会导致速度下降。默认情况下,此选项对 DynamicCacheFalse,对 StaticCacheTrue

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

ckpt = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=torch.float16, device_map="auto")
inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.

下面的示例展示了当内存不足时如何回退到卸载缓存。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import Accelerator

def resilient_generate(model, *args, **kwargs):
    oom = False
    device = Accelerator().device
    torch_device_module = getattr(torch, device, torch.cuda)
    try:
        return model.generate(*args, **kwargs)
    except torch.OutOfMemoryError as e:
        print(e)
        print("retrying with cache_implementation='offloaded'")
        oom = True
    if oom:
        torch_device_module.empty_cache()
        kwargs["cache_implementation"] = "offloaded"
        return model.generate(*args, **kwargs)

ckpt = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, dtype=torch.float16, device_map="auto")
prompt = ["okay "*1000 + "Fun fact: The most"]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
beams = { "num_beams": 40, "num_return_sequences": 20, "max_new_tokens": 23, "early_stopping": True, }
out = resilient_generate(model, **inputs, **beams)
responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True)

量化缓存

QuantizedCache 通过将 KV 值量化为较低精度来减少内存需求。 QuantizedCache 目前支持两种量化后端:

  • hqq 支持 int2、int4 和 int8 数据类型。
  • quanto 支持 int2 和 int4 数据类型。这是默认的量化后端。

量化缓存可能会损害延迟,如果上下文长度较短且有足够的 GPU 内存可用于在不启用缓存量化的情况下进行生成。尝试在内存效率和延迟之间找到平衡。

通过在 GenerationConfig 中配置 cache_implementation="quantized" 来启用 QuantizedCache,并且量化后端以及任何额外的量化相关参数也应以字典形式传递。您应该使用这些附加参数的默认值,除非您遇到内存不足的情况。在这种情况下,请考虑减小残差长度。

<hfoptions id="quantized-cache">

对于 hqq 后端,我们建议将 axis-keyaxis-value 参数设置为 1

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, QuantizedCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", dtype=torch.float16, device_map="auto")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"backend": "hqq"})
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

对于 quanto 后端,我们建议将 axis-keyaxis-value 参数设置为 0

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", dtype=torch.float16, device_map="auto")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

编码器-解码器缓存

EncoderDecoderCache 专为编码器-解码器模型设计。它管理自注意力和交叉注意力缓存,以确保先前 kv 对的存储和检索。可以为编码器和解码器单独设置不同的缓存类型。

此缓存类型不需要任何设置。它只是一个包装了 2 个 Cache(如上所述)的包装器,模型将直接独立使用它们。

特定于模型的缓存

某些模型存储过去 kv 对或状态的方式独特,与其他缓存类不兼容。

Mamba 模型(例如 Mamba)需要特定的缓存,因为该模型没有注意力机制或 kv 状态。因此,它们与上述 Cache 类不兼容。

迭代生成

缓存也可以在迭代生成场景中使用,即与模型进行来回交互(聊天机器人)。与常规生成一样,迭代生成配合缓存可以使模型高效地处理持续的对话,而无需在每一步都重新计算整个上下文。

对于带有缓存的迭代生成,首先初始化一个空的缓存类,然后您可以输入新的提示。使用 聊天模板 来跟踪对话历史。

以下示例演示了 Llama-2-7b-chat-hf。如果您使用的是不同的聊天式模型,apply_chat_template() 可能会以不同的方式处理消息。它可能会根据 Jinja 模板的编写方式截断重要 token。

例如,某些模型在推理过程中使用特殊的 <think> ... </think> token。这些 token 在重新编码时可能会丢失,导致索引问题。您可能需要手动删除或调整完成中的额外 token 以保持稳定性。

import torch
from transformers import AutoTokenizer,AutoModelForCausalLM, DynamicCache, StaticCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."]

past_key_values = DynamicCache(config=model.config)

messages = []
for prompt in user_prompts:
    messages.append({"role": "user", "content": prompt})
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
    input_length = inputs["input_ids"].shape[1]
    outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
    completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
    messages.append({"role": "assistant", "content": completion})

预填充缓存(前缀缓存)

在某些情况下,您可能希望使用特定前缀提示填充 Cache,并将其重用于生成不同的序列。

下面的示例初始化了一个 StaticCache,然后缓存了一个初始提示。现在您可以从预填充的提示中生成多个序列。

import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.bfloat16, device_map={"": 0})
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_cache_len=1024)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to(model.device.type)
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
     prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values

prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to(model.device.type)
    past_key_values = copy.deepcopy(prompt_cache)
    outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20)
    response = tokenizer.batch_decode(outputs)[0]
    responses.append(response)

print(responses)
在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.