Accelerate 文档
DDP 通信钩子
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
DDP 通信钩子
分布式数据并行(Distributed Data Parallel, DDP)通信钩子提供了一个通用接口,通过覆盖 `DistributedDataParallel` 中默认的 allreduce 操作来控制梯度如何在工作进程间通信。PyTorch 提供了一些内置的通信钩子,用户可以轻松应用任何这些钩子来优化通信。
- FP16 压缩钩子:通过将梯度转换为半精度浮点格式(`torch.float16`)来压缩梯度,从而减少通信开销。
- BF16 压缩钩子:与 FP16 类似,但使用 Brain 浮点格式(`torch.bfloat16`),这在某些硬件上可能更高效。
- PowerSGD 钩子:一种高级梯度压缩算法,提供高压缩率,可以加速受带宽限制的分布式训练。
在本教程中,您将了解如何快速设置 DDP 通信钩子,并使用 Accelerate 中提供的工具进行训练,这可以像添加一行新代码一样简单!本教程演示了如何使用 DDP 通信钩子来优化使用 Accelerate 库的分布式训练中的梯度通信。
FP16 压缩钩子
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
BF16 压缩钩子
BF16 压缩钩子 API 是实验性的,需要 2.9.6 以上版本的 NCCL。
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PowerSGD 钩子
PowerSGD 通常需要与模型梯度大小相同的额外内存来启用误差反馈,这可以补偿有偏的压缩通信并提高准确性。
PyTorch
Accelerate
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook
from accelerate.test_utils.testing import get_backend
device_type, _, _ = get_backend()
device_id = getattr(torch, device_type, torch.cuda).current_device()
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[device_id])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
DDP 通信钩子工具
还有两个额外的工具,用于支持通信钩子的可选功能。
comm_wrapper
`comm_wrapper` 是一个选项,用于为通信钩子包装额外的功能。例如,它可以用于将 FP16 压缩与其他通信策略结合使用。当前支持的包装器有 `no`、`fp16` 和 `bf16`。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
comm_state_option
`comm_state_option` 允许您传递某些通信钩子所需的附加状态信息。这对于像 `PowerSGD` 这样需要在训练步骤之间维护超参数和内部状态的有状态钩子特别有用。下面是一个展示如何将 `comm_state_option` 与 `PowerSGD` 钩子一起使用的示例。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
有关更高级的用法和其他钩子,请参阅 PyTorch DDP 通信钩子文档。
< > 在 GitHub 上更新