DDP 通信钩子
分布式数据并行 (DDP) 通信钩子提供了一个通用接口来控制梯度如何在工作器之间进行通信,方法是覆盖 `DistributedDataParallel` 中的普通 allreduce。提供了一些内置的通信钩子,用户可以轻松地将这些钩子中的任何一个应用于优化通信。
- FP16 压缩 Hook: 通过将梯度转换为半精度浮点数格式 (
torch.float16
) 来压缩梯度,从而降低通信开销。 - BF16 压缩 Hook: 与 FP16 类似,但使用 Brain 浮点数格式 (
torch.bfloat16
),在某些硬件上可能更有效。 - PowerSGD Hook: 一种高级梯度压缩算法,可提供高压缩率并加速带宽受限的分布式训练。
在本教程中,您将了解如何快速设置 DDP 通信 Hook 并使用 Accelerate 中提供的实用程序进行训练,这可能只需添加一行新的代码即可!这演示了如何使用 DDP 通信 Hook 来优化分布式训练中与 Accelerate 库的梯度通信。
FP16 压缩 Hook
PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.fp16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
BF16 压缩 Hook
BF16 压缩 Hook API 处于实验阶段,需要 NCCL 版本高于 2.9.6。
PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
model.register_comm_hook(state=None, hook=default_hooks.bf16_compress_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PowerSGD Hook
PowerSGD 通常需要与模型梯度大小相同的额外内存来启用错误反馈,这可以弥补偏差的压缩通信并提高准确性。
PyTorch
加速
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.algorithms.ddp_comm_hooks import powerSGD_hook
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
model = MyModel()
model = DDP(model, device_ids=[torch.cuda.current_device()])
state = powerSGD_hook.PowerSGDState(process_group=None)
model.register_comm_hook(state=state, hook=powerSGD_hook.powerSGD_hook)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
optimizer.zero_grad()
DDP 通信 Hook 实用程序
还有两个额外的实用程序来支持通信 Hook 的可选功能。
comm_wrapper
comm_wrapper
是一个选项,用于将通信 Hook 包装在额外的功能中。例如,它可用于将 FP16 压缩与其他通信策略相结合。当前支持的包装器有 no
、fp16
和 bf16
。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_wrapper=DDPCommunicationHookType.FP16
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
comm_state_option
comm_state_option
允许您传递某些通信 Hook 所需的额外状态信息。这对于有状态的 Hook(如 PowerSGD
)尤其有用,这些 Hook 需要在训练步骤之间维护超参数和内部状态。以下是一个展示 comm_state_option
与 PowerSGD
Hook 结合使用的示例。
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
import torch
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
# DDP Communication Hook setup
ddp_kwargs = DistributedDataParallelKwargs(
comm_hook=DDPCommunicationHookType.POWER_SGD,
comm_state_option={"matrix_approximation_rank": 2}
)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])
model = MyModel()
optimizer = torch.optim.Adam(model.parameters())
data_loader = DataLoader(dataset, batch_size=16)
model, optimizer, data_loader = accelerator.prepare(model, optimizer, data_loader)
# Training loop
for data, targets in data_loader:
outputs = model(data)
loss = criterion(outputs, targets)
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
有关更高级的用法和附加 Hook,请参阅 PyTorch DDP 通信 Hook 文档。
< > 在 GitHub 上更新