Accelerate 文档

DDP 通信钩子

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

DDP 通信钩子

分布式数据并行 (DDP) 通信钩子提供了一个通用接口,通过覆盖 DistributedDataParallel 中的 vanilla allreduce 来控制梯度如何在工作进程之间通信。提供了一些内置的通信钩子,用户可以轻松应用这些钩子中的任何一个来优化通信。

  • FP16 压缩钩子:通过将梯度转换为半精度浮点格式 (torch.float16) 来压缩梯度,从而减少通信开销。
  • BF16 压缩钩子:与 FP16 类似,但使用 Brain Floating Point 格式 (torch.bfloat16),这在某些硬件上可能更有效。
  • PowerSGD 钩子:一种先进的梯度压缩算法,提供高压缩率,可以加速带宽受限的分布式训练。

在本教程中,您将看到如何快速设置 DDP 通信钩子并使用 Accelerate 中提供的实用程序执行训练,这可以简单到只添加一行新代码!这演示了如何使用 DDP 通信钩子来优化使用 Accelerate 库进行分布式训练中的梯度通信。

FP16 压缩钩子

PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

BF16 压缩钩子

BF16 压缩钩子 API 是实验性的,并且需要 NCCL 版本高于 2.9.6。

PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

PowerSGD 钩子

PowerSGD 通常需要与模型梯度大小相同的额外内存来启用误差反馈,这可以补偿有偏差的压缩通信并提高准确性。

PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

DDP 通信钩子实用程序

还有两个额外的实用程序用于支持通信钩子的可选功能。

comm_wrapper

comm_wrapper 是一个选项,用于使用附加功能包装通信钩子。例如,它可以用于将 FP16 压缩与其他通信策略结合使用。当前支持的包装器有 nofp16bf16

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

comm_state_option

comm_state_option 允许您传递某些通信钩子所需的其他状态信息。这对于有状态的钩子(如 PowerSGD)特别有用,它需要跨训练步骤维护超参数和内部状态。以下示例展示了 comm_state_optionPowerSGD 钩子的用法。

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(self, x):
        return self.layer(x)

# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
    comm_hook=DDPCommunicationHookType.POWER_SGD,
    comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)

model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)

# Training loop
for data, targets in data_loader:
    outputs = model(data)
    loss = criterion(outputs, targets)
    accelerator.backward(loss)
    optimizer.step()
    optimizer.zero_grad()

有关更高级的用法和其他钩子,请参阅 PyTorch DDP 通信钩子文档

< > 在 GitHub 上更新