数据集文档

图像分类

Hugging Face's logo
加入 Hugging Face 社区

并获得增强型文档体验

开始使用

图像分类

图像分类数据集用于训练模型对整张图像进行分类。这些数据集支持各种各样的应用,例如识别濒危野生动物物种或在医学图像中筛查疾病。本指南将向您展示如何将转换应用于图像分类数据集。

在开始之前,请确保您已安装最新版本的 albumentationscv2

pip install -U albumentations opencv-python

本指南使用 Beans 数据集,根据豆类植物叶片的图像识别豆类植物病害的类型。

加载数据集并查看一个示例

>>> from datasets import load_dataset

>>> dataset = load_dataset("beans")
>>> dataset["train"][10]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7F8D2F4D7A10>,
 'image_file_path': '/root/.cache/huggingface/datasets/downloads/extracted/b0a21163f78769a2cf11f58dfc767fb458fc7cea5c05dccc0144a2c0f0bc1292/train/angular_leaf_spot/angular_leaf_spot_train.204.jpg',
 'labels': 0}

数据集包含三个字段

  • image:一个 PIL 图像对象。
  • image_file_path:图像文件的路径。
  • labels:图像的标签或类别。

接下来,查看一张图像

现在使用 albumentations 应用一些增强。您将随机裁剪图像、将其水平翻转并调整其亮度。

>>> import cv2
>>> import albumentations
>>> import numpy as np

>>> transform = albumentations.Compose([
...     albumentations.RandomCrop(width=256, height=256),
...     albumentations.HorizontalFlip(p=0.5),
...     albumentations.RandomBrightnessContrast(p=0.2),
... ])

创建一个函数,将转换应用于图像

>>> def transforms(examples):
...     examples["pixel_values"] = [
...         transform(image=np.array(image))["image"] for image in examples["image"]
...     ]
... 
...     return examples

使用 set_transform() 函数按需将转换应用于数据集的批次,以减少磁盘空间占用

>>> dataset.set_transform(transforms)

您可以通过索引第一个示例的 pixel_values 来验证转换是否成功

>>> import numpy as np
>>> import matplotlib.pyplot as plt

>>> img = dataset["train"][0]["pixel_values"]
>>> plt.imshow(img)

现在您已经了解了如何处理图像分类的数据集,请学习如何训练图像分类模型并将其用于推理。

< > GitHub 更新