Accelerate 文档
Accelerate 的内部机制
并获得增强的文档体验
开始使用
Accelerate 的内部机制
在内部,Accelerate 首先分析脚本启动的环境,以确定使用了哪种分布式设置、有多少个不同的进程以及当前脚本所在的进程。所有这些信息都存储在 ~AcceleratorState
中。
这个类在您第一次实例化 ~Accelerator 时被初始化,并执行分布式设置所需的任何特定初始化。然后,它的状态在所有 AcceleratorState 实例中唯一共享。(同样的操作也可以通过 PartialState 完成,它是一个更精简的版本,并且继承自 AcceleratorState
)
然后,当调用 prepare() 时,该库会
- 将您的模型包装在适用于分布式设置的容器中,
- 将您的优化器包装在 AcceleratedOptimizer 中,
- 将您的调度器包装在 AcceleratedScheduler 中
- 在 DataLoaderShard 或 DataLoaderDispatcher 中创建数据加载器的新版本
虽然模型、优化器和调度器只是被放入简单的包装器中,但数据加载器是重新创建的。这主要是因为 PyTorch 不允许用户在数据加载器创建后更改其 batch_sampler
,而该库通过更改 batch_sampler
来处理进程间的数据分片,使其每隔 num_processes
个批次产生一个批次(如果启用)。
DataLoaderShard 子类化了 DataLoader
并添加了以下功能:
- 它在每次新迭代时同步所有进程的适当随机数生成器,以确保任何随机化(如洗牌)在所有进程中都以完全相同的方式进行。
- 它在产生批次之前将批次放在正确的设备上(除非您已选择退出
device_placement=True
)。
DataLoaderDispatcher 子类与 DataLoaderShard 的不同之处在于,在迭代 DataLoader
时,数据全部从进程 0 开始,然后 才被分割并发送到每个进程,而不是在数据集层面进行。
随机数生成器的同步默认会同步:
- 对于 PyTorch >= 1.6,同步给定采样器(如 PyTorch
RandomSampler
)的generator
属性 - 在 PyTorch <=1.5.1 中同步主随机数生成器
您可以使用主 Accelerator 的 rng_types
参数选择要同步的随机数生成器。在 PyTorch >= 1.6 中,建议依赖本地 generator
以避免在所有进程的主随机数生成器中设置相同的种子。
同步主 torch(或 CUDA 或 XLA)随机数生成器将影响数据集中任何其他潜在的随机因素(如随机数据增强),因为所有进程将从 torch 随机模块获得相同的随机数(因此如果由 torch 控制,将应用相同的随机数据增强)。
自定义采样器、批处理采样器或可迭代数据集的随机化部分应使用本地 torch.Generator
对象(在 PyTorch >= 1.6 中),可参考传统的 RandomSampler
作为示例。
如果您安装了 torchdata>=0.8.0
,并且已将 use_stateful_dataloader=True
传入您的 DataLoaderConfiguration,这些类将直接从 StatefulDataLoader
继承,并维护一个 state_dict
。
有关内部的更多详细信息,请参阅内部页面。
< > 在 GitHub 上更新