批量映射
将 Dataset.map() 的实用程序与批处理模式相结合非常强大。它允许您加速处理,并自由控制生成的数据集的大小。
速度需求
批量映射的主要目标是加快处理速度。通常,处理数据批次比处理单个样本更快。自然地,批量映射适合于分词。例如,🤗 分词器 库在处理批次时速度更快,因为它并行化了批次中所有样本的分词操作。
输入大小 != 输出大小
能够控制生成数据集的大小可以被用于许多有趣的用例。在如何 映射 部分,有一些使用批量映射的示例,例如
- 将长句子拆分成更短的片段。
- 用额外的标记扩充数据集。
理解它是如何工作的很有帮助,这样你就可以想出自己使用批量映射的方法。现在,你可能想知道如何控制生成数据集的大小。答案是:映射函数不必返回与输入批次大小相同的输出批次。
换句话说,你的映射函数输入可以是大小为 N
的批次,并返回大小为 M
的批次。输出 M
可以大于或小于 N
。这意味着你可以连接你的示例,将其分成多个部分,甚至添加更多示例!
但是,请记住,输出字典中的所有值必须包含与输出字典中其他字段相同数量的元素。否则,无法确定映射函数返回的输出中的示例数量。这个数量可能在映射函数处理的连续批次之间有所不同。但是对于单个批次,输出字典的所有值应该具有相同的长度(即,元素数量)。
例如,从一个包含 1 列和 3 行的数据集中,如果你使用 map
返回一个包含两倍行数的新列,那么你将遇到错误。在这种情况下,你最终将得到一个包含 3 行的列,以及一个包含 6 行的列。如你所见,该表将无效。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset.map(lambda batch: {"b": batch["a"] * 2}, batched=True) # new column with 6 elements: [0, 1, 2, 0, 1, 2]
'ArrowInvalid: Column 1 named b expected length 3 but got length 6'
为了使其有效,你必须删除其中一列
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}, remove_columns=["a"], batched=True)
>>> len(dataset_with_duplicates)
6