在 PyTorch / XLA TPU 上运行 Hugging Face:更快更便宜的训练

发布于 2021 年 2 月 9 日
在 GitHub 上更新

Open In Colab

使用 PyTorch / XLA 在云 TPU 上训练你最喜欢的 Transformers 模型

PyTorch-TPU 项目最初是 Facebook PyTorch 和 Google TPU 团队的合作项目,并于 2019 年 PyTorch 开发者大会上正式启动。从那时起,我们与 Hugging Face 团队合作,为使用 PyTorch / XLA 在云 TPU 上进行训练提供了一流的支持。这项新的集成使得 PyTorch 用户能够在云 TPU 上运行和扩展他们的模型,同时保持与 Hugging Face 训练器完全相同的接口。

这篇博文概述了 Hugging Face 库中所做的更改,PyTorch / XLA 库的功能,一个让你开始在云 TPU 上训练你最喜欢的 transformers 的例子,以及一些性能基准。如果你迫不及待地想开始使用 TPU,请直接跳到“在云 TPU 上训练你的 Transformer”部分——我们在 Trainer 模块中为你处理了所有 PyTorch / XLA 的机制!

XLA:TPU 设备类型

PyTorch / XLA 为 PyTorch 添加了一种新的 xla 设备类型。这种设备类型的工作方式与其他 PyTorch 设备类型一样。例如,以下是如何创建和打印一个 XLA 张量

import torch
import torch_xla
import torch_xla.core.xla_model as xm

t = torch.randn(2, 2, device=xm.xla_device())
print(t.device)
print(t)

这段代码应该看起来很熟悉。PyTorch / XLA 使用与常规 PyTorch 相同的接口,并增加了一些内容。导入 torch_xla 会初始化 PyTorch / XLA,而 xm.xla_device() 会返回当前的 XLA 设备。根据你的环境,这可能是 CPU、GPU 或 TPU,但在本文中,我们将主要关注 TPU。

Trainer 模块利用一个 TrainingArguments 数据类来定义训练的具体细节。它处理多个参数,从批量大小、学习率、梯度累积等,到所使用的设备。基于以上内容,在 TrainingArguments._setup_devices() 中使用 XLA:TPU 设备时,我们只需返回要由 Trainer 使用的 TPU 设备即可。

@dataclass
class TrainingArguments:
    ...
    @cached_property
    @torch_required
    def _setup_devices(self) -> Tuple["torch.device", int]:
        ...
        elif is_torch_tpu_available():
            device = xm.xla_device()
            n_gpu = 0
        ...

        return device, n_gpu

XLA 设备上的单步计算

在典型的 XLA:TPU 训练场景中,我们在多个 TPU 核心上并行训练(一个云 TPU 设备包含 8 个 TPU 核心)。因此,我们需要确保通过合并梯度和执行优化器步骤,在数据并行副本之间交换所有梯度。为此,我们提供了 xm.optimizer_step(optimizer),它负责梯度合并和步进。在 Hugging Face 训练器中,我们相应地更新了训练步骤以使用 PyTorch / XLA API

class Trainer:
…
   def train(self, *args, **kwargs):
       ...
                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)

PyTorch / XLA 输入管道

运行 PyTorch / XLA 模型主要有两个部分:(1)惰性地追踪和执行模型的计算图(更深入的解释请参考下面的 “PyTorch / XLA 库” 部分)和(2)为你的模型提供数据。如果没有任何优化,模型的追踪/执行和输入供给将串行执行,导致主机 CPU 和 TPU 加速器分别出现空闲时间段。为了避免这种情况,我们提供了一个 API,将这两者流水线化,从而能够在第 n 步仍在执行时,重叠进行第 n+1 步的追踪。

alt text

import torch_xla.distributed.parallel_loader as pl
...
  dataloader = pl.MpDeviceLoader(dataloader, device)

检查点的写入和加载

当一个张量从 XLA 设备保存为检查点,然后从检查点加载回来时,它将被加载回原始设备。在为模型中的张量创建检查点之前,你需要确保所有的张量都在 CPU 设备上,而不是在 XLA 设备上。这样,当你加载回张量时,你会通过 CPU 设备加载它们,然后有机会将它们放置到你希望的任何 XLA 设备上。我们为此提供了 xm.save() API,它已经处理了只从每个主机上的一个进程(如果使用跨主机的共享文件系统,则全局只有一个)写入存储位置的问题。

class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
…
    def save_pretrained(self, save_directory):
        ...
        if getattr(self.config, "xla_device", False):
            import torch_xla.core.xla_model as xm

            if xm.is_master_ordinal():
                # Save configuration file
                model_to_save.config.save_pretrained(save_directory)
            # xm.save takes care of saving only from master
            xm.save(state_dict, output_model_file)
class Trainer:
…
   def train(self, *args, **kwargs):
       ...
       if is_torch_tpu_available():
           xm.rendezvous("saving_optimizer_states")
           xm.save(self.optimizer.state_dict(),
                   os.path.join(output_dir, "optimizer.pt"))
           xm.save(self.lr_scheduler.state_dict(),
                   os.path.join(output_dir, "scheduler.pt"))

PyTorch / XLA 库

PyTorch / XLA 是一个 Python 包,它使用 XLA 线性代数编译器将 PyTorch 深度学习框架与 XLA 设备(包括 CPU、GPU 和云 TPU)连接起来。以下部分内容也见于我们的 API_GUIDE.md

PyTorch / XLA 张量是惰性的

使用 XLA 张量和设备只需要更改几行代码。然而,尽管 XLA 张量的行为很像 CPU 和 CUDA 张量,但它们的内部机制是不同的。CPU 和 CUDA 张量会立即或“急切地”启动操作。而 XLA 张量则是“惰性”的。它们会将操作记录在一个图中,直到需要结果时才执行。像这样延迟执行可以让 XLA 对其进行优化。一个由多个独立操作组成的图可能会被融合成一个单一的优化操作。

惰性执行对调用者来说通常是不可见的。PyTorch / XLA 会自动构建图,将它们发送到 XLA 设备,并在 XLA 设备和 CPU 之间复制数据时进行同步。在执行优化器步骤时插入一个屏障会显式地同步 CPU 和 XLA 设备。

这意味着当你调用 model(input) 进行前向传播,计算损失 loss.backward(),并执行优化步骤 xm.optimizer_step(optimizer) 时,所有操作的图都在后台构建。只有当你显式地评估张量(例如打印张量或将其移动到 CPU 设备)或标记一个步骤(每次迭代 MpDeviceLoader 时都会这样做)时,完整的步骤才会被执行。

追踪、编译、执行,并重复

从用户的角度来看,在 PyTorch / XLA 上运行模型的典型训练方案包括运行前向传播、后向传播和优化器步骤。从 PyTorch / XLA 库的角度来看,情况略有不同。

当用户运行他们的前向和后向传播时,一个中间表示(IR)图会动态地被追踪。通向每个根/输出张量的 IR 图可以如下检查

>>> import torch
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>> t = torch.tensor(1, device=xm.xla_device())
>>> s = t*t
>>> print(torch_xla._XLAC._get_xla_tensors_text([s]))
IR {
  %0 = s64[] prim::Constant(), value=1
  %1 = s64[] prim::Constant(), value=0
  %2 = s64[] xla::as_strided_view_update(%1, %0), size=(), stride=(), storage_offset=0
  %3 = s64[] aten::as_strided(%2), size=(), stride=(), storage_offset=0
  %4 = s64[] aten::mul(%3, %3), ROOT=0
}

当用户程序运行前向和后向传播时,这个实时图会不断累积,一旦调用 xm.mark_step()(由 pl.MpDeviceLoader 间接调用),实时张量的图就会被切断。这种截断标志着一个步骤的完成,随后我们将 IR 图降级为 XLA 高级操作(HLO),这是 XLA 的 IR 语言。

然后,这个 HLO 图被编译成 TPU 二进制文件,并随后在 TPU 设备上执行。然而,这个编译步骤可能成本很高,通常比单个步骤耗时更长,所以如果我们每一步都编译用户的程序,开销会很大。为了避免这种情况,我们有缓存来存储已编译的 TPU 二进制文件,这些二进制文件以其 HLO 图的唯一哈希标识符为键。因此,一旦这个 TPU 二进制缓存 populated 在第一步被填充,后续的步骤通常不必重新编译新的 TPU 二进制文件;相反,它们可以简单地从缓存中查找必要的二进制文件。

由于 TPU 编译通常比步骤执行时间慢得多,这意味着如果图的形状不断变化,我们将会出现缓存未命中并过于频繁地编译。为了最小化编译成本,我们建议尽可能保持张量形状的静态。Hugging Face 库的形状大部分已经是静态的,输入标记会进行适当的填充,因此在整个训练过程中,缓存应该会持续命中。这可以使用 PyTorch / XLA 提供的调试工具来检查。在下面的例子中,你可以看到编译只发生了 5 次(CompileTime),而执行则在 1220 个步骤中的每一步都发生了(ExecuteTime)。

>>> import torch_xla.debug.metrics as met
>>> print(met.metrics_report())
Metric: CompileTime
  TotalSamples: 5
  Accumulator: 28s920ms153.731us
  ValueRate: 092ms152.037us / second
  Rate: 0.0165028 / second
  Percentiles: 1%=428ms053.505us; 5%=428ms053.505us; 10%=428ms053.505us; 20%=03s640ms888.060us; 50%=03s650ms126.150us; 80%=11s110ms545.595us; 90%=11s110ms545.595us; 95%=11s110ms545.595us; 99%=11s110ms545.595us
Metric: DeviceLockWait
  TotalSamples: 1281
  Accumulator: 38s195ms476.007us
  ValueRate: 151ms051.277us / second
  Rate: 4.54374 / second
  Percentiles: 1%=002.895us; 5%=002.989us; 10%=003.094us; 20%=003.243us; 50%=003.654us; 80%=038ms978.659us; 90%=192ms495.718us; 95%=208ms893.403us; 99%=221ms394.520us
Metric: ExecuteTime
  TotalSamples: 1220
  Accumulator: 04m22s555ms668.071us
  ValueRate: 923ms872.877us / second
  Rate: 4.33049 / second
  Percentiles: 1%=045ms041.018us; 5%=213ms379.757us; 10%=215ms434.912us; 20%=217ms036.764us; 50%=219ms206.894us; 80%=222ms335.146us; 90%=227ms592.924us; 95%=231ms814.500us; 99%=239ms691.472us
Counter: CachedCompile
  Value: 1215
Counter: CreateCompileHandles
  Value: 5
...

在云 TPU 上训练你的 Transformer

要配置您的 VM 和云 TPU,请遵循 “设置计算引擎实例”“启动云 TPU 资源”(撰写本文时为 pytorch-1.7 版本)部分。一旦您创建了 VM 和云 TPU,使用它们就像通过 SSH 连接到您的 GCE VM 并运行以下命令来启动 bert-large-uncased 训练一样简单(批量大小适用于 v3-8 设备,在 v2-8 上可能会内存溢出)

conda activate torch-xla-1.7
export TPU_IP_ADDRESS="ENTER_YOUR_TPU_IP_ADDRESS"  # ex. 10.0.0.2
export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"
git clone -b v4.2.2 https://github.com/huggingface/transformers.git
cd transformers && pip install .
pip install datasets==1.2.1
python examples/xla_spawn.py \
  --num_cores 8 \
  examples/language-modeling/run_mlm.py \
  --dataset_name wikitext \
  --dataset_config_name wikitext-103-raw-v1 \
  --max_seq_length 512 \
  --pad_to_max_length \
  --logging_dir ./tensorboard-metrics \
  --cache_dir ./cache_dir \
  --do_train \
  --do_eval \
  --overwrite_output_dir \
  --output_dir language-modeling \
  --overwrite_cache \
  --tpu_metrics_debug \
  --model_name_or_path bert-large-uncased \
  --num_train_epochs 3 \
  --per_device_train_batch_size 8 \
  --per_device_eval_batch_size 8 \
  --save_steps 500000

上述训练应在大约不到 200 分钟内完成,评估困惑度约为 3.25。

性能基准测试

下表显示了在运行 PyTorch / XLA 的 v3-8 云 TPU 系统(包含 4 个 TPU v3 芯片)上训练 bert-large-uncased 的性能。所有基准测试测量使用的数据集是 WikiText103 数据集,我们使用 Hugging Face 示例中提供的 run_mlm.py 脚本。为确保工作负载不受主机 CPU 限制,我们在这些测试中使用了 n1-standard-96 CPU 配置,但您也可以使用较小的配置而不会影响性能。

名称 数据集 硬件 全局批次大小 精度 训练时间(分钟)
bert-large-uncased WikiText103 4 个 TPUv3 芯片 (即 v3-8) 64 FP32 178.4
bert-large-uncased WikiText103 4 个 TPUv3 芯片 (即 v3-8) 128 BF16 106.4

开始在 TPU 上使用 PyTorch / XLA

请参阅 Hugging Face 示例下的 “在 TPU 上运行” 部分以开始使用。有关我们 API 的更详细描述,请查看我们的 API 指南,有关性能最佳实践,请参阅我们的 故障排除指南。对于通用的 PyTorch / XLA 示例,请运行我们提供的 Colab 笔记本,它们提供免费的云 TPU 访问。要直接在 GCP 上运行,请参阅我们文档网站上标记为“PyTorch”的教程。

还有其他问题吗?请在 https://github.com/huggingface/transformers/issues 或直接在 https://github.com/pytorch/xla/issues 上提出问题。

社区

注册登录 发表评论