Datasets 文档

与 PyArrow 一起使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

与 PyArrow 一起使用

本文档简要介绍了如何将 datasets 与 PyArrow 结合使用,重点关注如何使用 Arrow 计算函数处理数据集,以及如何将数据集转换为 PyArrow 或从 PyArrow 转换。

这特别有用,因为它允许快速的零拷贝操作,因为 datasets 在底层使用了 PyArrow。

数据集格式

默认情况下,数据集返回常规的Python对象:整数、浮点数、字符串、列表等。

要获取 PyArrow Tables 或 Arrays,您可以使用 Dataset.with_format() 将数据集的格式设置为 pyarrow

>>> from datasets import Dataset
>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
>>> ds = Dataset.from_dict(data)
>>> ds = ds.with_format("arrow")
>>> ds[0]       # pa.Table
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a"]]
col_1: [[0]]
>>> ds[:2]      # pa.Table
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a","b"]]
col_1: [[0,0]]
>>> ds["data"]  # pa.array
<pyarrow.lib.ChunkedArray object at 0x1394312a0>
[
  [
    "a",
    "b",
    "c",
    "d"
  ]
]

这也适用于例如使用 `load_dataset(..., streaming=True)` 获取的 `IterableDataset` 对象。

>>> ds = ds.with_format("arrow")
>>> for table in ds.iter(batch_size=2):
...     print(table)
...     break
pyarrow.Table
col_0: string
col_1: double
----
col_0: [["a","b"]]
col_1: [[0,0]]

处理数据

PyArrow 函数通常比手写的 Python 函数快,因此它们是优化数据处理的好选择。您可以使用 Arrow 计算函数在 Dataset.map()Dataset.filter() 中处理数据集。

>>> import pyarrow.compute as pc
>>> from datasets import Dataset
>>> data = {"col_0": ["a", "b", "c", "d"], "col_1": [0., 0., 1., 1.]}
>>> ds = Dataset.from_dict(data)
>>> ds = ds.with_format("arrow")
>>> ds = ds.map(lambda t: t.append_column("col_2", pc.add(t["col_1"], 1)), batched=True)
>>> ds[:2]
pyarrow.Table
col_0: string
col_1: double
col_2: double
----
col_0: [["a","b"]]
col_1: [[0,0]]
col_2: [[1,1]]
>>> ds = ds.filter(lambda t: pc.equal(t["col_0"], "b"), batched=True)
>>> ds[0]
pyarrow.Table
col_0: string
col_1: double
col_2: double
----
col_0: [["b"]]
col_1: [[0]]
col_2: [[1]]

我们使用 batched=True,因为在 PyArrow 中处理数据批次比逐行处理更快。您也可以在 map() 中使用 batch_size= 来设置每个 table 的大小。

这也适用于 IterableDataset.map()IterableDataset.filter()

从 PyArrow 导入或导出

Dataset 是 PyArrow Table 的包装器,您可以直接从 Table 实例化 Dataset。

ds = Dataset(table)

您可以使用 Dataset.data 访问数据集的 PyArrow Table,它会返回一个 MemoryMappedTableInMemoryTableConcatenationTable,具体取决于 Arrow 数据的来源以及执行的操作。

这些对象包装了可在 Dataset.data.table 访问的底层 PyArrow 表。此表包含数据集的所有数据,但可能还存在一个在 Dataset._indices 的索引映射,它将数据集行的索引映射到 PyArrow 表的行索引。如果数据集已使用 Dataset.shuffle() 进行了 shuffle,或者只使用了部分行(例如,在 Dataset.select() 之后),则可能会发生这种情况。

在一般情况下,您可以使用 table = ds.with_format("arrow")[:] 将数据集导出到 PyArrow Table。

在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.