与 Spark 协同使用
本文档简要介绍了如何将 🤗 数据集与 Spark 协同使用,重点介绍了如何将 Spark DataFrame 加载到 Dataset 对象中。
从那里,您可以快速访问任何元素,并将其用作数据加载器来训练模型。
从 Spark 加载
一个 Dataset 对象是 Arrow 表的包装器,它允许从数据集中快速读取数组到 PyTorch、TensorFlow 和 JAX 张量。Arrow 表从磁盘内存映射,这可以加载比可用内存更大的数据集。
你可以使用 Dataset.from_spark()
从 Spark DataFrame 获取 Dataset
>>> from datasets import Dataset
>>> df = spark.createDataFrame(
... data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
... columns=["id", "name"],
... )
>>> ds = Dataset.from_spark(df)
Spark 工作器将数据集写入磁盘上的缓存目录中,作为 Arrow 文件,并且 Dataset 从那里加载。
或者,你可以使用 IterableDataset.from_spark()
跳过物化,它返回一个 IterableDataset
>>> from datasets import IterableDataset
>>> df = spark.createDataFrame(
... data=[[1, "Elia"], [2, "Teo"], [3, "Fang"]],
... columns=["id", "name"],
... )
>>> ds = IterableDataset.from_spark(df)
>>> print(next(iter(ds)))
{"id": 1, "name": "Elia"}
缓存
使用 Dataset.from_spark()
时,生成的 Dataset 被缓存;如果你对同一个 DataFrame 多次调用 Dataset.from_spark()
,它不会重新运行将数据集作为 Arrow 文件写入磁盘的 Spark 作业。
你可以通过将 cache_dir=
传递给 Dataset.from_spark()
来设置缓存位置。确保使用你的工作器和当前机器(驱动程序)都可以访问的磁盘。
在不同的会话中,Spark DataFrame 没有相同的 语义哈希,它将重新运行一个 Spark 作业并将其存储在新的缓存中。
特征类型
如果你的数据集由图像、音频数据或 N 维数组组成,你可以在 Dataset.from_spark()
(或 IterableDataset.from_spark()
)中指定 features=
参数
>>> from datasets import Dataset, Features, Image, Value
>>> data = [(0, open("image.png", "rb").read())]
>>> df = spark.createDataFrame(data, "idx: int, image: binary")
>>> # Also works if you have arrays
>>> # data = [(0, np.zeros(shape=(32, 32, 3), dtype=np.int32).tolist())]
>>> # df = spark.createDataFrame(data, "idx: int, image: array<array<array<int>>>")
>>> features = Features({"idx": Value("int64"), "image": Image()})
>>> dataset = Dataset.from_spark(df, features=features)
>>> dataset[0]
{'idx': 0, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=32x32>}
你可以查看 Features 文档,了解所有可用的特征类型。
< > 在 GitHub 上更新