开源 AI 食谱文档

在自定义数据集上微调语义分割模型并在推理 API 中使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

Open In Colab

在自定义数据集上微调语义分割模型并在推理 API 中使用

作者:Sergio Paniego

在本笔记本中,我们将逐步介绍在自定义数据集上微调语义分割模型的过程。我们将使用的模型是预训练的Segformer,这是一种功能强大且灵活的基于 Transformer 的分割任务架构。

Segformer architecture

对于我们的数据集,我们将使用segments/sidewalk-semantic,其中包含人行道的标记图像,使其成为城市环境中应用的理想选择。

示例用例:此模型可以部署在自主导航人行道送披萨到您家门口的送货机器人中 🍕

微调模型后,我们将演示如何使用无服务器推理 API部署它,使其可以通过简单的 API 端点访问。

1. 安装依赖项

首先,我们将安装微调语义分割模型所需的必要库。

!pip install -q datasets transformers evaluate wandb
# Tested with datasets==3.0.0, transformers==4.44.2, evaluate==0.4.3, wandb==0.18.1

2. 加载数据集 📁

我们将使用sidewalk-semantic数据集,该数据集包含 2021 年夏季在比利时收集的人行道图像。

数据集包括

  • 1000 张图像及其对应的语义分割掩码 🖼
  • 34 个不同的类别 📦

由于此数据集受限,您需要登录并接受许可才能访问。我们还需要身份验证才能在训练后将微调的模型上传到 Hub。

from huggingface_hub import notebook_login

notebook_login()
sidewalk_dataset_identifier = "segments/sidewalk-semantic"
from datasets import load_dataset

dataset = load_dataset(sidewalk_dataset_identifier)

查看内部结构以熟悉它!

dataset

由于数据集仅包含训练集,我们将手动将其划分为训练集和测试集。我们将分配 80% 的数据用于训练,并将剩余的 20% 保留用于评估和测试。➗

dataset = dataset.shuffle(seed=42)
dataset = dataset["train"].train_test_split(test_size=0.2)
train_ds = dataset["train"]
test_ds = dataset["test"]

让我们检查示例中存在的对象类型。我们可以看到pixels_values保存 RGB 图像,而label包含地面实况掩码。掩码是单通道图像,其中每个像素表示 RGB 图像中对应像素的类别。

image = train_ds[0]
image

3. 可视化示例! 👀

现在我们已经加载了数据集,让我们可视化一些示例及其掩码,以更好地了解其结构。

数据集包含一个包含id2label映射的 JSON文件。我们将打开此文件以读取与每个 ID 关联的类别标签。

>>> import json
>>> from huggingface_hub import hf_hub_download

>>> filename = "id2label.json"
>>> id2label = json.load(
...     open(hf_hub_download(repo_id=sidewalk_dataset_identifier, filename=filename, repo_type="dataset"), "r")
... )
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}

>>> num_labels = len(id2label)
>>> print("Id2label:", id2label)
Id2label: {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}

让我们为每个类别分配颜色🎨。这将帮助我们更有效地可视化分割结果,并使我们更容易解释图像中的不同类别。

sidewalk_palette = [
    [0, 0, 0],  # unlabeled
    [216, 82, 24],  # flat-road
    [255, 255, 0],  # flat-sidewalk
    [125, 46, 141],  # flat-crosswalk
    [118, 171, 47],  # flat-cyclinglane
    [161, 19, 46],  # flat-parkingdriveway
    [255, 0, 0],  # flat-railtrack
    [0, 128, 128],  # flat-curb
    [190, 190, 0],  # human-person
    [0, 255, 0],  # human-rider
    [0, 0, 255],  # vehicle-car
    [170, 0, 255],  # vehicle-truck
    [84, 84, 0],  # vehicle-bus
    [84, 170, 0],  # vehicle-tramtrain
    [84, 255, 0],  # vehicle-motorcycle
    [170, 84, 0],  # vehicle-bicycle
    [170, 170, 0],  # vehicle-caravan
    [170, 255, 0],  # vehicle-cartrailer
    [255, 84, 0],  # construction-building
    [255, 170, 0],  # construction-door
    [255, 255, 0],  # construction-wall
    [33, 138, 200],  # construction-fenceguardrail
    [0, 170, 127],  # construction-bridge
    [0, 255, 127],  # construction-tunnel
    [84, 0, 127],  # construction-stairs
    [84, 84, 127],  # object-pole
    [84, 170, 127],  # object-trafficsign
    [84, 255, 127],  # object-trafficlight
    [170, 0, 127],  # nature-vegetation
    [170, 84, 127],  # nature-terrain
    [170, 170, 127],  # sky
    [170, 255, 127],  # void-ground
    [255, 0, 127],  # void-dynamic
    [255, 84, 127],  # void-static
    [255, 170, 127],  # void-unclear
]

我们可以可视化数据集中的某些示例,包括 RGB 图像、相应的掩码以及掩码在图像上的叠加。这将帮助我们更好地了解数据集以及掩码如何与图像对应。📸

>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from PIL import Image
>>> import matplotlib.patches as patches

>>> # Create and show the legend separately
>>> fig, ax = plt.subplots(figsize=(18, 2))

>>> legend_patches = [
...     patches.Patch(color=np.array(color) / 255, label=label)
...     for label, color in zip(id2label.values(), sidewalk_palette)
... ]

>>> ax.legend(handles=legend_patches, loc="center", bbox_to_anchor=(0.5, 0.5), ncol=5, fontsize=8)
>>> ax.axis("off")

>>> plt.show()

>>> for i in range(5):
...     image = train_ds[i]

...     fig, ax = plt.subplots(1, 3, figsize=(18, 6))

...     # Show the original image
...     ax[0].imshow(image["pixel_values"])
...     ax[0].set_title("Original Image")
...     ax[0].axis("off")

...     mask_np = np.array(image["label"])

...     # Create a new empty RGB image
...     colored_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)

...     # Assign colors to each value in the mask
...     for label_id, color in enumerate(sidewalk_palette):
...         colored_mask[mask_np == label_id] = color

...     colored_mask_img = Image.fromarray(colored_mask, "RGB")

...     # Show the segmentation mask
...     ax[1].imshow(colored_mask_img)
...     ax[1].set_title("Segmentation Mask")
...     ax[1].axis("off")

...     # Convert the original image to RGBA to support transparency
...     image_rgba = image["pixel_values"].convert("RGBA")
...     colored_mask_rgba = colored_mask_img.convert("RGBA")

...     # Adjust transparency of the mask
...     alpha = 128  # Transparency level (0 fully transparent, 255 fully opaque)
...     image_2_with_alpha = Image.new("RGBA", colored_mask_rgba.size)
...     for x in range(colored_mask_rgba.width):
...         for y in range(colored_mask_rgba.height):
...             r, g, b, a = colored_mask_rgba.getpixel((x, y))
...             image_2_with_alpha.putpixel((x, y), (r, g, b, alpha))

...     superposed = Image.alpha_composite(image_rgba, image_2_with_alpha)

...     # Show the mask overlay
...     ax[2].imshow(superposed)
...     ax[2].set_title("Mask Overlay")
...     ax[2].axis("off")

...     plt.show()

4. 可视化类别出现次数 📊

为了更深入地了解数据集,让我们绘制每个类别的出现次数。这将使我们能够了解类别的分布,并识别数据集中任何潜在的偏差或不平衡。

import matplotlib.pyplot as plt
import numpy as np

class_counts = np.zeros(len(id2label))

for example in train_ds:
    mask_np = np.array(example["label"])
    unique, counts = np.unique(mask_np, return_counts=True)
    for u, c in zip(unique, counts):
        class_counts[u] += c
>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from matplotlib import patches

>>> labels = list(id2label.values())

>>> # Normalize colors to be in the range [0, 1]
>>> normalized_palette = [tuple(c / 255 for c in color) for color in sidewalk_palette]

>>> # Visualization
>>> fig, ax = plt.subplots(figsize=(12, 8))

>>> bars = ax.bar(range(len(labels)), class_counts, color=[normalized_palette[i] for i in range(len(labels))])

>>> ax.set_xticks(range(len(labels)))
>>> ax.set_xticklabels(labels, rotation=90, ha="right")

>>> ax.set_xlabel("Categories", fontsize=14)
>>> ax.set_ylabel("Number of Occurrences", fontsize=14)
>>> ax.set_title("Number of Occurrences by Category", fontsize=16)

>>> ax.grid(axis="y", linestyle="--", alpha=0.7)

>>> # Adjust the y-axis limit
>>> y_max = max(class_counts)
>>> ax.set_ylim(0, y_max * 1.25)

>>> for bar in bars:
...     height = int(bar.get_height())
...     offset = 10  # Adjust the text location
...     ax.text(
...         bar.get_x() + bar.get_width() / 2.0,
...         height + offset,
...         f"{height}",
...         ha="center",
...         va="bottom",
...         rotation=90,
...         fontsize=10,
...         color="black",
...     )

>>> fig.legend(
...     handles=legend_patches, loc="center left", bbox_to_anchor=(1, 0.5), ncol=1, fontsize=8
... )  # Adjust ncol as needed

>>> plt.tight_layout()
>>> plt.show()

5. 初始化图像处理器并使用 Albumentations 添加数据增强 📸

我们将首先初始化图像处理器,然后使用 Albumentations 应用数据增强🪄。这将有助于增强我们的数据集并提高语义分割模型的性能。

import albumentations as A
from transformers import SegformerImageProcessor

image_processor = SegformerImageProcessor()

albumentations_transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.7),
        A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.5),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=25, val_shift_limit=20, p=0.5),
        A.GaussianBlur(blur_limit=(3, 5), p=0.3),
        A.GaussNoise(var_limit=(10, 50), p=0.4),
    ]
)


def train_transforms(example_batch):
    augmented_images = [albumentations_transform(image=np.array(x))["image"] for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(augmented_images, labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

6. 从检查点初始化模型

我们将使用来自检查点 nvidia/mit-b0 的预训练 Segformer 模型。该架构在论文 SegFormer: 基于 Transformer 的简单高效语义分割设计 中进行了详细介绍,并在 ImageNet-1k 上进行了训练。

from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name, id2label=id2label, label2id=label2id)

7. 设置训练参数并连接到 Weights & Biases 📉

接下来,我们将配置训练参数并连接到 Weights & Biases (W&B)。W&B 将帮助我们跟踪实验、可视化指标和管理模型训练工作流,在整个过程中提供宝贵的见解。

from transformers import TrainingArguments

output_dir = "segformer-b0-segments-sidewalk-finetuned"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=6e-5,
    num_train_epochs=20,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    report_to="wandb",
)
import wandb

wandb.init(
    project="segformer-b0-segments-sidewalk-finetuned",  # change this
    name="segformer-b0-segments-sidewalk-finetuned",  # change this
    config=training_args,
)

8. 使用 evaluate 设置自定义 compute_metrics 方法以增强日志记录

我们将使用 平均交并比 (mean IoU) 作为评估模型性能的主要指标。这将使我们能够详细跟踪每个类别的性能。

此外,我们将调整评估模块的日志记录级别,以最大程度地减少输出中的警告。如果在图像中未检测到某个类别,您可能会看到以下警告

RuntimeWarning: invalid value encountered in divide iou = total_area_intersect / total_area_union

如果您希望查看这些警告并继续执行下一步,可以跳过此单元格。

import evaluate

evaluate.logging.set_verbosity_error()
import torch
from torch import nn
import multiprocessing

metric = evaluate.load("mean_iou")


def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        # currently using _compute instead of compute: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        pred_labels = logits_tensor.detach().cpu().numpy()
        import warnings

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

9. 在我们的数据集上训练模型 🏋

现在是时候在我们自定义的数据集上训练模型了。我们将使用准备好的训练参数和连接的 Weights & Biases 集成来监控训练过程并根据需要进行调整。让我们开始训练并观察模型如何提高其性能!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)
trainer.train()

10. 评估模型在新图像上的性能 📸

训练后,我们将评估模型在新图像上的性能。我们将使用测试图像并利用 管道 来评估模型在未见数据上的表现。

import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw

url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"

image = Image.open(requests.get(url, stream=True).raw)

image_segmentator = pipeline(
    "image-segmentation", model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned"  # Change with your model name
)

results = image_segmentator(image)
>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

模型已经生成了一些掩码,因此我们可以对其进行可视化以评估和理解其性能。这将帮助我们了解模型分割图像的效果如何,并识别任何需要改进的区域。

>>> image_array = np.array(image)

>>> segmentation_map = np.zeros_like(image_array)

>>> for result in results:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

11. 评估测试集上的性能 📊

metrics = trainer.evaluate(test_ds)
print(metrics)

12. 使用推理 API 访问模型并可视化结果 🔌

Hugging Face 🤗 提供了一个 无服务器推理 API,允许您免费通过 API 端点直接测试模型。有关使用此 API 的详细指南,请查看此 指南

我们将使用此 API 来探索其功能,并了解如何将其用于测试我们的模型。

重要

在使用无服务器推理 API 之前,您需要通过创建模型卡来设置模型任务。在为您的微调模型创建模型卡时,请确保您适当地指定了任务。

image.png

设置模型任务后,我们可以下载图像并使用 InferenceClient 来测试模型。此客户端将允许我们通过 API 将图像发送到模型并检索结果以进行评估。

>>> url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

我们将使用 InferenceClient 中的 image_segmentation 方法。此方法将模型和图像作为输入,并返回预测的掩码。这将使我们能够测试模型在新图像上的性能。

from huggingface_hub import InferenceClient

client = InferenceClient()

response = client.image_segmentation(
    model="sergiopaniego/segformer-b0-segments-sidewalk-finetuned",  # Change with your model name
    image="https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)

print(response)

有了预测的掩码,我们就可以显示结果了。

>>> image_array = np.array(image)
>>> segmentation_map = np.zeros_like(image_array)

>>> for result in response:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

也可以使用 带有 JavaScript 的推理 API。以下是如何使用 JavaScript 使用 API 的示例

import { HfInference } from "@huggingface/inference";

const inference = new HfInference(HF_TOKEN);
await inference.imageSegmentation({
    data: await (await fetch("https://picsum.photos/300/300")).blob(),
    model: "sergiopaniego/segformer-b0-segments-sidewalk-finetuned",
});

额外要点

您还可以使用 Hugging Face Space 部署微调的模型。例如,我创建了一个自定义 Space 来展示这一点:在 Segments/Sidewalk 上微调的 SegFormer 的语义分割

HF Spaces logo
from IPython.display import IFrame

IFrame(src="https://sergiopaniego-segformer-b0-segments-sidewalk-finetuned.hf.space", width=1000, height=800)

结论

在本指南中,我们成功地在自定义数据集上微调了语义分割模型,并利用了无服务器推理 API 来对其进行测试。这演示了如何轻松地将模型集成到各种应用程序中,以及如何利用 Hugging Face 工具进行部署。

希望本指南能为您提供工具和知识,让您能够自信地微调和部署自己的模型! 🚀

< > 在 GitHub 上更新