加速文档

DDP 通信钩子

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验的访问权限

开始

DDP 通信钩子

分布式数据并行 (DDP) 通信钩子提供了一个通用接口来控制梯度如何在工作器之间进行通信,方法是覆盖 `DistributedDataParallel` 中的普通 allreduce。提供了一些内置的通信钩子,用户可以轻松地将这些钩子中的任何一个应用于优化通信。

  • FP16 压缩 Hook: 通过将梯度转换为半精度浮点数格式 (torch.float16) 来压缩梯度,从而降低通信开销。
  • BF16 压缩 Hook: 与 FP16 类似,但使用 Brain 浮点数格式 (torch.bfloat16),在某些硬件上可能更有效。
  • PowerSGD Hook: 一种高级梯度压缩算法,可提供高压缩率并加速带宽受限的分布式训练。

在本教程中,您将了解如何快速设置 DDP 通信 Hook 并使用 Accelerate 中提供的实用程序进行训练,这可能只需添加一行新的代码即可!这演示了如何使用 DDP 通信 Hook 来优化分布式训练中与 Accelerate 库的梯度通信。

FP16 压缩 Hook

PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

BF16 压缩 Hook

BF16 压缩 Hook API 处于实验阶段,需要 NCCL 版本高于 2.9.6。

PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

PowerSGD Hook

PowerSGD 通常需要与模型梯度大小相同的额外内存来启用错误反馈,这可以弥补偏差的压缩通信并提高准确性。

PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

DDP 通信 Hook 实用程序

还有两个额外的实用程序来支持通信 Hook 的可选功能。

comm_wrapper

comm_wrapper 是一个选项,用于将通信 Hook 包装在额外的功能中。例如,它可用于将 FP16 压缩与其他通信策略相结合。当前支持的包装器有 nofp16bf16

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

comm_state_option

comm_state_option 允许您传递某些通信 Hook 所需的额外状态信息。这对于有状态的 Hook(如 PowerSGD)尤其有用,这些 Hook 需要在训练步骤之间维护超参数和内部状态。以下是一个展示 comm_state_optionPowerSGD Hook 结合使用的示例。

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

有关更高级的用法和附加 Hook,请参阅 PyTorch DDP 通信 Hook 文档

< > 在 GitHub 上更新