Accelerate 文档
实验跟踪器
并获得增强的文档体验
开始使用
实验跟踪器
有大量的实验跟踪 API 可用,但是要让它们都在多进程环境中工作通常很复杂。Accelerate 提供了一个通用的跟踪 API,可以通过 Accelerator.log()
在脚本运行时记录有用的项目。
集成跟踪器
目前,Accelerate
开箱即用地支持七种跟踪器
- TensorBoard
- WandB
- CometML
- Aim
- MLFlow
- ClearML
- DVCLive
要使用它们中的任何一个,请将所选类型传递给 Accelerate
中的 log_with
参数
from accelerate import Accelerator
from accelerate.utils import LoggerType
accelerator = Accelerator(log_with="all") # For all available trackers in the environment
accelerator = Accelerator(log_with="wandb")
accelerator = Accelerator(log_with=["wandb", LoggerType.TENSORBOARD])
在实验开始时,应该使用 Accelerator.init_trackers()
来设置你的项目,并可能添加任何要记录的实验超参数
hps = {"num_iterations": 5, "learning_rate": 1e-2}
accelerator.init_trackers("my_project", config=hps)
当你准备好记录任何数据时,应该使用 Accelerator.log()
。也可以传入一个 step
来将数据与训练循环中的特定步骤相关联。
accelerator.log({"train_loss": 1.12, "valid_loss": 0.8}, step=1)
一旦你完成训练,请确保运行 Accelerator.end_training(),以便所有跟踪器都可以运行它们的完成功能(如果有)。
accelerator.end_training()
一个完整的示例如下
from accelerate import Accelerator
accelerator = Accelerator(log_with="all")
config = {
"num_iterations": 5,
"learning_rate": 1e-2,
"loss_function": str(my_loss_function),
}
accelerator.init_trackers("example_project", config=config)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)
device = accelerator.device
my_model.to(device)
for iteration in range(config["num_iterations"]):
for step, batch in enumerate(my_training_dataloader):
my_optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = my_model(inputs)
loss = my_loss_function(outputs, targets)
accelerator.backward(loss)
my_optimizer.step()
accelerator.log({"training_loss": loss}, step=step)
accelerator.end_training()
如果跟踪器需要一个目录来保存数据,例如 TensorBoard
,则将目录路径传递给 project_dir
。当有其他配置要与 ProjectConfiguration 数据类结合使用时,project_dir
参数非常有用。例如,你可以将 TensorBoard 数据保存到 project_dir
,而其他所有内容都可以记录在 [~utils.ProjectConfiguration
] 的 logging_dir
参数中
accelerator = Accelerator(log_with="tensorboard", project_dir=".")
# use with ProjectConfiguration
config = ProjectConfiguration(project_dir=".", logging_dir="another/directory")
accelerator = Accelerator(log_with="tensorboard", project_config=config)
实现自定义跟踪器
要实现一个在 Accelerator
中使用的新跟踪器,可以通过实现 GeneralTracker
类来创建一个新的跟踪器。每个跟踪器必须实现三个函数并具有三个属性
__init__
:- 应该存储一个
run_name
并初始化集成库的跟踪器 API。 - 如果跟踪器在本地存储其数据(例如 TensorBoard),则可以添加
logging_dir
参数。
- 应该存储一个
store_init_configuration
:- 应该接收一个
values
字典并将它们存储为一次性实验配置
- 应该接收一个
log
:- 应该接收一个
values
字典和一个step
,并将它们记录到运行中
- 应该接收一个
name
(str
)- 跟踪器的唯一字符串名称,例如 wandb 跟踪器的
"wandb"
。 - 这将用于专门与此跟踪器交互
- 跟踪器的唯一字符串名称,例如 wandb 跟踪器的
requires_logging_directory
(bool
)- 此特定跟踪器是否需要
logging_dir
以及是否使用它。
- 此特定跟踪器是否需要
tracker
:- 这应该实现为一个
@property
函数 - 应该返回库使用的内部跟踪机制,例如
wandb
的run
对象。
- 这应该实现为一个
如果记录器应该仅在主进程上执行,则每个方法还应该使用 state.PartialState 类。
下面可以看到一个简短的示例,它与 Weights and Biases 集成,仅包含相关信息并在主进程上进行日志记录
from accelerate.tracking import GeneralTracker, on_main_process
from typing import Optional
import wandb
class MyCustomTracker(GeneralTracker):
name = "wandb"
requires_logging_directory = False
@on_main_process
def __init__(self, run_name: str):
self.run_name = run_name
run = wandb.init(self.run_name)
@property
def tracker(self):
return self.run.run
@on_main_process
def store_init_configuration(self, values: dict):
wandb.config(values)
@on_main_process
def log(self, values: dict, step: Optional[int] = None):
wandb.log(values, step=step)
当你准备好构建你的 Accelerator
对象时,将你的跟踪器的**实例**传递给 Accelerator.log_with
,使其自动与 API 一起使用
tracker = MyCustomTracker("some_run_name")
accelerator = Accelerator(log_with=tracker)
这些也可以与现有跟踪器混合使用,包括 "all"
tracker = MyCustomTracker("some_run_name")
accelerator = Accelerator(log_with=[tracker, "all"])
访问内部跟踪器
如果可能需要直接与跟踪器进行一些自定义交互,你可以使用 Accelerator.get_tracker() 方法快速访问一个跟踪器。只需传入与跟踪器的 .name
属性对应的字符串,它将在主进程上返回该跟踪器。
此示例展示了如何使用 wandb 执行此操作
wandb_tracker = accelerator.get_tracker("wandb")
从那里你可以像往常一样与 wandb
的 run
对象进行交互
wandb_tracker.log_artifact(some_artifact_to_log)
如果你想完全移除 Accelerate 的包装,你可以通过以下方式实现相同的效果
wandb_tracker = accelerator.get_tracker("wandb", unwrap=True)
if accelerator.is_main_process:
wandb_tracker.log_artifact(some_artifact_to_log)
当包装器无法工作时
如果一个库的 API 不遵循严格的 .log
和整体字典(例如 Neptune.AI),则可以在 if accelerator.is_main_process
语句下手动完成日志记录
from accelerate import Accelerator
+ import neptune
accelerator = Accelerator()
+ run = neptune.init_run(...)
my_model, my_optimizer, my_training_dataloader = accelerate.prepare(my_model, my_optimizer, my_training_dataloader)
device = accelerator.device
my_model.to(device)
for iteration in config["num_iterations"]:
for batch in my_training_dataloader:
my_optimizer.zero_grad()
inputs, targets = batch
inputs = inputs.to(device)
targets = targets.to(device)
outputs = my_model(inputs)
loss = my_loss_function(outputs, targets)
total_loss += loss
accelerator.backward(loss)
my_optimizer.step()
+ if accelerator.is_main_process:
+ run["logs/training/batch/loss"].log(loss)