Nyströmformer:通过 Nyström 方法在线性时间和内存中逼近自注意力机制
引言
Transformer 在各种自然语言处理和计算机视觉任务上都表现出了卓越的性能。其成功可归功于自注意力机制,该机制能捕捉输入中所有词元(token)之间的成对交互。然而,标准的自注意力机制的时间和内存复杂度为 (其中 是输入序列的长度),这使得在长输入序列上进行训练的成本很高。
Nyströmformer 是众多高效 Transformer 模型之一,它以 的复杂度逼近了标准自注意力机制。Nyströmformer 在各种下游 NLP 和 CV 任务上展现出有竞争力的性能,同时提升了标准自注意力机制的效率。这篇博文旨在向读者概述 Nyström 方法以及如何将其应用于逼近自注意力机制。
用于矩阵逼近的 Nyström 方法
Nyströmformer 的核心是用于矩阵逼近的 Nyström 方法。它允许我们通过采样矩阵的部分行和列来逼近整个矩阵。让我们考虑一个矩阵 ,完整计算该矩阵的成本很高。因此,我们转而使用 Nyström 方法对其进行逼近。我们首先从 中采样 行和 列。然后我们可以将采样出的行和列排列如下:
现在我们有四个子矩阵: 和 ,它们的尺寸分别为 和 。采样的 列包含在 和 中,而采样的 行包含在 和 中。因此, 和 的元素是已知的,我们将估计 。根据 Nyström 方法, 由下式给出:
这里, 表示 Moore-Penrose 逆(或伪逆)。因此, 的 Nyström 逼近 可以写为:
如第二行所示, 可以表示为三个矩阵的乘积。这样做的原因稍后会变得清晰。
我们能用 Nyström 方法逼近自注意力机制吗?
我们的最终目标是逼近标准自注意力机制中的 softmax 矩阵:S = softmax
这里, 和 分别表示查询(queries)和键(keys)。按照上面讨论的程序,我们会从 中采样 行和 列,形成四个子矩阵,并得到 。
但是,从 中采样一列意味着什么呢?这意味着我们从每一行中选择一个元素。回想一下 S 是如何计算的:最后一步是逐行进行 softmax。要计算一行中的单个元素,我们必须访问该行的所有其他元素(用于 softmax 的分母)。因此,采样一列需要我们知道矩阵中的所有其他列。所以,我们无法直接应用 Nyström 方法来逼近 softmax 矩阵。
如何调整 Nyström 方法以逼近自注意力机制?
作者提出,不从 中采样,而是从查询和键中采样地标点(landmarks)(或 Nyström 点)。我们将查询地标点和键地标点表示为 和 。 和 可用于构建三个矩阵,这些矩阵对应于 的 Nyström 逼近中的矩阵。我们定义以下矩阵:
、 和 和 。我们用定义的新矩阵替换 的 Nyström 逼近中的三个矩阵,得到一个替代的 Nyström 逼近:
这就是自注意力机制中 softmax 矩阵的 Nyström 逼近。我们将这个矩阵与值()相乘,得到自注意力的线性逼近。请注意,我们从未计算乘积 ,从而避免了 的复杂度。
如何选择地标点(landmarks)?
作者提出,不从 和 中采样 行,而是使用分段均值来构建 和 。在这个过程中, 个词元被分成 个段,并计算每个段的均值。理想情况下, 远小于 。根据论文中的实验,即使对于长序列( 或 ),仅选择 或 个地标点就能产生与标准自注意力机制及其他高效注意力机制相媲美的性能。
论文中的下图总结了整个算法:
上图中的三个橙色矩阵对应于我们使用键和查询地标点构建的三个矩阵。另外,请注意有一个 DConv 框。这对应于使用一维深度卷积向值(values)添加的跳跃连接(skip connection)。
Nyströmformer 是如何实现的?
Nyströmformer 的原始实现可以在这里找到,HuggingFace 的实现可以在这里找到。让我们看一下 HuggingFace 实现中的几行代码(添加了一些注释)。请注意,为了简化,省略了一些细节,如归一化、注意力掩码和深度卷积。
key_layer = self.transpose_for_scores(self.key(hidden_states)) # K
value_layer = self.transpose_for_scores(self.value(hidden_states)) # V
query_layer = self.transpose_for_scores(mixed_query_layer) # Q
q_landmarks = query_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2) # \tilde{Q}
k_landmarks = key_layer.reshape(
-1,
self.num_attention_heads,
self.num_landmarks,
self.seq_len // self.num_landmarks,
self.attention_head_size,
).mean(dim=-2) # \tilde{K}
kernel_1 = torch.nn.functional.softmax(torch.matmul(query_layer, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{F}
kernel_2 = torch.nn.functional.softmax(torch.matmul(q_landmarks, k_landmarks.transpose(-1, -2)), dim=-1) # \tilde{A} before pseudo-inverse
attention_scores = torch.matmul(q_landmarks, key_layer.transpose(-1, -2)) # \tilde{B} before softmax
kernel_3 = nn.functional.softmax(attention_scores, dim=-1) # \tilde{B}
attention_probs = torch.matmul(kernel_1, self.iterative_inv(kernel_2)) # \tilde{F} * \tilde{A}
new_value_layer = torch.matmul(kernel_3, value_layer) # \tilde{B} * V
context_layer = torch.matmul(attention_probs, new_value_layer) # \tilde{F} * \tilde{A} * \tilde{B} * V
在 HuggingFace 中使用 Nyströmformer
用于掩码语言建模(MLM)的 Nyströmformer 已在 HuggingFace 上提供。目前有 4 个检查点,对应不同的序列长度:nystromformer-512
、nystromformer-1024
、nystromformer-2048
和 nystromformer-4096
。地标点的数量 可以通过 NystromformerConfig
中的 num_landmarks
参数来控制。让我们看一个 Nyströmformer 用于 MLM 的最小示例:
from transformers import AutoTokenizer, NystromformerForMaskedLM
import torch
tokenizer = AutoTokenizer.from_pretrained("uw-madison/nystromformer-512")
model = NystromformerForMaskedLM.from_pretrained("uw-madison/nystromformer-512")
inputs = tokenizer("Paris is the [MASK] of France.", return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
# retrieve index of [MASK]
mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
tokenizer.decode(predicted_token_id)
Output:
----------------------------------------------------------------------------------------------------
capital
另外,我们也可以使用 pipeline API(它为我们处理了所有复杂性):
from transformers import pipeline
unmasker = pipeline('fill-mask', model='uw-madison/nystromformer-512')
unmasker("Paris is the [MASK] of France.")
Output:
----------------------------------------------------------------------------------------------------
[{'score': 0.829957902431488,
'token': 1030,
'token_str': 'capital',
'sequence': 'paris is the capital of france.'},
{'score': 0.022157637402415276,
'token': 16081,
'token_str': 'birthplace',
'sequence': 'paris is the birthplace of france.'},
{'score': 0.01904447190463543,
'token': 197,
'token_str': 'name',
'sequence': 'paris is the name of france.'},
{'score': 0.017583081498742104,
'token': 1107,
'token_str': 'kingdom',
'sequence': 'paris is the kingdom of france.'},
{'score': 0.005948934704065323,
'token': 148,
'token_str': 'city',
'sequence': 'paris is the city of france.'}]
结论
Nyströmformer 为标准自注意力机制提供了一种高效的逼近方法,同时其性能优于其他线性自注意力方案。在这篇博文中,我们概要地介绍了 Nyström 方法以及如何将其用于自注意力机制。有兴趣在下游任务中部署或微调 Nyströmformer 的读者可以在这里找到 HuggingFace 的文档。