加速文档

检查点

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验的访问权限

开始使用

检查点

使用 Accelerate 训练 PyTorch 模型时,您可能经常需要保存和继续训练状态。这样做需要保存和加载模型、优化器、RNG 生成器和 GradScaler。Accelerate 内部有两个方便函数可以快速实现这一点。

  • 使用 save_state() 将上述所有内容保存到文件夹位置。
  • 使用 load_state() 加载先前 save_state 中存储的所有内容。

要进一步自定义通过 save_state() 保存状态的位置和方式,可以使用 ProjectConfiguration 类。例如,如果启用了 automatic_checkpoint_naming,则每个保存的检查点将位于 Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}

需要注意的是,这些状态应来自同一个训练脚本,不应该来自两个独立的脚本。

  • 通过使用 register_for_checkpointing(),您可以注册自定义对象,以便自动存储或加载这两个先前函数,只要对象具有 state_dict **和** load_state_dict 功能。这可能包括学习率调度器等对象。

以下是一个使用检查点保存和重新加载训练期间状态的简短示例

from accelerate import Accelerator
import torch

accelerator = Accelerator(project_dir="my/save/path")

my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)

# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)

# Save the starting state
accelerator.save_state()

device = accelerator.device
my_model.to(device)

# Perform training
for epoch in range(num_epochs):
    for batch in my_training_dataloader:
        my_optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = my_model(inputs)
        loss = my_loss_function(outputs, targets)
        accelerator.backward(loss)
        my_optimizer.step()
    my_scheduler.step()

# Restore the previous state
accelerator.load_state("my/save/path/checkpointing/checkpoint_0")

恢复 DataLoader 的状态

从检查点恢复后,如果在 epoch 中间保存了状态,您可能还想从活动 DataLoader 中的特定点恢复。您可以使用 skip_first_batches() 来实现这一点。

from accelerate import Accelerator

accelerator = Accelerator(project_dir="my/save/path")

train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
    # Do something
    pass

# Second epoch
for batch in train_dataloader:
    # Do something
    pass
< > 在 GitHub 上更新