Accelerate 文档
检查点
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
检查点
当使用 Accelerate 训练 PyTorch 模型时,您可能经常希望保存并继续训练状态。这样做需要保存和加载模型、优化器、RNG 生成器和 GradScaler。Accelerate 内部有两个便捷函数可以快速实现此目的
- 使用 save_state() 将上面提到的所有内容保存到文件夹位置
- 使用 load_state() 加载从早期
save_state
存储的所有内容
为了进一步自定义通过 save_state() 保存状态的位置和方式,可以使用 ProjectConfiguration 类。例如,如果启用了 automatic_checkpoint_naming
,则每个保存的检查点都将位于 Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}
。
应该注意的是,期望这些状态来自同一个训练脚本,它们不应该来自两个单独的脚本。
- 通过使用 register_for_checkpointing(),您可以注册自定义对象以从先前的两个函数自动存储或加载,只要该对象具有
state_dict
和load_state_dict
功能即可。这可能包括学习率调度器等对象。
以下是使用检查点在训练期间保存和重新加载状态的简短示例
from accelerate import Accelerator
import torch
accelerator = Accelerator(project_dir="my/save/path")
my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)
# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)
# Save the starting state
accelerator.save_state()
device = accelerator.device
my_model.to(device)
# Perform training
for epoch in range(num_epochs):
for batch in my_training_dataloader:
my_optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = my_model(inputs)
loss = my_loss_function(outputs, targets)
accelerator.backward(loss)
my_optimizer.step()
my_scheduler.step()
# Restore the previous state
accelerator.load_state("my/save/path/checkpointing/checkpoint_0")
恢复 DataLoader 的状态
从检查点恢复后,如果状态是在 epoch 中间保存的,则可能还需要从活动 DataLoader
中的特定点恢复。您可以使用 skip_first_batches() 来做到这一点。
from accelerate import Accelerator
accelerator = Accelerator(project_dir="my/save/path")
train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")
# Assume the checkpoint was saved 100 steps into the epoch
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)
# After the first iteration, go back to `train_dataloader`
# First epoch
for batch in skipped_dataloader:
# Do something
pass
# Second epoch
for batch in train_dataloader:
# Do something
pass