Datasets 文档

Dataset 和 IterableDataset 的区别

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Dataset 和 IterableDataset 的区别

有两种类型的数据集对象:DatasetIterableDataset。你选择使用或创建哪种类型的数据集取决于数据集的大小。一般来说,IterableDataset 因其惰性行为和速度优势而非常适合超大数据集(想想几百 GB!),而 Dataset 则适用于其他所有场景。本页将对比 DatasetIterableDataset 的区别,以帮助你选择合适的数据集对象。

下载与流式传输

当你拥有一个常规的 Dataset 时,你可以使用 my_dataset[0] 来访问它。这提供了对数据行的随机访问。此类数据集也称为“映射样式” (map-style) 数据集。例如,你可以像这样下载 ImageNet-1k 并访问任何一行:

from datasets import load_dataset

imagenet = load_dataset("timm/imagenet-1k-wds", split="train")  # downloads the full dataset
print(imagenet[0])

但一个注意事项是,你必须将整个数据集存储在磁盘或内存中,这会阻止你访问比磁盘容量更大的数据集。由于这对大数据集来说可能会变得不方便,因此存在另一种类型的数据集,即 IterableDataset。当你拥有 IterableDataset 时,你可以使用 for 循环在迭代数据集时逐步加载数据。这样,只有一小部分示例被加载到内存中,而且你不需要在磁盘上写入任何内容。

例如,你可以流式传输 ImageNet-1k 数据集,而无需将其下载到磁盘:

from datasets import load_dataset

imagenet = load_dataset("timm/imagenet-1k-wds", split="train", streaming=True)  # will start loading the data when iterated over
for example in imagenet:
    print(example)
    break

流式传输可以读取在线数据,而无需向磁盘写入任何文件。例如,你可以流式传输由多个分片组成的数据集,每个分片都有数百 GB,如 C4LAION-2B。在 数据集流式传输指南 中了解有关如何流式传输数据集的更多信息。

不过,这并不是唯一的区别,因为 IterableDataset 的“惰性”行为在数据集创建和处理时也会体现出来。

创建映射样式数据集和可迭代数据集

你可以使用列表或字典创建一个 Dataset,数据会完全转换为 Arrow 格式,以便你可以轻松访问任何一行:

my_dataset = Dataset.from_dict({"col_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]})
print(my_dataset[0])

另一方面,要创建 IterableDataset,你必须提供一种“惰性”加载数据的方法。在 Python 中,我们通常使用生成器函数。这些函数一次 yield(产生)一个示例,这意味着你不能像常规 Dataset 那样通过切片来访问某一行:

def my_generator(n):
    for i in range(n):
        yield {"col_1": i}

my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs={"n": 10})
for example in my_iterable_dataset:
    print(example)
    break

完整加载与逐步加载本地文件

可以使用 load_dataset() 将本地或远程数据文件转换为 Arrow Dataset

data_files = {"train": ["path/to/data.csv"]}
my_dataset = load_dataset("csv", data_files=data_files, split="train")
print(my_dataset[0])

然而,这需要一个从 CSV 到 Arrow 格式的转换步骤,如果你的数据集很大,这会消耗时间和磁盘空间。

为了节省磁盘空间并跳过转换步骤,你可以通过直接从本地文件流式传输来定义 IterableDataset。这样,当你迭代数据集时,数据会从本地文件中逐步读取:

data_files = {"train": ["path/to/data.csv"]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
for example in my_iterable_dataset:  # this reads the CSV file progressively as you iterate over the dataset
    print(example)
    break

支持许多文件格式,如 CSV、JSONL 和 Parquet,以及图像和音频文件。你可以在加载 表格文本视觉音频 数据集的相应指南中找到更多信息。

急切数据处理与惰性数据处理

当你使用 Dataset.map() 处理 Dataset 对象时,整个数据集会立即处理并返回。这类似于 pandas 的工作方式。

my_dataset = my_dataset.map(process_fn)  # process_fn is applied on all the examples of the dataset
print(my_dataset[0])

另一方面,由于 IterableDataset 的“惰性”本质,调用 IterableDataset.map() 不会将 map 函数应用于整个数据集。相反,你的 map 函数是即时应用的。

正因为如此,你可以将多个处理步骤链接起来,当你开始迭代数据集时,它们会一次性全部运行:

my_iterable_dataset = my_iterable_dataset.map(process_fn_1)
my_iterable_dataset = my_iterable_dataset.filter(filter_fn)
my_iterable_dataset = my_iterable_dataset.map(process_fn_2)

# process_fn_1, filter_fn and process_fn_2 are applied on-the-fly when iterating over the dataset
for example in my_iterable_dataset:  
    print(example)
    break

精确打乱与快速近似打乱

当你使用 Dataset.shuffle() 打乱 Dataset 时,你会对数据集应用精确打乱。它的工作原理是获取索引列表 [0, 1, 2, ... len(my_dataset) - 1] 并打乱该列表。然后,访问 my_dataset[0] 会返回由已打乱索引映射的第一个元素所定义的行和索引:

my_dataset = my_dataset.shuffle(seed=42)
print(my_dataset[0])

由于在 IterableDataset 的情况下我们无法随机访问行,因此我们不能使用打乱的索引列表并在任意位置访问行。这阻止了精确打乱的使用。相反,IterableDataset.shuffle() 中使用了快速近似打乱。它使用打乱缓冲区从数据集中迭代地采样随机示例。由于数据集仍然是迭代读取的,因此它提供了出色的速度性能:

my_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)
for example in my_iterable_dataset:
    print(example)
    break

但仅使用打乱缓冲区不足以为机器学习模型训练提供令人满意的打乱效果。因此,如果你的数据集由多个文件或来源组成,IterableDataset.shuffle() 还会打乱数据集分片:

# Stream from the internet
my_iterable_dataset = load_dataset("deepmind/code_contests", split="train", streaming=True)
my_iterable_dataset.num_shards  # 39

# Stream from local files
data_files = {"train": [f"path/to/data_{i}.csv" for i in range(1024)]}
my_iterable_dataset = load_dataset("csv", data_files=data_files, split="train", streaming=True)
my_iterable_dataset.num_shards  # 1024

# From a generator function
def my_generator(n, sources):
    for source in sources:
        for example_id_for_current_source in range(n):
            yield {"example_id": f"{source}_{example_id_for_current_source}"}

gen_kwargs = {"n": 10, "sources": [f"path/to/data_{i}" for i in range(1024)]}
my_iterable_dataset = IterableDataset.from_generator(my_generator, gen_kwargs=gen_kwargs)
my_iterable_dataset.num_shards  # 1024

速度差异

常规 Dataset 对象基于 Arrow,它提供了对行的快速随机访问。由于内存映射以及 Arrow 是内存格式的事实,从磁盘读取数据不会进行昂贵的系统调用和反序列化。通过在连续的 Arrow 记录批次上迭代,它在使用 for 循环迭代时提供了更快的数据加载速度。

然而,一旦你的 Dataset 具有索引映射(例如通过 Dataset.shuffle()),速度可能会慢 10 倍。这是因为有一个额外的步骤来使用索引映射获取要读取的行索引,最重要的是,你不再读取连续的数据块。为了恢复速度,你需要使用 Dataset.flatten_indices() 再次在磁盘上重写整个数据集,这会移除索引映射。不过,这可能需要很多时间,具体取决于数据集的大小:

my_dataset[0]  # fast
my_dataset = my_dataset.shuffle(seed=42)
my_dataset[0]  # up to 10x slower
my_dataset = my_dataset.flatten_indices()  # rewrite the shuffled dataset on disk as contiguous chunks of data
my_dataset[0]  # fast again

在这种情况下,我们建议切换到 IterableDataset 并利用其快速近似打乱方法 IterableDataset.shuffle()。它仅打乱分片顺序并为你的数据集添加一个打乱缓冲区,这使你的数据集保持最佳速度。你还可以轻松地重新打乱数据集:

for example in enumerate(my_iterable_dataset):  # fast
    pass

shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=42, buffer_size=100)

for example in enumerate(shuffled_iterable_dataset):  # as fast as before
    pass

shuffled_iterable_dataset = my_iterable_dataset.shuffle(seed=1337, buffer_size=100)  # reshuffling using another seed is instantaneous

for example in enumerate(shuffled_iterable_dataset):  # still as fast as before
    pass

如果你在多个 epoch(轮次)上使用数据集,打乱缓冲区中打乱分片顺序的有效种子是 seed + epoch。这使得在 epoch 之间重新打乱数据集变得容易:

for epoch in range(n_epochs):
    my_iterable_dataset.set_epoch(epoch)
    for example in my_iterable_dataset:  # fast + reshuffled at each epoch using `effective_seed = seed + epoch`
        pass

要重新开始映射样式数据集的迭代,你可以简单地跳过前面的示例:

my_dataset = my_dataset.select(range(start_index, len(dataset)))

但如果你将 DataLoaderSampler 一起使用,你应该改为保存采样器的状态(你可能已经编写了允许恢复的自定义采样器)。

另一方面,可迭代数据集不提供用于从中恢复的特定示例索引的随机访问。但你可以使用 IterableDataset.state_dict()IterableDataset.load_state_dict() 改为从检查点恢复,类似于你可以为模型和优化器所做的操作:

>>> iterable_dataset = Dataset.from_dict({"a": range(6)}).to_iterable_dataset(num_shards=3)
>>> # save in the middle of training
>>> state_dict = iterable_dataset.state_dict()
>>> # and resume later
>>> iterable_dataset.load_state_dict(state_dict)

在底层,可迭代数据集会跟踪当前正在读取的分片以及当前分片中的示例索引,并将此信息存储在 state_dict 中。

要从检查点恢复,数据集会跳过之前读取的所有分片,从当前分片重新开始。然后它读取该分片并跳过示例,直到达到检查点中的确切示例。

因此,重启数据集是相当快的,因为它不会重新读取已经迭代过的分片。尽管如此,恢复数据集通常不是瞬时的,因为它必须从当前分片的开头重新开始读取并跳过示例,直到到达检查点位置。

这可以与 torchdata 中的 StatefulDataLoader 一起使用,请参阅 使用 PyTorch DataLoader 进行流式传输

从映射样式切换到可迭代样式

如果你想从 IterableDataset 的“惰性”行为或其速度优势中获益,可以将映射样式的 Dataset 切换为 IterableDataset

my_iterable_dataset = my_dataset.to_iterable_dataset()

如果你想打乱数据集或将其与 PyTorch DataLoader 一起使用,我们建议生成一个分片的 IterableDataset

my_iterable_dataset = my_dataset.to_iterable_dataset(num_shards=1024)
my_iterable_dataset.num_shards  # 1024
在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.