加速文档

执行和延迟作业

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

执行和延迟作业

当您运行常规脚本时,指令按顺序执行。使用 Accelerate 在多个 GPU 上同时部署脚本会引入一个复杂性:虽然每个进程都按顺序执行所有指令,但某些进程可能比其他进程更快。

您可能需要等待所有进程都达到某个点,然后才能执行给定的指令。例如,在确定每个进程都完成训练之前,不应保存模型;并且在所有模型权重都已加载之前,您不会希望继续训练。要做到这一点,只需在您的代码中写入以下行

accelerator.wait_for_everyone()

此指令将阻止所有先到达的进程,直到所有其他进程都到达该点为止(如果您仅在一个 GPU 或 CPU 上运行脚本,则此操作不会执行任何操作)。

以下列出了一些何时使用此实用程序的示例情况

其中一些与 main_process_first() 上下文管理器一起使用,该管理器利用 wait_for_everyone() 在触发和启动其他进程之前,首先在主进程上运行特定的代码集

下载数据集

下载数据集时,应首先在主进程上下载,然后在之后加载缓存的数据集

load_dataset 将在后台执行锁定,以阻止同时发生多次下载,但是如果您下载的内容未使用此库,则应使用此方法。

with accelerator.main_process_first():
    datasets = load_dataset("glue", "mrpc")

在后台,这与调用相同

# First do something on the main process
if accelerator.is_main_process:
    datasets = load_dataset("glue", "mrpc")
else:
    accelerator.wait_for_everyone()

# And then send it to the rest of them
if not accelerator.is_main_process:
    datasets = load_dataset("glue", "mrpc")
else:
    accelerator.wait_for_everyone()

保存 state_dict

当保存模型的 state_dict 时,由于通常只在主进程上保存一个文件,因此您应该指定这一点

if accelerator.is_main_process:
    model = accelerator.unwrap_model(model)
    torch.save(model.state_dict(), "weights.pth")

加载 state_dict

当将 state_dict 加载到模型、优化器或调度器中时,应等待所有工作进程都加载完权重,然后再继续进行训练

with accelerator.main_process_first():
    state = torch.load("weights.pth")
    model.load_state_dict(state)

应用多工作进程 CPU 操作

在多个工作进程上应用 map() 操作(例如,分词)应首先在主进程上完成,然后再传播到每个工作进程。

datasets = load_dataset("glue", "mrpc")

with accelerator.main_process_first():
    tokenized_datasets = datasets.map(
        tokenize_function,
        batched=True,
        remove_columns=["idx", "sentence1", "sentence2"],
    )

应用检查(例如,提前停止)

为了进行与特定进程设置的标志一起使用的检查,应使用 set_triggercheck_trigger API。这样做的有用示例可以包括使用提前停止和监视损失的情况(因为每个进程上的损失略有不同)。

当您的条件已满足时,调用 Accelerator.set_trigger();当检查任何进程中是否已满足该条件时,调用 Accelerator.check_trigger()

for (x,y) in data_loader:
    logits = model(x)
    loss = loss_func(logits, y)
    # Assume `should_do_early_stopping` is a custom defined function that returns a conditional
    if should_do_early_stopping(loss):
        accelerator.set_trigger()

    # Later in the training script when we need to check for the breakpoint
    if accelerator.check_trigger():
        break
< > 在 GitHub 上更新