Transformers 文档
概览
并获得增强的文档体验
开始使用
概述
Transformers 提供了多种推理优化技术,使模型运行更快速、成本更低且更易于访问。选项包括用于减少内存流量的替代注意力机制、用于加速执行的代码编译,以及用于提高吞吐量的优化算子(kernels)。堆叠使用这些技术可获得最大性能。
内存和速度密切相关,但并不相同。缩小内存占用会使模型变得“更快”,因为需要移动的数据量减少了。纯速度优化并不总是能减少内存,有时反而会增加内存使用。请根据您的用例和硬件选择合适的优化方案。
请使用下表选择一种优化技术。
| 技术 | 速度 | 内存 |
|---|---|---|
| 编译 (Compilation) | ✅ | |
| 注意力后端 | ✅ | ✅ |
| Kernels | ✅ | ✅ |
| 量化 | ✅ | ✅ |
| 缓存 | ✅ | ✅ |
| 并行性 | ✅ | |
| 连续批处理 | ✅ |
本指南为您提供 Transformers 优化的快速入门。
编译
torch.compile 可以减少 Python 开销,融合操作,并创建针对您的形状和硬件调优的算子。首次运行会进行预热,随后的运行将使用更快的编译路径。
向 generate() 传递一个 固定大小的缓存(fixed size cache) 以自动触发 torch.compile。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.float16, device_map="auto")
input = tokenizer("The French Bread Law states", return_tensors="pt").to(model.device)
output = model.generate(**input, do_sample=False, max_new_tokens=20, cache_implementation="static")
tokenizer.batch_decode(output, skip_special_tokens=True)[0]避免在 generate() 之外调用
torch.compile(model),以防止模型在每一步都重新编译。
Attention 后端
替代的 注意力后端(attention backends) 可以降低内存流量。例如,FlashAttention 通过对注意力计算进行分块(tiles)并避免产生大型中间张量,从而减少内存占用。
在 from_pretrained() 中设置 attn_implementation 以加载优化的注意力后端。
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", attn_implementation="flash_attention_2")Kernels
算子(Kernels)通过融合操作来提高吞吐量并减少内存使用。Kernels 库能够以灵活且版本安全的方式,从 Hub 加载优化的计算算子。
下面的示例在不安装软件包的情况下加载了一个优化的 FlashAttention-2 算子。
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)量化
量化(Quantization) 缩小了每个参数的大小,从而降低了内存占用并提高了速度,因为您可以执行更多的操作。
在 from_pretrained() 的 quantization_config 参数中传递量化配置。每个量化后端具有不同的配置和参数。下面的示例使用 bitsandbytes 后端将模型量化为 4 位,并配置计算精度(dtype)。
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
"allenai/Olmo-3-7B-Think", quantization_config=bnb_config
)缓存
缓存(Caching) 通过重用过去的键(keys)和值(values),而不是为每个 token 重新计算,从而加快生成速度。为了抵消并减少存储过去键值对的内存成本,Transformers 支持将缓存卸载(offloading)到 CPU。只有当前层保留在 GPU 上。
使用 generate() 中的 cache_implementation 参数来设置缓存策略。
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B", attn_implementation="kernels-community/flash-attn2"
)
inputs = tokenizer("The Le Décret Pain states that a baguette must,", return_tensors="pt")
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=50, cache_implementation="offloaded")并行化 (Parallelism)
并行化 将模型分布在多个设备上,使得单个设备无法容纳的大模型也能快速运行。由于分片开销和同步结果的通信,这种方法会使用更多内存。
张量并行 (Tensor parallelism) 将模型的某一层拆分到多个设备上。在 from_pretrained() 中设置 tp_plan="auto" 即可启用。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", tp_plan="auto")
print(model._tp_plan)连续批处理 (Continuous batching)
连续批处理 (Continuous batching) 通过动态调度和分块预填充(chunked prefill)使 GPU 保持繁忙,从而最大限度地提高吞吐量。推理服务 (Serving) 应用程序使用它来并发处理多个传入请求。
使用 generate_batch() 来启用连续批处理。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
attn_implementation="paged|sdpa",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
prompts = [
"The Le Décret Pain states that a baguette must",
"Explain gravity in one sentence.",
"Name the capital of France.",
]
inputs = [tokenizer.encode(p) for p in prompts]
generation_config = GenerationConfig(
max_new_tokens=32,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=False,
max_batch_tokens=512,
)
outputs = model.generate_batch(
inputs=inputs,
generation_config=generation_config,
)
for request_id, output in outputs.items():
text = tokenizer.decode(output.generated_tokens, skip_special_tokens=True)
print(f"[{request_id}] {text}")