数据集文档

与 PyArrow 一起使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

与 PyArrow 一起使用

本文档是对如何将 datasets 与 PyArrow 一起使用的快速介绍,特别关注如何使用 Arrow 计算函数处理数据集,以及如何将数据集转换为 PyArrow 或从 PyArrow 转换数据集。

这非常有用,因为它允许快速的零拷贝操作,因为 datasets 在底层使用了 PyArrow。

数据集格式

默认情况下,datasets 返回常规 Python 对象:整数、浮点数、字符串、列表等。

要获取 PyArrow 表格或数组,你可以使用 Dataset.with_format() 将数据集的格式设置为 pyarrow

>>> from datasets import Dataset
>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
>>> ds = Dataset.from_dict(data)
>>> ds = ds.with_format("arrow")
>>> ds[0]       # pa.Table
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a"]]
col_1: [[0]]
>>> ds[:2]      # pa.Table
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a","b"]]
col_1: [[0,0]]
>>> ds["data"]  # pa.array
<pyarrow.lib.ChunkedArray object at 0x1394312a0>
[
  [
    "a",
    "b",
    "c",
    "d"
  ]
]

这也适用于通过例如 load_dataset(..., streaming=True) 获取的 IterableDataset 对象

>>> ds = ds.with_format("arrow")
>>> for table in ds.iter(batch_size=2):
...     print(table)
...     break
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a","b"]]
col_1: [[0,0]]

处理数据

PyArrow 函数通常比手写的 python 函数更快,因此它们是优化数据处理的好选择。你可以使用 Arrow 计算函数在 Dataset.map()Dataset.filter() 中处理数据集

>>> import pyarrow.compute as pc
>>> from datasets import Dataset
>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
>>> ds = Dataset.from_dict(data)
>>> ds = ds.with_format("arrow")
>>> ds = ds.map(lambda t: t.append_column("col_2", pc.add(t["col_1"], 1)), batched=True)
>>> ds[:2]
pyarrow.Table
col_0: string
col_1: double
col_2: double
----
col_0: [["a","b"]]
col_1: [[0,0]]
col_2: [[1,1]]
>>> ds = ds.filter(lambda t: pc.equal(t["col_0"], "b"), batched=True)
>>> ds[0]
pyarrow.Table
col_0: string
col_1: double
col_2: double
----
col_0: [["b"]]
col_1: [[0]]
col_2: [[1]]

我们使用 batched=True 是因为在 PyArrow 中批量处理数据比逐行处理更快。也可以在 map() 中使用 batch_size= 来设置每个 table 的大小。

这也适用于 IterableDataset.map()IterableDataset.filter()

从 PyArrow 导入或导出

Dataset 是 PyArrow Table 的包装器,你可以直接从 Table 实例化 Dataset

ds = Dataset(table)

你可以使用 Dataset.data 访问数据集的 PyArrow Table,它返回 MemoryMappedTableInMemoryTableConcatenationTable,具体取决于 Arrow 数据的来源和已应用的操作。

这些对象包装了底层 PyArrow 表,可以通过 Dataset.data.table 访问。此表包含数据集的所有数据,但也可能在 Dataset._indices 中存在索引映射,它将数据集行索引映射到 PyArrow Table 行索引。如果数据集已使用 Dataset.shuffle() 进行了混洗,或者仅使用了行的子集(例如,在 Dataset.select() 之后),则可能会发生这种情况。

在一般情况下,你可以使用 table = ds.with_format("arrow")[:] 将数据集导出到 PyArrow Table。

< > 在 GitHub 上更新