Datasets 文档
批处理映射
加入 Hugging Face 社区
并获得增强的文档体验
开始使用
批处理映射
将 Dataset.map() 的实用性与批处理模式相结合是非常强大的。它能让您加快处理速度,并自由控制生成数据集的大小。
对速度的追求
批处理映射的主要目标是加快处理速度。通常,处理批量数据比处理单个样本要快。自然而然地,批处理映射非常适合用于分词。例如,🤗 Tokenizers 库在处理批量数据时速度更快,因为它会并行处理批次中的所有样本的分词任务。
输入大小 != 输出大小
控制生成数据集大小的能力可以用于许多有趣的用例。在“如何操作”映射部分,有一些使用批处理映射的例子:
- 将长句子分割成更短的块。
- 通过添加额外的词元(token)来扩充数据集。
了解其工作原理有助于您想出自己使用批处理映射的方法。此时,您可能想知道如何控制生成数据集的大小。答案是:映射函数返回的输出批次大小不必与输入批次大小相同。
换句话说,您的映射函数输入可以是一个大小为 `N` 的批次,并返回一个大小为 `M` 的批次。输出大小 `M` 可以大于或小于 `N`。这意味着您可以连接您的样本,将其分割,甚至添加更多样本!
但是,请记住,输出字典中的所有值必须包含与该输出字典中其他字段相同数量的元素。否则,将无法定义映射函数返回的输出中的样本数量。这个数量可以在映射函数处理的连续批次之间变化。但是,对于单个批次,输出字典的所有值应具有相同的长度(即,元素数量)。
例如,对于一个包含1列3行的数据集,如果您使用 `map` 返回一个行数翻倍的新列,那么您将会遇到错误。在这种情况下,您最终会得到一列有3行,另一列有6行。如您所见,表格将是无效的。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset.map(lambda batch: {"b": batch["a"] * 2}, batched=True) # new column with 6 elements: [0, 1, 2, 0, 1, 2]
'ArrowInvalid: Column 1 named b expected length 3 but got length 6'
要使其有效,您必须删除其中一列。
>>> from datasets import Dataset
>>> dataset = Dataset.from_dict({"a": [0, 1, 2]})
>>> dataset_with_duplicates = dataset.map(lambda batch: {"b": batch["a"] * 2}, remove_columns=["a"], batched=True)
>>> len(dataset_with_duplicates)
6