为 AMD MI300 创建自定义核函数

发布于 2025 年 7 月 9 日
在 GitHub 上更新

AMD 核函数

Title card

导言

每天超过十亿次:这是对 ChatGPT 每日处理请求数量的保守估计,而且这个数字短期内不太可能下降。对于每个请求和每个生成的词元 (token),我们都会对一个拥有数十亿参数的模型进行一次推理。这就是为什么模型优化在每个层面上都至关重要:当处理如此巨大的规模时,即使是 1% 的延迟或功耗提升也能带来巨大的成本节约。

但是,这些提升能从何而来?模型架构已经相当成熟,流行的模型也早已实现了权重化。然而,还有一个关键层面可以优化模型推理:核函数 (kernel) 层面。核函数是你在网络中执行任何操作时运行的算法:有矩阵乘法核函数、卷积核函数、批量归一化核函数等。核函数是低级的、高度优化的算法,通常是为它们将要运行的设备量身定制的。它们编写起来 notoriously 长且困难,并且需要对 GPU 的内部工作原理有很好的理解。

核函数对于在神经网络中运行操作至关重要——没有核函数,一个操作实际上就无法使用。因此,新的创新产品通常会推出一个“day 0”核函数,该核函数通常只为最新的 Nvidia 硬件优化。这种方法排除了许多其他设备,特别是 AMD GPU,尽管它们提供相当甚至更优的规格,却常常被核函数开发者忽视。Hugging Face 与 AMD 合作,在 AMD 平台上提供最先进的性能,并让开源社区受益。作为这次合作的一部分,我们与 AMD 决定专注于提供开源的优化核函数,以提升在 8 个 MI300X 节点上使用 VLLM 以 FP8 格式服务 Llama 3.1 405B 的性能。

在这篇博客文章中,我们将探讨我们如何为 MI300X 优化性能,以及每个核函数是如何被单独微调的。但首先,让我们看看使用我们的自定义核函数所实现的性能提升。通过结合以下三个优化核函数:

  • 融合残差连接、RMS 范数和 FP8 转换的核函数
  • 融合 SwiGLU 激活和 FP8 转换的核函数
  • Skinny GEMM 核函数

我们在由 MI300X GPU 驱动的节点上运行 VLLM 时,实现了显著的加速。

Latency gains

测量是在输入大小为 1、输出大小为 128 的情况下进行的,以模拟解码模式。我们使用 30 次迭代的中位数来测量解码延迟。

这些性能提升是在 VLLM 中测量的,但你也可以单独使用这些核函数,具体方法见下文的“如何使用”部分。

如何使用这些核函数

hf-rocm-kernels 仓库

前面描述的所有核函数都可以在 hf-rocm-kernels 仓库中找到,地址在这里。在该仓库中,你会找到如何安装该包的说明、每个核函数的源代码、它们各自的 Python 绑定、各种基准测试脚本和一个测试套件。使用基准测试脚本和 MI300X,你甚至可以复现这篇博客文章中的结果。为了确保 Torch 或 VLLM 的结果一致,你可以使用与我们相同的容器。你也可以将该仓库作为基础来构建自己的核函数:它包含了如何将一个 CUDA 风格的核函数绑定到 Python 的说明和一个简单的示例核函数。你甚至可以查看正在开发中的分支,以了解新的核函数,比如这里描述的计算与通信核函数。

在 VLLM 中的集成

所描述的核函数很快将被集成到 VLLM 项目的 AMD 分支中,但如果你想自己看看如何实现类似的功能,可以查看这个分支和这份文档

优化过程

我们首先将快速回顾一下我们正在使用的设备架构:MI300X。然后,我们将看看在优化之前模型推理的状态。这将帮助我们识别瓶颈,并确定需要编写哪些自定义核函数。接着,我们将逐一审视我们编写的每个核函数,这将为我们提供一个从多个角度探讨核函数优化如何进行的机会。

MI300X 简介

在我们深入优化 GPU 代码之前,我们需要了解 GPU 是如何工作的。已经有很多资源对 GPU 的内部工作原理做了很好的解释,我将链接放在这里这里这里。我们仍然会快速过一遍 GPU 的不同层次,作为一个简短的回顾。如果你想跳过回顾,直接进入我们自定义核函数的细节,请点击这里

线程 (Threads)

GPU 中最小的工作单元是线程 (thread)。GPU 上完成的任何工作都是因为一个线程执行了一条指令。指令是基本操作,如加法、乘法、从一种数据类型到另一种的转换,或加载和存储。每个线程都有自己的内存,称为寄存器 (registers,或 VGPRs),只有它自己可以访问。一个线程最多可以有 256 个寄存器,每个寄存器 32 位宽。下面是一个线程及其可访问的 256 个 VGPRs 的示意图。

Representation of a thread

除了使用加载或存储指令外,线程只能在自己的寄存器上执行指令。例如,要将两个向量 A 和 B 相加,每个线程将:1) 将 A 中的一个元素加载到其寄存器中,2) 将 B 中的另一个元素加载到寄存器中,然后 3) 执行加法并将结果存储在另一个寄存器中,最后 4) 将该寄存器中的值存储到内存中。总共是 4 条指令。

线程束 (Warps)

下一个工作单元是线程束 (warp):每个线程束由 64 个线程组成。线程束没有自己的内存,但它们对我们很重要,因为一个线程束中的所有线程必须在同一时间执行相同的指令。这既是一种保证,也是一种约束。

Representation of a warp

线程束还允许不同的线程与同一线程束中的其他线程交换来自其寄存器的信息。尽管一个线程束中的不同线程可以访问不同的数据,但它们都必须执行相同的指令,这意味着在编写核函数时,你需要考虑的是线程束级别的行为。

计算单元 (Compute units)

线程束被捆绑成线程块 (thread blocks):线程块是软件抽象,但运行在称为计算单元 (CU) 的硬件组件上。一个计算单元可以同时运行多个线程块,但它最多只能容纳 16 个线程束。每个计算单元都有一个专用的 L1 缓存和共享内存。L1 缓存无法控制或分配,它有助于位于该 CU 上的所有线程束的数据重用。相反,共享内存可以被分配和用作所有线程束共享的存储空间。例如,当我们希望一个计算单元中的所有线程束(因此也是所有线程)访问同一个缓冲区时,我们会在共享内存中分配它。共享内存和 L1 缓存的访问速度都很快,因为它们“靠近”线程。

Representation of a compute unit

线程块还提供了同步其内部所有正在运行的线程的能力:这在处理影响共享内存的操作时非常有用,比如在共享内存中将一个数组初始化为零或进行归约操作。总的来说,在编写核函数时,线程块是需要考虑的最高级别:很难同步不同的线程块或让它们以任何方式进行交互。核函数的吞吐量与 GPU 上存在的计算单元数量紧密相关:CU 越多,可以同时运行的线程块就越多,如果你能充分利用所有 CU,吞吐量就会增加。

XCD

计算单元随后被分组为加速器复合晶片 (XCD),每个 XCD 包含 38 个计算单元。尽管 CU 之间可能无法直接交互,但它们都共享一个 L2 缓存,你无法控制这个缓存,但在重用数据时它可能非常有用。例如,在访问内存时,让位于同一个 XCD 上的两个计算单元访问相同的数据将大大减少加载延迟。L2 缓存相当大:大小为 4MB,而共享内存大小为 64kB,L1 缓存包含 32kB。

Representation of a XCD

整个 GPU (MI300X)

通过组装 8 个 XCD(这给了我们 8 * 38 = 304 个 CU),并增加最后一级缓存(称为 infinity cache,大小为 256MB)和大量的显存(192GB),我们就得到了 MI300X。

Representation of a MI300

所有的 XCD,因此所有的线程,都可以访问显存 (VRAM),但到达那里的速度相当慢。随着你离线程级别越远,内存访问速度变得越慢,但其大小和作用范围也越大,意味着它服务于更多的线程。在优化核函数时,总需要在执行大量操作和加载大量数据之间取得平衡,但总的来说,你应该尽可能少地访问显存(通常称为全局内存)。

当看这张图时,我们可以理解为什么 GPU 被称为“大规模并行”:在这里,我们有 304 个计算单元,每个计算单元可以运行 16 个线程束,每个线程束有 64 个线程。这意味着我们最多可以同时运行 311,296 个线程,每个线程执行自己的指令。请记住,一条指令是像加法这样的基本操作,所以像牛顿法这样的简单例程对于单个线程来说可能运行时间很长。GPU 并非为指令快速运行而优化,即不是为了降低每条指令的延迟:那是延迟导向的设备。它们被优化为让许多线程一起运行,消耗和输出大量数据:它是一个吞吐量导向的设备。在为 GPU 优化核函数时,我们相应地进行调整:最好是让一个算法在许多线程上同时运行几条指令,而不是让它在少数线程上运行许多指令。因此,将在 GPU 上运行的算法称为“并行”的。

有三件事可能会阻碍此类算法以优化方式运行:当需要加载大量数据时(内存受限)、当需要执行许多操作时(计算受限)或当线程必须协同工作时(同步开销)。

Day 0 性能分析

在优化工作负载时,写下第一行代码之前要做的第一件事就是对当前工作负载的状态进行性能分析。在我们的案例中,我们将对 VLLM 中的模型推理进行性能分析,以了解每个操作占用的时间。这有助于识别主要瓶颈以及我们可以首先处理哪些核函数以获得最大加速。例如,以下是批量大小为 32 时的分解图:

Disk plot ok kernels latency

我们可以通过每个切片看到网络的不同部分:

  • “Attention*”切片,我们将 RoPE、注意力 (attention) 和 KV 缓存核函数分组在一起;
  • “Attention GEMMs”,包括两个投影,QKV 和 Output;
  • “Communications”,由两个 all-reduce 操作组成,一个在 Attention 块之后,一个在 MLP 块之后,它们的存在是因为我们正在进行张量并行(TP8)工作;
  • “MLP GEMMs”,包括在 MLP 中进行的两个投影,Gate / Up 和 Down;
  • “RMS norm”和“SwiGLU”切片,每个核函数一个——请注意,RMS norm 核函数每个块被调用两次,一次在 Attention 之前,一次在 MLP 之前;
  • “Other”切片,重新组合了我们没有标记为更大类别的核函数,因为它们的影响较小。

我们已经可以看到,大部分延迟来自 GEMM 和通信,但注意力及其周围的操作对延迟的贡献并不大。这可能有点令人惊讶,因为许多论文都关注于注意力并降低其成本,但似乎通过 KV 缓存和 FlashAttention 的结合(VLLM 中已经进行了优化),这部分可能不再是首要任务。令人惊讶的是,对“RMS norm”核函数的两次调用成本相当高,因此优化该核函数可能会带来很大的好处。连同 SwiGLU 核函数,它们占总延迟的 15%,这是不可忽视的。总而言之,我们最好的行动方案可能是致力于这两个核函数,并尝试在 GEMM 上获得 небольшое 加速。为了确认这种性能分解不是偶然现象,我们可以看看其他批量大小:

Latency distribution over batch sizes

我们可以看到,在批量大小为 32 时出现的模式在其他批量大小下也成立,尽管随着批量大小的增加,GEMM 和通信的延迟贡献变得更大。此外,批量大小为 32 在 GEMM 的延迟方面似乎是一个异常值:这可能是因为当批量大小为 32 时选择的 GEMM 经过了手动调整,或者因为批量大小为 32 呈现出良好的内存对齐模式,所以批量大小为 32 的 GEMM 比批量大小为 24 或 28 的更快。

现在我们已经确定了一些需要优化的热点,让我们来看看我们编写的第一个核函数:RMS norm 核函数。


RMS norm 核函数

在每个解码器块中,我们有两个主要部分:一个注意力块和一个 MLP 块。两者都以两个输入之间的残差连接开始:当前隐藏状态 x x 和残差 r r 。它们具有相同的形状,即 n n 行(与词元数量相同)和 d d 列。将它们相加后,我们对 x x 应用逐行的均方根 (RMS) 范数,并且由于模型采用 FP8,我们使用一个缩放因子 s s x x 量化为 FP8。仅仅将这三个操作融合到一个核函数中就可以带来不错的性能提升。在数学上,我们需要执行的操作如下:

i+j+kxx+rrxV=i=1dxi2xxV+ϵxQ=Qfp8(sxw) \begin{align} \phantom{i + j + k} &\begin{aligned} x &\leftarrow x + r\\ r &\leftarrow x \end{aligned}\\ &\begin{aligned} V &= \sum_{i=1}^{d} x_i^2 \end{aligned}\\ &\begin{aligned} x &\leftarrow \frac{x}{\sqrt{V + \epsilon}} \\ x_Q &= Q_{\text{fp8}} \left( s * x * w\right) \end{aligned} \end{align}

其中 w w 是一个大小为 d d 的权重向量。步骤 (1) 和 (3) 都非常基础。对于步骤 (1),我们只需将每个线程定位到张量中的不同位置,加载 x x r r 的一些元素,将它们相加并存回 r r 。对于步骤 (3),每个线程执行一些标量操作(加法、平方根、除法)和一次到 FP8 的转换。所有这些,每个线程都可以独立完成:这完全符合 GPU 的并行特性。需要注意的步骤是 (2):我们需要对 d d 进行求和,这意味着要么每个线程将访问 d d 列中的每一列,要么我们需要在线程之间交换数据。d d 越大,第一种方案需要加载的数据就越多,因此可行性越低。我们将选择第二种方案:在块级别同步线程,它们将使用共享内存交换数据。每个线程将独立累加 V V 的一部分,然后我们将在整个线程块中对所有这些部分求和,这就是我们所说的归约 (reduction)。由于 V V 是跨整行计算的,我们将为每一行分配一个线程块。

与开箱即用的 PyTorch 相比,这个核函数的最基本版本带来了大约 10 倍的加速。但这还不够:在此基础上我们还可以添加许多优化。

优化:内存相关

就延迟而言,成本最高的操作之一是访问显存,也称为全局内存。幸运的是,有一些易于遵循的原则可以显著降低加载数据的成本。

首先,我们可以看看单个线程在单个指令中能加载多少数据:使用 MI300X 指令指南,我们看到从全局内存进行的最大加载是 128 位宽。由于我们加载的是 FP16 数据,我们将每次加载 128b / 16b = 8 个元素。对于 FP32 元素,这将对应于每次加载 4 个元素。

其次,我们确保内存访问是合并的。由于每个线程都是线程束的一部分,当一个线程到达“加载”指令时,线程束中的所有其他线程也同时到达。为了提高效率,这些“加载”指令会在整个线程束中被捆绑在一起。然后,线程束集体获取所需的数据,每个线程得到它需要的数据。当线程束获取一个没有任何间隙的单个数据块时,就达到了最高效率:这就是我们所说的连续数据。当我们需要的加载数据量超过一次“加载”指令所能加载时,就会出现问题,如下图所示。

Two loading scenarios

在这个假设的场景中,我们在同一个线程束中有两个线程。它们需要共同加载 16 个 fp32 元素,对于哪个线程加载哪个元素没有限制。这是一个典型的“归约”情况。由于一个线程每个指令只能加载 4 个 fp32 元素,我们至少有两种读取数据的方式,如场景 (a) 和 (b) 所示。要决定哪个场景最好,我们需要从线程束的角度来看,而不是线程的角度。在场景 (a) 中,第一次加载获取元素 0,1,2,3,8,9,10,11:我们看到数据不是连续的,因为元素 3 和 8 之间有间隙。而在场景 (b) 中,第一次加载获取元素 0,1,2,3,4,5,6,7:我们加载了连续的数据。第二次加载也是如此。因此场景 (b) 更好。尽管在场景 (a) 中,每个线程最终得到 8 个连续的元素,但这并不重要:重要的是线程束是否加载了连续的数据。这很重要,因为如果线程束在一个周期内只能加载 8 个连续元素,那么场景 (a) 的每次加载都需要两个周期来处理,而在场景 (b) 中,每次加载只需要一个周期。

第三,我们减少存储次数:当我们看步骤 (1) 和 (3) 时,可以看到只需要两次存储:一次是 r r ,一次是 xQ x_Q 。在步骤 (1) 之后,我们已经可以存储 r r 并完成该操作。但我们仍然需要在步骤 (2) 完成后访问 x x 的修改版本。为此,我们可以将 x x 的修改版本存储在全局内存中,并在步骤 (2) 完成后重新加载它,并依赖于重新加载时的缓存命中。或者,如果 x x 足够小,我们可以将其修改版本存储在共享内存中:如果 x x 是 FP16 格式,并且每个 CU 只有一个线程块,那么我们每个线程块可以在共享内存中存储 64KB / 2B = 32 * 1024 个元素。在 Llama 405B 的情况下,d d 等于 16384,所以这能放得下。使用共享内存比依赖缓存命中提供了更好的加速,特别是当许多线程块同时活动时:如果 L1 缓存不够大,无法容纳整个 x x ,那么我们必须依赖 L2 缓存,而 L2 缓存是由 38 个 CU 共享的。

除了内存访问,我们还可以优化计算效率,但我们将把这部分留到下一个核函数,因为两种情况下的优化是相似的。

结果

当我们应用上述优化后,我们得到以下结果:

Latency of RMS norm kernels

行数 Torch (μs) VLLM (μs) 我们的 (μs)
1 38.8998 5.5145 4.18138
2 43.2469 5.65645 4.36976
4 41.1304 5.6893 4.37628
8 43.8883 5.72275 4.39081
16 46.8876 5.85667 4.48165
32 55.2276 6.08502 4.72017
64 75.6086 6.4629 5.54214
128 98.1122 7.49166 6.27341
256 119.727 11.8812 10.739
512 195.782 23.1595 18.5549
1024 355.42 44.8143 34.7204
2048 671.513 81.2089 73.35

输入张量为形状为 [X, 16384] 的 FP16。我们核函数的最基本版本,称为“Pointwise”,没有任何与内存相关的优化,但已经比 Torch 快了至少 4 倍。它不如 VLLM 的核函数实现,但我们的“Vectorized”实现超过了“Pointwise”和 VLLM。这是实现了合并 128 位加载的核函数版本,仅次于“Vectorized + SMEM”(SMEM 代表共享内存)实现,后者在低和高批量大小下都提供了比 VLLM 明显更好的加速比。


SwiGLU 核函数

在 MLP 块中,在我们刚才讨论的核函数之后,是一个我们之前称之为“Gate / Up”投影的投影。我们之所以这样称呼它,是因为“Gate / Up”投影实际上是两个具有相同输入的投影的拼接:“Gate”和“Up”。因此,我们将“Gate / Up”投影的结果 x x 写为 x=xGxU x = x_G | x_U ,其中 | 是沿列轴应用的拼接运算符。xG x_G xU x_U 具有相同的维度。我们需要这两个投影的原因是紧随其后的 SwiGLU 激活函数,其结果 y y 由方程 (4) 定义。SwiGLU 激活函数之后是“Down”投影,在我们的案例中是 FP8 格式,所以我们还需要如方程 (5) 所示对 y y 进行量化。

i+j+ky=σ(xG)xUyQ=QFP8(sy) \begin{align} \phantom{i + j + k}& \begin{aligned} y = \sigma \left( x_G \right) \cdot x_U \\\end{aligned}\\ &\begin{aligned} y_Q = Q_\text{FP8} \left( s * y \right) \end{aligned} \end{align}

其中 σ \sigma 是 sigmoid 函数:σ(x)=ex/(1+x) \sigma (x) = e^{-x} / (1 + x) 。我们将编写一个融合核函数 (fused kernel) 来处理所有这些操作。对于这个核函数,除了共享内存缓冲区外,为 RMS 核函数描述的优化仍然适用。这里我们将重点关注与计算相关的优化。

优化:与计算相关

我们将通过两种方式来提高核函数的速度:增加每条执行指令完成的工作量,以及使用更快的指令。

为了增加每条指令完成的工作量,我们可以使用 打包 (packed) 指令。当我们要对多个元素应用相同操作时,打包指令非常有用:我们不是对每个元素执行一条指令,而是在一个元素向量上执行一条指令。在 CPU 中,打包(或矢量化)指令是单线程优化的基础,AVX 指令集家族就是明证。GPU 上也有一些打包指令,在适当的地方它们可以非常有用。在 MI300X 上,除其他外,还有用于 FP16 加法和乘法的打包指令,我们将在两个步骤中都使用它们。还存在从 FP32 到 FP8 的打包转换,与非打包转换相比,这可以显著提升性能。事实上,除了从 FP32,没有任何其他数据类型可以转换为 FP8,因此对于 RMS norm 核函数和这个核函数,我们必须先转到 FP32 精度才能转换为 FP8。

然而,在这个核函数中这不成问题:sigmoid 函数 σ \sigma 需要我们计算一个指数,这是一个能从 FP32 精度中获益匪浅的操作。这是一个我们可以通过使用更快的指令来优化计算的例子:我们不使用 exp 指令,而是将输入乘以 log(2) \text{log}(2) 并使用 exp2 指令,这要快得多。我们只遭受几乎可以忽略不计的精度损失,但却降低了延迟。

结果

对于形状为 [X, 16384] 的 FP16 输入张量,我们得到下表

行数 1 2 4 8 16 32 64 128 256 512 1024 2048
Torch (μs) 40.2731 29.923 35.305 23.5763 22.4738 25.3445 31.5829 40.3194 53.5369 79.8037 124.873 243.202
VLLM (μs) 3.84116 3.86192 3.92937 3.94151 4.01047 4.02421 4.08943 4.20317 4.48755 7.48465 13.7389 25.4306
我们的 (μs) 1.92981 1.93904 1.93524 1.99316 2.00415 1.91563 2.04498 2.61763 3.57726 5.47608 10.0482 19.8957
加速比 (VLLM / 我们的) 1.990434291 1.991665979 2.030430334 1.977518112 2.001082753 2.100724044 1.999740829 1.605715857 1.254465708 1.366789747 1.367299616 1.278195791

通过针对 MI300X 定制的内存和计算优化,我们得到的核函数平均比 Torch 快 14 倍以上,比 VLLM 的核函数快 27% 到 100%。


瘦 GEMM 核函数

正如我们之前所见,模型推理延迟的大约 60% 来自于投影,而投影依赖于 GEMM 核函数。GEMM 核函数在 AMD 的 hipBLASLT rocBLAS 等专用库中被高度优化,因此编写一个在所有情况下都表现更好的自定义核函数相当困难。但如果我们专注于一些与我们相关的边缘情况,并为这些特定情况编写一个 GEMM 核函数,那么我们的自定义核函数就有可能比专用库中的更快。

在预填充和解码阶段,网络中任何投影的输入行数都与正在处理的 token 数量相同。而在解码期间,正在处理的 token 数量等于批处理大小。因此,在解码期间,所有 GEMM 核函数的输入行数都等于批处理大小,为了我们的目的,这个范围在 1 到 256 之间。我们将关注非常小的批处理大小。当我们有一个 GEMM AB=C A * B = C A A 的行数很少而列数很多时,我们称之为 瘦 (skinny) GEMM。我们为这种 GEMM 使用一个特定术语的原因是,它们不适合我们在 GPU 上运行的经典 GEMM 算法。通常,GEMM 核函数的效率来自于 分块 (tiling):我们将结果矩阵分成许多子矩阵,称为块 (tile),并将每个块分配给一个不同的计算单元 (CU)。如果我们有很多块,就可以使用很多 CU,GPU 使用率就会很高。下图对此进行了说明。

Classic GEMM dimensions

但是如果输入 A A 的行数非常少,那么只能形成少数几个块,这导致只有少数计算单元处于活动状态,因此 GPU 利用率很低。

Skinny GEMM dimensions

瘦 GEMM 对 GPU 来说是天生不便的。在下一部分,我们将看到如何通过一个假设我们处于瘦 GEMM 上下文中的自定义核函数,使它们变得更方便。

优化:split-K

由于瘦 GEMM 的主要问题是我们使用的计算单元太少,所以我们首先要做的就是找出一种方法来使用更多的计算单元。为此,我们可以利用以下这个令人拍案叫绝的公式:

cij=k=1Kaikbkj=(k=1K/2aikbkj)+(k=1+K/2Kaikbkj) c_{ij} = \sum_{k=1}^K a_{ik} b_{kj} = \left( \sum_{k=1}^{K/2} a_{ik} b_{kj} \right) + \left( \sum_{k=1+K/2}^{K} a_{ik} b_{kj} \right)

借助和的结合律,我们可以沿着共享轴(通常称为 K 轴)拆分主 GEMM,并用几个并发执行的子 GEMM 替换一个 GEMM。每个子 GEMM 将使用与主 GEMM 一样多的 CU,因此使用的 CU 数量将乘以我们拆分 K 轴的次数。下图对此进行了说明。

Split-K algorithm

在这里,我们将 split-K 设置为 2,从而使一次性使用的 CU 数量增加了一倍。由于我们得到的是部分结果,我们需要在两个子 GEMM 都完成后将它们相加。可能看起来违反直觉的是,我们增加了一个操作——对部分结果求和,但我们声称这减少了整个过程的延迟。但由于每个 CU 都需要遍历整个 K 轴来计算结果,因为我们将其一分为二,所以每个 CU 完成的工作量也减少了一半。如果以这种方式节省的工作量能够抵消对最终结果求和所增加的工作量,那么我们就能实现整体优化。只要 K 很大且原始 GEMM 使用的 GPU 不到 50%,这通常是成立的。

优化:移除填充

如果我们假设通过 split-K,大多数计算单元都在忙于处理自己的块,我们就可以将优化范围集中在计算单元级别。我们将看一下实际的矩阵乘法是如何完成的,以及我们如何加速它。

在像 MI300X 这样的顶级 GPU 中,矩阵乘法由一个称为张量核心 (tensor core) 的专用硬件单元处理。张量核心只执行矩阵乘法,但速度非常快。张量核心指令的格式是 mfma_MxNxK...,其中 mfma 代表矩阵融合乘加 (matrix fused multiply-add),M 是左侧矩阵的行数,N 是右侧矩阵的列数,K 是两者的共享维度。我们在下面展示一个假设的指令 mfma_2x2x4

MFMA dense version

张量核心指令只有少数几种,但对于任何三元组 MxNxK,使用专用的张量核心指令都比任何其他替代方案快得多。张量核心指令还有两种类型:“密集 (dense)” 和 “稀疏 (sparse)”。密集指令对应于标准矩阵乘法。稀疏指令假设左侧矩阵 A A 具有 4:2 结构化稀疏模式,这意味着沿矩阵 K 轴每 4 个元素中就有两个是零。在数学上,对于任何 i,j i, j 使得 ai,4j+3 a_{i, 4j+3} A A 的一个元素,我们在 (ai,4j,ai,4j+1,ai,4j+2,ai,4j+3) \left( a_{i,4j}, a_{i,4j+1}, a_{i,4j+2}, a_{i,4j+3} \right) 中至少有两个零。下面是一个稀疏矩阵的例子。

A 4:2 sparse matrix

让我们回到我们的模型,FP8 精度的 Llama 405B。对于 FP8,我们只有两个密集张量核心指令:16x16x3232x32x16。我们还有一个大小为 16x16x64 的稀疏指令。对于一个有 8 行的输入,即使使用最小的密集指令 16x16x32 也意味着我们必须为输入添加 8 行填充,这是对计算资源的浪费。人们可能会想,我们是否可以改用稀疏指令:毕竟,如果一个 16 行矩阵的一半是 4:2 稀疏的,我们可以用一个密集的 8 行矩阵完全描述其非零系数。反之,如果我们有一个 8 行的密集矩阵,我们可以将其所有数据放入一个具有 4:2 稀疏性的 16 行矩阵中。而使用稀疏指令的好处是显而易见的:密集指令的 K=32,而稀疏指令的 K=64。在相同的周期数内,稀疏指令的深度是原来的两倍。我们在下图中用一个 1 行输入和 2x2x4 密集指令及其稀疏的 2x2x8 对应指令来说明这个稀疏技巧。

Using sparsity for skinny inputs

利用这个技巧,我们可以显著加快任何行数小于等于 8 的输入的 GEMM 速度,这导致任何批处理请求数少于 8 的解码的每 token 延迟降低。

优化:Warp 专用化和异步执行

我们已经看到,在瘦 GEMM 中,行数少的事实限制了输出块的数量,这反过来又限制了 GPU 的利用率。但行数少也限制了每个输出块的行数,这反过来又减少了我们所说的 算术强度 (arithmetic intensity)。简单地说,算术强度是完成的工作量除以为完成该工作而加载的数据量。让我们比较两个例子

sn=i=1nxitn=i=1nyi=y (1+tn1) s_n = \sum_{i=1}^{n} x_i \\ t_n = \sum_{i=1}^n y^i = y ~( 1 + t_{n-1})

其中 x x 是一个大小为 n n 的向量,而 y y 是一个标量。要计算 sn s_n ,我们加载 n n 个元素并执行 n1 n-1 次加法。要计算 tn t_n ,我们加载 1 个元素并执行 2n1 2n-1 次加法和乘法。因此,计算 sn s_n 的“算术强度”是 n1n \frac{n-1}{n} tn t_n 的是 2n1 2n - 1 :计算 tn t_n 比计算 sn s_n “算术强度”更高。我们在这里看到的是,当 算术强度越低,我们需要加载更多数据来执行工作

这对我们来说为什么重要?嗯,我们已经看到从 VRAM 加载数据有很高的延迟成本,这对 GPU 来说不是好事。换句话说,算术强度低的工作负载不适合 GPU,而事实证明,瘦 GEMM 的算术强度比它们的非瘦对应物要低。当看下面的图时,这一点变得直观:我们可以看到,当我们将加载的数据量减半时,由于 GEMM 维度的二次性质,输出系数的数量减少了四倍。

The arithmetic intensity of two GEMMs

在瘦 GEMM 中,输出块的行数是有限的,因此算术强度也是有限的。这已经意味着我们需要加载大量数据来计算一个输出块。此外,由于我们使用的是 FP8 算术,计算速度相当快,所以我们不能依靠计算时间来隐藏数据加载的延迟。总而言之,理想情况是让负责加载数据的线程多于负责计算结果的线程。

为了实现这一点,我们将使用一种称为 warp 专用化 (warp specialization) 的技术。我们不再让线程块中的所有 warp 执行相同的指令,而是将一些 warp 专门用于仅加载数据,另一些专门用于仅计算结果。负责加载数据的 warp 称为 生产者 (producers),计算结果的 warp 称为 消费者 (consumers)。生产者和消费者异步工作:生产者首先从 VRAM 加载数据(这很慢),然后通过将其存储在共享内存缓冲区中使其对消费者可用。在数据在共享内存中可用之前,消费者是空闲的。数据可用后,消费者从共享内存加载数据(这很快)并计算结果。生产者和消费者的协调是通过存储在共享内存中的队列来实现的。当生产者完成在共享内存缓冲区 i i 中存储数据时,它会改变队列的第 i i 个变量的状态,以表示数据在那里可用。消费者正在监视这一点,然后开始加载数据。当它完成后,它会改变队列的第 i i 个变量的状态,以表示数据可以被写入缓冲区 i i 。在下图中,我们展示了一个简单的异步 GEMM 中涉及的步骤,其中有一个生产者、一个消费者和一个大小为 2 的队列。

Async GEMM mechanism

使整个过程奏效的是,一旦缓冲区 0 0 被生产者填充,它就可以开始处理缓冲区 1 1 ,而无需等待消费者从缓冲区 0 0 加载数据。目标是拥有一个足够大的队列,以便生产者不断填充缓冲区,而消费者不断消费它们。队列的大小受共享内存大小的限制。

我们还需要调整生产者与消费者的比例:我们说过我们的算术强度低,所以我们需要加载大量数据来做一个相对快速的计算。因此,我们将有大量的生产者 warp(通常是 8 或 10 个)对应少数消费者 warp(比如 2 或 3 个)。此外,我们可以利用 GEMM 是瘦的事实,为输入(瘦矩阵)和权重(非瘦矩阵)设置不同的生产者。为了使输出块在不受约束的维度(即列维度)上更大,我们为权重分配更多的生产者。

关于异步 GEMM 更深入的博客文章,我鼓励您查看这篇博客文章。不过,其中的许多内容在我们的情况下不适用:MI300X 没有 warp 级别的屏障,只有一个线程块级别的屏障。这导致了一些“有趣”的把戏,比如使用 ASM 来确保 warp 在其屏障处等待,共享内存加载和存储在检查屏障状态之前得到解决,以及对队列的模块化特性进行仔细处理。所有这些在这里都会显得不合时宜,但我鼓励您查看代码或在评论中提问。未来可能会有关于异步处理细节的深入探讨。

通过 warp 专用化和异步工作,我们可以使我们的核函数适应低算术强度的负载,但这是否足以超越像 hipBLASLT 这样的库?答案是肯定的,在某些情况下。

结果

由于 Torch 已经绑定了取自 AMD 线性代数库的高度优化的 GEMM,我们不会得到与最后两个核函数相同范围的加速。我们首先将看一下我们感兴趣的三个 GEMM 维度:即与 QKV 投影、Gate / Up 投影和 Down 投影相关的 GEMM 维度。输出投影被排除在外,因为它的维度不符合瘦 GEMM 的情况。

M (行) N (列) K (深度) Torch 时间 (μs) SkG 时间 (μs) 加速比
1 2304 16384 14.938 ± 0.292 11.685 ± 0.299 127.84 %
8 2304 16384 16.300 ± 0.282 12.342 ± 0.375 132.07 %
16 2304 16384 16.693 ± 0.233 13.909 ± 0.295 120.02 %
32 2304 16384 16.817 ± 0.124 17.021 ± 0.133 98.80 %
1 13312 16384 77.636 ± 0.364 54.717 ± 0.628 141.88 %
8 13312 16384 80.031 ± 0.449 58.355 ± 0.612 137.15 %
16 13312 16384 75.236 ± 0.378 59.973 ± 1.922 125.45 %
32 13312 16384 82.198 ± 0.590 69.483 ± 1.672 118.30 %
1 16384 6656 31.066 ± 0.193 27.613 ± 0.218 112.51 %
8 16384 6656 31.559 ± 0.200 28.134 ± 0.209 112.17 %
16 16384 6656 31.671 ± 0.250 30.233 ± 0.267 104.76 %
32 16384 6656 35.561 ± 0.335 35.052 ± 1.365 101.45 %

测量是在 500 次预热迭代后,在 2000 次性能分析迭代中进行的,使用 CUDA graph 和多个权重以避免缓存命中。上面显示的 GEMM 维度按顺序对应 QKV 投影 (N = 2304 和 K = 16384)、Gate / Up 投影 (N = 13312 和 K = 16384) 和 Down 投影 (N = 16384 和 K = 6656)。我们可以看到,对于那些经过调整的维度,在行数较少 (M = 1, 8, 16) 的情况下有显著的加速,但在行数较多 (M = 32) 的情况下则不那么明显。特别是在我们可以使用稀疏技巧的维度 (M = 1, 8) 中,我们看到了比 Torch 显著的加速,Torch 可能将所有内容都填充到 16 行以使用最小的 MFMA 指令。

结论

在这篇文章中,我们只探讨了众多可用核函数优化技术中的一小部分。如果您有兴趣尝试它们,请随时深入 hf-rocm-kernels 仓库并开始动手实验!如果您开发了自己喜欢的核函数并希望分发它,请务必查看 kernel-builderkernels — 这两个 Hugging Face 软件包旨在帮助核函数构建者将其工作广泛提供并产生更大影响。

社区

干得漂亮!对于 Skinny GEMM,您有分别来自稀疏 mfma 和 warp 专用化 (WS) 的改进百分比数据吗?想了解在没有稀疏 mfma 的情况下,WS 对不同形状的影响。

·

看了代码,我的理解是稀疏性只用于 M = 8。这正确吗?

注册登录 以发表评论