Datasets 文档

与 Spark 共同使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

与 Spark 共同使用

本文档简要介绍了如何将 🤗 Datasets 与 Spark 结合使用,特别关注如何将 Spark DataFrame 加载到 Dataset 对象中。

之后,您就可以快速访问任何元素,并将其用作数据加载器来训练模型。

从 Spark 加载

Dataset 对象是 Arrow 表的包装器,它允许从数据集中的数组快速读取数据到 PyTorch、TensorFlow 和 JAX 的张量中。Arrow 表是从磁盘进行内存映射的,这可以加载比您可用 RAM 更大的数据集。

您可以使用 Dataset.from_spark() 从 Spark DataFrame 获取一个 Dataset 对象。

>>> from datasets import Dataset
>>> df = spark.createDataFrame(
...     data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
...     columns=["id", "name"],
... )
>>> ds = Dataset.from_spark(df)

Spark 工作节点将数据集以 Arrow 文件的形式写入磁盘上的缓存目录中,然后 Dataset 从那里加载。

或者,您可以通过使用 IterableDataset.from_spark() 来跳过物化过程,该方法会返回一个 IterableDataset

>>> from datasets import IterableDataset
>>> df = spark.createDataFrame(
...     data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
...     columns=["id", "name"],
... )
>>> ds = IterableDataset.from_spark(df)
>>> print(next(iter(ds)))
{"id": 1, "name": "Elia"}

缓存

当使用 Dataset.from_spark() 时,生成的 Dataset 会被缓存;如果您在同一个 DataFrame 上多次调用 Dataset.from_spark(),它不会重新运行将数据集写入磁盘为 Arrow 文件的 Spark 作业。

您可以通过向 Dataset.from_spark() 传递 cache_dir= 来设置缓存位置。请确保使用对您的工作节点和当前机器(驱动程序)都可用的磁盘。

在不同的会话中,Spark DataFrame 不具有相同的语义哈希值,它会重新运行一个 Spark 作业并将其存储在一个新的缓存中。

特征类型

如果您的数据集由图像、音频数据或 N 维数组组成,您可以在 Dataset.from_spark() (或 IterableDataset.from_spark()) 中指定 features= 参数。

>>> from datasets import Dataset, Features, Image, Value
>>> data = [(0, open("image.png", "rb").read())]
>>> df = spark.createDataFrame(data, "idx: int, image: binary")
>>> # Also works if you have arrays
>>> # data = [(0, np.zeros(shape=(32, 32, 3), dtype=np.int32).tolist())]
>>> # df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
>>> features = Features({"idx": Value("int64"), "image": Image()})
>>> dataset = Dataset.from_spark(df, features=features)
>>> dataset[0]
{'idx': 0, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>}

您可以查阅 Features 文档来了解所有可用的特征类型。

< > 在 GitHub 上更新