Datasets 文档
批量映射
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
批量映射
结合 Dataset.map() 的实用性和批量模式,非常强大。它允许你加速处理,并自由控制生成数据集的大小。
追求速度
批量映射的主要目标是加快处理速度。通常,处理数据批比处理单个示例要快。自然,批量映射非常适合分词。例如,🤗 Tokenizers 库通过批处理可以更快地工作,因为它并行处理批次中所有示例的分词。
输入大小 != 输出大小
控制生成数据集大小的能力可以用于许多有趣的用例。在“如何操作”的 map 部分,有使用批量映射的示例,用于
- 将长句子分割成更短的块。
- 使用其他 token 增强数据集。
理解其工作原理很有帮助,这样你就可以想出自己的批量映射使用方法。此时,你可能想知道如何控制生成数据集的大小。答案是:**映射函数不必返回与输入批次相同大小的输出**。
换句话说,你的映射函数输入可以是大小为 N 的批次,并返回大小为 M 的批次。输出 M 可以大于或小于 N。这意味着你可以连接示例,分割示例,甚至添加更多示例!
但是,请记住,输出字典中的所有值必须包含与输出字典中其他字段**相同数量的元素**。否则,无法定义映射函数返回的输出中的示例数量。对于映射函数处理的连续批次,数量可能有所不同。但是,对于单个批次,输出字典的所有值应具有相同的长度(即元素数量)。
例如,对于一个有 1 列和 3 行的数据集,如果你使用 map 返回一个有两倍行数的新列,那么你将得到一个错误。在这种情况下,你最终得到一列 3 行,另一列 6 行。如你所见,该表将无效。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset.map(lambda batch: {"b": batch["a"] * 2}, batched=True) # new column with 6 elements: [0, 1, 2, 0, 1, 2]
'ArrowInvalid: Column 1 named b expected length 3 but got length 6'为了使其有效,你必须删除其中一列。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}, remove_columns=["a"], batched=True)
>>> len(dataset_with_duplicates)
6或者,你可以覆盖现有列以达到相同的结果。例如,以下是如何通过覆盖列 "a" 来复制数据集中每个行的方法。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
# overwrites the existing "a" column with duplicated values
>>> duplicated_dataset = dataset.map(
... lambda batch: {"a": [x for x in batch["a"] for _ in range(2)]},
... batched=True
... )
>>> duplicated_dataset
Dataset({
features: ['a'],
num_rows: 6
})
>>> duplicated_dataset["a"]
[0, 0, 1, 1, 2, 2]