Datasets 文档

批量映射

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

批量映射

结合 Dataset.map() 的实用性和批量模式,非常强大。它允许你加速处理,并自由控制生成数据集的大小。

追求速度

批量映射的主要目标是加快处理速度。通常,处理数据批比处理单个示例要快。自然,批量映射非常适合分词。例如,🤗 Tokenizers 库通过批处理可以更快地工作,因为它并行处理批次中所有示例的分词。

输入大小 != 输出大小

控制生成数据集大小的能力可以用于许多有趣的用例。在“如何操作”的 map 部分,有使用批量映射的示例,用于

  • 将长句子分割成更短的块。
  • 使用其他 token 增强数据集。

理解其工作原理很有帮助,这样你就可以想出自己的批量映射使用方法。此时,你可能想知道如何控制生成数据集的大小。答案是:**映射函数不必返回与输入批次相同大小的输出**。

换句话说,你的映射函数输入可以是大小为 N 的批次,并返回大小为 M 的批次。输出 M 可以大于或小于 N。这意味着你可以连接示例,分割示例,甚至添加更多示例!

但是,请记住,输出字典中的所有值必须包含与输出字典中其他字段**相同数量的元素**。否则,无法定义映射函数返回的输出中的示例数量。对于映射函数处理的连续批次,数量可能有所不同。但是,对于单个批次,输出字典的所有值应具有相同的长度(即元素数量)。

例如,对于一个有 1 列和 3 行的数据集,如果你使用 map 返回一个有两倍行数的新列,那么你将得到一个错误。在这种情况下,你最终得到一列 3 行,另一列 6 行。如你所见,该表将无效。

>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset.map(lambda batch: {"b": batch["a"] * 2}, batched=True)  # new column with 6 elements: [0, 1, 2, 0, 1, 2]
'ArrowInvalid: Column 1 named b expected length 3 but got length 6'

为了使其有效,你必须删除其中一列。

>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}, remove_columns=["a"], batched=True)
>>> len(dataset_with_duplicates)
6

或者,你可以覆盖现有列以达到相同的结果。例如,以下是如何通过覆盖列 "a" 来复制数据集中每个行的方法。

>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
# overwrites the existing "a" column with duplicated values
>>> duplicated_dataset = dataset.map(
...     lambda batch: {"a": [x for x in batch["a"] for _ in range(2)]},
...     batched=True
... )
>>> duplicated_dataset
Dataset({
    features: ['a'],
    num_rows: 6
})
>>> duplicated_dataset["a"]
[0, 0, 1, 1, 2, 2]
在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.