Transformers 文档
注意力后端
并获得增强的文档体验
开始使用
Attention 后端
所有注意力实现执行的计算都是相同的。每个 token 都会与其他所有 token 进行比较。区别在于计算的执行方式。基础注意力机制的扩展性较差,因为它会将完整的注意力矩阵存入内存,从而产生瓶颈,降低推理速度。优化后的实现通过调整数学运算方式来减少内存流量,从而实现更快、更经济的推理。
AttentionInterface 提供了优化后的注意力实现。它将注意力实现与模型实现解耦,从而简化了对不同函数的实验。通过这种一致的接口,可以轻松添加新的后端。
| 注意力后端 | description |
|---|---|
"flash_attention_3" | 对 FlashAttention-2 进行了改进,通过重叠操作并更紧密地融合前向和后向传递来提升性能 |
"flash_attention_2" | 将计算分块(tiling)为更小的区块,并使用快速的片上内存 |
"flex_attention" | 无需手动编写低级内核,即可指定自定义注意力模式(稀疏、块局部、滑动窗口)的框架 |
"sdpa" | PyTorch 内置的 缩放点积注意力 (scaled dot product attention) 实现 |
“paged|flash_attention_3” | FlashAttention-3 的分页版本 |
“paged|flash_attention_2” | FlashAttention-2 的分页版本 |
“paged|sdpa” | SDPA 的分页版本 |
“paged|eager” | eager 模式的分页版本 |
设置注意力后端
在 from_pretrained() 中使用 attn_implementation 参数来实例化具有特定注意力函数的模型。
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_2"
)使用 set_attn_implementation() 在运行时切换注意力后端,无需重新加载模型。
model.set_attn_implementation("sdpa")Kernels
使用 Kernels 库在运行时直接从 Hub 下载并加载编译好的计算内核。这避免了因 PyTorch 或 CUDA 版本不匹配导致的打包问题。
内核在检测到后会自动注册到 AttentionInterface。您无需显式安装 FlashAttention 包。
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="kernels-community/flash-attn2"
)SDPA 上下文管理器
PyTorch 的缩放点积注意力 (SDPA) 会自动为 CUDA 后端选择最快的注意力函数。对于其他后端,它默认为 PyTorch C++ 实现。
使用 torch.nn.attention.sdpa_kernel 上下文管理器强制 SDPA 使用特定的实现。
import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B", attn_implementation="sdpa"
)
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
outputs = model.generate(**inputs)特定骨干网络(Backbone)的注意力机制
多模态模型为每种模态使用不同的骨干网络。通过为每个骨干网络分配特定的注意力函数来优化性能。例如,某些视觉骨干网络在 fp32 下性能更好,而 FlashAttention 并不支持 fp32。
使用字典将视觉骨干网络映射到不同的注意力函数,同时文本骨干网络继续使用 FlashAttention。注意力实现中的键必须与子配置名称匹配。
from transformers import AutoModelForImageTextToText
attention_implementation_per_backbone = {"vision_config": "sdpa", "text_config": "flash_attention_2"}
for key in attention_implementation_per_backbone:
assert key in model.config.sub_configs, f"Invalid key in `attention_implementation`"
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation=attention_implementation_per_backbone
)从字典中省略某些骨干网络以使用默认的注意力函数 (SDPA)。
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation={"text_config": "flash_attention_2"}
)使用单个字符串为所有骨干网络设置相同的注意力函数。
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation="eager"
)使用空键全局设置注意力函数。
model = AutoModelForImageTextToText.from_pretrained(
"facebook/chameleon-7b", attn_implementation={"": "eager"}
)创建新的注意力函数
通过使用 AttentionInterface.register() 将它们添加到注意力注册表中,可以自定义或创建新的注意力函数。模型通过 attn_implementation 参数使用这些函数。
在注册自定义注意力函数时,请同时注册一个匹配的注意力掩码函数。如果自定义的attn_implementation名称未在 AttentionMaskInterface 中注册,Transformers 将跳过掩码创建并将attention_mask=None传递给注意力层。您的注意力函数必须自行处理因果、填充、打包或滑动窗口约束,否则这些约束可能会被静默丢弃。
此示例自定义了注意力函数,以在每一层打印一条语句。它通过注册 masking_utils.sdpa_mask 作为注意力掩码函数,保持了原始实现中的掩码逻辑。
import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask
def my_new_sdpa(*args, **kwargs):
print("I just entered the attention computation")
return sdpa_attention_forward(*args, **kwargs)
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
AttentionMaskInterface.register("my_new_sdpa", sdpa_mask) # must have the same name as the registered attention function
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
model(torch.ones(1, 5, dtype=int))您还可以向注意力函数添加新参数。支持 AttentionInterface 的模型会将 kwargs 传播到注意力层和注意力函数。在模型的前向传递函数中以 kwargs 的形式传递参数。自定义注意力函数必须遵循此签名和返回格式。
import torch
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
from transformers.masking_utils import sdpa_mask
def custom_attention(
module: torch.nn.Module, # required arg
query: torch.Tensor, # required arg
key: torch.Tensor, # required arg
value: torch.Tensor, # required arg
attention_mask: Optional[torch.Tensor], # required arg
a_new_kwargs = None, # You can now add as many kwargs as you need
another_new_kwargs = None, # You can now add as many kwargs as you need
**kwargs, # You need to accept **kwargs as models will pass other args
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
... # do your magic!
return attn_output, attn_weights # attn_weights are optional here
AttentionInterface.register("custom", custom_attention)
AttentionMaskInterface.register("custom", sdpa_mask) # to leave the existing mask untouched
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)查看模型的 建模代码,以确认它发送给注意力函数的参数和 kwargs。
AttentionMaskInterface
AttentionMaskInterface 是一个注册表,create_*_mask 函数会查询该注册表,以将掩码转换为活动注意力后端所期望的格式。FlexAttention 需要 BlockMask,SDPA 需要 4D 张量,而 FlashAttention 需要基础 2D 填充掩码。使用 AttentionMaskInterface.register() 注册自定义后端或覆盖现有后端的格式化程序。
import torch
from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
def my_new_sdpa_mask(*args, **kwargs):
print("I just entered the attention mask computation")
return sdpa_mask(*args, **kwargs)
AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)如果活动 attn_implementation 没有已注册的格式化程序,将跳过掩码创建,并将 attention_mask=None 传递给注意力层。
注册的函数必须匹配此签名。
def custom_attention_mask(
batch_size: int, # required arg
q_length: int, # required arg
kv_length: int, # required arg
q_offset: int = 0, # required arg
kv_offset: int = 0, # required arg
mask_function: Callable = causal_mask_function, # required arg
attention_mask: Optional[torch.Tensor] = None, # required arg
**kwargs, # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:mask_function 参数是一个模拟 PyTorch mask_mod 函数的 Callable。它接受 4 个索引 (batch_idx, head_idx, q_idx, kv_idx),并返回一个布尔值,指示该位置是否参与注意力计算。这与 构建注意力掩码 中 or_mask_function 和 and_mask_function 使用的原始形状相同。
如果
mask_function无法创建掩码,请使用此 变通方法 进行 torch.export。
构建注意力掩码
使用 transformers.masking_utils 中的 create_*_mask 函数来构建注意力掩码。每个函数都会从模型配置中读取活动的注意力后端,在 AttentionMaskInterface 中查找该后端的掩码格式化程序,并返回该后端所期望的格式。您无需自行反转、扩展或转换掩码类型。
选择与注意力模式匹配的函数。
| 函数 | 使用场景 |
|---|---|
create_causal_mask | 仅解码器(decoder-only)模型,其中每个 token 仅关注自身和之前的 token |
create_bidirectional_mask | 编码器模型,或从解码器到编码器状态的交叉注意力 |
create_sliding_window_causal_mask | 具有滑动窗口注意力模式的解码器模型 |
create_chunked_causal_mask | 将序列分块为固定大小区块的解码器模型 |
create_bidirectional_sliding_window_mask | 具有滑动窗口注意力模式的编码器模型 |
旧版可调用掩码辅助函数(
get_extended_attention_mask,create_extended_attention_mask_for_decoder,invert_attention_mask)会触发弃用警告,并将在未来版本中删除。请改用create_*_mask函数。
在解码器的前向传递中调用 create_causal_mask。传递配置、输入嵌入、用户提供的 2D attention_mask 以及缓存。该函数使用嵌入来读取批次大小、查询长度、数据类型和设备,并使用缓存来计算键长度。
from transformers.masking_utils import create_causal_mask
attention_mask = create_causal_mask(
config=self.config,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=past_key_values,
)使用 or_mask_function 和 and_mask_function 参数在基础掩码之上添加额外约束。使用 or_mask_function 允许额外的位置参与注意力计算,使用 and_mask_function 进一步限制基础模式。两者都遵循 AttentionMaskInterface 中描述的 4 索引 mask_function 签名。它们接受 (batch_idx, head_idx, q_idx, kv_idx) 并返回一个布尔值。
or_mask_function和and_mask_function可以表示任何注意力模式,但它们比内置模式慢,且与 ExecuTorch 不兼容。这种开销在较小的模型(约 2 亿参数)上最为明显,因为掩码创建在前向传递时间中占比较大。仅在标准的create_*_mask函数无法满足需求时才使用它们。
例如,在因果掩码上叠加一个在任何地方都返回 True 的函数,将其转变为完全的双向掩码。与因果模式的并集使每个 token 都可以关注到其他所有 token。
mask_kwargs = {
"config": self.config,
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
"or_mask_function": lambda *args: torch.tensor(True, dtype=torch.bool),
}
attention_mask = create_causal_mask(**mask_kwargs)在生成过程中,generate() 通过 create_masks_for_generate 构建掩码,该函数根据模型配置分发到正确的 create_*_mask。您可以在模型类上重写它,以插入自定义的生成掩码策略。
双向注意力
仅解码器模型默认使用因果(单向)注意力,其中每个 token 仅关注自身和之前的 token。设置 is_causal=False 可切换到双向注意力,其中每个 token 都可以关注到其他所有 token。这使您可以将仅解码器模型用作文本编码器,例如用于生成嵌入。
这仅适用于因果(解码器)模型。它不会将编码器模型转变为解码器模型。
在模型配置中设置 is_causal=False,使双向注意力成为每次前向传递的默认设置。
from transformers import AutoModel, AutoConfig
config = AutoConfig.from_pretrained("meta-llama/Llama-3.2-1B")
config.is_causal = False
model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", config=config)
# all forward passes now use bidirectional attention
outputs = model(**inputs)在前向调用中传递 is_causal 而不是在模型配置中设置,可以在不重新加载模型的情况下切换因果和双向注意力。该 kwarg 会临时覆盖配置,并在调用后恢复。
from transformers import AutoModel
model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B")
# run with bidirectional attention
outputs = model(**inputs, is_causal=False)
# run with default causal attention
outputs = model(**inputs)