Transformers 文档

KV 缓存策略

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始

KV 缓存策略

键值 (KV) 向量用于计算注意力分数。对于自回归模型,每次都会计算 KV 分数,因为模型一次预测一个 token。每次预测都依赖于之前的 token,这意味着模型每次都执行相同的计算。

KV 缓存存储这些计算结果,以便可以重用它们而无需重新计算。高效的缓存对于优化模型性能至关重要,因为它减少了计算时间并提高了响应速度。有关缓存工作原理的更详细说明,请参阅 缓存 文档。

Transformers 提供了几个 Cache 类,它们实现了不同的缓存机制。其中一些 Cache 类经过优化以节省内存,而另一些则旨在最大限度地提高生成速度。请参考下表比较缓存类型,并使用它来帮助您为您的用例选择最佳缓存。

缓存类型 内存效率   支持 torch.compile() 建议初始化 延迟 长上下文生成
动态缓存
静态缓存
卸载缓存
卸载静态缓存
量化缓存
滑动窗口缓存
Sink 缓存

本指南向您介绍不同的 Cache 类,并向您展示如何使用它们进行生成。

默认缓存

DynamicCache 是大多数模型的默认缓存类。它允许缓存大小动态增长,以便随着生成的进行存储越来越多的键和值。

通过在 generate() 中配置 use_cache=False 来禁用缓存。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

model.generate(**inputs, do_sample=False, max_new_tokens=20, use_cache=False)

缓存类也可以先初始化,然后再调用并将其传递给模型的 past_key_values 参数。这种缓存初始化策略仅建议用于某些缓存类型。

在大多数其他情况下,在 cache_implementation 参数中定义缓存策略更容易。

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

past_key_values = DynamicCache()
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, past_key_values=past_key_values)

内存高效缓存

KV 缓存可能占用大量内存,并成为 瓶颈 长上下文生成的瓶颈。内存高效缓存侧重于牺牲速度以减少内存使用量。这对于大型语言模型 (LLMs) 以及硬件内存受限的情况尤其重要。

卸载缓存

OffloadedCache 通过将大多数模型层的 KV 缓存移动到 CPU 来节省 GPU 内存。在模型在各层上进行 forward 迭代期间,只有当前层缓存保留在 GPU 上。OffloadedCache 异步预取下一层缓存并将上一层缓存发送回 CPU。

这种缓存策略始终生成与 DynamicCache 相同的结果,并且可以作为直接替换或回退方案。如果您有 GPU 并且遇到内存不足 (OOM) 错误,您可能需要使用 OffloadedCache

DynamicCache 相比,您可能会注意到生成吞吐量略有下降,具体取决于您的模型和生成选择(上下文大小、生成的 token 数量、beam 数量等)。

通过在 GenerationConfiggenerate() 中配置 cache_implementation="offloaded" 来启用 OffloadedCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

ckpt = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("Fun fact: The shortest", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=23, cache_implementation="offloaded")
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
Fun fact: The shortest war in history was between Britain and Zanzibar on August 27, 1896.

下面的示例展示了当您内存不足时如何回退到 OffloadedCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def resilient_generate(model, *args, **kwargs):
    oom = False
    try:
        return model.generate(*args, **kwargs)
    except torch.cuda.OutOfMemoryError as e:
        print(e)
        print("retrying with cache_implementation='offloaded'")
        oom = True
    if oom:
        torch.cuda.empty_cache()
        kwargs["cache_implementation"] = "offloaded"
        return model.generate(*args, **kwargs)

ckpt = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16).to("cuda:0")
prompt = ["okay "*1000 + "Fun fact: The most"]
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
beams = { "num_beams": 40, "num_beam_groups": 40, "num_return_sequences": 40, "diversity_penalty": 1.0, "max_new_tokens": 23, "early_stopping": True, }
out = resilient_generate(model, **inputs, **beams)
responses = tokenizer.batch_decode(out[:,-28:], skip_special_tokens=True)

量化缓存

QuantizedCache 通过将 KV 值量化为较低的精度来减少内存需求。QuantizedCache 目前支持两个量化后端。

如果上下文长度较短并且有足够的 GPU 内存可用于生成而无需启用缓存量化,则量化缓存可能会损害延迟。尝试在内存效率和延迟之间找到平衡。

通过在 GenerationConfig 中配置 cache_implementation="quantized" 来启用 QuantizedCache,并在 QuantizedCacheConfig 中指示量化后端。任何其他与量化相关的参数也应作为 dict 或 QuantizedCacheConfig 的实例传递。除非您遇到内存不足的情况,否则应使用这些附加参数的默认值。在这种情况下,请考虑减少残差长度。

HQQQuantizedCache
Quanto

对于 HQQQuantizedCache,我们建议将 axis-keyaxis-value 参数设置为 1

from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"axis-key": 1, "axis-value": 1, "backend": "hqq"})
print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel

Sink 缓存

SinkCache 能够生成非常长的序列(根据论文为“无限长度”),方法是仅保留序列中的几个初始 token。这些被称为 sink token,因为它们在生成过程中占注意力分数的很大一部分。后续的 token 在滑动窗口的基础上被丢弃,并且仅保留最新的 window_size 个 token。这意味着大部分先前的知识被丢弃。

Sink token 允许模型即使在处理非常长的文本序列时也能保持稳定的性能。

通过首先使用 window_lengthnum_sink_tokens 参数初始化 SinkCache,然后再将其传递给 past_key_valuesgenerate() 中来启用 SinkCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("This is a long story about unicorns, fairies and magic.", return_tensors="pt").to(model.device)

past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
out = model.generate(**inputs, do_sample=False, max_new_tokens=30, past_key_values=past_key_values)
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"This is a long story about unicorns, fairies and magic. It is a fantasy world where unicorns and fairies live together in harmony. The story follows a young girl named Lily"

速度优化缓存

默认的 DynamicCache 阻止您利用即时 (JIT) 优化,因为缓存大小不是固定的。JIT 优化使您能够以牺牲内存使用量为代价来最大限度地提高延迟。以下所有缓存类型都与 JIT 优化(如 torch.compile)兼容,以加速生成。

静态缓存

StaticCache 为 kv 对预先分配一个特定的最大缓存大小。您可以生成最多达到最大缓存大小的内容,而无需修改它。

通过在 generate() 中配置 cache_implementation="static" 来启用 StaticCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="static")
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"

卸载静态缓存

OffloadedStaticCache卸载缓存 非常相似,不同之处在于缓存大小设置为最大缓存大小。否则,OffloadedStaticCache 仅将当前层缓存保留在 GPU 上,其余部分移动到 CPU。

通过在 generate() 中配置 cache_implementation="offloaded_static" 来启用 OffloadedStaticCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, device_map="auto")
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="offloaded_static")
tokenizer.batch_decode(out, skip_special_tokens=True)[0]
"Hello, my name is [Your Name], and I am a [Your Profession] with [Number of Years] of"

缓存卸载需要 CUDA GPU。

滑动窗口缓存

SlidingWindowCache 在之前的 kv 对上实现滑动窗口,并且仅保留最后 sliding_window 个 token。这种缓存类型旨在仅与支持滑动窗口注意力的模型一起使用,例如 Mistral。较旧的 kv 状态被丢弃并由新的 kv 状态替换。

通过在 generate() 中配置 cache_implementation="sliding_window" 来启用 SlidingWindowCache

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", torch_dtype=torch.float16).to("cuda:0")
inputs = tokenizer("Yesterday I was on a rock concert and.", return_tensors="pt").to(model.device)

out = model.generate(**inputs, do_sample=False, max_new_tokens=30, cache_implementation="sliding_window")
tokenizer.batch_decode(out, skip_special_tokens=True)[0]

模型缓存

某些模型类型,如编码器-解码器模型或 Gemma2Mamba,具有专用的缓存类。

编码器-解码器缓存

EncoderDecoderCache 专为编码器-解码器模型设计。它管理自注意力缓存和交叉注意力缓存,以确保存储和检索之前的 kv 对。可以为编码器和解码器分别设置不同的缓存类型。

这种缓存类型不需要任何设置。在调用 generate() 或模型的 forward 方法时可以使用它。

EncoderDecoderCache 目前仅支持 Whisper

模型特定缓存

某些模型具有独特的存储过去 kv 对或状态的方式,这与其他任何缓存类都不兼容。

Gemma2 需要 HybridCache,它在底层结合使用了 SlidingWindowCache 用于滑动窗口注意力,以及 StaticCache 用于全局注意力。

Mamba 需要 MambaCache,因为该模型没有注意力机制或 kv 状态。

迭代生成

缓存也可以在迭代生成设置中使用,在这种设置中,模型之间存在来回交互(聊天机器人)。与常规生成一样,使用缓存的迭代生成允许模型有效地处理正在进行的对话,而无需在每个步骤中重新计算整个上下文。

对于使用缓存的迭代生成,首先初始化一个空的缓存类,然后您可以输入新的提示。使用 聊天模板 跟踪对话历史记录。

如果您正在使用 SinkCache,则需要将输入截断为最大长度,因为 SinkCache 可以生成超过其最大窗口大小的文本。但是,第一个输入不应超过最大缓存长度。

下面的示例演示了如何使用缓存进行迭代生成。

import torch
from transformers import AutoTokenizer,AutoModelForCausalLM
from transformers.cache_utils import (
    DynamicCache,
    SinkCache,
    StaticCache,
    SlidingWindowCache,
    QuantoQuantizedCache,
    QuantizedCacheConfig,
)

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_id)

user_prompts = ["Hello, what's your name?", "Btw, yesterday I was on a rock concert."]

past_key_values = DynamicCache()
max_cache_length = past_key_values.get_max_length()

messages = []
for prompt in user_prompts:
    messages.append({"role": "user", "content": prompt})
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
    if isinstance(past_key_values, SinkCache):
        inputs = {k: v[:, -max_cache_length:] for k, v in inputs.items()}
    input_length = inputs["input_ids"].shape[1]
    outputs = model.generate(**inputs, do_sample=False, max_new_tokens=256, past_key_values=past_key_values)
    completion = tokenizer.decode(outputs[0, input_length: ], skip_special_tokens=True)
    messages.append({"role": "assistant", "content": completion})

预填充缓存

在某些情况下,您可能希望使用特定前缀提示的键值对来填充 Cache,并复用它来生成不同的序列。

下面的示例初始化了一个 StaticCache,然后缓存了一个初始提示。现在您可以从预填充的提示生成多个序列。

import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example) 
# You can also init a DynamicCache, if that suits you better
prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be able to copy
with torch.no_grad():
     prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values

prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
    past_key_values = copy.deepcopy(prompt_cache)
    outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20) 
    response = tokenizer.batch_decode(outputs)[0]
    responses.append(response)

print(responses)
< > 在 GitHub 上更新