一次失败的实验:Infini-Attention,以及我们为何应继续尝试?
TLDR:Infini-attention 的性能随着内存压缩次数的增加而变差,据我们所知,环注意力(ring attention)、YaRN 和 rope scaling 仍然是扩展预训练模型到更长上下文长度的最佳方法。
第 0 节:引言
语言模型的上下文长度是其核心属性之一,与模型性能并列。自上下文学习兴起以来,向模型输入添加相关信息变得越来越重要。因此,上下文长度迅速从段落(BERT/GPT-1 的 512 个 token)增加到页面(GPT-2 和 GPT-3 分别为 1024/2048 个 token),再到书籍(Claude 的 128k),直至书籍集合(Gemini 的 1-10M 个 token)。然而,将标准注意力扩展到如此长的上下文长度仍然具有挑战性。
环注意力(Ring Attention)简介:环注意力最初由加州大学伯克利分校的研究人员于 2024 年提出(据我们所知)[链接]。这项工程技术通过分块执行自注意力和前馈网络计算,并将序列维度分布到多个设备上,从而实现并发计算和通信,有助于克服内存限制。
即使使用环注意力,以批量大小为 1、100 万 token 上下文长度训练 Llama 3 8B 仍需要 512 个 GPU。正如缩放定律所表明的,模型大小与其下游性能之间存在强相关性,这意味着模型越大越好(当然,两个模型都应该经过良好训练)。因此,我们不仅需要 100 万的上下文长度,还需要在最大的模型(例如 Llama 3 8B 405B)上实现 100 万的上下文长度。而目前只有少数公司拥有这样做的资源。
自注意力内存复杂性回顾:在标准注意力(非 Flash Attention)中,每个 token 都会关注序列中的所有其他 token,导致注意力矩阵的大小为 [seq_len, seq_len]。对于每对 token,我们都会计算一个注意力分数,并且随着序列长度 (seq_len) 的增加,内存和计算需求呈二次方增长:注意力矩阵的内存复杂度为 O(seq_len^2)。例如,序列长度增加 10 倍会导致内存需求增加 100 倍。即使是内存效率高的注意力方法,如 Flash Attention,其内存需求也随上下文长度线性增加,并受限于单个 GPU 内存,导致当今 GPU 上的典型最大上下文长度远低于 1M token。
受此启发,我们探索了标准注意力的另一种方法:Infini-attention。该论文由 Google 的研究人员于 2024 年 4 月发布[链接]。Infini-attention 不计算每个词之间的注意力分数,而是将序列分成更小的固定大小的“片段”,将较早的片段压缩到固定缓冲区中,并允许下一个片段从较早的片段中检索内存,同时将注意力分数限制在当前片段的词中。一个关键的优势是其固定的缓冲区大小限制了总内存使用量。它还在一个片段中使用相同的查询来访问其自身片段和压缩内存中的信息,这使得我们能够廉价地扩展预训练模型的上下文长度。理论上,我们可以实现无限上下文长度,因为它只为所有较早片段的内存保留一个缓冲区。然而,实际上,压缩限制了可以有效存储的信息量,因此问题是:这种压缩后的内存有多大的可用性?
虽然在纸面上理解一种新方法相对容易,但实际使其工作通常是另一回事,而且这个故事很少公开分享。受此启发,我们决定分享我们在重现 Infini-attention 论文方面的实验和历程,是什么激励我们在调试过程中(我们花了 90% 的时间调试收敛问题),以及让这些事情工作起来有多么困难。
随着 Llama 3 8B 的发布(其上下文长度限制为 8k token),我们试图将其长度扩展到 100 万 token,而无需二次方增加内存。在这篇博客文章中,我们将首先解释 Infini-attention 的工作原理。然后,我们将概述我们的复现原则,并描述我们最初的小规模实验。我们将讨论我们面临的挑战,我们如何解决这些挑战,并总结我们的发现和我们探索过的其他想法。如果您有兴趣测试我们训练好的检查点[链接],您可以在以下存储库中找到它[链接](请注意,我们目前按原样提供代码)。
第 1 节:复现原则
我们发现在实现新方法时,以下规则很有帮助,并将其作为我们许多工作的指导原则:
- 原则 1: 从能提供良好信号的最小模型尺寸开始,一旦获得良好信号,再扩大实验规模。
- 原则 2: 始终训练一个可靠的基线模型来衡量进展。
- 原则 3: 为了确定修改是否能提高性能,训练两个模型,除了被测试的修改之外,其他设置均相同。
牢记这些原则,让我们深入了解 Infini-attention 的实际工作原理。理解其机制对我们后续实验至关重要。
第 2 节:Infini-attention 的工作原理
步骤 1:将输入序列分割成更小的、固定大小的块,称为“片段”。
步骤 2:在每个片段内计算标准的因果点积注意力。
步骤 3:使用当前片段的查询向量从压缩内存中提取相关信息。检索过程的数学定义如下:
- :从内存中检索到的内容,表示长期上下文。
- :查询矩阵,其中 是查询数量, 是每个查询的维度。
- :来自前一个片段的内存矩阵,存储键值对。
- :一个非线性激活函数,具体为逐元素的指数线性单元(ELU)加 1。
- :一个归一化项。
import torch.nn.functional as F
from torch import einsum
from einops import rearrange
def _retrieve_from_memory(query_states, prev_memory, prev_normalization):
...
sigma_query_states = F.elu(query_states) + 1
retrieved_memory = einsum(
sigma_query_states,
prev_memory,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k d_v -> batch_size n_heads seq_len d_v",
)
denominator = einsum(
sigma_query_states,
prev_normalization,
"batch_size n_heads seq_len d_head, batch_size n_heads d_head -> batch_size n_heads seq_len",
)
denominator = rearrange(
denominator,
"batch_size n_heads seq_len -> batch_size n_heads seq_len 1",
)
# NOTE: because normalization is the sum of all the keys, so each word should have the same normalization
retrieved_memory = retrieved_memory / denominator
return retrieved_memory
步骤 4:将局部上下文(来自当前片段)与长期上下文(从压缩内存中检索)结合,生成最终输出。通过这种方式,注意力输出可以同时考虑短期和长期上下文。
- :合并后的注意力输出。
- :一个可学习的标量参数,用于控制长期记忆内容 与局部上下文之间的权衡。
- :使用点积注意力从当前片段获得的注意力输出。
步骤 5:通过添加当前片段的键值状态来更新压缩内存,从而允许我们随着时间的推移累积上下文。
- :当前片段的更新内存矩阵,包含了新信息。
- :当前片段的键矩阵,表示要存储的新键。
- :当前片段的值矩阵,表示与键关联的新值。
- :键矩阵中的第 个键向量。
- :当前片段的更新归一化项。
import torch
def _update_memory(prev_memory, prev_normalization, key_states, value_states):
...
sigma_key_states = F.elu(key_states) + 1
if prev_memory is None or prev_normalization is None:
new_value_states = value_states
else:
numerator = einsum(
sigma_key_states,
prev_memory,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k d_v -> batch_size n_heads seq_len d_v",
)
denominator = einsum(
sigma_key_states,
prev_normalization,
"batch_size n_heads seq_len d_k, batch_size n_heads d_k -> batch_size n_heads seq_len",
)
denominator = rearrange(
denominator,
"batch_size n_heads seq_len -> batch_size n_heads seq_len 1",
)
prev_v = numerator / denominator
new_value_states = value_states - prev_v
memory = torch.matmul(sigma_key_states.transpose(-2, -1), new_value_states)
normalization = reduce(
sigma_key_states,
"batch_size n_heads seq_len d_head -> batch_size n_heads d_head",
reduction="sum",
...
)
memory += prev_memory if prev_memory is not None else 0
normalization += prev_normalization if prev_normalization is not None else 0
return memory, normalization
- 步骤 6:当我们从一个片段移动到下一个片段时,我们丢弃前一个片段的注意力状态,并将更新后的压缩内存传递给下一个片段。
def forward(...):
...
outputs = []
global_weights = F.sigmoid(self.balance_factors)
...
local_weights = 1 - global_weights
memory = None
normalization = None
for segment_hidden_state, segment_sequence_mask in zip(segment_hidden_states, segment_sequence_masks):
attn_outputs = self.forward_with_hidden_states(
hidden_states=segment_hidden_state, sequence_mask=segment_sequence_mask, return_qkv_states=True
)
local_attn_outputs = attn_outputs["attention_output"]
query_states, key_states, value_states = attn_outputs["qkv_states_without_pe"]
q_bs = query_states.shape[0]
q_length = query_states.shape[2]
...
retrieved_memory = _retrieve_from_memory(
query_states, prev_memory=memory, prev_normalization=normalization
)
attention_output = global_weights * retrieved_memory + local_weights * local_attn_outputs
...
output = o_proj(attention_output)
memory, normalization = _update_memory(memory, normalization, key_states, value_states)
outputs.append(output)
outputs = torch.cat(outputs, dim=1) # concat along sequence dimension
...
既然我们已经掌握了理论,是时候卷起袖子,进行一些实际的实验了。让我们从小处着手,以便快速获得反馈并快速迭代。
第 3 节:小规模初步实验
Llama 3 8B 相当大,因此我们决定从一个 200M 的 Llama 模型开始,使用 Nanotron [链接] 和 Fineweb 数据集 [链接] 从头开始预训练 Infini-attention。一旦我们获得了 200M 模型的良好结果,我们便开始对 Llama 3 8B 进行持续预训练。我们使用 200 万 token 的批量大小,256 的上下文长度,梯度裁剪为 1,权重衰减为 0.1,前 5,000 次迭代为线性预热,其余步骤为余弦衰减,学习率为 3e-5。
使用通行码检索任务进行评估
通行码检索任务最初由 EPFL 的研究人员引入[链接]。它旨在评估模型从长上下文(信息位置可控)中检索信息的能力。提示模型的输入格式结构如下:
有重要信息隐藏在大量无关文本中。找到并记住它们。我将考你关于那里的重要信息。草是绿色的。天空是蓝色的。太阳是黄色的。我们开始吧。来来回回。(重复 x 次)通行码是 9054。记住它。9054 是通行码。草是绿色的。天空是蓝色的。太阳是黄色的。我们开始吧。来来回回。(重复 y 次)通行码是什么?通行码是
如果模型的输出包含“针”(在上述情况下为“9054”),则我们认为模型在此任务中成功;如果模型输出不包含,则认为不成功。在我们的实验中,我们将针放置在上下文中的不同位置,具体是总上下文长度的 0%、5%、10%、...、95% 和 100%(0% 是离生成 token 最远的位置)。例如,如果上下文长度为 1024 个 token,将针放置在 10% 意味着它位于大约第 102 个 token 的位置。在每个深度位置,我们用 10 个不同的样本测试模型,并计算平均成功率。
初步结果
以下是一些 200M 小模型的初步结果:
如你所见,它在一定程度上起作用。如果你查看样本生成,你会发现 Infini-attention 生成的内容与早期片段相关。
由于 Infini-attention 通过以第一个片段的全部内容为条件来预测第二个片段的第一个 token(它将第一个 token 生成为“_grad”),这提供了一个良好的信号。为了验证该信号是否是假阳性,我们假设 Infini-attention 生成与其早期片段相关的内容,因为当给定“_grad”作为第二个片段的第一个生成 token 时,它始终生成与 PyTorch 相关的教程,而这些教程恰好与它之前的片段相关。因此,我们进行了一个健全性测试,其中唯一的输入 token 是“_grad”,它生成了[文本在此]。这表明它确实使用了内存,但使用得不够好(无法检索精确的针或继续其早期片段的精确内容)。生成结果如下:
_graduate_education.html
Graduate Education
The Department of Physics and Astronomy offers a program leading to the Master of Science degree in physics. The program is designed to provide students with a broad background in
根据这些结果,模型似乎确实使用了压缩内存。我们决定通过持续预训练 Llama 3 8B 来扩大实验规模。不幸的是,当针被放置在较早的片段中时,模型未能通过针评估。
我们决定检查所有层中的平衡因子(平衡压缩内存和未压缩内存的因子)。根据图 3a 和图 3b,我们发现大约 95% 的权重集中在 0.5 左右。回想一下,权重是否能收敛到理想范围取决于两个一般因素:步长和梯度的幅度。然而,Adam 将梯度归一化为 1 的幅度,所以问题变成了:训练超参数是否正确,以便微调能够收敛?
第 4 节:研究收敛性?
我们决定模拟在梯度处于良好范围(L2 范数为 0.01)时平衡权重在训练期间会改变多少,发现根据上一次 8B LLaMA3 微调实验的配置,权重的绝对总变化将为 0.03。由于我们将平衡因子初始化为 0(在这种情况下无关紧要),因此最终权重将在 [0 - 0.03, 0 + 0.03] = [-0.03, 0.03] 范围内。
对于 Infini-attention 的良好工作,一个合理的猜测是全局权重如论文中所示在 0 和 1 范围内分散。鉴于上述权重,sigmoid([-0.03, 0.03]) = tensor([0.4992, 0.5008])(这与我们之前的实验结果相符,即平衡因子约为 0.5)。我们决定下一步对平衡因子使用更高的学习率(所有其他参数使用 Llama 3 8B 的学习率),并增加训练步数,以允许平衡因子至少改变 4,这样我们可以让全局权重在梯度下降需要时达到理想权重(sigmoid(-4) ≈ 0, sigmoid(4) ≈ 1)。
我们还注意到,由于梯度并不总是朝同一个方向,因此会出现抵消。这意味着我们应该将学习率和训练步数设定得显著大于总绝对变化。回想一下,Llama 3 8B 的学习率为 3.0x10^-4,这意味着如果我们将其用作全局学习率,门控功能将无法收敛。
结论:我们决定采用 3.0x10^-4 的全局学习率和 0.01 的门控学习率,这应该能使门控函数收敛。
在这些超参数下,Infini-attention 中的平衡因子是可训练的,但我们观察到 200M llama 的损失在 20B token 后变为 NaN(我们尝试了从 0.001 到 1.0e-6 的学习率)。我们检查了 20B token 检查点(10k 训练步)的一些生成结果,您可以在图 4a 中看到。模型现在继续生成精确的内容并召回身份(如果内存被清除,它会生成垃圾)。
但它仍然无法从一个片段中召回另一个片段中的“针”(它在片段内可靠地完成)。当“针”放置在第一个片段中时,针评估完全失败(当放置在第二个片段中时,总共两个片段,成功率为 100%)。如图 4b 所示,我们还观察到平衡因子在 5,000 步后停止变化。虽然我们取得了一些进展,但我们尚未完全摆脱困境。平衡因子仍然没有按我们希望的方式表现。我们决定深入挖掘并进行更多调整。
第 5 节:平衡因子无权重衰减
再次仔细检查平衡因子,我们看到了一些进展:大约 95% 的头部现在显示全局权重在 0.4 到 0.5 之间,并且没有一个头部的全局权重大于 0.6。但是权重仍然不在理想范围内。
我们想到了另一个潜在原因:权重衰减,它会促使平衡因子 L2 范数较小,导致 sigmoid 值收敛到接近零,并且因子集中在 0.5 附近。
另一个潜在原因是我们的“rollout”设置太小。在 200M 实验中,我们只使用了 4 个“rollout”,而在 8B 实验中,我们只使用了 2 个“rollout”(8192**2)。使用更大的“rollout”应该会促使模型更好地压缩和使用内存。因此,我们决定将“rollout”数量增加到 16,并且不使用权重衰减。我们将上下文长度缩小到 1024,并使用 16 个“rollout”,从而获得 64 的片段长度。
如您所见,全局权重现在分布在 0 到 1 的范围内,其中 10% 的头部全局权重在 0.9 到 1.0 之间,尽管在 18k 步之后,大多数头部停止了其全局权重的变化。我们现在非常有信心,如果梯度下降的“精神”与我们同在,实验设置将允许收敛。唯一剩下的问题是 Infini-attention 的总体方法是否能够很好地工作。
以下评估在 1.5B token 下运行。
- 0-短:在提示 2 中,它回忆了一个人学习的地方(昨天的 8b 模型在这方面失败了),但在针通行码方面失败了(尚未全面运行;将运行)。
- 1-短
- 提示 3:它识别了一个人的位置。
- 提示 4:它通过了通行码检测。
在这种情况下,模型会继续生成与早期片段完全相同的内容。(在我们之前的实验中,模型未能继续生成早期片段的精确内容,而只是生成了大致相关的内容;因此,新模型已经好很多了。)
第 6 节:结论
不幸的是,尽管取得了这些进展,我们发现在我们的实验中 Infini-attention 并没有足够的说服力,尤其是不够可靠。在我们的复现阶段,我们仍然认为环注意力(Ring Attention)[链接]、YaRN [链接] 和 rope scaling [链接] 是将预训练模型扩展到更长上下文长度的更好选择。
对于超大型模型(例如 400B 及以上),这些技术仍然需要大量的资源。因此,我们仍然认为探索压缩技术或继续推进我们在本博客文章中描述的系列实验对社区来说具有巨大的兴趣,我们很高兴能关注并尝试可能开发出来的新技术,以克服当前工作的一些限制。
总结
- 训练神经网络的意义:提供好的数据,设置好的架构和训练以接收良好的梯度信号,并使其收敛。
- Infini-attention 的长上下文性能随着内存压缩次数的增加而下降。
- 门控很重要;调整训练以使门控收敛可以提高 Infini-attention 的长上下文性能(但还不够好)。
- 始终训练一个好的参考模型作为基线来衡量进展。
- 还有一个错误会搞乱注意力输出的维度,导致即使训练过程中损失不断下降,模型仍然无法在其片段长度内生成连贯的文本。得到的教训是:即使你对模型进行了糟糕的条件设置,梯度下降仍然可以找到降低损失的方法。然而,模型不会按预期工作,所以务必进行评估。
致谢
感谢 Leandro von Werra 和 Thomas Wolf 在项目中的指导,以及 Tsendsuren Munkhdalai 分享了原始实验的更多细节。我们还要感谢 Leandro 对博客文章的反馈,并感谢 Hugging Face 的科学集群提供的计算资源。