MInference 1.0:用单张 GPU 实现快 10 倍的百万上下文推理

社区文章 发布于 2024年7月11日

本文介绍 MInference 1.0:一种基于动态稀疏注意力的预填充加速方法。该方法为每个注意力头寻找最佳稀疏模式,根据输入提示动态构建稀疏索引,最后对长上下文注意力进行稀疏计算。

在单张 A100 GPU 上,MInference 1.0 可实现高达 10 倍的预填充加速,同时在众多任务上保持与全注意力相同甚至更优的准确性

MInference 1.0

长上下文 LLM 的推理瓶颈

由于多头注意力操作的二次复杂度,它是一种极其昂贵的 token 混合方法。

在 BERT 和 GPT-2 时代,由于输入序列相对较短,注意力引起的延迟通常可以接受。然而,在当前 LLM 普遍追求长文本能力的情况下,这种延迟正成为主要的瓶颈之一。

*Test hardware: single A100-80G GPU*

测试硬件:单张 A100-80G GPU

如上图 (a) 所示,当输入提示长度超过 500K 时,推理的预填充阶段(又称首个 token 时间,TTFT)耗时近 10 分钟,其中大部分时间花在注意力上。当提示长度继续增长到 1M 时,此等待时间将达到 30 分钟。

同时,我们还发现注意力操作在长上下文场景中效率极低。如图 (b) 所示,当提示长度为 128K 时,如果在注意力中只使用 Top-4K 列进行计算,可以召回超过 96% 的全局注意力分数。

这表明注意力矩阵在长文本中极其稀疏。换句话说,二次复杂度的注意力操作在接近零的元素上花费了大量不必要的计算能力,导致整个预填充阶段效率低下,这个问题在长上下文场景中尤为突出。

注意力矩阵的稀疏性已在许多工作中详细分析。如下图所示,这一特性也促成了许多高效相关的工作。

image/png

然而,这些稀疏注意力方法通常依赖于固定模式——例如,滑动窗口总是关注局部 token,而 StreamingLLM 类似的模式总是关注局部 + 初始 token。这种固定模式与注意力操作的另一个特性——动态性相矛盾。

image/png

再次看这张图,(b) 显示只计算 top-4k 列可以覆盖大部分注意力分数;但如果我们将此过程中的 top-4K 索引用于另一个提示,我们得到 (c) 中的结果——其覆盖注意力分数的能力显著下降到 83%,甚至在某些层中更低。

这表明注意力的稀疏分布是高度依赖于输入的,其稀疏分布随不同输入而变化很大。这使得固定模式难以在不同场景和输入中实现良好的泛化。

总而言之,理想的高效长上下文注意力需要同时考虑注意力的稀疏性和动态性:根据注意力输入动态估计与输入相关的稀疏掩码,然后完成注意力的稀疏计算。

MInference 1.0

MInference 1.0 旨在构建可以在 GPU 上高速计算的动态稀疏注意力。

总体而言,MInference 分析了长上下文 LLM 中注意力的稀疏分布,并提出了三种易于在 GPU 上加速的稀疏注意力模式,并为每种模式实现了高效的 GPU 内核。

*A-shape sparse, vertical-slash (v-s) sparse, and block sparse*

A 形稀疏、垂直斜线 (v-s) 稀疏和块稀疏

我们发现,在长上下文 LLM 的多头注意力矩阵中,存在明显的空间聚类,并且大多数属于以下三种模式之一。

例如,一些注意力头显示出A 形稀疏模式:注意力主要集中在局部和初始 token 上(上图 1,即 StreamingLLM 的稀疏模式)。此外,还有一种强调特殊 token 的垂直模式,与强调相对位置的对角线模式配对(垂直斜线,上图 2)。最后,还有一种块稀疏模式,它以块状聚类(上图 3)。

提出这三种模式有两个主要好处

  1. 三者的组合几乎可以覆盖所有稀疏注意力分布。
  2. 这些模式的特性允许它们使用 Flash Attention 和 PIT 方法在 GPU 上高效计算。

Comparing the three sparse patterns of MInference against the Top-K sparse attntion

MInference 的三种稀疏模式与 Top-K 稀疏注意力的比较

问题:为什么这些模式可以高效计算,而 Top-K 稀疏注意力却不能?

因为 MInference 的三种模式显示出清晰的空间聚类,GPU 内核可以使用 64×64(A 形和块稀疏头)或 64×1(v-s 头)的块高效完成稀疏操作。

另一方面,Top-K 稀疏注意力(即,只计算每个 Q 的 Top-k K 向量,如右图所示)由于其过于细粒度的稀疏分布,需要很长时间来构建稀疏索引,并且在使用块在 GPU 上进行计算时会产生大量无效操作。

MInference 指定三种稀疏模式后,将完成以下三个步骤

  1. 在给定 FLOPs 预算下,为每个注意力头找到最佳模式。
  2. 动态计算每个输入的最佳稀疏分布(例如,垂直斜线模式中垂直线的位置,或块稀疏模式中块的索引)。
  3. 根据步骤 2 中获得的稀疏索引执行注意力的稀疏计算。

为每个注意力头寻找最佳模式

search space

搜索空间

对于每个注意力头,MInference 将搜索最佳稀疏模式,即在总 FLOPs 限制下,找出哪种模式可以在该注意力头上召回尽可能多的注意力分数。

为此,MInference 提出了内核感知稀疏模式搜索。该算法基于 GPU 内核的实际计算量(内核感知),而不是注意力的稀疏率,在给定 FLOPs 预算下搜索最佳模式,具体包括以下两个步骤:

  1. 确定注意力头属于 A 形、垂直斜线还是块稀疏?
  2. 确定注意力头的最佳稀疏率(即垂直斜线模式中垂直线和对角线的数量、块稀疏模式中块的数量等)。

根据我们的观察,注意力头的稀疏模式类型是与输入无关的,即显示垂直斜线模式的注意力头在不同输入下总是显示垂直斜线模式(注意:但是,垂直线和对角线的具体分布位置仍然与输入有关)。

因此,MInference 在完成内核感知稀疏模式搜索时实际上采用了离线方法。

动态稀疏索引构建

*vs pattern: using last_q as proxy; block pattern: using pooling for dimensionality reduction*

vs 模式:使用 last_q 作为代理;块模式:使用池化进行降维

在确认稀疏模式的类型和稀疏率后,MInference 将根据输入动态构建稀疏索引:对于垂直斜线模式,此步骤是确定注意力矩阵中最重要的垂直线和对角线的位置;对于块稀疏模式,此步骤是找到整个注意力矩阵中最相关的块

然而,由于稀疏索引的构建需要引入额外开销,因此在此步骤中我们需要非常高效的算法,以避免稀疏注意力带来的延迟降低浪费在索引构建步骤上。

我们使用两种高效的估计方法分别估计垂直斜线和块稀疏模式的注意力矩阵,以获取稀疏计算所需的稀疏索引:

  1. 对于垂直斜线模式,我们发现使用 Q 矩阵末尾位置的一些 q(即算法 2 中的 last_q)可以准确获得全局垂直线和对角线位置。
  2. 对于块稀疏模式,我们对 seq_len 维度应用大小为 64 的池化层,以降低 Q 和 K 矩阵的维度,然后计算降维后的𝑄̂𝐾̂𝑇,并找到幅度最大的块。

在复杂度方面,垂直斜线模式的索引构建(算法 2)的复杂度为 O(Nd)。至于算法 3,由于使用了大小为 64 的池化层进行降维,其复杂度仅为全注意力复杂度的 1/64*64。这表明两种算法都非常高效。

稀疏注意力计算

MInference 为所提出的三种稀疏注意力模式实现了相应的 GPU 内核:对于 A 形模式和块稀疏模式,内核使用 64×64 大小的块进行计算,而对于垂直斜线模式,内核使用 64×1 大小的块进行计算。

结果

MInference 1.0 在单张 A100 上处理 500K 长度的输入提示速度提高了 6.8 倍,当提示长度为 1M 时,MInference 可以实现 10 倍的预填充加速。

*All green Needle + 10x acceleration*

所有绿色 Needle 结果 + 10 倍加速

同时,与全注意力相比,MInference 的动态稀疏注意力在众多任务上实现了与全注意力相同甚至超越的准确性。

RULER

*Better accuracy than full attention, and effective context length surpassing full attention*

比全注意力更高的准确性,并且有效上下文长度超越全注意力

“大海捞针”

*MInference achieves the same or better performance as full attention on GLM-4, Yi, Phi-3, Qwen2 and other LLMs*

MInference 在 GLM-4、Yi、Phi-3、Qwen2 等 LLM 上实现了与全注意力相同或更好的性能

将 MInference 与 vLLM 和 HF 模型一起使用

在 HuggingFace 模型和 vLLM 模型上使用 MInference 1.0 只需三行代码

# For HuggingFace models
from transformers import AutoModel
+from minference import MInferenceConfig, apply_minference

model = AutoModel.from_pretrained("meta-llama/Llama-2-7b-hf")
+config = MInferenceConfig(sparsity_ratio=0.9)
+model = apply_minference(model, config)
# For vLLM models
from vllm import LLM
+from minference import MInferenceConfig, apply_minference

llm = LLM(model="meta-llama/Llama-2-7b-hf")
+config = MInferenceConfig(sparsity_ratio=0.9)
+llm.model = apply_minference(llm.model, config)

查看更多 MInference 信息

社区

注册登录 以发表评论