与 JAX 结合使用
本文档简要介绍了如何将 datasets
与 JAX 结合使用,重点介绍如何从我们的数据集中获取 jax.Array
对象,以及如何使用它们来训练 JAX 模型。
要重现上述代码,需要 jax
和 jaxlib
,因此请确保您已安装它们,方法是 pip install datasets[jax]
。
数据集格式
默认情况下,数据集返回常规 Python 对象:整数、浮点数、字符串、列表等,并且字符串和二进制对象保持不变,因为 JAX 仅支持数字。
要改为获取 JAX 数组(类似 NumPy),您可以将数据集的格式设置为 jax
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': DeviceArray([1, 2], dtype=int32)}
>>> ds[:2]
{'data': DeviceArray([
[1, 2],
[3, 4]], dtype=int32)}
一个 Dataset 对象是 Arrow 表格的包装器,它允许快速读取数据集中的数组到 JAX 数组。
请注意,完全相同的过程适用于 DatasetDict
对象,因此当将 DatasetDict
的格式设置为 jax
时,其中的所有 Dataset
都将被格式化为 jax
。
>>> from datasets import DatasetDict
>>> data = {"train": {"data": [[1, 2], [3, 4]]}, "test": {"data": [[5, 6], [7, 8]]}}
>>> dds = DatasetDict.from_dict(data)
>>> dds = dds.with_format("jax")
>>> dds["train"][:2]
{'data': DeviceArray([
[1, 2],
[3, 4]], dtype=int32)}
您需要考虑的另一件事是,格式化不会在您实际访问数据之前应用。因此,如果您想从数据集中获取 JAX 数组,则需要先访问数据,否则格式将保持不变。
最后,要将数据加载到您选择的设备中,您可以指定 device
参数,但请注意 jaxlib.xla_extension.Device
不受支持,因为它既不能与 pickle
序列化,也不能与 dill
序列化,因此您需要使用其字符串标识符。
>>> import jax
>>> from datasets import Dataset
>>> data = [[1, 2], [3, 4]]
>>> ds = Dataset.from_dict({"data": data})
>>> device = str(jax.devices()[0]) # Not casting to `str` before passing it to `with_format` will raise a `ValueError`
>>> ds = ds.with_format("jax", device=device)
>>> ds[0]
{'data': DeviceArray([1, 2], dtype=int32)}
>>> ds[0]["data"].device()
TFRT_CPU_0
>>> assert ds[0]["data"].device() == jax.devices()[0]
True
请注意,如果未向 with_format
提供 device
参数,则它将使用默认设备,即 jax.devices()[0]
。
N维数组
如果您的数据集包含 N 维数组,您会发现默认情况下,如果形状固定,则它们被视为相同的张量。
>>> from datasets import Dataset
>>> data = [[[1, 2],[3, 4]], [[5, 6],[7, 8]]] # fixed shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
>>> from datasets import Dataset
>>> data = [[[1, 2],[3]], [[4, 5, 6],[7, 8]]] # varying shape
>>> ds = Dataset.from_dict({"data": data})
>>> ds = ds.with_format("jax")
>>> ds[0]
{'data': [Array([1, 2], dtype=int32), Array([3], dtype=int32)]}
但是,这种逻辑通常需要缓慢的形状比较和数据复制。为了避免这种情况,您必须显式使用 Array
特征类型并指定张量的形状。
>>> from datasets import Dataset, Features, Array2D
>>> data = [[[1, 2],[3, 4]],[[5, 6],[7, 8]]]
>>> features = Features({"data": Array2D(shape=(2, 2), dtype='int32')})
>>> ds = Dataset.from_dict({"data": data}, features=features)
>>> ds = ds.with_format("torch")
>>> ds[0]
{'data': Array([[1, 2],
[3, 4]], dtype=int32)}
>>> ds[:2]
{'data': Array([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]], dtype=int32)}
其他特征类型
ClassLabel 数据被正确转换为数组。
>>> from datasets import Dataset, Features, ClassLabel
>>> labels = [0, 0, 1]
>>> features = Features({"label": ClassLabel(names=["negative", "positive"])})
>>> ds = Dataset.from_dict({"label": labels}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[:3]
{'label': DeviceArray([0, 0, 1], dtype=int32)}
字符串和二进制对象保持不变,因为 JAX 只支持数字。
要使用 Image 特征类型,您需要安装 vision
扩展,方法是 pip install datasets[vision]
。
>>> from datasets import Dataset, Features, Image
>>> images = ["path/to/image.png"] * 10
>>> features = Features({"image": Image()})
>>> ds = Dataset.from_dict({"image": images}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[0]["image"].shape
(512, 512, 3)
>>> ds[0]
{'image': DeviceArray([[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]], dtype=uint8)}
>>> ds[:2]["image"].shape
(2, 512, 512, 3)
>>> ds[:2]
{'image': DeviceArray([[[[ 255, 255, 255],
[ 255, 255, 255],
...,
[ 255, 255, 255],
[ 255, 255, 255]]]], dtype=uint8)}
要使用 Audio 特征类型,您需要安装 audio
扩展,方法是 pip install datasets[audio]
。
>>> from datasets import Dataset, Features, Audio
>>> audio = ["path/to/audio.wav"] * 10
>>> features = Features({"audio": Audio()})
>>> ds = Dataset.from_dict({"audio": audio}, features=features)
>>> ds = ds.with_format("jax")
>>> ds[0]["audio"]["array"]
DeviceArray([-0.059021 , -0.03894043, -0.00735474, ..., 0.0133667 ,
0.01809692, 0.00268555], dtype=float32)
>>> ds[0]["audio"]["sampling_rate"]
DeviceArray(44100, dtype=int32, weak_type=True)
数据加载
JAX 没有任何内置的数据加载功能,因此您需要使用像 PyTorch 这样的库使用 DataLoader
加载数据,或者使用 TensorFlow 使用 tf.data.Dataset
加载数据。引用关于此主题的 JAX 文档:“JAX 专注于程序转换和加速支持的 NumPy,因此我们没有在 JAX 库中包含数据加载或处理。已经有许多很棒的数据加载器,所以让我们直接使用它们,而不是重新发明轮子。我们将获取 PyTorch 的数据加载器,并制作一个微小的垫片,使其与 NumPy 数组一起工作。”
这就是为什么 datasets
中的 JAX 格式化如此有用的原因,因为它允许您使用 HuggingFace Hub 中的任何模型与 JAX 一起使用,而无需担心数据加载部分。
使用 with_format('jax')
从数据集中获取 JAX 数组的最简单方法是使用 with_format('jax')
方法。假设我们想要在 HuggingFace Hub 上的 MNIST 数据集 (https://huggingface.co/datasets/mnist) 上训练神经网络。
>>> from datasets import load_dataset
>>> ds = load_dataset("mnist")
>>> ds = ds.with_format("jax")
>>> ds["train"][0]
{'image': DeviceArray([[ 0, 0, 0, ...],
[ 0, 0, 0, ...],
...,
[ 0, 0, 0, ...],
[ 0, 0, 0, ...]], dtype=uint8),
'label': DeviceArray(5, dtype=int32)}
设置格式后,我们可以使用 Dataset.iter()
方法分批将数据集馈送到 JAX 模型。
>>> for epoch in range(epochs):
... for batch in ds["train"].iter(batch_size=32):
... x, y = batch["image"], batch["label"]
... ...