Accelerate 文档

🤗 accelerate 中的上下文并行

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

🤗 accelerate 中的上下文并行

本指南将介绍在 🤗accelerate 中使用上下文并行(context parallelism)的基础知识。对于更好奇的读者,我们将在后面的章节中介绍一些技术细节。

为何需要上下文并行?

随着大型语言模型和最近推理模型的出现,序列长度迅速增长。这与注意力机制的二次方内存复杂度相结合,导致需要更有效的方法来训练具有长序列的模型。对于 128k 的序列长度,使用 `bf16` 精度和 vanilla attention 实现,注意力矩阵的内存需求为 `128k * 128k * 2 字节 * num_heads = ~32 GB * num_heads`。当然,使用不实例化这些注意力权重的 `flash attention` 或 `SDPA`,这个数值会大幅下降,但内存需求的增长仍然相当可观。

上下文并行允许我们沿序列维度对注意力计算的输入进行分片,并在多个 GPU 上并行计算注意力。这样,我们就可以训练具有长序列的模型,并有可能扩展到 1M+ 序列长度。

如何使用上下文并行?

from accelerate.utils import ParallelismConfig, TorchContextParallelConfig

+ cp_config = TorchContextParallelConfig(
+       cp_comm_strategy="alltoall", # no need to use cp_config at all, if you want to use the default "allgather"
+ )

+ parallelism_config = ParallelismConfig(
+     cp_size=8,
+     cp_handler=cp_config,  # or just cp_size=8, if you want to use the default "allgather"
+ )

accelerator = Accelerator(
    ...,
    parallelism_config=parallelism_config,
)

与 🤗accelerate 中的任何其他功能一样,您也可以通过向 `accelerate launch` 传递相应的标志来启用上下文并行。在这种情况下,没有区别

accelerate launch --parallelism-config-cp-size 8 --parallelism-config-cp-comm-strategy [allgather|alltoall] ...

你也可以在 `accelerate config` 命令中设置 `cp_size` 和 `cp_comm_strategy`,这会把它们保存在你的 `accelerate` 配置文件中,这样你就不用每次启动脚本时都传递它们了。

上下文并行与其他并行策略兼容,例如数据并行、张量并行和 FSDP2。你可以简单地通过将并行大小设置为期望的值来组合它们,例如 `--parallelism-config-dp-size 8 --parallelism-config-tp-size 2 --parallelism-config-cp-size 8`。或者,你可以使用 `ParallelismConfig` 类以编程方式设置它们。

上下文并行与 `FSDP2` 紧密耦合,你可以在 FSDP2 简介中了解更多。这意味着,上下文并行仅在您为程序使用 `FullyShardedDataParallelPlugin` 或将版本设置为 2 的 `--use-fsdp` 时才有效。如果不使用 `FSDP2`,将会引发错误。

上下文并行仅适用于SDPA,并且仅在没有掩码或使用因果掩码的情况下工作。我们无法为您正确检测这一点,因此您有责任确保您使用的是没有掩码或带有因果掩码的 `SDPA`。如果您使用任何其他注意力实现,它将引发错误。

通过上述方法启用上下文并行后,你可以将其应用于你的训练循环。我们提供了一个围绕 `torch.distributed.tensor.experimental.context_parallel` 的薄包装器,你可以在你的训练循环中使用它,它抽象了一些使用它的复杂性(稍后会详细介绍)。为了最小化你对训练循环的修改,我们提供了一个上下文管理器,如果上下文并行未启用,它就是一个 `noop`(空操作),如果启用了,它就应用上下文并行。这样,你可以在你的训练循环中使用它,而无需根据你的并行配置更改任何代码。你可以如下使用它

for batch in dataloader:
    with accelerator.maybe_context_parallel(
        buffers=[batch["input_ids"], batch["attention_mask"]],
        buffer_seq_dims=[1, 1],
        no_restore_buffers={batch["input_ids"], batch["labels"]},
    ):
        outputs = model(**batch)
        ...

这个上下文管理器必须在每个训练步骤中重新创建,如上例所示。这样做至关重要。

这有可能将您的上下文大小扩展到 1M+ 的序列长度。下面,我们展示了上下文并行在高达 256k 上下文大小下的速度和内存使用情况。我们可以看到,当我们加倍上下文大小和 GPU 数量时,我们可以实现一致的内存使用,从而可能实现无限的上下文长度扩展。

context parallelism memory usage
图 1:上下文并行在高达 256k 上下文大小下的内存使用和速度。

这些示例是使用您可以在示例文件夹中找到的脚本创建的。要在 8 个 H100 GPU(128k 序列长度)上运行该示例,您可以使用以下命令

accelerate launch --use-fsdp --fsdp-activation-checkpointing=TRUE examples/fsdp2/nd_parallel.py --cp-size=8 --sequence-length=128000

Accelerate 的接口

上下文管理器接受几个参数,用于配置上下文并行。

  • `buffers`:这是一个张量列表,它们将在序列维度上进行分片。这些张量通常是输入 ID、标签和注意力掩码。
  • `buffer_seq_dims`:这是一个整数列表,按 `buffers` 列表的顺序指定了缓冲区的序列维度。如果你传递 `buffers=[input_ids, shift_labels]`,两者形状都为 `[batch_size, sequence_length]`,那么你应该传递 `buffer_seq_dims=[1, 1]`,因为序列维度是张量的第二个维度。这对于正确计算模型输出是必需的。
  • `no_restore_buffers`:上下文并行的实现会原地修改缓冲区,将它们转换为 `torch.distributed.tensor.Dtensor`。在上下文管理器退出后,需要启动一个通信内核来将缓冲区恢复到其原始状态(通常是 all-gather)。这需要一些时间,所以建议传递与 `buffers` 参数中相同的张量,以避免不必要的通信,除非你确定在上下文管理器退出后需要使用这些缓冲区。

上下文并行与 `labels` 是 `input_ids` 的副本不兼容,因为 🤗 transformers 的模型可能会自行移动 `labels` 以启用因果语言建模。想象这种情况:labels = [l1, l2, l3, l4, … li],如果我们应用上下文并行,每个 rank 会得到一部分 labels,例如:labels_rank_0 = [l1, l2], labels_rank_1 = [l3, l4], … 在 transformers 的建模代码移动 labels 后,会变成:labels_rank_0 = [l2, PAD], labels_rank_1 = [l3, PAD], … 其中 `PAD` 是一个填充标记。这会导致损失计算不正确,因为 labels 不再与输入对齐。因此,你需要在将 labels 传入模型之前手动移动它们。

可配置选项

Accelerate 仅提供一个选项来配置上下文并行(除了 `cp_size`)

  • `cp_comm_strategy`:用于分片轮换的方法。我们强烈建议将其保持为 `"allgather"`,因为它很可能在大多数情况下优于 `"alltoall"`。

上下文并行大小相当不言自明,它是输入被分片的 rank 数量。上下文并行分片轮换定义了输入分片如何在 rank 之间轮换。我们将在下一节更详细地介绍这两种选项。

您可以在 ND 并行示例文件中看到一个端到端的示例,在那里您可以在单个 8xH100 节点上训练一个 8B 模型,上下文长度可达 128k。通过多节点训练,您可以在多个 GPU 上将其扩展到 1M+ 的序列长度。您还可以无缝地将其与其他并行策略结合起来,以满足您的需求。

技术细节

本节技术性较强,如果您不需要了解上下文并行的内部原理,可以跳过此节,直接开始构建 🚀

在接下来的章节中,我们将大量使用 `shard` (分片) 这个词,所以我们先来定义它。如果我们将一个张量称为在第 `D` 维上,跨 `N` 个 rank `sharded`(分片),我们的意思是这个张量被分成 `N` 部分,其中张量的每个部分的形状为 `[..., D//N, ...]`。

那么它是如何工作的呢?

上下文并行通过在序列维度上对 `Q、K 和 V` 矩阵进行分片来工作。每个 rank 都有其分配的 `Q` 分片,我们称之为 `Q_i`。在整个计算过程中,这个矩阵只保留在该 rank 上。同样,每个 rank 都有自己的 `K` 和 `V` 分片,我们称之为 `K_i` 和 `V_i`。然后,每个 rank 用自己的 `Q_i`、`K_i` 和 `V_i` 计算注意力,我们称之为 `attn_i`。在此计算过程中,会启动一个通信内核来从所有其他 rank 收集 `K` 和 `V`。使用哪种通信原语取决于 `context_parallel_shard_rotation` 选项。这样,每个 rank 首先用 `Q_i`、`K_i` 和 `V_i` 计算本地注意力,然后用所有其他 rank 的 `K_j` 和 `V_j` 计算。由于每个 rank 持有的 `Q、K 和 V` 矩阵都是在序列维度上分片的,因此结果矩阵更小,可以容纳在单个 GPU 上。

我们可以用以下伪代码来形式化这个过程

comm_kernel = {"allgather": allgather, "alltoall": alltoall}[context_parallel_shard_rotation]
Qi, Ki, Vi = shard(Q, K, V, seq_dim)
attn[i] = attn(Qi, Ki, Vi)
for j in range(context_parallel_size):
    Kj, Vj = comm_kernel()
    attn[j] = attn(Qi, Kj, Vj) # [batch, num_heads, seq_len // context_parallel_size, head_dim]

final_attn = combine(attn)

all-to-all vs all-gather

all-gather

那么 all-to-all 和 all-gather 有什么区别呢?使用 all-gather,通信非常简单。在我们计算完本地注意力 `attn_i` 之后(或者更确切地说,之前,因为它通常耗时更长),我们会启动一个 all-gather 来收集所有其他 rank 的 `K` 和 `V`。当这个通信完成后,每个 rank 就拥有了所有其他 rank 的 `K` 和 `V`,并可以依次与它们计算注意力。在理想情况下,all-gather 的完成时间恰好与 `attn_i` 的计算完成时间一致。然而,在实践中这从未发生,因此理想的实际重叠是在 `attn_i` 的全部计算与一部分通信重叠时实现的,然后为了开始用 `K_j` 和 `V_j` 进行计算,我们等待 all-gather 完成。

all-to-all

All-to-all,有时也称为 `ring-rotation`,利用了一种环状的通信模式。在完成 `attn_i` 计算后,会启动一个 all-to-all 操作,将 `K_i` 和 `V_i` 发送给相邻的 rank。然后我们重复这个操作 `context_parallel_size-1` 次,这样每个 rank 都能看到所有其他 rank 的 `K` 和 `V` 的分片一次。在理想情况下,我们预取相邻 rank 的分片 `K_i+1` 和 `V_i+1`,并且这个通信过程与我们当前 `attn_i` 的计算完全重叠。同样,现实中这种完美的重叠从未发生。鉴于这种方法的性质,如果我们没有实现完美的重叠,其代价要比使用 all-gather 大得多。

如何选择正确的轮换方法?

理论上,all-to-all 应该是更好的选择。但实际上,它很少如此。因此,我们默认使用 all-gather,因为它更有可能获得更好的性能。`torchtitan` 团队的广泛基准测试也表明,all-to-all 很少优于 all-gather。尽管如此,我们仍然提供两种选择,因为您可能会发现其中一种更适合您的用例。

您可以直接在下图的性能分析器输出中看到这个问题

all-to-all profiler output
图 1:红色部分显示了等待 all-to-all 内核完成时的空闲时间。在第一个蓝色条中高亮显示的部分,您可以看到它需要大约 250 微秒才能完成,这个过程在每次注意力调用中重复 N-1 次,其中 N 是上下文并行大小。

为何只支持 FSDP2?

我们只支持 `FSDP2` 的上下文并行,因为我们创建了一个 `context_parallel_size` 和 `dp_shard_size` 的联合网格来充分利用其潜力。它的工作原理是:我们在大小为 `cp_size*dp_shard_size` 的联合网格上对模型进行分片,这最大化了内存节省。这在某种程度上是“免费的午餐”,因为 `FSDP` 通信与注意力的计算完全重叠,如下图所示。

why FSDP2+CP
图 2:在蓝色矩形(Stream 23)中,您可以看到 `FSDP` 分片的预取与注意力的计算(Stream 7)完全重叠,而在红色矩形(Stream 24)中,您可以看到 all-gather 内核导致了一个空闲时间的“气泡”,在此期间我们的计算流(7)是空闲的。

在上图中,您还可以注意到 all-to-all 和 all-gather 之间的区别。在 all-to-all(图 1)中,我们每次注意力调用都会启动 N-1 次通信内核,而在 all-gather(图 2)中,我们只启动一次通信内核。这导致了一个更大的“气泡”,但每次注意力调用只发生一次,而在 all-to-all 中,它会发生 N-1 次。

联合网格中的数据分发

我们确保将同一批数据分发到整个 `cp` 子组,以确保结果正确。(意味着 `cp` 子组中的每个 rank 都会收到同一批数据。)然而,我们也会将不同的批次分发到 `dp_shard` 组的每个 rank。可以这样想象:

# 8 GPUS, --dp_shard_size 4, --cp_size 2
# mesh = [[0, 1], [2, 3], [4, 5], [6, 7]]
# model is sharded across the whole mesh (each GPU holds 1/8 of the model)
# GPUs 0,1 = batch 0
# GPUs 2,3 = batch 1
... and so on.
< > 在 GitHub 上更新