数据集文档
批量映射
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
批量映射
将 Dataset.map() 的实用性与批量模式相结合非常强大。它允许您加速处理,并自由控制生成数据集的大小。
速度需求
批量映射的主要目的是加速处理。通常,处理批量数据比处理单个示例更快。自然地,批量映射适用于分词。例如,🤗 Tokenizers 库在处理批次时速度更快,因为它并行化批次中所有示例的分词。
输入大小 != 输出大小
控制生成数据集大小的能力可以用于许多有趣的用例。在 How-to map 部分,有一些使用批量映射的示例,用于
- 将长句子分割成较短的片段。
- 使用额外的 token 扩充数据集。
了解这是如何工作的很有帮助,这样您就可以想出自己使用批量映射的方法。此时,您可能想知道如何控制生成数据集的大小。答案是:映射函数不必返回相同大小的输出批次。
换句话说,您的映射函数输入可以是大小为 N 的批次,并返回大小为 M 的批次。输出 M 可以大于或小于 N。这意味着您可以连接您的示例,将其分割,甚至添加更多示例!
但是,请记住,输出字典中的所有值必须包含与输出字典中其他字段相同数量的元素。否则,无法定义映射函数返回的输出中示例的数量。数量可能在映射函数处理的连续批次之间变化。但是对于单个批次,输出字典的所有值都应具有相同的长度(即,元素数量)。
例如,从一个包含 1 列和 3 行的数据集开始,如果您使用 map 返回一个具有两倍行数的新列,则会出现错误。在这种情况下,您最终会得到一列包含 3 行,另一列包含 6 行。如您所见,该表将无效
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset.map(lambda batch: {"b": batch["a"] * 2}, batched=True) # new column with 6 elements: [0, 1, 2, 0, 1, 2]
'ArrowInvalid: Column 1 named b expected length 3 but got length 6'
为了使其有效,您必须删除其中一列
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}, remove_columns=["a"], batched=True)
>>> len(dataset_with_duplicates)
6