Accelerate 文档

使用 Accelerate 执行梯度累积

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

使用 Accelerate 执行梯度累积

梯度累积是一种技术,您可以使用比机器通常能够容纳在内存中更大的批量大小进行训练。这是通过在多个批次上累积梯度来完成的,并且仅在执行一定数量的批次后才步进优化器。

虽然从技术上讲,标准梯度累积代码可以在分布式设置中正常工作,但这不是最有效的方法,您可能会遇到相当大的减速!

在本教程中,您将看到如何快速设置梯度累积并使用 Accelerate 中提供的实用程序执行它,总共只需添加一行新代码!

此示例将使用一个非常简单的 PyTorch 训练循环,该循环每两个批次执行一次梯度累积

device = "cuda"
model.to(device)

gradient_accumulation_steps = 2

for index, batch in enumerate(training_dataloader):
    inputs, targets = batch
    inputs = inputs.to(device)
    targets = targets.to(device)
    outputs = model(inputs)
    loss = loss_function(outputs, targets)
    loss = loss / gradient_accumulation_steps
    loss.backward()
    if (index + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

将其转换为 Accelerate

首先,将转换前面显示的代码以利用 Accelerate,而无需特殊的梯度累积助手

+ from accelerate import Accelerator
+ accelerator = Accelerator()

+ model, optimizer, training_dataloader, scheduler = accelerator.prepare(
+     model, optimizer, training_dataloader, scheduler
+ )

  for index, batch in enumerate(training_dataloader):
      inputs, targets = batch
-     inputs = inputs.to(device)
-     targets = targets.to(device)
      outputs = model(inputs)
      loss = loss_function(outputs, targets)
      loss = loss / gradient_accumulation_steps
+     accelerator.backward(loss)
      if (index+1) % gradient_accumulation_steps == 0:
          optimizer.step()
          scheduler.step()
          optimizer.zero_grad()

在其当前状态下,由于称为梯度同步的过程,此代码将无法有效地执行梯度累积。在概念教程中阅读更多相关信息!

让 Accelerate 处理梯度累积

现在剩下要做的就是让 Accelerate 为我们处理梯度累积。为此,您应该将 gradient_accumulation_steps 参数传递给 Accelerator,指示在每次调用 step() 之前要执行的步数,以及如何在调用 backward() 期间自动调整损失

  from accelerate import Accelerator
- accelerator = Accelerator()
+ accelerator = Accelerator(gradient_accumulation_steps=2)

或者,您可以将 gradient_accumulation_plugin 参数传递给 Accelerator 对象的 __init__,这将允许您进一步自定义梯度累积行为。在 GradientAccumulationPlugin 文档中阅读更多相关信息。

从这里,您可以在训练循环内使用 accumulate() 上下文管理器,以自动为您执行梯度累积!您只需将其包裹在我们代码的整个训练部分周围即可

- for index, batch in enumerate(training_dataloader):
+ for batch in training_dataloader:
+     with accelerator.accumulate(model):
          inputs, targets = batch
          outputs = model(inputs)

您可以删除所有对步数和损失调整的特殊检查

- loss = loss / gradient_accumulation_steps
  accelerator.backward(loss)
- if (index+1) % gradient_accumulation_steps == 0:
  optimizer.step()
  scheduler.step()
  optimizer.zero_grad()

如您所见,Accelerator 能够跟踪您所在的批次号,并且它将自动知道是否要逐步执行准备好的优化器以及如何调整损失。

通常,对于梯度累积,您需要调整步数以反映您正在训练的总批次的变化。Accelerate 默认情况下会自动为您执行此操作。在幕后,我们实例化了一个配置为执行此操作的 GradientAccumulationPlugin

state.GradientState 与正在迭代的活动数据加载器同步。因此,它天真地假设当我们到达数据加载器的末尾时,一切都将同步,并且将执行一个步骤。要禁用此功能,请在 GradientAccumulationPlugin 中将 sync_with_dataloader 设置为 False

from accelerate import Accelerator
from accelerate.utils import GradientAccumulationPlugin

plugin = GradientAccumulationPlugin(sync_with_dataloader=False)
accelerator = Accelerator(..., gradient_accumulation_plugin=plugin)

完成的代码

以下是使用 Accelerate 执行梯度累积的完整实现

from accelerate import Accelerator
accelerator = Accelerator(gradient_accumulation_steps=2)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
    model, optimizer, training_dataloader, scheduler
)
for batch in training_dataloader:
    with accelerator.accumulate(model):
        inputs, targets = batch
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

重要的是,在上下文管理器 with accelerator.accumulate(model) 内部应该只完成一次前向/后向传播

要了解更多关于它包装的魔力,请阅读梯度同步概念指南

自包含示例

这是一个自包含的示例,您可以运行它以查看使用 Accelerate 的梯度累积的实际效果

import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from torch.utils.data import TensorDataset, DataLoader

# seed
set_seed(0)

# define toy inputs and labels
x = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.])
y = torch.tensor([2., 4., 6., 8., 10., 12., 14., 16.])
gradient_accumulation_steps = 4
per_device_batch_size = len(x) // gradient_accumulation_steps

# define dataset and dataloader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size)

# define model, optimizer and loss function
class SimpleLinearModel(torch.nn.Module):
    def __init__(self):
        super(SimpleLinearModel, self).__init__()
        self.weight = torch.nn.Parameter(torch.zeros((1, 1)))

    def forward(self, inputs):
        return inputs @ self.weight

model = SimpleLinearModel()
model_clone = copy.deepcopy(model)
criterion = torch.nn.MSELoss()
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.02)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, model_optimizer, dataloader = accelerator.prepare(model, model_optimizer, dataloader)
model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.02)
print(f"initial model weight is {model.weight.mean().item():.5f}")
print(f"initial model weight is {model_clone.weight.mean().item():.5f}")
for i, (inputs, labels) in enumerate(dataloader):
    with accelerator.accumulate(model):
        inputs = inputs.view(-1, 1)
        print(i, inputs.flatten())
        labels = labels.view(-1, 1)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        accelerator.backward(loss)
        model_optimizer.step()
        model_optimizer.zero_grad()
loss = criterion(x.view(-1, 1) @ model_clone.weight, y.view(-1, 1))
model_clone_optimizer.zero_grad()
loss.backward()
model_clone_optimizer.step()
print(f"w/ accumulation, the final model weight is {model.weight.mean().item():.5f}")
print(f"w/o accumulation, the final model weight is {model_clone.weight.mean().item():.5f}")
initial model weight is 0.00000
initial model weight is 0.00000
0 tensor([1., 2.])
1 tensor([3., 4.])
2 tensor([5., 6.])
3 tensor([7., 8.])
w/ accumulation, the final model weight is 2.04000
w/o accumulation, the final model weight is 2.04000

可变大小训练样本上的梯度累积

正如这篇博客文章中指出的那样,该文章指出了在可变大小的训练样本上执行梯度累积时发生的常见错误

[...] 对于跨令牌级任务(如因果 LM 训练)的梯度累积,正确的损失应该通过梯度累积步骤中所有批次的总体损失除以这些批次中所有非填充令牌的总数来计算。这与每个批次损失值的平均值不同。

换句话说,必须对令牌级操作的损失进行一些调整。

骨架代码

from accelerate import Accelerator
import math
import contextlib

gradient_accumulation_steps = 2
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
model, optimizer, training_dataloader, scheduler = accelerator.prepare(
    model, optimizer, training_dataloader, scheduler
)

training_iterator = iter(training_dataloader)
num_samples_in_epoch = len(training_dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)
        

total_batched_samples = 0
for update_step in range(total_updates):
        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
        batch_samples = []
        num_batches_in_step = gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
        for _ in range(num_batches_in_step):
            batch_samples += [next(training_iterator)]
            
        # get local num items in batch 
        num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
        num_items_in_batch = accelerator.gather(num_items_in_batch).sum().item()
            
        for i, batch in enumerate(batch_samples):
            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html
            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
                ctx = model.no_sync
            else:
                ctx = contextlib.nullcontext
            
            total_batched_samples += 1

            with ctx():
                inputs, targets = batch
                outputs = model(inputs)
                loss = loss_function(outputs, targets) # the loss function shoud sum over samples rather than averaging
                
                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps
                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
                
                accelerator.backward(loss)

        # Sync gradients and perform optimization steps once every gradient_accumulation_steps
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

自包含的因果 LM 示例

import torch
import copy
from accelerate import Accelerator
from accelerate.utils import set_seed
from accelerate.logging import  get_logger
from torch.utils.data import Dataset, DataLoader
import math
import contexlib

# seed
set_seed(0)
logger = get_logger(__name__)

class MyDataset(Dataset):
    def __init__(self, num_samples):
        super().__init__()
        self.len = num_samples

    def __getitem__(self, index):
        input_ids = torch.arange(1, index+2, dtype=torch.float32)
        labels = torch.remainder(input_ids, 2)
        return {"input_ids": input_ids, "labels": labels}

    def __len__(self):
        return self.len
    
def collate_fn(features):
    input_ids = torch.nn.utils.rnn.pad_sequence([f["input_ids"] for f in features], batch_first=True, padding_value=-100)
    labels = torch.nn.utils.rnn.pad_sequence([f["labels"] for f in features], batch_first=True, padding_value=-100)
    return {"input_ids": input_ids[..., None], "labels": labels[..., None]}

# define toy inputs and labels
gradient_accumulation_steps = 2
per_device_batch_size = 4

# define accelerator
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)

# define dataset and dataloader
# for this toy example, we'll compute gradient descent over one single global batch
dataset = MyDataset(per_device_batch_size*gradient_accumulation_steps*accelerator.num_processes)
dataloader = DataLoader(dataset, batch_size=per_device_batch_size, collate_fn=collate_fn)

# define model, model_optimizer and loss function
model = torch.nn.Linear(1, 2, bias=False)
model_clone = copy.deepcopy(model)
criterion = torch.nn.CrossEntropyLoss(reduction="sum") # must sum over samples rather than averaging
model_optimizer = torch.optim.SGD(model.parameters(), lr=0.08)


logger.warning(f"initial model weight is {model.weight.detach().cpu().squeeze()}")
logger.warning(f"initial model clone weight is {model_clone.weight.detach().cpu().squeeze()}")

# prepare artifacts - accelerator handles device placement and dataloader splitting
model, model_optimizer = accelerator.prepare(model, model_optimizer)
dataloader = accelerator.prepare_data_loader(dataloader, device_placement=True)
training_iterator = iter(dataloader)

num_samples_in_epoch = len(dataloader)
remainder = num_samples_in_epoch % gradient_accumulation_steps
remainder = remainder if remainder != 0 else gradient_accumulation_steps
total_gradient_updates = math.ceil(num_samples_in_epoch / gradient_accumulation_steps)

total_batched_samples = 0
for update_step in range(total_gradient_updates):
        # In order to correctly the total number of non-padded tokens on which we'll compute the cross-entropy loss
        # we need to pre-load the full local batch - i.e the next per_device_batch_size * accumulation_steps samples
        batch_samples = []
        num_batches_in_step = gradient_accumulation_steps if update_step != (total_gradient_updates - 1) else remainder
        for _ in range(num_batches_in_step):
            batch_samples += [next(training_iterator)]
            
        # get local num items in batch 
        local_num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
        logger.warning(f"Step {update_step} - Device {accelerator.process_index} - num items in the local batch {local_num_items_in_batch}", main_process_only=False)

        # to compute it correctly in a multi-device DDP training, we need to gather the total number of items in the full batch.
        num_items_in_batch = accelerator.gather(local_num_items_in_batch).sum().item()
        logger.warning(f"Total num items {num_items_in_batch}")

        for i, batch in enumerate(batch_samples):
            inputs, labels = batch["input_ids"], batch["labels"]
            total_batched_samples += 1
            # if we perform gradient accumulation in a multi-devices set-up, we want to avoid unecessary communications when accumulating
            # cf: https://muellerzr.github.io/blog/gradient_accumulation.html
            if (i < len(batch_samples) - 1 and accelerator.num_processes > 1):
                ctx = model.no_sync
            else:
                ctx = contextlib.nullcontext
            with ctx():

                outputs = model(inputs)
                loss = criterion(outputs.view(-1, 2), labels.view(-1).to(torch.int64))
                
                # We multiply by num_processes because the DDP calculates the average gradient across all devices whereas dividing by num_items_in_batch already takes into account all devices
                # Same reason for gradient_accumulation_steps, but this times it's Accelerate that calculate the average gradient across the accumulated steps 
                loss = (loss * gradient_accumulation_steps * accelerator.num_processes) / num_items_in_batch
                accelerator.backward(loss)
        model_optimizer.step()
        model_optimizer.zero_grad()
                

logger.warning(f"Device {accelerator.process_index} - w/ accumulation, the final model weight is {accelerator.unwrap_model(model).weight.detach().cpu().squeeze()}", main_process_only=False)

# We know do the same operation but on a single device and without gradient accumulation

if accelerator.is_main_process:
    # prepare one single entire batch
    dataloader = DataLoader(dataset, batch_size=len(dataset), collate_fn=collate_fn)
    full_batch_without_accum = next(iter(dataloader))
    total_inputs, total_labels = full_batch_without_accum["input_ids"], full_batch_without_accum["labels"]
    model_clone_optimizer = torch.optim.SGD(model_clone.parameters(), lr=0.08)
    
    # train the cloned model
    loss = torch.nn.CrossEntropyLoss(reduction="mean")(model_clone(total_inputs).view(-1, 2), total_labels.view(-1).to(torch.int64))
    model_clone_optimizer.zero_grad()
    loss.backward()
    model_clone_optimizer.step()
    
    # We should have the same final weights.
    logger.warning(f"w/o accumulation, the final model weight is {model_clone.weight.detach().cpu().squeeze()}")

单设备上的结果 - 梯度累积步数设置为 1,batch_size 设置为 8

initial model weight is tensor([-0.0075,  0.5364])
initial model clone weight is tensor([-0.0075,  0.5364])
Step 0 - Device 0 - num items in the local batch 36
Total num items 36
Device 0 - w/ accumulation, the final model weight is tensor([0.0953, 0.4337])
w/o accumulation, the final model weight is tensor([0.0953, 0.4337])

双设备设置上的结果 - 梯度累积步数设置为 2,batch_size 设置为 4。

initial model weight is tensor([-0.0075,  0.5364])
initial model clone weight is tensor([-0.0075,  0.5364])
Step 0 - Device 0 - num items in the local batch 52
Step 0 - Device 1 - num items in the local batch 84
Total num items 136
Device 1 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
Device 0 - w/ accumulation, the final model weight is tensor([0.2117, 0.3172])
w/o accumulation, the final model weight is tensor([0.2117, 0.3172])

进一步探索:

请在路径 accelerate/examples/by_feature/gradient_accumulation_for_autoregressive_models.py 的 examples 文件夹中找到真实世界训练运行的完整示例脚本。

在全局批量大小恒定为 32 的几种训练配置上运行它,得到以下图表

请注意,训练损失完全相同,直到训练步骤 20。在此训练步骤之后出现的小偏差发生在第一个 epoch 的末尾,因为按照默认设置,当总批量大小不能完全除以数据集时,数据加载器会在数据集的开头复制样本。

< > 在 GitHub 上更新