图像分类
图像分类数据集用于训练模型对整张图像进行分类。这些数据集支持各种各样的应用,例如识别濒危野生动物物种或在医学图像中筛查疾病。本指南将向您展示如何将转换应用于图像分类数据集。
在开始之前,请确保您已安装最新版本的 albumentations
和 cv2
pip install -U albumentations opencv-python
本指南使用 Beans 数据集,根据豆类植物叶片的图像识别豆类植物病害的类型。
加载数据集并查看一个示例
>>> from datasets import load_dataset
>>> dataset = load_dataset("beans")
>>> dataset["train"][10]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=500x500 at 0x7F8D2F4D7A10>,
'image_file_path': '/root/.cache/huggingface/datasets/downloads/extracted/b0a21163f78769a2cf11f58dfc767fb458fc7cea5c05dccc0144a2c0f0bc1292/train/angular_leaf_spot/angular_leaf_spot_train.204.jpg',
'labels': 0}
数据集包含三个字段
image
:一个 PIL 图像对象。image_file_path
:图像文件的路径。labels
:图像的标签或类别。
接下来,查看一张图像
现在使用 albumentations
应用一些增强。您将随机裁剪图像、将其水平翻转并调整其亮度。
>>> import cv2
>>> import albumentations
>>> import numpy as np
>>> transform = albumentations.Compose([
... albumentations.RandomCrop(width=256, height=256),
... albumentations.HorizontalFlip(p=0.5),
... albumentations.RandomBrightnessContrast(p=0.2),
... ])
创建一个函数,将转换应用于图像
>>> def transforms(examples):
... examples["pixel_values"] = [
... transform(image=np.array(image))["image"] for image in examples["image"]
... ]
...
... return examples
使用 set_transform() 函数按需将转换应用于数据集的批次,以减少磁盘空间占用
>>> dataset.set_transform(transforms)
您可以通过索引第一个示例的 pixel_values
来验证转换是否成功
>>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> img = dataset["train"][0]["pixel_values"]
>>> plt.imshow(img)
现在您已经了解了如何处理图像分类的数据集,请学习如何训练图像分类模型并将其用于推理。