Transformers 文档

Transformer 中的张量并行

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Transformer 中的张量并行

张量并行将模型分片到多个 GPU 上,并并行化矩阵乘法等计算。它使更大的模型能够适应内存,并且速度更快,因为每个 GPU 都可以处理一个张量切片。本文档假定您已经熟悉张量并行的基础知识。如果还不熟悉,请参阅超大规模手册中关于张量并行的部分。

张量并行是通信密集型的,因此建议在具有多个 GPU 的单台机器上使用它,利用快速的节点内通信。对于多节点训练,流水线或数据并行方法更高效(取决于您的用例)。

张量并行需要对模型参数进行微小的更改,因此在 transformers 中,我们开箱即用地支持一些流行的模型。

展开下面的列表,查看支持张量并行的模型。如果模型目前不在列表中,请提交 GitHub 问题或拉取请求以添加支持。

支持的模型

使用 🤗 transformers

Transformers 提供了一个简单的接口用于张量并行。我们提供多个类来实现不同的分区策略,以及一个简单的入口点来并行化 nn.Module 实例。您不必直接与此接口交互,一切都在 PretrainedModel.from_pretrained 方法中为您完成。本节将首先讨论我们支持的分区策略,然后是您将与之交互的用户界面,最后将教您如何用您自己的分区策略扩展它。

分区策略

在 transformers 中,分区策略存在于 ParallelInterface 类中,它就像字符串到策略实现的映射。

class ParallelInterface(MutableMapping):
    """
    Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
    with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
    it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
    """
    _global_mapping = {
        "colwise": ColwiseParallel(),
        "rowwise": RowwiseParallel(),
        "colwise_rep": ColwiseParallel(output_layouts=Replicate()),
        "rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
        "local_colwise": ColwiseParallel(use_dtensor=False),
        "local_rowwise": RowwiseParallel(use_dtensor=False),
        "local": IsolatedParallel(),
        "gather": GatherParallel(),
        "local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
        "sequence_parallel": SequenceParallel(),
        "replicate": ReplicateParallel(),
    }

我们支持以下策略:

  • ColwiseParallel - 简单的列式分区,能够同时处理权重和偏差,功能与我们之前讨论的完全一致。
  • RowwiseParallel - 再次,行式分区,如前所述,支持权重和偏差,此外还支持 nn.Embedding 模块。
  • SequenceParallel - 序列并行实现,用于支持 LayerNormDropout 层。还支持 RMSNorm 的 Python 实现(参见此处
  • PackedColwiseParallel - 列式分区的一种变体,但它作用于打包的权重(即 up_projgate_proj 打包在一起)。有关更多详细信息,请参阅此注释
  • PackedRowwiseParallel - 行式分区的一种变体,作用于打包的权重,有关更多详细信息请查看上面链接的注释。
  • GatherParallel - 一个非常简单的类,只将模块的输出在设备之间进行收集。
  • IsolatedParallel - 这是一个特殊情况,我们希望将模块与其余设备(世界)隔离。这用于 MoE 层中的专家,基本上创建了一种专家并行性。
  • ReplicateParallel - 许多 torch.distributed API 在模型部分分片时会中断,因此此类别用于在所有设备上复制模块。

模型分片

我们提供两种分片模型的方式,第一种是使用 auto 张量并行计划,它将根据我们预定义的配置自动分片模型。这需要模型在 transformers 中具有预定义的张量并行计划。

from transformers import AutoModelForCausalLM

# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" # better for smaller number of GPUs
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" # better to visualize all the possible strategies

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan="auto")

print(model._tp_plan)

有关支持张量并行的模型列表,请参阅上面的支持的模型部分。

第二种方法是手动指定您自己的分区计划。

from transformers import AutoModelForCausalLM

tp_plan = {
    "model.layers.*.self_attn.q_proj": "colwise",
    "model.layers.*.self_attn.k_proj": "colwise",
    "model.layers.*.self_attn.v_proj": "colwise",
    "model.layers.*.self_attn.o_proj": "rowwise",
    ...
}

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)

print(model._tp_plan)

您可能已经注意到 ParallelInterface 映射中存在一些特殊情况,现在让我们来讨论它们。这将帮助您理解它们的用途,并有助于扩展到其他策略。

PackedRowwiseParallel

这个类是 RowwiseParallel 的一个特例,它用于分片打包的权重。权重打包是模型中常用的一种技术。它是一种将多个线性层打包成一个更大的层的方法。

例如,在 Llama4 模型中,我们将 up_projgate_proj 打包到一个 gate_up_proj 模块中。

class Llama4TextExperts(nn.Module):
    ...
    self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, 2 * self.expert_dim))

然后在前向传播中,我们可以使用批量矩阵乘法来计算 gate_up_proj 模块的输出。

def forward(self, hidden_states):
    ...
    gate_up = torch.bmm(hidden_states, self.gate_up_proj) # Compute the output of the gate_up_proj module
    gate, up = gate_up.chunk(2, dim=-1) # Split the output into gate and up

在这种情况下,我们需要使用 PackedRowwiseParallel 策略来分片 gate_up_proj 模块,因为使用简单的 RowwiseParallel 将错误地分片层。

如果这有点难以理解,请查看此注释,其中提供了为什么需要使用 Packed* 的精彩视觉表示。

local* 策略

您可能已经注意到存在 local* 策略,它们使用与 * 策略相同的层,但根本不使用 DTensor。这是因为 DTensor 不支持某些操作,例如 torch.chunk。因此,有时我们需要使用 local* 策略,它们使用普通的 torch.Tensor 并手动完成一些分布式逻辑。

手动指定您自己的分区计划需要对模型架构以及分区策略如何相互作用有很好的理解。如果您对此不确定,最终的模型可能会非常慢,甚至出现故障或不正确。再次,请参阅超大规模手册,它可以教您所需的一切。

使用您自己的分区策略扩展接口

这是一个非常高级的话题,需要对分布式集合和模型架构有很好的理解。您的自定义分区策略应该继承自integrations/tensor_parallel.py中定义的 TensorParallelLayer,并实现:partition_tensor_prepare_input_fn_prepare_output_fn。然后,它应该在 ParallelInterface 映射中注册,以便我们的调度逻辑在 tp_plan 中指定时可以找到它。

让我们以一个现有示例 ColwiseParallel 为例,逐步讲解此工作流程。

  1. 继承自 TensorParallelLayer 并进行初始化
class ColwiseParallel(TensorParallelLayer):
    def __init__(
        self,
        *,
        input_layouts: Optional[Placement] = None, # The input layout coming from the previous layer
        output_layouts: Optional[Placement] = None, # The output layout we want to achieve
        use_local_output: bool = True, # Whether to use local output or not
        use_dtensor=True, # Whether to use DTensor or not
    ):
        self.input_layouts = (input_layouts or Replicate(),) # The input sharding coming from the previous layer
        self.output_layouts = (output_layouts or Shard(-1),) # Desired output sharding
        self.desired_input_layouts = (Replicate(),) # Desired input sharding, inputs should be replicated across GPUs
        self.use_local_output = use_local_output
        self.use_dtensor = use_dtensor

__init__ 方法中,我们定义这些属性,其中 input_layoutsoutput_layouts 描述了输入和输出张量应如何放置在设备上。desired_input_layouts 用于指定输入应该如何放置在设备上。

2a. 实现 partition_tensor 方法

def partition_tensor(
    self,
    param, # Full tensor of the parameter
    empty_param, # Empty tensor of the parameter, will be filled with the partitioned tensor
    param_type, # Type of the parameter, `bias` or `weight`
    param_casting_dtype, # The type to cast the parameter to
    to_contiguous, # Whether to convert the tensor to a contiguous memory layout
    rank, # The rank of the current device
    device_mesh, # The device mesh
) -> nn.Parameter: # Return the partitioned parameter
    ...

此方法用于对张量进行分区,并用分区后的张量填充 empty_param。我们提供了一些实用函数来帮助您完成此操作,例如 get_tensor_shard,它将为您获取此 rank 的原始参数的正确分片,或 get_packed_weights 以帮助处理打包权重。

2b. 实现 _prepare_input_fn_prepare_output_fn 方法

这些方法分别用作 pre-forwardforward 钩子。它们的作用是将输入和输出重新分配到 __init__ 方法中传入的所需布局。

def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
    ...
    # Do some custom logic, cast to DTensor etc.
    ...
    return inputs.redistribute(placements=desired_input_layouts, device_mesh=device_mesh)

def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
    ...
    # Do some custom logic, cast to DTensor etc.
    ...
    return outputs.redistribute(placements=output_layouts, device_mesh=device_mesh)
  1. 注册策略 恭喜!您已经实现了自己的分区策略。现在,要将其与您自己的 tp_plan 一起使用,您需要将其注册到 ParallelInterface 映射中。
from transformers.integrations.tensor_parallel import ParallelInterface

ParallelInterface.register_strategy("colwise_custom", ColwiseParallel)

现在您可以在 tp_plan 中使用它,如下所示:

tp_plan = {
    "model.layers.*.self_attn.q_proj": "colwise_custom",
    ...
}

model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, tp_plan=tp_plan)

完整示例

让我们通过一个完整的张量并行推理示例。

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


# enable tensor parallelism
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    tp_plan="auto",
)

# prepare input tokens
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

# distributed run
outputs = model(inputs)

torchrun 上启动上述推理脚本,每个 GPU 运行 4 个进程。

torchrun --nproc-per-node 4 demo.py

您可以从推理中获得显著的加速,特别是对于批处理大小或序列长度较大的输入。

对于序列长度为 512 且批处理大小不同的 Llama 的单个前向传播,您可以预期以下加速。

深入了解张量并行

我们的张量并行实现设计上是与框架无关的,但我们开发的具体实现依赖于 torch.distributed 包。我们大量利用 DeviceMeshDTensor 等抽象来为用户提供简单且可扩展的接口。

DeviceMesh

DeviceMesh 想象成一个相互通信的多维设备网格。不同的并行化策略需要不同类型的通信模式,因此我们可以创建具有多个子网格的 DeviceMesh

from torch.distributed.device_mesh import init_device_mesh

# Create a 1D mesh of 4 GPUs
device_mesh = init_device_mesh("cuda", (4,), mesh_dim_names=["tp"])

然后,大多数 torch.distributed 定义的并行化策略都可以应用于网格本身或其子网格,从而自动处理通信模式。

DTensor

分布式张量(Distributed Tensor)的缩写,DTensor 是一种张量子类,它在常规张量操作之上处理分布式逻辑。在张量并行的情况下,大多数模型权重都存储为 DTensor(有一些例外,稍后详述)。DTensor 最重要且必须理解的部分是 placement 属性。此属性告诉 PyTorch 张量如何放置在 DeviceMesh 的设备上。

它可以具有以下值:

  • Shard(dimension) - 注释此 DTensor 在其构建的 DeviceMesh 上按给定维度进行分片。例如,如果我们想对列式分区进行权重分片,我们会这样做:
weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(0)]) # Shard across the 1st (column-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Shard(-1)]) # Shard across the ONLY dimension

再举一个例子,对于行式分区,我们会这样做:

weight = ...
weight = DTensor.from_local(weight, device_mesh["tp"], placements=[Shard(1)]) # Shard across the 2nd (row-wise) dimension
bias = ...
bias = DTensor.from_local(bias, device_mesh["tp"], placements=[Replicate()]) # Replicate bias across all GPUs
  • Replicate() - 注释此 DTensorDeviceMesh 上被复制。非常直接,只在每个设备上创建张量的完整副本。
  • Partial() - 此 placement 对我们来说大部分无关紧要,它用于注释此张量正在等待归约操作。
< > 在 GitHub 上更新