Transformers 文档

图像分割

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

入门

图像分割

图像分割模型将对应于图像中不同感兴趣区域的区域分开。这些模型通过为每个像素分配标签来工作。分割有几种类型:语义分割、实例分割和全景分割。

在本指南中,我们将

  1. 了解不同类型的分割.
  2. 提供语义分割的端到端微调示例.

在你开始之前,请确保你已安装所有必要的库

# uncomment to install the necessary libraries
!pip install -q datasets transformers evaluate accelerate

我们鼓励你登录你的 Hugging Face 帐户,以便你可以将你的模型上传并与社区分享。在系统提示时,输入你的令牌以登录

>>> from huggingface_hub import notebook_login

>>> notebook_login()

分割类型

语义分割为图像中的每个像素分配一个标签或类别。让我们看一下语义分割模型输出。它将为图像中遇到的每个对象的实例分配相同的类别,例如,所有猫将被标记为“猫”而不是“猫-1”、“猫-2”。我们可以使用 transformers 的图像分割管道快速推断语义分割模型。让我们看一下示例图像。

from transformers import pipeline
from PIL import Image
import requests

url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/segmentation_input.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image
Segmentation Input

我们将使用 nvidia/segformer-b1-finetuned-cityscapes-1024-1024

semantic_segmentation = pipeline("image-segmentation", "nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
results = semantic_segmentation(image)
results

分割管道输出包括每个预测类别的掩码。

[{'score': None,
  'label': 'road',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'sidewalk',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'building',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'wall',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'pole',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'traffic sign',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'vegetation',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'terrain',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'sky',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': None,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>}]

看一下汽车类别的掩码,我们可以看到每辆车都使用相同的掩码进行分类。

results[-1]["mask"]
Semantic Segmentation Output

在实例分割中,目标不是对每个像素进行分类,而是对给定图像中 **每个对象的实例** 预测一个掩码。它的工作原理与目标检测非常相似,目标检测对每个实例都有一个边界框,而实例分割则有一个分割掩码。我们将使用 facebook/mask2former-swin-large-cityscapes-instance 来完成此操作。

instance_segmentation = pipeline("image-segmentation", "facebook/mask2former-swin-large-cityscapes-instance")
results = instance_segmentation(image)
results

如下所示,有多辆汽车被分类,并且除了属于汽车和人物实例的像素之外,其他像素没有分类。

[{'score': 0.999944,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999945,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999652,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.903529,
  'label': 'person',
  'mask': <PIL.Image.Image image mode=L size=612x415>}]

查看下面的一个汽车掩码。

results[2]["mask"]
Semantic Segmentation Output

全景分割结合了语义分割和实例分割,其中每个像素都分类为一个类别和该类别的实例,并且每个类别实例都有多个掩码。我们可以使用 facebook/mask2former-swin-large-cityscapes-panoptic 来完成此操作。

panoptic_segmentation = pipeline("image-segmentation", "facebook/mask2former-swin-large-cityscapes-panoptic")
results = panoptic_segmentation(image)
results

如下所示,我们有更多的类别。稍后我们将说明每个像素如何被分类到一个类别中。

[{'score': 0.999981,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999958,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.99997,
  'label': 'vegetation',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999575,
  'label': 'pole',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999958,
  'label': 'building',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999634,
  'label': 'road',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.996092,
  'label': 'sidewalk',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.999221,
  'label': 'car',
  'mask': <PIL.Image.Image image mode=L size=612x415>},
 {'score': 0.99987,
  'label': 'sky',
  'mask': <PIL.Image.Image image mode=L size=612x415>}]

让我们并排比较所有类型的分割。

Segmentation Maps Compared

在看到所有类型的分割之后,让我们深入了解如何微调语义分割模型。

语义分割的常见现实世界应用包括训练自动驾驶汽车识别行人和重要的交通信息,识别医学图像中的细胞和异常,以及通过卫星图像监测环境变化。

微调分割模型

我们现在将

  1. SegFormerSceneParse150 数据集上进行微调。
  2. 使用你微调的模型进行推断。

要查看与该任务兼容的所有架构和检查点,我们建议查看 任务页面

加载 SceneParse150 数据集

首先从 🤗 Datasets 库加载 SceneParse150 数据集的一个较小的子集。这将让你有机会进行实验并确保一切正常,然后再花更多时间在完整数据集上进行训练。

>>> from datasets import load_dataset

>>> ds = load_dataset("scene_parse_150", split="train[:50]")

使用 train_test_split 方法将数据集的 train 拆分拆分为训练集和测试集

>>> ds = ds.train_test_split(test_size=0.2)
>>> train_ds = ds["train"]
>>> test_ds = ds["test"]

然后看一下示例

>>> train_ds[0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x683 at 0x7F9B0C201F90>,
 'annotation': <PIL.PngImagePlugin.PngImageFile image mode=L size=512x683 at 0x7F9B0C201DD0>,
 'scene_category': 368}

# view the image
>>> train_ds[0]["image"]
  • image: 场景的 PIL 图像。
  • annotation: 分割图的 PIL 图像,这也是模型的目标。
  • scene_category: 描述图像场景的类别 ID,例如“厨房”或“办公室”。在本指南中,你只需要 imageannotation,它们都是 PIL 图像。

你还需要创建一个字典,将标签 ID 映射到标签类别,这在你稍后设置模型时很有用。从 Hub 下载映射并创建 id2labellabel2id 字典

>>> import json
>>> from pathlib import Path
>>> from huggingface_hub import hf_hub_download

>>> repo_id = "huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.loads(Path(hf_hub_download(repo_id, filename, repo_type="dataset")).read_text())
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)

自定义数据集

如果你更喜欢使用 run_semantic_segmentation.py 脚本而不是笔记本实例进行训练,你也可以创建和使用自己的数据集。该脚本需要

  1. 一个具有两个 Image 列的 DatasetDict,“image”和“label”

    from datasets import Dataset, DatasetDict, Image
    
    image_paths_train = ["path/to/image_1.jpg/jpg", "path/to/image_2.jpg/jpg", ..., "path/to/image_n.jpg/jpg"]
    label_paths_train = ["path/to/annotation_1.png", "path/to/annotation_2.png", ..., "path/to/annotation_n.png"]
    
    image_paths_validation = [...]
    label_paths_validation = [...]
    
    def create_dataset(image_paths, label_paths):
        dataset = Dataset.from_dict({"image": sorted(image_paths),
                                    "label": sorted(label_paths)})
        dataset = dataset.cast_column("image", Image())
        dataset = dataset.cast_column("label", Image())
        return dataset
    
    # step 1: create Dataset objects
    train_dataset = create_dataset(image_paths_train, label_paths_train)
    validation_dataset = create_dataset(image_paths_validation, label_paths_validation)
    
    # step 2: create DatasetDict
    dataset = DatasetDict({
         "train": train_dataset,
         "validation": validation_dataset,
         }
    )
    
    # step 3: push to Hub (assumes you have ran the huggingface-cli login command in a terminal/notebook)
    dataset.push_to_hub("your-name/dataset-repo")
    
    # optionally, you can push to a private repo on the Hub
    # dataset.push_to_hub("name of repo on the hub", private=True)
  2. 一个 id2label 字典,将类别整数映射到它们的类别名称

    import json
    # simple example
    id2label = {0: 'cat', 1: 'dog'}
    with open('id2label.json', 'w') as fp:
    json.dump(id2label, fp)

例如,看一下这个 示例数据集,它使用上面所示的步骤创建。

预处理

下一步是加载 SegFormer 图像处理器,为模型准备图像和标注。一些数据集,比如这个,使用零索引作为背景类。但是,背景类实际上并不包含在 150 个类中,因此您需要设置 `do_reduce_labels=True` 来从所有标签中减去 1。零索引被替换为 `255`,因此 SegFormer 的损失函数会忽略它。

>>> from transformers import AutoImageProcessor

>>> checkpoint = "nvidia/mit-b0"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint, do_reduce_labels=True)
Pytorch
隐藏 Pytorch 内容

通常的做法是对图像数据集应用一些数据增强,使模型更能抵抗过拟合。在本指南中,您将使用来自 torchvisionColorJitter 函数随机更改图像的颜色属性,但您也可以使用任何您喜欢的图像库。

>>> from torchvision.transforms import ColorJitter

>>> jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)

现在创建两个预处理函数来为模型准备图像和标注。这些函数将图像转换为 `pixel_values`,并将标注转换为 `labels`。对于训练集,在将图像提供给图像处理器之前应用 `jitter`。对于测试集,图像处理器裁剪和归一化 `images`,并且只裁剪 `labels`,因为在测试期间不应用数据增强。

>>> def train_transforms(example_batch):
...     images = [jitter(x) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs


>>> def val_transforms(example_batch):
...     images = [x for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs

要在整个数据集上应用 `jitter`,请使用 🤗 Datasets 的 set_transform 函数。变换是动态应用的,这更快且占用的磁盘空间更少。

>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)
TensorFlow
隐藏 TensorFlow 内容

通常的做法是对图像数据集应用一些数据增强,使模型更能抵抗过拟合。在本指南中,您将使用 tf.image 随机更改图像的颜色属性,但您也可以使用任何您喜欢的图像库。定义两个独立的转换函数

  • 包括图像增强的训练数据变换
  • 仅转置图像的验证数据变换,因为 🤗 Transformers 中的计算机视觉模型期望通道优先布局
>>> import tensorflow as tf


>>> def aug_transforms(image):
...     image = tf.keras.utils.img_to_array(image)
...     image = tf.image.random_brightness(image, 0.25)
...     image = tf.image.random_contrast(image, 0.5, 2.0)
...     image = tf.image.random_saturation(image, 0.75, 1.25)
...     image = tf.image.random_hue(image, 0.1)
...     image = tf.transpose(image, (2, 0, 1))
...     return image


>>> def transforms(image):
...     image = tf.keras.utils.img_to_array(image)
...     image = tf.transpose(image, (2, 0, 1))
...     return image

接下来,创建两个预处理函数来为模型准备图像和标注的批次。这些函数应用图像变换,并使用之前加载的 `image_processor` 将图像转换为 `pixel_values`,并将标注转换为 `labels`。`ImageProcessor` 还负责调整图像大小和归一化图像。

>>> def train_transforms(example_batch):
...     images = [aug_transforms(x.convert("RGB")) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs


>>> def val_transforms(example_batch):
...     images = [transforms(x.convert("RGB")) for x in example_batch["image"]]
...     labels = [x for x in example_batch["annotation"]]
...     inputs = image_processor(images, labels)
...     return inputs

要在整个数据集上应用预处理变换,请使用 🤗 Datasets 的 set_transform 函数。变换是动态应用的,这更快且占用的磁盘空间更少。

>>> train_ds.set_transform(train_transforms)
>>> test_ds.set_transform(val_transforms)

评估

在训练期间包含指标通常有助于评估模型的性能。您可以使用 🤗 Evaluate 库快速加载评估方法。对于此任务,请加载 平均交并比 (IoU) 指标(参见 🤗 Evaluate 快速教程,了解有关如何加载和计算指标的更多信息)。

>>> import evaluate

>>> metric = evaluate.load("mean_iou")

然后创建一个函数来 `compute` 指标。您的预测需要先转换为 logits,然后重新整形以匹配标签的大小,才能调用 `compute`。

Pytorch
隐藏 Pytorch 内容
>>> import numpy as np
>>> import torch
>>> from torch import nn

>>> def compute_metrics(eval_pred):
...     with torch.no_grad():
...         logits, labels = eval_pred
...         logits_tensor = torch.from_numpy(logits)
...         logits_tensor = nn.functional.interpolate(
...             logits_tensor,
...             size=labels.shape[-2:],
...             mode="bilinear",
...             align_corners=False,
...         ).argmax(dim=1)

...         pred_labels = logits_tensor.detach().cpu().numpy()
...         metrics = metric.compute(
...             predictions=pred_labels,
...             references=labels,
...             num_labels=num_labels,
...             ignore_index=255,
...             reduce_labels=False,
...         )
...         for key, value in metrics.items():
...             if isinstance(value, np.ndarray):
...                 metrics[key] = value.tolist()
...         return metrics
TensorFlow
隐藏 TensorFlow 内容
>>> def compute_metrics(eval_pred):
...     logits, labels = eval_pred
...     logits = tf.transpose(logits, perm=[0, 2, 3, 1])
...     logits_resized = tf.image.resize(
...         logits,
...         size=tf.shape(labels)[1:],
...         method="bilinear",
...     )

...     pred_labels = tf.argmax(logits_resized, axis=-1)
...     metrics = metric.compute(
...         predictions=pred_labels,
...         references=labels,
...         num_labels=num_labels,
...         ignore_index=-1,
...         reduce_labels=image_processor.do_reduce_labels,
...     )

...     per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
...     per_category_iou = metrics.pop("per_category_iou").tolist()

...     metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
...     metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})
...     return {"val_" + k: v for k, v in metrics.items()}

您的 `compute_metrics` 函数现在已经准备就绪,您将在设置训练时回到它。

训练

Pytorch
隐藏 Pytorch 内容

如果您不熟悉使用 Trainer 微调模型,请查看 此处 的基本教程!

您现在可以开始训练您的模型了!使用 AutoModelForSemanticSegmentation 加载 SegFormer,并将标签 ID 与标签类别之间的映射传递给模型。

>>> from transformers import AutoModelForSemanticSegmentation, TrainingArguments, Trainer

>>> model = AutoModelForSemanticSegmentation.from_pretrained(checkpoint, id2label=id2label, label2id=label2id)

此时,只剩下三个步骤。

  1. TrainingArguments 中定义您的训练超参数。重要的是不要删除未使用的列,因为这将删除 `image` 列。如果没有 `image` 列,您将无法创建 `pixel_values`。设置 `remove_unused_columns=False` 以防止此行为!另一个必需的参数是 `output_dir`,它指定保存模型的位置。通过设置 `push_to_hub=True` 将此模型推送到 Hub(您需要登录 Hugging Face 才能上传您的模型)。在每个纪元的末尾,Trainer 将评估 IoU 指标并保存训练检查点。
  2. 将训练参数传递给 Trainer,以及模型、数据集、tokenizer、数据整理器和 `compute_metrics` 函数。
  3. 调用 train() 来微调您的模型。
>>> training_args = TrainingArguments(
...     output_dir="segformer-b0-scene-parse-150",
...     learning_rate=6e-5,
...     num_train_epochs=50,
...     per_device_train_batch_size=2,
...     per_device_eval_batch_size=2,
...     save_total_limit=3,
...     eval_strategy="steps",
...     save_strategy="steps",
...     save_steps=20,
...     eval_steps=20,
...     logging_steps=1,
...     eval_accumulation_steps=5,
...     remove_unused_columns=False,
...     push_to_hub=True,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=train_ds,
...     eval_dataset=test_ds,
...     compute_metrics=compute_metrics,
... )

>>> trainer.train()

训练完成后,使用 push_to_hub() 方法将您的模型共享到 Hub,以便每个人都可以使用您的模型。

>>> trainer.push_to_hub()
TensorFlow
隐藏 TensorFlow 内容

如果您不熟悉使用 Keras 微调模型,请先查看 基本教程

要使用 TensorFlow 微调模型,请按照以下步骤操作。

  1. 定义训练超参数,并设置优化器和学习率调度器。
  2. 实例化一个预训练模型。
  3. 将 🤗 Dataset 转换为 `tf.data.Dataset`。
  4. 编译您的模型。
  5. 添加回调以计算指标并将您的模型上传到 🤗 Hub。
  6. 使用 `fit()` 方法运行训练。

首先定义超参数、优化器和学习率调度器。

>>> from transformers import create_optimizer

>>> batch_size = 2
>>> num_epochs = 50
>>> num_train_steps = len(train_ds) * num_epochs
>>> learning_rate = 6e-5
>>> weight_decay_rate = 0.01

>>> optimizer, lr_schedule = create_optimizer(
...     init_lr=learning_rate,
...     num_train_steps=num_train_steps,
...     weight_decay_rate=weight_decay_rate,
...     num_warmup_steps=0,
... )

然后,使用 TFAutoModelForSemanticSegmentation 以及标签映射加载 SegFormer,并使用优化器对其进行编译。请注意,Transformers 模型都具有默认的任务相关损失函数,因此您无需指定损失函数,除非您想要。

>>> from transformers import TFAutoModelForSemanticSegmentation

>>> model = TFAutoModelForSemanticSegmentation.from_pretrained(
...     checkpoint,
...     id2label=id2label,
...     label2id=label2id,
... )
>>> model.compile(optimizer=optimizer)  # No loss argument!

使用 to_tf_datasetDefaultDataCollator 将您的数据集转换为 `tf.data.Dataset` 格式。

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator(return_tensors="tf")

>>> tf_train_dataset = train_ds.to_tf_dataset(
...     columns=["pixel_values", "label"],
...     shuffle=True,
...     batch_size=batch_size,
...     collate_fn=data_collator,
... )

>>> tf_eval_dataset = test_ds.to_tf_dataset(
...     columns=["pixel_values", "label"],
...     shuffle=True,
...     batch_size=batch_size,
...     collate_fn=data_collator,
... )

要从预测中计算准确度并将您的模型推送到 🤗 Hub,请使用 Keras 回调。将您的 `compute_metrics` 函数传递给 KerasMetricCallback,并使用 PushToHubCallback 上传模型。

>>> from transformers.keras_callbacks import KerasMetricCallback, PushToHubCallback

>>> metric_callback = KerasMetricCallback(
...     metric_fn=compute_metrics, eval_dataset=tf_eval_dataset, batch_size=batch_size, label_cols=["labels"]
... )

>>> push_to_hub_callback = PushToHubCallback(output_dir="scene_segmentation", tokenizer=image_processor)

>>> callbacks = [metric_callback, push_to_hub_callback]

最后,您就可以训练您的模型了!使用您的训练和验证数据集、纪元数量和回调调用 `fit()` 来微调模型。

>>> model.fit(
...     tf_train_dataset,
...     validation_data=tf_eval_dataset,
...     callbacks=callbacks,
...     epochs=num_epochs,
... )

恭喜!您已经微调了模型并在 🤗 Hub 上共享了它。您现在可以使用它进行推断了!

推断

太好了,既然您已经微调了模型,您就可以使用它进行推断了!

重新加载数据集并加载图像以进行推断。

>>> from datasets import load_dataset

>>> ds = load_dataset("scene_parse_150", split="train[:50]")
>>> ds = ds.train_test_split(test_size=0.2)
>>> test_ds = ds["test"]
>>> image = ds["test"][0]["image"]
>>> image
Image of bedroom
Pytorch
隐藏 Pytorch 内容

我们现在将了解如何在没有 pipeline 的情况下进行推断。使用图像处理器处理图像并将 `pixel_values` 放置在 GPU 上。

>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # use GPU if available, otherwise use a CPU
>>> encoding = image_processor(image, return_tensors="pt")
>>> pixel_values = encoding.pixel_values.to(device)

将您的输入传递给模型并返回 `logits`。

>>> outputs = model(pixel_values=pixel_values)
>>> logits = outputs.logits.cpu()

接下来,将 logits 重新缩放到原始图像大小。

>>> upsampled_logits = nn.functional.interpolate(
...     logits,
...     size=image.size[::-1],
...     mode="bilinear",
...     align_corners=False,
... )

>>> pred_seg = upsampled_logits.argmax(dim=1)[0]
TensorFlow
隐藏 TensorFlow 内容

加载图像处理器以预处理图像并将输入作为 TensorFlow 张量返回。

>>> from transformers import AutoImageProcessor

>>> image_processor = AutoImageProcessor.from_pretrained("MariaK/scene_segmentation")
>>> inputs = image_processor(image, return_tensors="tf")

将您的输入传递给模型并返回 `logits`。

>>> from transformers import TFAutoModelForSemanticSegmentation

>>> model = TFAutoModelForSemanticSegmentation.from_pretrained("MariaK/scene_segmentation")
>>> logits = model(**inputs).logits

接下来,将 logits 重新缩放到原始图像大小,并在类维度上应用 argmax。

>>> logits = tf.transpose(logits, [0, 2, 3, 1])

>>> upsampled_logits = tf.image.resize(
...     logits,
...     # We reverse the shape of `image` because `image.size` returns width and height.
...     image.size[::-1],
... )

>>> pred_seg = tf.math.argmax(upsampled_logits, axis=-1)[0]

要可视化结果,请加载 数据集颜色调色板 作为 `ade_palette()`,它将每个类映射到其 RGB 值。

def ade_palette():
  return np.asarray([
      [0, 0, 0],
      [120, 120, 120],
      [180, 120, 120],
      [6, 230, 230],
      [80, 50, 50],
      [4, 200, 3],
      [120, 120, 80],
      [140, 140, 140],
      [204, 5, 255],
      [230, 230, 230],
      [4, 250, 7],
      [224, 5, 255],
      [235, 255, 7],
      [150, 5, 61],
      [120, 120, 70],
      [8, 255, 51],
      [255, 6, 82],
      [143, 255, 140],
      [204, 255, 4],
      [255, 51, 7],
      [204, 70, 3],
      [0, 102, 200],
      [61, 230, 250],
      [255, 6, 51],
      [11, 102, 255],
      [255, 7, 71],
      [255, 9, 224],
      [9, 7, 230],
      [220, 220, 220],
      [255, 9, 92],
      [112, 9, 255],
      [8, 255, 214],
      [7, 255, 224],
      [255, 184, 6],
      [10, 255, 71],
      [255, 41, 10],
      [7, 255, 255],
      [224, 255, 8],
      [102, 8, 255],
      [255, 61, 6],
      [255, 194, 7],
      [255, 122, 8],
      [0, 255, 20],
      [255, 8, 41],
      [255, 5, 153],
      [6, 51, 255],
      [235, 12, 255],
      [160, 150, 20],
      [0, 163, 255],
      [140, 140, 140],
      [250, 10, 15],
      [20, 255, 0],
      [31, 255, 0],
      [255, 31, 0],
      [255, 224, 0],
      [153, 255, 0],
      [0, 0, 255],
      [255, 71, 0],
      [0, 235, 255],
      [0, 173, 255],
      [31, 0, 255],
      [11, 200, 200],
      [255, 82, 0],
      [0, 255, 245],
      [0, 61, 255],
      [0, 255, 112],
      [0, 255, 133],
      [255, 0, 0],
      [255, 163, 0],
      [255, 102, 0],
      [194, 255, 0],
      [0, 143, 255],
      [51, 255, 0],
      [0, 82, 255],
      [0, 255, 41],
      [0, 255, 173],
      [10, 0, 255],
      [173, 255, 0],
      [0, 255, 153],
      [255, 92, 0],
      [255, 0, 255],
      [255, 0, 245],
      [255, 0, 102],
      [255, 173, 0],
      [255, 0, 20],
      [255, 184, 184],
      [0, 31, 255],
      [0, 255, 61],
      [0, 71, 255],
      [255, 0, 204],
      [0, 255, 194],
      [0, 255, 82],
      [0, 10, 255],
      [0, 112, 255],
      [51, 0, 255],
      [0, 194, 255],
      [0, 122, 255],
      [0, 255, 163],
      [255, 153, 0],
      [0, 255, 10],
      [255, 112, 0],
      [143, 255, 0],
      [82, 0, 255],
      [163, 255, 0],
      [255, 235, 0],
      [8, 184, 170],
      [133, 0, 255],
      [0, 255, 92],
      [184, 0, 255],
      [255, 0, 31],
      [0, 184, 255],
      [0, 214, 255],
      [255, 0, 112],
      [92, 255, 0],
      [0, 224, 255],
      [112, 224, 255],
      [70, 184, 160],
      [163, 0, 255],
      [153, 0, 255],
      [71, 255, 0],
      [255, 0, 163],
      [255, 204, 0],
      [255, 0, 143],
      [0, 255, 235],
      [133, 255, 0],
      [255, 0, 235],
      [245, 0, 255],
      [255, 0, 122],
      [255, 245, 0],
      [10, 190, 212],
      [214, 255, 0],
      [0, 204, 255],
      [20, 0, 255],
      [255, 255, 0],
      [0, 153, 255],
      [0, 41, 255],
      [0, 255, 204],
      [41, 0, 255],
      [41, 255, 0],
      [173, 0, 255],
      [0, 245, 255],
      [71, 0, 255],
      [122, 0, 255],
      [0, 255, 184],
      [0, 92, 255],
      [184, 255, 0],
      [0, 133, 255],
      [255, 214, 0],
      [25, 194, 194],
      [102, 255, 0],
      [92, 0, 255],
  ])

然后您可以组合和绘制图像和预测的分割图。

>>> import matplotlib.pyplot as plt
>>> import numpy as np

>>> color_seg = np.zeros((pred_seg.shape[0], pred_seg.shape[1], 3), dtype=np.uint8)
>>> palette = np.array(ade_palette())
>>> for label, color in enumerate(palette):
...     color_seg[pred_seg == label, :] = color
>>> color_seg = color_seg[..., ::-1]  # convert to BGR

>>> img = np.array(image) * 0.5 + color_seg * 0.5  # plot the image with the segmentation map
>>> img = img.astype(np.uint8)

>>> plt.figure(figsize=(15, 10))
>>> plt.imshow(img)
>>> plt.show()
Image of bedroom overlaid with segmentation map
< > 更新 on GitHub