Datasets 文档
Dataset 和 IterableDataset 之间的区别
并获得增强的文档体验
开始使用
Dataset 和 IterableDataset 之间的区别
有两种类型的数据集对象:Dataset 和 IterableDataset。您选择使用或创建哪种类型的数据集取决于数据集的大小。总的来说,由于其惰性行为和速度优势,IterableDataset 非常适合处理大型数据集(想想数百 GB!),而 Dataset 则适用于其他所有情况。本页将比较 Dataset 和 IterableDataset 之间的区别,以帮助您选择适合自己的数据集对象。
下载与流式传输
当您有一个常规的 Dataset 时,您可以使用 `my_dataset[0]` 来访问它。这提供了对行的随机访问。这类数据集也称为“映射式”(map-style)数据集。例如,您可以这样下载 ImageNet-1k 并访问任何行:
from datasets import load_dataset
imagenet = load_dataset("timm/imagenet-1k-wds", split="train") # downloads the full dataset
print(imagenet[0])
但一个缺点是,您必须将整个数据集存储在磁盘或内存中,这会阻止您访问比磁盘更大的数据集。由于这对大型数据集可能带来不便,因此存在另一种类型的数据集,即 IterableDataset。当您有一个 `IterableDataset` 时,您可以使用 `for` 循环在迭代数据集时逐步加载数据。这样,只有一小部分样本被加载到内存中,并且您不会在磁盘上写入任何内容。
例如,您可以流式传输 ImageNet-1k 数据集而无需将其下载到磁盘:
from datasets import load_dataset
imagenet = load_dataset("timm/imagenet-1k-wds", split="train", streaming=True) # will start loading the data when iterated over
for example in imagenet:
print(example)
break
流式传输可以读取在线数据而无需向磁盘写入任何文件。例如,您可以流式传输由多个分片(shard)组成的数据集,每个分片都可能有数百 GB,例如 C4 或 LAION-2B。有关如何流式传输数据集的更多信息,请参阅数据集流式传输指南。
但这并不是唯一的区别,因为 `IterableDataset` 的“惰性”行为在数据集的创建和处理方面也同样存在。
创建映射式数据集和可迭代数据集
您可以使用列表或字典创建 Dataset,数据会完全转换为 Arrow 格式,以便您轻松访问任何行:
my_dataset = Dataset.from_dict({"col_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]})
print(my_dataset[0])
另一方面,要创建 `IterableDataset`,您必须提供一种“惰性”加载数据的方式。在 Python 中,我们通常使用生成器函数。这些函数一次 `yield` 一个样本,这意味着您不能像常规 `Dataset` 那样通过切片来访问行:
def my_generator(n):
for i in range(n):
yield {"col_1": i}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs={"n": 10})
for example in my_iterable_dataset:
print(example)
break
完全加载与逐步加载本地文件
可以使用 load_dataset() 将本地或远程数据文件转换为 Arrow 格式的 Dataset:
data_files = {"train": ["path/to/data.csv"]}
my_dataset = load_dataset("csv", data_files=data_files, split="train")
print(my_dataset[0])
然而,这需要一个从 CSV 到 Arrow 格式的转换步骤,如果您的数据集很大,这会消耗时间和磁盘空间。
为了节省磁盘空间并跳过转换步骤,您可以通过直接从本地文件流式传输来定义一个 `IterableDataset`。这样,当您迭代数据集时,数据会逐步从本地文件中读取:
data_files = {"train": ["path/to/data.csv"]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
for example in my_iterable_dataset: # this reads the CSV file progressively as you iterate over the dataset
print(example)
break
支持多种文件格式,如 CSV、JSONL 和 Parquet,以及图像和音频文件。您可以在相应的指南中找到更多关于加载表格、文本、视觉和音频数据集的信息。
即时数据处理与惰性数据处理
当您使用 Dataset.map() 处理一个 Dataset 对象时,整个数据集会立即被处理并返回。这与 `pandas` 的工作方式类似。
my_dataset = my_dataset.map(process_fn) # process_fn is applied on all the examples of the dataset
print(my_dataset[0])
另一方面,由于 `IterableDataset` 的“惰性”特性,调用 IterableDataset.map() 并不会将您的 `map` 函数应用于整个数据集。相反,您的 `map` 函数是即时应用的。
因此,您可以链接多个处理步骤,当您开始迭代数据集时,它们会一次性全部运行:
my_iterable_dataset = my_iterable_dataset.map(process_fn_1)
my_iterable_dataset = my_iterable_dataset.filter(filter_fn)
my_iterable_dataset = my_iterable_dataset.map(process_fn_2)
# process_fn_1, filter_fn and process_fn_2 are applied on-the-fly when iterating over the dataset
for example in my_iterable_dataset:
print(example)
break
精确洗牌与快速近似洗牌
当您使用 Dataset.shuffle() 对 Dataset 进行洗牌时,您应用的是对数据集的精确洗牌。它的工作原理是获取一个索引列表 `[0, 1, 2, ... len(my_dataset) - 1]` 并对这个列表进行洗牌。然后,访问 `my_dataset[0]` 会返回由洗牌后的索引映射的第一个元素定义的行和索引:
my_dataset = my_dataset.shuffle(seed=42)
print(my_dataset[0])
由于在 `IterableDataset` 的情况下我们没有对行的随机访问权限,因此我们不能使用一个洗牌后的索引列表来访问任意位置的行。这使得精确洗牌无法实现。取而代之的是,在 IterableDataset.shuffle() 中使用了一种快速的近似洗牌方法。它使用一个洗牌缓冲区来迭代地从数据集中抽样随机样本。由于数据集仍然是迭代读取的,它提供了出色的速度性能。
my_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)
for example in my_iterable_dataset:
print(example)
break
但仅使用洗牌缓冲区不足以为机器学习模型训练提供令人满意的洗牌效果。因此,如果您的数据集由多个文件或来源组成,IterableDataset.shuffle() 也会对数据集的分片进行洗牌。
# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.num_shards # 39
# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.num_shards # 1024
# From a generator function
def my_generator(n, sources):
for source in sources:
for example_id_for_current_source in range(n):
yield {"example_id": f"{source}_{example_id_for_current_source}"}
gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.num_shards # 1024
速度差异
常规的 Dataset 对象基于 Arrow,它提供了对行的快速随机访问。得益于内存映射以及 Arrow 是一种内存格式的事实,从磁盘读取数据不会进行昂贵的系统调用和反序列化。通过在连续的 Arrow 记录批次上迭代,使用 `for` 循环迭代时,它提供了更快的数据加载速度。
然而,一旦您的 Dataset 有了索引映射(例如通过 Dataset.shuffle()),速度可能会慢 10 倍。这是因为有一个额外的步骤来使用索引映射获取要读取的行索引,更重要的是,您不再读取连续的数据块。要恢复速度,您需要使用 Dataset.flatten_indices() 再次将整个数据集重写到磁盘上,这会移除索引映射。不过,这可能需要很长时间,具体取决于您的数据集大小。
my_dataset[0] # fast
my_dataset = my_dataset.shuffle(seed=42)
my_dataset[0] # up to 10x slower
my_dataset = my_dataset.flatten_indices() # rewrite the shuffled dataset on disk as contiguous chunks of data
my_dataset[0] # fast again
在这种情况下,我们建议切换到 IterableDataset 并利用其快速的近似洗牌方法 IterableDataset.shuffle()。它只对分片顺序进行洗牌,并为您的数据集添加一个洗牌缓冲区,这能保持数据集的最佳速度。您也可以轻松地重新洗牌数据集。
for example in enumerate(my_iterable_dataset): # fast
pass
shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)
for example in enumerate(shuffled_iterable_dataset): # as fast as before
pass
shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=1337, buffer_size=100) # reshuffling using another seed is instantaneous
for example in enumerate(shuffled_iterable_dataset): # still as fast as before
pass
如果您在多个周期(epoch)上使用数据集,用于洗牌分片顺序和洗牌缓冲区的有效种子是 `seed + epoch`。这使得在不同周期之间轻松地重新洗牌数据集成为可能。
for epoch in range(n_epochs):
my_iterable_dataset.set_epoch(epoch)
for example in my_iterable_dataset: # fast + reshuffled at each epoch using `effective_seed = seed + epoch`
pass
要重新开始迭代一个映射式数据集,您只需跳过前面的样本即可:
my_dataset = my_dataset.select(range(start_index, len(dataset)))
但如果您使用带 `Sampler` 的 `DataLoader`,您应该保存您的采样器状态(您可能编写了一个允许恢复的自定义采样器)。
另一方面,可迭代数据集不提供对特定样本索引的随机访问以供恢复。但您可以使用 IterableDataset.state_dict() 和 IterableDataset.load_state_dict() 从检查点恢复,类似于您对模型和优化器可以做的那样:
>>> iterable_dataset = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=3)
>>> # save in the middle of training
>>> state_dict = iterable_dataset.state_dict()
>>> # and resume later
>>> iterable_dataset.load_state_dict(state_dict)
在底层,可迭代数据集会跟踪当前正在读取的分片和当前分片中的样本索引,并将此信息存储在 `state_dict` 中。
要从检查点恢复,数据集会跳过所有先前读取过的分片,以从当前分片重新开始。然后它会读取该分片并跳过样本,直到达到检查点中的确切样本位置。
因此,重新启动数据集相当快,因为它不会重新读取已经迭代过的分片。尽管如此,恢复数据集通常不是瞬时的,因为它必须从当前分片的开头重新开始读取并跳过样本,直到达到检查点位置。
这可以与 `torchdata` 中的 `StatefulDataLoader` 一起使用,请参阅使用 PyTorch DataLoader 进行流式传输。
从映射式切换到可迭代式
如果您想受益于 IterableDataset 的“惰性”行为或其速度优势,您可以将您的映射式 Dataset 切换为 IterableDataset。
my_iterable_dataset = my_dataset.to_iterable_dataset()
如果您想对数据集进行洗牌或将其与 PyTorch DataLoader 一起使用,我们建议生成一个分片的 IterableDataset。
my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.num_shards # 1024