Transformers 文档

Attention Interface

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

注意力接口

本页描述如何使用 `AttentionInterface` 来注册自定义注意力函数,以用于支持的模型。

自定义注意力函数

得益于一个简单的映射,大多数最新模型现在可以从注意力层中使用的注意力函数切换到另一个注意力函数。默认情况下,我们提供了 `sdpa`、`flash_attention_2` 和 `flex_attention` 的实现,以及 `eager`,它是一个简单的矩阵乘法,没有任何优化。
这是您在实例化模型时通常可以选择的设置

from transformers import AutoModelForCausalLM

model_id = "meta-llama/Llama-3.2-1B"

# Here, using flash attention as an example
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="flash_attention_2")

但是,如果您想创建自己的注意力函数呢?或者只是尝试现有的函数,在其中添加一些语句?现在,您可以使用 `AttentionInterface` 来实现!这是一个示例

from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch

model_id = "meta-llama/Llama-3.2-1B"

def my_new_sdpa(*args, **kwargs):
    print("I just entered the attention computation")
    return sdpa_attention_forward(*args, **kwargs)

AttentionInterface.register("my_new_sdpa", my_new_sdpa)

model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="my_new_sdpa")
# Try running the forward with the new attention function
model(torch.ones(1, 5, dtype=int))

您将看到它打印“I just entered the attention computation”,打印次数与模型中的层数一样多(在此示例中为 16 次)。

动态切换注意力函数

您还可以通过覆盖 `config._attn_implementation` 字段来动态更改模型的注意力函数

# Back to use original sdpa implementation
model.config._attn_implementation = "sdpa"

model(torch.ones(1, 5, dtype=int))

它将停止打印语句,因为它现在使用 `sdpa` 注意力。
这允许快速更改注意力函数,而无需重新加载模型!

我的自定义注意力函数中需要新参数怎么办?

但确实,如果新函数需要新参数才能正常使用怎么办?这不是问题!支持 `AttentionInterface` 的模型会将 kwargs 一直传播到注意力层和所使用的注意力函数。这样,您只需在模型的 forward 中传递参数(作为 kwargs,即您需要限定参数名称),它就会在注意力中正确使用。但是,自定义注意力函数有一些限制。特别是,它必须遵循其他注意力函数的签名和返回格式,即

from transformers import AutoModelForCausalLM, AttentionInterface
from transformers.integrations.sdpa_attention import sdpa_attention_forward
import torch

def custom_attention(
    module: torch.nn.Module,  # required arg
    query: torch.Tensor,  # required arg
    key: torch.Tensor,  # required arg
    value: torch.Tensor,  # required arg
    attention_mask: Optional[torch.Tensor],  # required arg
    a_new_kwargs = None,  # You can now add as many kwargs as you need
    another_new_kwargs = None,  # You can now add as many kwargs as you need
    **kwargs,  # You need to accept **kwargs as models will pass other args
) -> tuple[torch.Tensor, Optional[torch.Tensor]]
    ...  # do your magic!
    return attn_output, attn_weights  # attn_weights are optional here

AttentionInterface.register("custom", custom_attention)

model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
# Forward pass with the new kwargs
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)

如果对给定模型发送给注意力函数的参数/关键字参数有疑问,只需查看该模型在 GitHub 上的建模代码!

访问当前可用实现

大多数情况下,您只需 `注册` 一个新函数。但是,如果您需要访问现有函数,和/或执行一些检查,首选的方式是使用全局 `ALL_ATTENTION_FUNCTIONS`。它的行为方式与您期望的普通 Python 字典相同

>>> from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

>>> list(ALL_ATTENTION_FUNCTIONS.keys())
>>> ['flash_attention_2', 'flex_attention', 'sdpa']

>>> ALL_ATTENTION_FUNCTIONS["sdpa"]
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>

>>> ALL_ATTENTION_FUNCTIONS.get("sdpa", None)
>>> <function transformers.integrations.sdpa_attention.sdpa_attention_forward>

# You can also globally `register` a new function directly on it
>>> ALL_ATTENTION_FUNCTIONS.register("new_func", new_func)

注意力掩码接口

拥有一个新的注意力函数可能意味着您需要一种新的注意力掩码格式来决定查询令牌应该关注哪些键和值令牌。现在,通过 `AttentionMaskInterface` 可以实现这一点!它的工作方式与 `AttentionInterface` 相同。

from transformers import AttentionMaskInterface
from transformers.masking_utils import sdpa_mask
import torch

def my_new_sdpa_mask(*args, **kwargs):
    print("I just entered the attention mask computation")
    return sdpa_mask(*args, **kwargs)

AttentionMaskInterface.register("my_new_sdpa_mask", my_new_sdpa_mask)

您必须注册它的原因是,我们需要根据注意力实现自动更正您的掩码格式(例如,flex attention 使用 BlockMask 格式,而 sdpa 使用 4D 张量)。默认情况下,如果您没有注册注意力掩码函数以及您的注意力函数,将跳过掩码创建,并且 `attention_mask=None` 将传递给注意力层。

注意力掩码函数的默认签名如下:

def custom_attention_mask(
    batch_size: int,  # required arg
    cache_position: torch.Tensor,  # required arg
    kv_length: int,  # required arg
    kv_offset: int = 0,  # required arg
    mask_function: Callable = causal_mask_function,  # required arg
    attention_mask: Optional[torch.Tensor] = None,  # required arg
    **kwargs,  # a few additional args may be passed as kwargs, especially the model's config is always passed
) -> Optional[torch.Tensor]:

它主要通过 `mask_function` 实现,这是一个 `Callable`,形式类似于 torch 的 `mask_mod` 函数,它接受 4 个索引作为输入,并返回一个布尔值以指示该位置是否应参与注意力计算。

如果由于某种原因无法使用 `mask_function` 创建掩码,您可以尝试通过类似于我们的 torch 导出变通方法 来解决。

< > 在 GitHub 上更新