Accelerate 文档
执行与推迟作业
并获得增强的文档体验
开始使用
执行与推迟作业
当您运行常规脚本时,指令会按顺序执行。使用 Accelerate 在多个 GPU 上同时部署脚本会带来一个复杂问题:虽然每个进程都按顺序执行所有指令,但某些进程可能会比其他进程快。
您可能需要等待所有进程都到达某个特定点后,才能执行给定的指令。例如,在确定每个进程都完成训练之前,您不应该保存模型;在所有模型权重都加载完毕之前,您也不想继续训练。要实现这一点,只需在代码中写入以下行:
accelerator.wait_for_everyone()
这条指令会阻塞所有先到达的进程,直到所有其他进程都到达该点(如果您只在一个 GPU 或 CPU 上运行脚本,这条指令将不起任何作用)。
下面列出了一些使用此功能的示例情况:
其中一些是与 main_process_first() 上下文管理器一起使用的,该管理器利用 wait_for_everyone() 在主进程上预先运行一组特定的代码,然后再触发和启动其他进程。
下载数据集
下载数据集时,您应该先在主进程上下载,然后再加载缓存的数据集。
load_dataset
会在内部执行锁定,以阻止同时进行多次下载,但如果您下载的内容未使用此库,则应使用此方法。
with accelerator.main_process_first():
datasets = load_dataset("glue", "mrpc")
这在底层与调用以下代码相同:
# First do something on the main process
if accelerator.is_main_process:
datasets = load_dataset("glue", "mrpc")
else:
accelerator.wait_for_everyone()
# And then send it to the rest of them
if not accelerator.is_main_process:
datasets = load_dataset("glue", "mrpc")
else:
accelerator.wait_for_everyone()
保存 state_dict
在保存模型的 state_dict
时,由于您通常只在主进程上保存一个文件,您应该明确指定这一点。
if accelerator.is_main_process:
model = accelerator.unwrap_model(model)
torch.save(model.state_dict(), "weights.pth")
加载 state_dict
将 state_dict
加载到模型、优化器或调度器时,您应等待所有工作进程都加载完权重后,再继续进行训练。
with accelerator.main_process_first():
state = torch.load("weights.pth")
model.load_state_dict(state)
应用多工作进程 CPU 操作
在多个工作进程上应用 map()
操作(如分词)时,应首先在主进程上完成,然后传播到每个工作进程。
datasets = load_dataset("glue", "mrpc")
with accelerator.main_process_first():
tokenized_datasets = datasets.map(
tokenize_function,
batched=True,
remove_columns=["idx", "sentence1", "sentence2"],
)
应用检查,如提前停止
要实现一个由特定进程设置标志的检查,应使用 set_trigger
和 check_trigger
API。这在某些情况下非常有用,例如使用提前停止和监控损失(因为每个进程上的损失会略有不同)。
当您的条件满足时,调用 Accelerator.set_trigger();当检查任何进程中该条件是否已满足时,调用 Accelerator.check_trigger()。
for (x,y) in data_loader:
logits = model(x)
loss = loss_func(logits, y)
# Assume `should_do_early_stopping` is a custom defined function that returns a conditional
if should_do_early_stopping(loss):
accelerator.set_trigger()
# Later in the training script when we need to check for the breakpoint
if accelerator.check_trigger():
break