🕳️ LLM 中的注意力池(Attention Sinks)实现无限流畅性

社区文章 发布于 2023 年 10 月 9 日

摘要

使用带有注意力池(attention sink)词元的窗口注意力机制,可以让预训练的聊天式大语言模型(LLM),如所有 Llama、Mistral、MPT、Falcon 和 GPT-NeoX (Pythia) 模型,在数百个连续的提示中保持流畅性,这与使用 transformers 加载这些模型时的情况不同。此外,这种方法可以实现恒定的内存使用,而大多数用 transformers 加载的 LLM 具有线性空间复杂度,会导致内存问题。

使用这种形式的注意力非常简单,只需从 attention_sinks 而不是 transformers 导入你的模型类即可。

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device_map="auto")

目录

聊天助手 LLM 的局限性

大型语言模型(LLM)席卷了整个行业,推动了聊天机器人和虚拟助手领域的发展。LLM 似乎特别擅长扮演(专业化的)个人助理角色,但它们也存在各种局限性。在这篇博文中,我们将重点关注以下两个主要限制:

  • 显存使用:许多 LLM(例如 Llama 2)在推理时存在线性空间复杂度的问题。在聊天助手场景中,这意味着设备的显存限制将制约用户持续进行顺序提示的能力。

  • 流畅性丧失:迄今为止训练的所有 LLM 在输入过长时都会丧失流畅性。当这种情况发生时,模型将失去生成语言的能力,并开始生成例如无尽的换行符、任意字符(0OgOATO0OATO)、损坏的 Unicode(���)或重复的单词(assistant: assistant: assistant: assistant:)。

    大多数 LLM 在输入长度超过其预训练长度后会出现这种行为。例如,Llama 2 7B 在超过 4096 个词元后会遇到此问题,而 Mistral-7B-v0.1 在大约 1 万个词元后会失去流畅性。

这些局限性在实践中很容易展现出来,例如让 LLM 根据之前的所有词元预测一本书的下一个词元。这种预测的平均负对数似然损失称为对数困惑度 (log perplexity),它是衡量 LLM 质量的常用指标。较低的对数困惑度对应较低的平均损失,因此越低越好。显存也很容易测量,这两个指标都在下图中绘制出来:

有关生成这些图表的脚本的更多信息,请参见困惑度

这些局限性严重阻碍了在生产环境中使用 LLM 作为聊天助手的能力。

窗口注意力

应对问题 1(显存使用)的一个简单尝试是限制输入到 LLM 的词元数量。在 transformers 的基础上构建这个功能是一个相当复杂的过程,但其要点是,每当生成一个词元时,如果当前大小超过窗口大小,past_key_values 缓存就会被缩小到窗口大小。

在我的实验中,我使用的窗口大小仅为 1024 个词元。结果如下图所示:

窗口注意力确实在生成 1024 个词元后保持了内存使用的恒定,但一旦超过这个窗口大小,对数困惑度立即飙升。这使得它和使用 transformers 加载模型一样,都是不可行的方法。

注意力池

Xiao 等人(2023)注意到,当应用窗口注意力时,即使在第一个词元被从窗口中移除后,模型也会立即失去流畅性。他们注意到自回归 LLM 的一个有趣现象:最初的几个词元占据了惊人比例的注意力分数,即使这些词元在语义上并不重要。

这种行为在下图中可视化呈现:image/png

除了最初的两层,几乎所有的注意力都集中在开头的几个词元上,作者称之为**注意力池 (attention sinks)**。直观的解释是,如果下一个要生成的词元与之前的任何词元都不匹配,Softmax 操作仍然会强制注意力分数总和为 1。因此,LLM 学会将注意力分数“卸载”到开头的几个词元上。

因此,当窗口注意力机制导致第一个词元掉出窗口时,LLM 就无法再将注意力分数卸载到该词元上。结果,注意力分数被分散到所有其他词元上,总和仍然为 1。这导致即使与待生成词元匹配度不高的词元也会意外地获得高注意力分数。其后果是:模型“崩溃”并失去流畅性。

在发现这一现象后,作者提出了一种对窗口注意力的改进,即**始终**保留序列最初的 4 个词元,也就是注意力池词元。这可以像这样可视化:image/png

此外,在向缓存词元添加位置信息时,该方法使用缓存内的位置,而不是真实文本中的位置。因此,注意力池词元总是靠近其余的词元,从而可以有效地用于卸载注意力。

举一个简单的例子,我们考虑一个窗口大小为 10 的场景,其中包括 4 个注意力池词元,文本是一个用空格分隔的字母表。生成时,模型看到的是:

A
A B
A B C
A B C D
A B C D E
A B C D E F
A B C D E F G
A B C D E F G H 
A B C D E F G H I
A B C D E F G H I J
A B C D F G H I J K
A B C D G H I J K L
A B C D H I J K L M
...

分配的位置如下:

0
0 1
0 1 2
0 1 2 3
0 1 2 3 4
0 1 2 3 4 5
0 1 2 3 4 5 6
0 1 2 3 4 5 6 7
0 1 2 3 4 5 6 7 8
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
0 1 2 3 4 5 6 7 8 9
...

简而言之,分配的位置仅取决于缓存中的位置,而不是完整文本中的位置。

注意力池 - 困惑度实验

在我使用注意力池的实验中,我调整了我的窗口注意力实现,使其包含 4 个永不离开窗口的注意力池词元,并将窗口大小保持在 1024。结果如下图所示:

Falcon-7B、MPT-7B 和 Pythia-6.9B 的结果请参见此处

结果令人瞩目:使用带有注意力池的窗口注意力的 LLM 兼具两者的优点:恒定的空间复杂度和稳定的困惑度。Xiao 等人(2023)的研究表明,困惑度在高达 400 万个词元的情况下都保持稳定,之后他们的数据就用完了(!)。

请注意,在约 8000 个词元时,注意力池方法的对数困惑度略高于(即更差)基线。这是因为注意力池仅使用 1024 个词元的窗口大小。这个窗口大小可以增加到例如 8192 个词元,这样对数困惑度将在 8000 个词元时与基线持平,*并且*将内存保持在约 14.85GB 的恒定水平。

注意力池 - 无尽生成实验

批评者认为,困惑度是衡量 LLM 质量的一个不完美指标,例如因为它并不实际要求模型生成词元。为了证明注意力池确实有效,我使用 Llama-2-7B,采用本博文中描述的三种方法生成多达 10,000 个词元:默认(例如 transformers)、windowedattention_sinks

如果一个模型开始失去流畅性,我就会终止生成。我已将每种方法的完整日志上传到我的仓库。

  • transformers完整日志:该模型在约 1900 个词元后失去流畅性,并开始无休止地生成损坏的 Unicode 字符,如 🤖🧠👨‍��������������������� ❌。
  • 窗口注意力 (window attention):完整日志:该模型在约 1000 个词元后失去流畅性,生成了数百个换行符,并夹杂着诸如 OOOMMO̶OANOOAMOO̶OMMO ❌ 之类的文本。
  • attention_sinks 完整日志:在测试的全部 1 万个词元中都保持流畅 ✅。

有关复现这些结果的脚本的更多信息,请参见无尽生成中的流畅性

注意力池 - 聊天助手实验

注意力池方法非常适合聊天式 LLM 应用,因为它比仅使用 transformers 加载模型时要流畅得多,并且占用的内存也少得多。因此,一个自然的基准测试是在常见的聊天助手场景中对各种方法进行实验。

在这个基准测试中,我将 MT-Bench 中的连续提示发送给模型,并自动检测流畅性何时丧失。这模拟了一个场景:聊天助手在同一历史记录中被提示数百次,在此期间模型必须处理数万个词元的历史记录。

如果一个响应满足以下条件,我会自动将其归类为失败:

  • 包含少于 26 个不同字符,并且
  • 长度超过 1000 个词元。

在实践中,这种启发式方法似乎能准确检测流畅性的丧失。我已将结果绘制在下图中:

有关复现这些结果的脚本的更多信息,请参见聊天式 LLM 在连续提示下的流畅性

对于 Llama-2-7b-chat,transformers 会耗尽显存,因此它只能处理少数几个连续的提示。对于 MPT-7B-chat,当输入长度超过 2048 时,transformers 会遇到一个 RuntimeError。这些图表清楚地表明,使用 attention_sinks 加载模型对模型在连续提示下的流畅性有非常积极的影响。然而,正如 Llama-2-7B-chat-hf 的情况所示,它并不能完全避免所有流畅性问题。

注意力池 - 基准测试结论

本博文中描述的基准测试,以及我在我的 attention_sinks 仓库中描述的针对 MPT、Pythia 和 Falcon 模型的附加基准测试,都清楚地表明,注意力池可以有效地应用于预训练的 LLM,以应对模型的不稳定性和流畅性丧失。这种额外的稳定性不带来任何额外成本,甚至可以实现恒定的内存使用,而不是大多数 LLM 的线性内存使用。

任何希望使用助手式 LLM 的组织或用户都应该考虑注意力池技术。

注意力池的实际应用

在最先进的研究和从业者可以合理使用的技术之间通常存在显著的差距。然而,我很高兴地说,注意力池可以以几乎零额外工作量添加到任何预训练的 LLM 中。

我发布了 attention_sinks Python 模块,它可以作为 transformers API 的直接替代品。这个 Python 模块支持所有使用 Llama、Mistral、Falcon、MPT 和 GPT-NeoX (Pythia) 架构的模型,可以像这样使用:

from attention_sinks import AutoModel

model = AutoModel.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", device_map="auto")

这将自动为模型添加一个注意力池 KV 缓存 (Attention Sink KV Cache),它能正确地将注意力池保留在窗口中。您可以使用以下参数配置此缓存:

  • attention_sink_sizeint,默认为 4:用作注意力池的初始词元数量。这些词元总是包含在注意力池 KV 缓存中。
  • attention_sink_window_sizeint,默认为 1020:滑动窗口的大小,即包含在注意力池 KV 缓存中的“最近词元”数量。较大的窗口大小会消耗更多内存。不建议将其设置得比 LLM 的上下文窗口大,因为 LLM 仍然只能处理最后的 context window 个词元。

总窗口大小将是这两个参数的总和,例如默认为 1024。

例如,加载具有更大窗口大小的 Llama-2-7B-chat 可以这样做:

from attention_sinks import AutoModel

model = AutoModel.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    device_map="auto",
    attention_sink_size=4,
    attention_sink_window_size=4092,
)

请参阅流式演示,这是一个可以执行的脚本,用于模拟向您选择的 LLM 输入数百个连续提示。(注意,您可能需要更改聊天模板)。

常见问题解答

本 FAQ 主要由 Xiao 等人(2023)撰写。

  1. 对于 LLM 来说,“处理无限长度输入”意味着什么?

    用 LLM 处理无限长度的文本带来了挑战。值得注意的是,存储所有先前的键 (Key) 和值 (Value) 状态需要大量内存,并且模型可能难以生成超出其训练序列长度的文本。注意力池模型通过仅保留最近的词元和注意力池,丢弃中间的词元来解决这个问题。这使得模型能够从最近的词元生成连贯的文本,而无需重置缓存——这是早期方法所不具备的能力。

  2. LLM 的上下文窗口是否被扩展了?

    没有。上下文窗口保持不变。只有最近的词元和注意力池被保留,中间的词元被丢弃。这意味着模型只能处理最新的词元。上下文窗口仍然受其初始预训练的限制。例如,如果 Llama-2 是用 4096 个词元的上下文窗口进行预训练的,那么基于 Llama-2 的注意力池模型的最大缓存大小仍然是 4096。

  3. 我能将一篇长文,比如一本书,输入到注意力池模型中进行摘要吗?

    虽然您可以输入一篇长文,但模型只会识别最新的词元。因此,如果输入的是一本书,注意力池模型可能只会总结结尾的段落,这可能没有太大意义。如前所述,我们既没有扩展 LLM 的上下文窗口,也没有增强它们的长期记忆。注意力池模型的优势在于能够从最近的词元生成流畅的文本,而无需刷新缓存。

  4. 注意力池模型的理想用例是什么?

    注意力池模型针对流式应用进行了优化,例如多轮对话。它非常适合模型需要持续运行,而不需要大量内存或依赖过去数据的场景。一个例子是基于 LLM 的日常助手。注意力池模型可以让模型持续运行,根据最近的对话做出回应,而无需刷新其缓存。早期的方法要么在对话长度超过训练长度时需要重置缓存(丢失最近的上下文),要么需要从最近的文本历史中重新计算 KV 状态,这可能非常耗时。

  5. 注意力池方法与最近关于上下文扩展的研究有什么关系?

    注意力池方法与最近的上下文扩展方法是正交的,并且可以与它们集成。在注意力池模型的背景下,“上下文扩展”指的是使用更大的缓存大小来存储更多最近词元的可能性。有关实际演示,请参阅论文中的图 9,其中 LongChat-7B-v1.5-32K 和 Llama-2-7B-32K-Instruct 均已适配注意力池技术。

了解更多

查看以下资源以获取有关此主题的更多信息:

引用

@article{xiao2023streamingllm,
    title={Efficient Streaming Language Models with Attention Sinks},
    author={Xiao, Guangxuan and Tian, Yuandong and Chen, Beidi and Han, Song and Lewis, Mike},
    journal={arXiv},
    year={2023}
}

社区

注册登录以发表评论