Transformers 文档
连续批处理
并获得增强的文档体验
开始使用
连续批处理 (Continuous batching)
持续批处理通过在每个生成步骤动态重新调度批次来最大限度地提高 GPU 利用率。当请求完成时,新请求会立即加入,而无需等待整个批次完成。GPU 始终保持满载,吞吐量保持在较高水平。
对于生产部署,请使用 transformers serve。它构建在 ContinuousBatchingManager 之上,并提供了一个兼容 OpenAI 的 HTTP 端点。
generate_batch
持续批处理通过 generate_batch() 提供支持。传入一个标记化(tokenized)提示词列表,并在所有请求完成后获取结果。generate_batch 在内部处理调度,并会阻塞直到所有请求完成。
对于服务和流式处理用例,请直接使用 ContinuousBatchingManager 来管理请求。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import ContinuousBatchingConfig, GenerationConfig
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B",
attn_implementation="flash_attention_2",
device_map="cuda",
dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
prompts = [
"Whats up?",
"Name a cat breed.",
"Write a detailed history of quantum mechanics.",
]
inputs = [tokenizer.encode(p) for p in prompts]
generation_config = GenerationConfig(
max_new_tokens=64,
eos_token_id=tokenizer.eos_token_id,
)
outputs = model.generate_batch(inputs=inputs, generation_config=generation_config)
for request_id, output in outputs.items():
text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
print(f"[{request_id}] {text}")ContinuousBatchingManager
ContinuousBatchingManager 运行一个后台线程,允许您独立提交请求并检索结果。在每个生成步骤中,它都会检查已完成的请求并调度新请求加入批次。这对于流式传输、实时服务或按请求到达顺序提交请求非常有用。
使用 continuous_batching_context_manager() 来安全地启动和停止管理器。下面的示例包含长度不等的输入。一旦最短的提示词生成完毕,它就会离开批次,而较长的提示词则继续生成。使用静态批处理时,您必须将它们填充(pad)到相同的长度。持续批处理可以释放已完成的提示词,以便您立即开始处理下一个提示词。
with model.continuous_batching_context_manager(generation_config=generation_config) as manager:
manager.add_request(
input_ids=tokenizer.encode("Write a detailed history of quantum mechanics."),
request_id="long",
max_new_tokens=512,
)
manager.add_request(
input_ids=tokenizer.encode("What's up?"),
request_id="short_0",
max_new_tokens=32,
)
manager.add_request(
input_ids=tokenizer.encode("Name a cat breed."),
request_id="short_1",
max_new_tokens=32,
)
for result in manager:
text = tokenizer.decode(result.generated_tokens, skip_special_tokens=True)
print(f"[{result.request_id}] {text}")您也可以调用 init_continuous_batching() 自行管理生命周期。
manager = model.init_continuous_batching(generation_config=generation_config)
manager.start()
# submit and retrieve requests...调用 ContinuousBatchingManager.stop() 来终止管理器。
manager.stop()
添加请求
add_request() 用于提交单个请求。提供一个 request_id,或者让管理器自动生成一个。
manager.add_request(input_ids=input_ids, request_id="my_request")add_requests() 用于一次性提交一个批次。它会自动对输入进行排序,以便在启用块共享(block sharing)时最大限度地提高前缀缓存命中率。
manager.add_requests(inputs=inputs)
使用 cancel_request() 取消请求。
manager.cancel_request(request_id="my_request")单请求采样参数
启用 per_request_processors 可以在同一前向传递中为每个请求独立应用 temperature、top_k 和 top_p,从而允许不同请求使用不同的采样参数(例如:创意型的高温输出与精确型的低温输出)。
cb_config = ContinuousBatchingConfig(per_request_processors=True)
# each request gets its own sampling parameters
manager.add_request(input_ids=inputs_a, temperature=0.9, top_p=0.95)
manager.add_request(input_ids=inputs_b, temperature=0.1, top_k=10)GenerationConfig 中的每个参数都必须是非默认值,以便在运行时创建相应的 Logits 处理器。例如,将 temperature 设置为 None 或 1 以外的值,以支持单请求温度控制。此后仍然可以创建温度为 1 的请求。
检索结果
遍历管理器以接收陆续到达的结果。
for result in manager:
print(tokenizer.decode(result.generated_tokens, skip_special_tokens=True))get_result() 从输出队列中获取下一个结果。传递 request_id 以筛选特定请求。如果队列中的下一个结果不匹配,它会被重新放入队列,并且该方法返回 None。
# next available result
result = manager.get_result()
# filter for a specific request
result = manager.get_result(request_id="my_request")流式处理
在请求上设置 streaming=True,然后使用 request_id_iter() 在 Token 生成时遍历部分输出。
from transformers.generation.continuous_batching import RequestStatus
manager.add_request(input_ids=input_ids, request_id="streamed", streaming=True)
for chunk in manager.request_id_iter(request_id="streamed"):
token = tokenizer.decode(chunk.generated_tokens[-1:], skip_special_tokens=True)
print(token, end="", flush=True)
if chunk.status == RequestStatus.FINISHED:
breakContinuousBatchingConfig
ContinuousBatchingConfig 控制 KV 缓存、调度、CUDA 图、内存使用等。将其与 GenerationConfig 一起传递以自定义持续批处理。
默认情况下,num_blocks 和 max_batch_tokens 会根据可用 GPU 内存自动推断。使用下表来帮助您选择合适的功能。
| 特性 | 内存 | 吞吐量 | 延迟 |
|---|---|---|---|
max_memory_percent / block_size | ✓ 控制 KV 预算 | ||
调度器 | ✓ 调度策略 | ✓ TTFT(首Token延迟) | |
| CUDA 图 | ↑ 图存储 | ✓ 减少分发开销 | ✓ |
| 异步批处理 | ↑ ~2倍 I/O 缓冲区 | ✓ 重叠 CPU/GPU 计算 | |
| 解码快速路径 | ↑ 每个请求的块表 | ✓ 加快仅解码步骤 | ✓ |
| CPU 卸载 | ↑ 固定 CPU 内存 | ✓ 跳过部分预填充 | |
| 前缀缓存 | ↓ 共享 KV 块 | ✓ 跳过冗余预填充 | ✓ TTFT(首Token延迟) |
| 分页注意力 (Paged Attention) | ↓ 无碎片 | ✓ 动态批处理成员 | |
| 滑动窗口 | ↓ 每层有界 KV | ||
| 单请求处理器 | ✓ 单批次内混合采样参数 |
from transformers.generation import ContinuousBatchingConfig
cb_config = ContinuousBatchingConfig(
max_memory_percent=0.8, # fraction of free GPU memory to use for the KV cache
block_size=256, # KV cache block size in tokens
scheduler_type="fifo", # "fifo" or "prefill_first"
)
outputs = model.generate_batch(
inputs=inputs,
generation_config=generation_config,
continuous_batching_config=cb_config,
)对数概率 (Log probabilities)
ContinuousBatchingConfig 在 return_logprobs=True 时返回每个生成的 Token 的对数概率。这对于强化学习(RL)很有用,因为对数概率是某些训练循环的输入。
cb_config = ContinuousBatchingConfig(return_logprobs=True)
# generate_batch()
for request_id, output in outputs.items():
for token_id, log_prob in zip(output.generated_tokens, output.logprobs):
token = tokenizer.decode([token_id])
print(f"{token} | logprob: {log_prob}")CUDA 图
CUDA 图通过记录 GPU 执行图并在具有匹配形状的批次中重放它来消除 CPU 分发开销。通过 use_cuda_graph=True 显式启用它们。
cb_config = ContinuousBatchingConfig(use_cuda_graph=True)当激活时,管理器会将查询和 KV 长度填充到固定间隔,以便形状重复并重用图形。较小的 q_padding_interval_size 和 kv_padding_interval_size 值会减少因填充而造成的计算浪费,但这意味着图形需要记录和存储更多唯一形状,从而消耗更多内存。
cb_config = ContinuousBatchingConfig(
use_cuda_graph=True,
q_padding_interval_size=64,
kv_padding_interval_size=16384,
max_cached_graphs=32,
)异步批处理
异步批处理将下一个批次的 CPU 调度与当前批次的 GPU 计算重叠起来。它需要 CUDA 图,并且大约会使输入张量使用的 VRAM 加倍。
cb_config = ContinuousBatchingConfig(
use_cuda_graph=True,
use_async_batching=True,
)解码快速路径
当批次仅包含解码请求(每个序列一个查询 Token)时,管理器可以分发到 flash_attn_with_kvcache 内核,而不是变长内核。这比变长路径更快,因为内核通过块表就地读取和写入分页 KV 缓存,而不是通过手动更新。有关内核级详细信息,请参阅 分页注意力 (Paged attention)。
快速路径的大小由 max_blocks_per_request 控制,它决定了每个请求的块表维度。默认情况下,这是自动推断的。如果管理器上设置了 max_prompt_length 和 max_generated_length,则块表的大小将调整为适合最大序列长度。否则,使用后备默认值(每个请求 32 个块)。
将 max_blocks_per_request 设置为特定值以显式调整块表大小。当您知道每个请求的最大序列长度并希望限制块表内存成本时,这非常有用。
cb_config = ContinuousBatchingConfig(max_blocks_per_request=64)将 max_blocks_per_request=0 设置为禁用快速路径,并强制每个批次通过变长内核。这会恢复默认之前的行为,当快速路径不可用于您的注意力实现时(管理器在无法使用底层内核时也会自动禁用它),这非常有用。
cb_config = ContinuousBatchingConfig(max_blocks_per_request=0)快速路径依赖于 flash_attn_with_kvcache 内核,该内核适用于两种设备和注意力实现组合。
| 设备 | attn_implementation |
|---|---|
| CUDA | flash_attention_3 |
| XPU | flash_attention_2 |
对于任何其他组合,或者当无法导入内核时,管理器将退回到变长路径。它仅在您显式设置 max_blocks_per_request 时记录警告。
CPU 卸载
当 GPU KV 缓存满时,CPU 卸载会将已驱逐的 KV 缓存块复制到预分配的固定 CPU 缓冲区中。当缓存空间可用后,管理器会将这些块复制回 GPU,并恢复请求,而无需重新计算其提示词和已生成的 Token。
将 cpu_offload_space 设置为以 GiB 为单位的 CPU 交换空间。默认值 0.0 禁用 CPU 卸载。
cb_config = ContinuousBatchingConfig(cpu_offload_space=8.0)默认情况下,当安装了 psutil 时,cpu_offload_space_safety_threshold=0.8 会将请求的空间限制为可用系统 RAM 的 80%。将 cpu_offload_space=None 设置为从安全阈值确定交换池大小。
前缀缓存
当多个请求共享一个通用前缀(如系统提示词)时,管理器会重用它们的 KV 缓存块,而不是重新计算它们。此功能默认启用,并要求所有模型层使用完全注意力(对于滑动窗口模型会自动禁用)。
cb_config = ContinuousBatchingConfig(
allow_block_sharing=True, # default
)分页注意力
持续批处理需要分页注意力后端。加载模型时请设置 attn_implementation。如果您使用非分页后端("flash_attention_2")加载模型,则在启动持续批处理时会自动添加 "paged|" 前缀。
| 后端 | attn_implementation | 要求 |
|---|---|---|
| FlashAttention | “paged|flash_attention_2” | flash-attn 包 |
| SDPA(PyTorch 原生) | “paged|sdpa” | 无 |
| Eager | “paged|eager” | 无 |
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B",
attn_implementation="paged|flash_attention_2",
device_map="cuda",
dtype=torch.bfloat16,
)滑动窗口注意力
具有滑动窗口注意力的模型(Mistral, Gemma 2)适用于持续批处理。要为微调或自定义实验手动配置滑动窗口,请在加载前在模型配置中进行设置。
from transformers import AutoConfig, AutoModelForCausalLM
config = AutoConfig.from_pretrained("google/gemma-2-2b")
config.sliding_window = 4096
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2-2b",
config=config,
attn_implementation="paged|sdpa",
device_map="cuda",
dtype=torch.bfloat16,
)当滑动窗口注意力处于激活状态时,前缀缓存会自动禁用。