使用 TensorFlow 在 TPU 上训练
如果您不需要详细解释,只想获取一些 TPU 代码示例来入门,请查看 我们的 TPU 示例笔记本!
什么是 TPU?
TPU 是 **张量处理单元 (Tensor Processing Unit)**。它们是谷歌设计的硬件,用于极大地加速神经网络中的张量计算,就像 GPU 一样。它们可用于网络训练和推理。它们通常通过谷歌的云服务访问,但也可以通过 Google Colab 和 Kaggle Kernels 免费直接访问小型 TPU。
因为 🤗 Transformers 中的所有 TensorFlow 模型都是 Keras 模型,所以本文档中的大多数方法通常适用于任何 Keras 模型的 TPU 训练!但是,有一些点是特定于 🤗 Transformers 和 Datasets 生态系统(hug-o-system?)的,当我们遇到它们时,我们会确保将其标记出来。
有哪些类型的 TPU 可用?
新用户常常对 TPU 的种类以及访问它们的不同方式感到困惑。首先要理解的关键区别是 **TPU 节点** 和 **TPU VM** 之间的区别。
当您使用 **TPU 节点** 时,您实际上是在间接访问远程 TPU。您需要一个单独的 VM,它将初始化您的网络和数据管道,然后将其转发到远程节点。当您在 Google Colab 上使用 TPU 时,您是在 **TPU 节点** 样式下访问它。
使用 TPU 节点可能会对不习惯它的人产生一些意想不到的行为!特别是,由于 TPU 位于与您运行 Python 代码的机器不同的物理系统上,因此您的数据无法位于您的机器本地 - 任何从机器内部存储加载数据的管道都将完全失败!相反,数据必须存储在 Google Cloud Storage 中,即使管道在远程 TPU 节点上运行,您的数据管道也可以访问它。
如果您可以在内存中将所有数据放入 np.ndarray
或 tf.Tensor
中,那么即使使用 Colab 或 TPU 节点,您也可以对该数据进行 fit()
,而无需将其上传到 Google Cloud Storage。
**🤗Hugging Face 特别提示🤗:**您将在我们的 TF 代码示例中看到的 Dataset.to_tf_dataset()
方法及其更高级别的包装器 model.prepare_tf_dataset()
都将在 TPU 节点上失败。其原因是,即使它们创建了一个 tf.data.Dataset
,它也不是一个“纯” tf.data
管道,并且使用 tf.numpy_function
或 Dataset.from_generator()
从底层的 HuggingFace Dataset
流式传输数据。此 HuggingFace Dataset
由位于本地磁盘上的数据支持,远程 TPU 节点将无法读取这些数据。
访问 TPU 的第二种方式是通过 **TPU VM**。使用 TPU VM 时,您可以直接连接到 TPU 所连接的机器,就像在 GPU VM 上进行训练一样。TPU VM 通常更容易使用,尤其是在处理数据管道时。以上所有警告都不适用于 TPU VM!
这是一篇带有个人观点的文档,因此以下是我们的观点:**如果可能,请避免使用 TPU 节点。**与 TPU VM 相比,它更令人困惑,也更难调试。它也可能在将来不受支持 - 谷歌最新的 TPU,TPUv4,只能作为 TPU VM 访问,这表明 TPU 节点将越来越成为一种“遗留”访问方法。但是,我们了解到,唯一的免费 TPU 访问是在 Colab 和 Kaggle Kernels 上,它们使用 TPU 节点 - 因此,如果必须使用,我们将尝试解释如何处理!请查看 TPU 示例笔记本,获取更详细的代码示例。
有哪些尺寸的 TPU 可用?
单个 TPU (v2-8/v3-8/v4-8) 运行 8 个副本。TPU 存在于可以同时运行数百或数千个副本的 **集群 (pods)** 中。当您使用多个 TPU 但少于整个集群(例如,v3-32)时,您的 TPU 集群被称为 **集群切片 (pod slice)**。
当您通过 Colab 访问免费 TPU 时,通常会获得单个 v2-8 TPU。
我总是听到关于 XLA 的事情。什么是 XLA,它与 TPU 有什么关系?
XLA 是一种优化编译器,TensorFlow 和 JAX 都使用它。在 JAX 中,它是唯一的编译器,而在 TensorFlow 中,它是可选的(但在 TPU 上是强制性的!)。在训练 Keras 模型时启用它的最简单方法是将参数 jit_compile=True
传递给 model.compile()
。如果您没有收到任何错误并且性能良好,则表示您已准备好迁移到 TPU!
在 TPU 上进行调试通常比在 CPU/GPU 上更难,因此我们建议您先在 CPU/GPU 上使用 XLA 运行代码,然后再尝试在 TPU 上运行。当然,您不必训练很长时间 - 只需运行几个步骤以确保您的模型和数据管道按预期工作即可。
XLA 编译的代码通常更快 - 因此,即使您不打算在 TPU 上运行,添加 jit_compile=True
也可以提高您的性能。但是,请务必注意下面关于 XLA 兼容性的注意事项!
**源于痛苦经验的提示:**虽然使用 jit_compile=True
是获得速度提升并测试您的 CPU/GPU 代码是否与 XLA 兼容的好方法,但如果在 TPU 上实际训练时保留它,它实际上会导致很多问题。XLA 编译将在 TPU 上隐式发生,因此请记住在 TPU 上实际运行代码之前删除该行!
如何使我的模型与 XLA 兼容?
在许多情况下,您的代码可能已经与 XLA 兼容了!但是,有一些在普通 TensorFlow 中有效但在 XLA 中无效的操作。我们已将它们提炼成以下三个核心规则
**🤗Hugging Face 特别提示🤗:**我们付出了很多努力来重写我们的 TensorFlow 模型和损失函数,以使其与 XLA 兼容。我们的模型和损失函数通常默认遵循规则 #1 和 #2,因此如果您使用的是 transformers
模型,可以跳过它们。但是,在编写您自己的模型和损失函数时,请不要忘记这些规则!
XLA 规则 #1:您的代码不能包含“依赖于数据的条件语句”
这意味着任何if
语句都不能依赖于tf.Tensor
内部的值。例如,此代码块无法使用XLA编译!
if tf.reduce_sum(tensor) > 10:
tensor = tensor / 2.0
起初这可能看起来非常限制,但大多数神经网络代码不需要这样做。您可以通过使用tf.cond
(请参阅此处的文档)或删除条件并找到一个巧妙的数学技巧来使用指示变量来解决此限制,如下所示
sum_over_10 = tf.cast(tf.reduce_sum(tensor) > 10, tf.float32)
tensor = tensor / (1.0 + sum_over_10)
此代码与上面的代码具有完全相同的效果,但通过避免条件,我们确保它能够在没有问题的情况下使用XLA编译!
XLA规则#2:您的代码不能具有“数据相关的形状”
这意味着代码中所有tf.Tensor
对象的形状都不能依赖于它们的值。例如,函数tf.unique
无法使用XLA编译,因为它返回一个包含输入中每个唯一值的单个实例的tensor
。此输出的形状显然会根据输入Tensor
的重复程度而有所不同,因此XLA拒绝处理它!
通常,大多数神经网络代码默认情况下都遵循规则#2。但是,在某些常见情况下,它会成为一个问题。一个非常常见的情况是使用**标签掩码**,将标签设置为负值以指示在计算损失时应忽略这些位置。如果您查看支持标签掩码的NumPy或PyTorch损失函数,您通常会看到如下使用布尔索引的代码
label_mask = labels >= 0
masked_outputs = outputs[label_mask]
masked_labels = labels[label_mask]
loss = compute_loss(masked_outputs, masked_labels)
mean_loss = torch.mean(loss)
此代码在NumPy或PyTorch中完全没问题,但在XLA中会出错!为什么?因为masked_outputs
和masked_labels
的形状取决于掩码了多少个位置——这使其成为**数据相关的形状**。但是,就像规则#1一样,我们通常可以重写此代码以产生完全相同的输出,而没有任何数据相关的形状。
label_mask = tf.cast(labels >= 0, tf.float32)
loss = compute_loss(outputs, labels)
loss = loss * label_mask # Set negative label positions to 0
mean_loss = tf.reduce_sum(loss) / tf.reduce_sum(label_mask)
在这里,我们通过计算每个位置的损失来避免数据相关的形状,但在计算均值时,在分子和分母中都将掩码位置清零,这与第一个代码块产生完全相同的结果,同时保持XLA兼容性。请注意,我们使用了与规则#1相同的技巧——将tf.bool
转换为tf.float32
并将其用作指示变量。这是一个非常有用的技巧,因此如果您需要将自己的代码转换为XLA,请记住它!
XLA规则#3:XLA需要为它看到的每个不同的输入形状重新编译您的模型
这是最重要的一条。这意味着如果您的输入形状变化很大,XLA将不得不反复重新编译您的模型,这将导致巨大的性能问题。这通常出现在NLP模型中,其中输入文本在标记化后具有可变长度。在其他模态中,静态形状更常见,此规则的问题要小得多。
如何解决规则#3?关键是**填充**——如果将所有输入填充到相同的长度,然后使用attention_mask
,则可以获得与可变形状相同的结果,而不会出现任何XLA问题。但是,过度填充也会导致严重的减速——如果将所有样本填充到整个数据集中最大长度,您最终可能会得到包含无数填充标记的批次,这将浪费大量计算和内存!
这个问题没有完美的解决方案。但是,您可以尝试一些技巧。一个非常有用的技巧是**将样本批次填充到 32 或 64 个标记的倍数**。这通常只会增加少量标记,但它大大减少了唯一输入形状的数量,因为每个输入形状现在都必须是 32 或 64 的倍数。更少的唯一输入形状意味着更少的XLA编译!
🤗HuggingFace 特定提示🤗:我们的标记器和数据整理器具有可以在这里帮助您的方法。在调用标记器以使其输出填充数据时,您可以使用padding="max_length"
或padding="longest"
。我们的标记器和数据整理器还具有一个pad_to_multiple_of
参数,您可以使用它来减少看到的唯一输入形状的数量!
如何真正地在TPU上训练我的模型?
一旦您的训练与XLA兼容,并且(如果您使用TPU节点/Colab)您的数据集已适当准备,在TPU上运行就会变得出奇地容易!您实际上只需要在代码中更改几行以初始化TPU,并确保您的模型和数据集是在TPUStrategy
范围内创建的。请查看我们的TPU示例笔记本以了解其操作方式!
总结
这里有很多内容,因此让我们总结一个快速清单,您可以在想要使模型准备好进行TPU训练时遵循。
- 确保您的代码遵循XLA的三条规则
- 使用
jit_compile=True
在CPU/GPU上编译您的模型,并确认您可以使用XLA进行训练 - 将您的数据集加载到内存中或使用与TPU兼容的数据集加载方法(请参阅笔记本)
- 将您的代码迁移到Colab(将加速器设置为“TPU”)或Google Cloud上的TPU虚拟机
- 添加TPU初始化代码(请参阅笔记本)
- 创建您的
TPUStrategy
并确保数据集加载和模型创建在strategy.scope()
内(请参阅笔记本) - 不要忘记在迁移到TPU后再次删除
jit_compile=True
! - 🙏🙏🙏🥺🥺🥺
- 调用
model.fit()
- 您做到了!