Accelerate 文档

Accelerate 的内部机制

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Accelerate 的内部机制

在内部,Accelerate 首先分析脚本启动的环境,以确定使用了哪种分布式设置、有多少个不同的进程以及当前脚本所在的进程。所有这些信息都存储在 ~AcceleratorState 中。

这个类在您第一次实例化 ~Accelerator 时被初始化,并执行分布式设置所需的任何特定初始化。然后,它的状态在所有 AcceleratorState 实例中唯一共享。(同样的操作也可以通过 PartialState 完成,它是一个更精简的版本,并且继承自 AcceleratorState

然后,当调用 prepare() 时,该库会

虽然模型、优化器和调度器只是被放入简单的包装器中,但数据加载器是重新创建的。这主要是因为 PyTorch 不允许用户在数据加载器创建后更改其 batch_sampler,而该库通过更改 batch_sampler 来处理进程间的数据分片,使其每隔 num_processes 个批次产生一个批次(如果启用)。

DataLoaderShard 子类化了 DataLoader 并添加了以下功能:

  • 它在每次新迭代时同步所有进程的适当随机数生成器,以确保任何随机化(如洗牌)在所有进程中都以完全相同的方式进行。
  • 它在产生批次之前将批次放在正确的设备上(除非您已选择退出 device_placement=True)。

DataLoaderDispatcher 子类与 DataLoaderShard 的不同之处在于,在迭代 DataLoader 时,数据全部从进程 0 开始,然后 才被分割并发送到每个进程,而不是在数据集层面进行。

随机数生成器的同步默认会同步:

  • 对于 PyTorch >= 1.6,同步给定采样器(如 PyTorch RandomSampler)的 generator 属性
  • 在 PyTorch <=1.5.1 中同步主随机数生成器

您可以使用主 Acceleratorrng_types 参数选择要同步的随机数生成器。在 PyTorch >= 1.6 中,建议依赖本地 generator 以避免在所有进程的主随机数生成器中设置相同的种子。

同步主 torch(或 CUDA 或 XLA)随机数生成器将影响数据集中任何其他潜在的随机因素(如随机数据增强),因为所有进程将从 torch 随机模块获得相同的随机数(因此如果由 torch 控制,将应用相同的随机数据增强)。

自定义采样器、批处理采样器或可迭代数据集的随机化部分应使用本地 torch.Generator 对象(在 PyTorch >= 1.6 中),可参考传统的 RandomSampler 作为示例。

如果您安装了 torchdata>=0.8.0,并且已将 use_stateful_dataloader=True 传入您的 DataLoaderConfiguration,这些类将直接从 StatefulDataLoader 继承,并维护一个 state_dict

有关内部的更多详细信息,请参阅内部页面

< > 在 GitHub 上更新