Datasets 文档

加载

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

加载

您的数据可以存储在各种地方;它们可以在本地机器的磁盘上、在Github仓库中,也可以在内存数据结构中,例如Python字典和Pandas DataFrames。无论数据集存储在哪里,🤗 Datasets 都可以帮助您加载它。

本指南将向您展示如何从以下位置加载数据集:

  • Hugging Face Hub
  • 本地文件
  • 内存数据
  • 离线
  • 特定切分

有关加载其他数据集模态的更多详细信息,请参阅加载音频数据集指南加载图像数据集指南加载视频数据集指南加载文本数据集指南

Hugging Face Hub

您还可以从 Hub 上的任何数据集仓库加载数据集!首先创建数据集仓库并上传您的数据文件。现在您可以使用load_dataset()函数加载数据集。

例如,尝试通过提供仓库命名空间和数据集名称从这个演示仓库加载文件。此数据集仓库包含 CSV 文件,下面的代码从 CSV 文件加载数据集:

>>> from datasets import load_dataset
>>> dataset = load_dataset("lhoestq/demo1")

有些数据集可能有多个版本,基于 Git 标签、分支或提交。使用 revision 参数来指定您要加载的数据集版本:

>>> dataset = load_dataset(
...   "lhoestq/custom_squad",
...   revision="main"  # tag name, or branch name, or commit hash
... )

有关如何在 Hub 上创建数据集仓库以及如何上传数据文件的更多详细信息,请参阅将数据集上传到 Hub教程。

数据集默认将所有数据加载到 train 分割中,或者检查数据文件名称中是否有提及或分割名称(例如“train”、“test”和“validation”)。使用 data_files 参数将数据文件映射到 trainvalidationtest 等分割:

>>> data_files = {"train": "train.csv", "test": "test.csv"}
>>> dataset = load_dataset("namespace/your_dataset_name", data_files=data_files)

如果您没有指定要使用哪些数据文件,load_dataset() 将返回所有数据文件。如果您加载像 C4 这样的大型数据集(大约 13TB 数据),这可能需要很长时间。

您还可以使用 data_filesdata_dir 参数加载文件的特定子集。这些参数可以接受一个相对路径,该路径解析为数据集加载的基础路径。

>>> from datasets import load_dataset

# load files that match the grep pattern
>>> c4_subset = load_dataset("allenai/c4", data_files="en/c4-train.0000*-of-01024.json.gz")

# load dataset from the en directory on the Hub
>>> c4_subset = load_dataset("allenai/c4", data_dir="en")

split 参数也可以将数据文件映射到特定分割:

>>> data_files = {"validation": "en/c4-validation.*.json.gz"}
>>> c4_validation = load_dataset("allenai/c4", data_files=data_files, split="validation")

本地和远程文件

数据集可以从存储在您计算机上的本地文件以及远程文件加载。数据集很可能存储为 csvjsontxtparquet 文件。load_dataset() 函数可以加载这些文件类型。

CSV

🤗 Datasets 可以读取由一个或多个 CSV 文件组成的数据集(在这种情况下,将您的 CSV 文件作为列表传递):

>>> from datasets import load_dataset
>>> dataset = load_dataset("csv", data_files="my_file.csv")

有关更多详细信息,请查看如何从 CSV 文件加载表格数据集指南。

JSON

JSON 文件可以直接使用load_dataset()加载,如下所示:

>>> from datasets import load_dataset
>>> dataset = load_dataset("json", data_files="my_file.json")

JSON 文件有多种格式,但我们认为最有效的格式是包含多个 JSON 对象;每行代表一个单独的数据行。例如:

{"a": 1, "b": 2.0, "c": "foo", "d": false}
{"a": 4, "b": -5.5, "c": null, "d": true}

您可能会遇到另一种 JSON 格式是嵌套字段,在这种情况下,您需要指定 field 参数,如下所示:

{"version": "0.1.0",
 "data": [{"a": 1, "b": 2.0, "c": "foo", "d": false},
          {"a": 4, "b": -5.5, "c": null, "d": true}]
}

>>> from datasets import load_dataset
>>> dataset = load_dataset("json", data_files="my_file.json", field="data")

要通过 HTTP 加载远程 JSON 文件,请传递 URL:

>>> base_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/"
>>> dataset = load_dataset("json", data_files={"train": base_url + "train-v1.1.json", "validation": base_url + "dev-v1.1.json"}, field="data")

虽然这些是最常见的 JSON 格式,但您会看到其他格式不同的数据集。🤗 Datasets 识别这些其他格式,并将相应地回退到 Python JSON 加载方法来处理它们。

Parquet

Parquet 文件以列式格式存储,与 CSV 等基于行的文件不同。大型数据集可以存储在 Parquet 文件中,因为它更高效且查询速度更快。

加载 Parquet 文件:

>>> from datasets import load_dataset
>>> dataset = load_dataset("parquet", data_files={'train': 'train.parquet', 'test': 'test.parquet'})

要通过 HTTP 加载远程 Parquet 文件,请传递 URL:

>>> base_url = "https://huggingface.co/datasets/wikimedia/wikipedia/resolve/main/20231101.ab/"
>>> data_files = {"train": base_url + "train-00000-of-00001.parquet"}
>>> wiki = load_dataset("parquet", data_files=data_files, split="train")

Arrow

Arrow 文件以内存中的列式格式存储,与 CSV 等基于行的格式和 Parquet 等未压缩格式不同。

加载 Arrow 文件:

>>> from datasets import load_dataset
>>> dataset = load_dataset("arrow", data_files={'train': 'train.arrow', 'test': 'test.arrow'})

要通过 HTTP 加载远程 Arrow 文件,请传递 URL:

>>> base_url = "https://huggingface.co/datasets/croissantllm/croissant_dataset/resolve/main/english_660B_11/"
>>> data_files = {"train": base_url + "train/data-00000-of-00080.arrow"}
>>> wiki = load_dataset("arrow", data_files=data_files, split="train")

Arrow 是 🤗 Datasets 在底层使用的文件格式,因此您可以直接使用Dataset.from_file()加载本地 Arrow 文件:

>>> from datasets import Dataset
>>> dataset = Dataset.from_file("data.arrow")

load_dataset()不同,Dataset.from_file()会将 Arrow 文件内存映射,而无需在缓存中准备数据集,从而节省磁盘空间。在这种情况下,用于存储中间处理结果的缓存目录将是 Arrow 文件目录。

目前仅支持 Arrow 流式格式。不支持 Arrow IPC 文件格式(也称为 Feather V2)。

SQL

使用from_sql()读取数据库内容,通过指定连接数据库的 URI。您可以读取表名和查询:

>>> from datasets import Dataset
# load entire table
>>> dataset = Dataset.from_sql("data_table_name", con="sqlite:///sqlite_file.db")
# load from query
>>> dataset = Dataset.from_sql("SELECT text FROM table WHERE length(text) > 100 LIMIT 10", con="sqlite:///sqlite_file.db")

有关更多详细信息,请查看如何从 SQL 数据库加载表格数据集指南。

WebDataset

WebDataset 格式基于 TAR 归档,适用于大型图像数据集。由于其大小,WebDataset 通常以流式模式加载(使用 streaming=True)。

您可以像这样加载 WebDataset:

>>> from datasets import load_dataset
>>>
>>> path = "path/to/train/*.tar"
>>> dataset = load_dataset("webdataset", data_files={"train": path}, split="train", streaming=True)

要通过 HTTP 加载远程 WebDataset,请传递 URL:

>>> from datasets import load_dataset
>>>
>>> base_url = "https://huggingface.co/datasets/lhoestq/small-publaynet-wds/resolve/main/publaynet-train-{i:06d}.tar"
>>> urls = [base_url.format(i=i) for i in range(4)]
>>> dataset = load_dataset("webdataset", data_files={"train": urls}, split="train", streaming=True)

多进程

当数据集由多个文件(我们称之为“分片”)组成时,可以显著加快数据集下载和准备步骤。

您可以选择使用多少个进程并行准备数据集,使用 num_proc。在这种情况下,每个进程被分配一个分片子集进行准备:

from datasets import load_dataset

imagenet = load_dataset("timm/imagenet-1k-wds", num_proc=8)
ml_librispeech_spanish = load_dataset("facebook/multilingual_librispeech", "spanish", num_proc=8)

内存数据

🤗 Datasets 还允许您直接从内存数据结构(如 Python 字典和 Pandas DataFrames)创建Dataset

Python 字典

使用from_dict()加载 Python 字典:

>>> from datasets import Dataset
>>> my_dict = {"a": [1, 2, 3]}
>>> dataset = Dataset.from_dict(my_dict)

Python 字典列表

使用 from_list() 加载 Python 字典列表:

>>> from datasets import Dataset
>>> my_list = [{"a": 1}, {"a": 2}, {"a": 3}]
>>> dataset = Dataset.from_list(my_list)

Python 生成器

使用from_generator()从 Python 生成器创建数据集:

>>> from datasets import Dataset
>>> def my_gen():
...     for i in range(1, 4):
...         yield {"a": i}
...
>>> dataset = Dataset.from_generator(my_gen)

这种方法支持加载大于可用内存的数据。

您还可以通过将列表传递给 gen_kwargs 来定义分片数据集:

>>> def gen(shards):
...     for shard in shards:
...         with open(shard) as f:
...             for line in f:
...                 yield {"line": line}
...
>>> shards = [f"data{i}.txt" for i in range(32)]
>>> ds = IterableDataset.from_generator(gen, gen_kwargs={"shards": shards})
>>> ds = ds.shuffle(seed=42, buffer_size=10_000)  # shuffles the shards order + uses a shuffle buffer
>>> from torch.utils.data import DataLoader
>>> dataloader = DataLoader(ds.with_format("torch"), num_workers=4)  # give each worker a subset of 32/4=8 shards

Pandas DataFrame

使用from_pandas()加载 Pandas DataFrames:

>>> from datasets import Dataset
>>> import pandas as pd
>>> df = pd.DataFrame({"a": [1, 2, 3]})
>>> dataset = Dataset.from_pandas(df)

有关更多详细信息,请查看如何从 Pandas DataFrames 加载表格数据集指南。

离线

即使没有互联网连接,仍然可以加载数据集。只要您之前从 Hub 仓库下载过数据集,它就应该被缓存。这意味着您可以从缓存重新加载数据集并离线使用它。

如果您知道没有互联网连接,可以以完全离线模式运行 🤗 Datasets。这可以节省时间,因为无需等待数据集构建器下载超时,🤗 Datasets 将直接在缓存中查找。将环境变量 HF_HUB_OFFLINE 设置为 1 以启用完全离线模式。

切分

您还可以选择只加载分割的特定切片。切分分割有两种选择:使用字符串或ReadInstruction API。对于简单情况,字符串更紧凑易读,而ReadInstruction更易于与可变切分参数一起使用。

连接 traintest 切分:

>>> train_test_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train+test")

选择 train 切分的特定行:

>>> train_10_20_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[10:20]")

或者使用以下命令选择分割的百分比:

>>> train_10pct_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[:10%]")

选择每个分割的百分比组合:

>>> train_10_80pct_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[:10%]+train[-80%:]")

最后,您甚至可以创建交叉验证的分割。下面的示例创建了 10 折交叉验证分割。每个验证数据集是 10% 的块,训练数据集构成剩余的互补 90% 的块:

>>> val_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split=[f"train[{k}%:{k+10}%]" for k in range(0, 100, 10)]) >>> train_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split=[f"train[:{k}%]+train[{k+10}%:]" for k in range(0, 100, 10)])

百分比切片和四舍五入

默认行为是将边界四舍五入到最接近的整数,对于请求的切片边界不能被 100 整除的数据集。如下所示,有些切片可能包含比其他切片更多的示例。例如,如果以下训练集包含 999 条记录,则:

# 19 records, from 500 (included) to 519 (excluded).
>>> train_50_52_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[50%:52%]")
# 20 records, from 519 (included) to 539 (excluded).
>>> train_52_54_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[52%:54%]")

如果您想要大小相等的分割,请使用 pct1_dropremainder 四舍五入。这会将指定的百分比边界视为 1% 的倍数。

# 18 records, from 450 (included) to 468 (excluded).
>>> train_50_52pct1_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split=datasets.ReadInstruction("train", from_=50, to=52, unit="%", rounding="pct1_dropremainder"))
# 18 records, from 468 (included) to 486 (excluded).
>>> train_52_54pct1_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split=datasets.ReadInstruction("train",from_=52, to=54, unit="%", rounding="pct1_dropremainder"))
# Or equivalently:
>>> train_50_52pct1_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[50%:52%](pct1_dropremainder)")
>>> train_52_54pct1_ds = datasets.load_dataset("ajibawa-2023/General-Stories-Collection", split="train[52%:54%](pct1_dropremainder)")

如果数据集中的示例数量不能被 100 整除,pct1_dropremainder 四舍五入可能会截断数据集中的最后一个示例。

故障排除

有时,加载数据集时可能会得到意想不到的结果。您可能会遇到的两个最常见问题是手动下载数据集和指定数据集的特征。

指定特征

当您从本地文件创建数据集时,Features 会由 Apache Arrow 自动推断。然而,数据集的特征可能并不总是与您的预期一致,或者您可能希望自己定义特征。以下示例展示了如何使用 ClassLabel 特征添加自定义标签。

首先使用 Features 类定义您自己的标签:

>>> class_names = ["sadness", "joy", "love", "anger", "fear", "surprise"]
>>> emotion_features = Features({'text': Value('string'), 'label': ClassLabel(names=class_names)})

接下来,在load_dataset()中指定 features 参数,并使用您刚刚创建的特征:

>>> dataset = load_dataset('csv', data_files=file_dict, delimiter=';', column_names=['text', 'label'], features=emotion_features)

现在,当您查看数据集特征时,您会发现它使用了您定义的自定义标签:

>>> dataset['train'].features
{'text': Value('string'),
'label': ClassLabel(names=['sadness', 'joy', 'love', 'anger', 'fear', 'surprise'])}
< > 在 GitHub 上更新