开源 AI 食谱文档

在自定义数据集上微调语义分割模型并通过 Inference API 使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Open In Colab

在自定义数据集上微调语义分割模型并通过 Inference API 使用

作者:Sergio Paniego

在本笔记本中,我们将逐步介绍在自定义数据集上微调 语义分割 模型的流程。我们将使用的模型是预训练的 Segformer,这是一种强大且灵活的基于 Transformer 的架构,适用于分割任务。

Segformer architecture

对于我们的数据集,我们将使用 segments/sidewalk-semantic,其中包含人行道的标记图像,非常适合城市环境中的应用。

用例示例:此模型可以部署在自主导航人行道以将披萨送到您家门口的送货机器人中 🍕

微调模型后,我们将演示如何使用 Serverless Inference API 部署它,使其可以通过简单的 API 端点访问。

1. 安装依赖项

首先,我们将安装微调语义分割模型所需的必要库。

!pip install -q datasets transformers evaluate wandb
# Tested with datasets==3.0.0, transformers==4.44.2, evaluate==0.4.3, wandb==0.18.1

2. 加载数据集 📁

我们将使用 sidewalk-semantic 数据集,该数据集包含 2021 年夏季在比利时收集的人行道图像。

该数据集包括

  • 1,000 张图像及其对应的语义分割掩码 🖼
  • 34 个不同的类别 📦

由于此数据集是受限的,因此您需要登录并接受许可才能获得访问权限。我们还需要身份验证才能在训练后将微调模型上传到 Hub。

from huggingface_hub import notebook_login

notebook_login()
sidewalk_dataset_identifier = "segments/sidewalk-semantic"
from datasets import load_dataset

dataset = load_dataset(sidewalk_dataset_identifier)

查看内部结构以熟悉它!

dataset

由于数据集仅包含训练集拆分,我们将手动将其划分为 训练集和测试集。我们将分配 80% 的数据用于训练,并保留剩余 20% 用于评估和测试。 ➗

dataset = dataset.shuffle(seed=42)
dataset = dataset["train"].train_test_split(test_size=0.2)
train_ds = dataset["train"]
test_ds = dataset["test"]

让我们检查一下示例中存在的对象类型。我们可以看到 pixels_values 保存 RGB 图像,而 label 包含真实掩码。掩码是单通道图像,其中每个像素代表 RGB 图像中对应像素的类别。

image = train_ds[0]
image

3. 可视化示例! 👀

现在我们已经加载了数据集,让我们可视化一些示例及其掩码,以便更好地理解其结构。

数据集包含一个 JSON 文件,其中包含 id2label 映射。我们将打开此文件以读取与每个 ID 关联的类别标签。

>>> import json
>>> from huggingface_hub import hf_hub_download

>>> filename = "id2label.json"
>>> id2label = json.load(
...     open(hf_hub_download(repo_id=sidewalk_dataset_identifier, filename=filename, repo_type="dataset"), "r")
... )
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}

>>> num_labels = len(id2label)
>>> print("Id2label:", id2label)
Id2label: {0: 'unlabeled', 1: 'flat-road', 2: 'flat-sidewalk', 3: 'flat-crosswalk', 4: 'flat-cyclinglane', 5: 'flat-parkingdriveway', 6: 'flat-railtrack', 7: 'flat-curb', 8: 'human-person', 9: 'human-rider', 10: 'vehicle-car', 11: 'vehicle-truck', 12: 'vehicle-bus', 13: 'vehicle-tramtrain', 14: 'vehicle-motorcycle', 15: 'vehicle-bicycle', 16: 'vehicle-caravan', 17: 'vehicle-cartrailer', 18: 'construction-building', 19: 'construction-door', 20: 'construction-wall', 21: 'construction-fenceguardrail', 22: 'construction-bridge', 23: 'construction-tunnel', 24: 'construction-stairs', 25: 'object-pole', 26: 'object-trafficsign', 27: 'object-trafficlight', 28: 'nature-vegetation', 29: 'nature-terrain', 30: 'sky', 31: 'void-ground', 32: 'void-dynamic', 33: 'void-static', 34: 'void-unclear'}

让我们为每个类别分配颜色 🎨。这将帮助我们更有效地可视化分割结果,并更容易解释图像中的不同类别。

sidewalk_palette = [
    [0, 0, 0],  # unlabeled
    [216, 82, 24],  # flat-road
    [255, 255, 0],  # flat-sidewalk
    [125, 46, 141],  # flat-crosswalk
    [118, 171, 47],  # flat-cyclinglane
    [161, 19, 46],  # flat-parkingdriveway
    [255, 0, 0],  # flat-railtrack
    [0, 128, 128],  # flat-curb
    [190, 190, 0],  # human-person
    [0, 255, 0],  # human-rider
    [0, 0, 255],  # vehicle-car
    [170, 0, 255],  # vehicle-truck
    [84, 84, 0],  # vehicle-bus
    [84, 170, 0],  # vehicle-tramtrain
    [84, 255, 0],  # vehicle-motorcycle
    [170, 84, 0],  # vehicle-bicycle
    [170, 170, 0],  # vehicle-caravan
    [170, 255, 0],  # vehicle-cartrailer
    [255, 84, 0],  # construction-building
    [255, 170, 0],  # construction-door
    [255, 255, 0],  # construction-wall
    [33, 138, 200],  # construction-fenceguardrail
    [0, 170, 127],  # construction-bridge
    [0, 255, 127],  # construction-tunnel
    [84, 0, 127],  # construction-stairs
    [84, 84, 127],  # object-pole
    [84, 170, 127],  # object-trafficsign
    [84, 255, 127],  # object-trafficlight
    [170, 0, 127],  # nature-vegetation
    [170, 84, 127],  # nature-terrain
    [170, 170, 127],  # sky
    [170, 255, 127],  # void-ground
    [255, 0, 127],  # void-dynamic
    [255, 84, 127],  # void-static
    [255, 170, 127],  # void-unclear
]

我们可以可视化数据集中的一些示例,包括 RGB 图像、对应的掩码以及图像上掩码的叠加。这将帮助我们更好地理解数据集以及掩码如何对应于图像。 📸

>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from PIL import Image
>>> import matplotlib.patches as patches

>>> # Create and show the legend separately
>>> fig, ax = plt.subplots(figsize=(18, 2))

>>> legend_patches = [
...     patches.Patch(color=np.array(color) / 255, label=label)
...     for label, color in zip(id2label.values(), sidewalk_palette)
... ]

>>> ax.legend(handles=legend_patches, loc="center", bbox_to_anchor=(0.5, 0.5), ncol=5, fontsize=8)
>>> ax.axis("off")

>>> plt.show()

>>> for i in range(5):
...     image = train_ds[i]

...     fig, ax = plt.subplots(1, 3, figsize=(18, 6))

...     # Show the original image
...     ax[0].imshow(image["pixel_values"])
...     ax[0].set_title("Original Image")
...     ax[0].axis("off")

...     mask_np = np.array(image["label"])

...     # Create a new empty RGB image
...     colored_mask = np.zeros((mask_np.shape[0], mask_np.shape[1], 3), dtype=np.uint8)

...     # Assign colors to each value in the mask
...     for label_id, color in enumerate(sidewalk_palette):
...         colored_mask[mask_np == label_id] = color

...     colored_mask_img = Image.fromarray(colored_mask, "RGB")

...     # Show the segmentation mask
...     ax[1].imshow(colored_mask_img)
...     ax[1].set_title("Segmentation Mask")
...     ax[1].axis("off")

...     # Convert the original image to RGBA to support transparency
...     image_rgba = image["pixel_values"].convert("RGBA")
...     colored_mask_rgba = colored_mask_img.convert("RGBA")

...     # Adjust transparency of the mask
...     alpha = 128  # Transparency level (0 fully transparent, 255 fully opaque)
...     image_2_with_alpha = Image.new("RGBA", colored_mask_rgba.size)
...     for x in range(colored_mask_rgba.width):
...         for y in range(colored_mask_rgba.height):
...             r, g, b, a = colored_mask_rgba.getpixel((x, y))
...             image_2_with_alpha.putpixel((x, y), (r, g, b, alpha))

...     superposed = Image.alpha_composite(image_rgba, image_2_with_alpha)

...     # Show the mask overlay
...     ax[2].imshow(superposed)
...     ax[2].set_title("Mask Overlay")
...     ax[2].axis("off")

...     plt.show()

4. 可视化类别出现次数 📊

为了更深入地了解数据集,让我们绘制每个类别的出现次数。这将使我们能够了解类别的分布,并识别数据集中任何潜在的偏差或不平衡。

import matplotlib.pyplot as plt
import numpy as np

class_counts = np.zeros(len(id2label))

for example in train_ds:
    mask_np = np.array(example["label"])
    unique, counts = np.unique(mask_np, return_counts=True)
    for u, c in zip(unique, counts):
        class_counts[u] += c
>>> from matplotlib import pyplot as plt
>>> import numpy as np
>>> from matplotlib import patches

>>> labels = list(id2label.values())

>>> # Normalize colors to be in the range [0, 1]
>>> normalized_palette = [tuple(c / 255 for c in color) for color in sidewalk_palette]

>>> # Visualization
>>> fig, ax = plt.subplots(figsize=(12, 8))

>>> bars = ax.bar(range(len(labels)), class_counts, color=[normalized_palette[i] for i in range(len(labels))])

>>> ax.set_xticks(range(len(labels)))
>>> ax.set_xticklabels(labels, rotation=90, ha="right")

>>> ax.set_xlabel("Categories", fontsize=14)
>>> ax.set_ylabel("Number of Occurrences", fontsize=14)
>>> ax.set_title("Number of Occurrences by Category", fontsize=16)

>>> ax.grid(axis="y", linestyle="--", alpha=0.7)

>>> # Adjust the y-axis limit
>>> y_max = max(class_counts)
>>> ax.set_ylim(0, y_max * 1.25)

>>> for bar in bars:
...     height = int(bar.get_height())
...     offset = 10  # Adjust the text location
...     ax.text(
...         bar.get_x() + bar.get_width() / 2.0,
...         height + offset,
...         f"{height}",
...         ha="center",
...         va="bottom",
...         rotation=90,
...         fontsize=10,
...         color="black",
...     )

>>> fig.legend(
...     handles=legend_patches, loc="center left", bbox_to_anchor=(1, 0.5), ncol=1, fontsize=8
... )  # Adjust ncol as needed

>>> plt.tight_layout()
>>> plt.show()

5. 初始化图像处理器并使用 Albumentations 添加数据增强 📸

我们将首先初始化图像处理器,然后使用 Albumentations 应用数据增强 🪄。这将有助于增强我们的数据集并提高语义分割模型的性能。

import albumentations as A
from transformers import SegformerImageProcessor

image_processor = SegformerImageProcessor()

albumentations_transform = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.7),
        A.RandomResizedCrop(height=512, width=512, scale=(0.8, 1.0), ratio=(0.75, 1.33), p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.25, contrast_limit=0.25, p=0.5),
        A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=25, val_shift_limit=20, p=0.5),
        A.GaussianBlur(blur_limit=(3, 5), p=0.3),
        A.GaussNoise(var_limit=(10, 50), p=0.4),
    ]
)


def train_transforms(example_batch):
    augmented = [
        albumentations_transform(image=np.array(image), mask=np.array(label))
        for image, label in zip(example_batch["pixel_values"], example_batch["label"])
    ]
    augmented_images = [item["image"] for item in augmented]
    augmented_labels = [item["mask"] for item in augmented]
    inputs = image_processor(augmented_images, augmented_labels)
    return inputs


def val_transforms(example_batch):
    images = [x for x in example_batch["pixel_values"]]
    labels = [x for x in example_batch["label"]]
    inputs = image_processor(images, labels)
    return inputs


# Set transforms
train_ds.set_transform(train_transforms)
test_ds.set_transform(val_transforms)

6. 从检查点初始化模型

我们将使用来自检查点的预训练 Segformer 模型:nvidia/mit-b0。此架构在论文 SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers 中详细介绍,并在 ImageNet-1k 上进行了训练。

from transformers import SegformerForSemanticSegmentation

pretrained_model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(pretrained_model_name, id2label=id2label, label2id=label2id)

7. 设置训练参数并连接到 Weights & Biases 📉

接下来,我们将配置训练参数并连接到 Weights & Biases (W&B)。W&B 将帮助我们跟踪实验、可视化指标和管理模型训练工作流程,在整个过程中提供有价值的见解。

from transformers import TrainingArguments

output_dir = "test-segformer-b0-segments-sidewalk-finetuned"

training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=6e-5,
    num_train_epochs=20,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    save_total_limit=2,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=20,
    eval_steps=20,
    logging_steps=1,
    eval_accumulation_steps=5,
    load_best_model_at_end=True,
    push_to_hub=True,
    report_to="wandb",
)
import wandb

wandb.init(
    project="test-segformer-b0-segments-sidewalk-finetuned",  # change this
    name="test-segformer-b0-segments-sidewalk-finetuned",  # change this
    config=training_args,
)

8. 设置自定义 compute_metrics 方法以使用 evaluate 进行增强日志记录

我们将使用 平均交并比 (mean IoU) 作为评估模型性能的主要指标。这将使我们能够详细跟踪每个类别的性能。

此外,我们将调整评估模块的日志记录级别,以最大程度地减少输出中的警告。如果在图像中未检测到类别,您可能会看到如下警告

RuntimeWarning: invalid value encountered in divide iou = total_area_intersect / total_area_union

如果您希望看到这些警告并继续下一步,则可以跳过此单元格。

import evaluate

evaluate.logging.set_verbosity_error()
import torch
from torch import nn
import multiprocessing

metric = evaluate.load("mean_iou")


def compute_metrics(eval_pred):
    with torch.no_grad():
        logits, labels = eval_pred
        logits_tensor = torch.from_numpy(logits)
        # scale the logits to the size of the label
        logits_tensor = nn.functional.interpolate(
            logits_tensor,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False,
        ).argmax(dim=1)

        # currently using _compute instead of compute: https://github.com/huggingface/evaluate/pull/328#issuecomment-1286866576
        pred_labels = logits_tensor.detach().cpu().numpy()
        import warnings

        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            metrics = metric._compute(
                predictions=pred_labels,
                references=labels,
                num_labels=len(id2label),
                ignore_index=0,
                reduce_labels=image_processor.do_reduce_labels,
            )

        # add per category metrics as individual key-value pairs
        per_category_accuracy = metrics.pop("per_category_accuracy").tolist()
        per_category_iou = metrics.pop("per_category_iou").tolist()

        metrics.update({f"accuracy_{id2label[i]}": v for i, v in enumerate(per_category_accuracy)})
        metrics.update({f"iou_{id2label[i]}": v for i, v in enumerate(per_category_iou)})

        return metrics

9. 在我们的数据集上训练模型 🏋

现在是在我们的自定义数据集上训练模型的时候了。我们将使用准备好的训练参数和连接的 Weights & Biases 集成来监控训练过程,并在需要时进行调整。让我们开始训练,并观察模型如何提高其性能!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)
trainer.train()

10. 评估模型在新图像上的性能 📸

训练完成后,我们将评估模型在新图像上的性能。我们将使用测试图像并利用 pipeline 来评估模型在未见过的数据上的表现。

import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw

url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"

image = Image.open(requests.get(url, stream=True).raw)

image_segmentator = pipeline(
    "image-segmentation",
    model="sergiopaniego/test-segformer-b0-segments-sidewalk-finetuned",  # Change with your model name
)

results = image_segmentator(image)
>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

该模型已经生成了一些掩码,因此我们可以可视化它们以评估和理解其性能。这将帮助我们了解模型对图像的分割效果,并找出任何需要改进的领域。

>>> image_array = np.array(image)

>>> segmentation_map = np.zeros_like(image_array)

>>> for result in results:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

11. 评估模型在测试集上的性能 📊

>>> metrics = trainer.evaluate(test_ds)
>>> print(metrics)
{'eval_loss': 0.6063494086265564, 'eval_mean_iou': 0.26682655949637757, 'eval_mean_accuracy': 0.3233445959272099, 'eval_overall_accuracy': 0.834762670692357, 'eval_accuracy_unlabeled': nan, 'eval_accuracy_flat-road': 0.8794976463015708, 'eval_accuracy_flat-sidewalk': 0.9287807675111692, 'eval_accuracy_flat-crosswalk': 0.5247038032656313, 'eval_accuracy_flat-cyclinglane': 0.795399495199148, 'eval_accuracy_flat-parkingdriveway': 0.4010852199852775, 'eval_accuracy_flat-railtrack': nan, 'eval_accuracy_flat-curb': 0.4902816930389514, 'eval_accuracy_human-person': 0.5913439011934908, 'eval_accuracy_human-rider': 0.0, 'eval_accuracy_vehicle-car': 0.9253204043875328, 'eval_accuracy_vehicle-truck': 0.0, 'eval_accuracy_vehicle-bus': 0.0, 'eval_accuracy_vehicle-tramtrain': 0.0, 'eval_accuracy_vehicle-motorcycle': 0.0, 'eval_accuracy_vehicle-bicycle': 0.0013499147866290941, 'eval_accuracy_vehicle-caravan': 0.0, 'eval_accuracy_vehicle-cartrailer': 0.0, 'eval_accuracy_construction-building': 0.8815560533904696, 'eval_accuracy_construction-door': 0.0, 'eval_accuracy_construction-wall': 0.4455930603622635, 'eval_accuracy_construction-fenceguardrail': 0.3431640802292688, 'eval_accuracy_construction-bridge': 0.0, 'eval_accuracy_construction-tunnel': nan, 'eval_accuracy_construction-stairs': 0.0, 'eval_accuracy_object-pole': 0.24341265579591848, 'eval_accuracy_object-trafficsign': 0.0, 'eval_accuracy_object-trafficlight': 0.0, 'eval_accuracy_nature-vegetation': 0.9478392425169023, 'eval_accuracy_nature-terrain': 0.8560970005175594, 'eval_accuracy_sky': 0.9530036096232858, 'eval_accuracy_void-ground': 0.0, 'eval_accuracy_void-dynamic': 0.0, 'eval_accuracy_void-static': 0.13859852156564748, 'eval_accuracy_void-unclear': 0.0, 'eval_iou_unlabeled': nan, 'eval_iou_flat-road': 0.7270368663334998, 'eval_iou_flat-sidewalk': 0.8484429155310914, 'eval_iou_flat-crosswalk': 0.3716762279636531, 'eval_iou_flat-cyclinglane': 0.6983685965068486, 'eval_iou_flat-parkingdriveway': 0.3073600964845036, 'eval_iou_flat-railtrack': nan, 'eval_iou_flat-curb': 0.3781660047058077, 'eval_iou_human-person': 0.38559031115261033, 'eval_iou_human-rider': 0.0, 'eval_iou_vehicle-car': 0.7473290757373612, 'eval_iou_vehicle-truck': 0.0, 'eval_iou_vehicle-bus': 0.0, 'eval_iou_vehicle-tramtrain': 0.0, 'eval_iou_vehicle-motorcycle': 0.0, 'eval_iou_vehicle-bicycle': 0.0013499147866290941, 'eval_iou_vehicle-caravan': 0.0, 'eval_iou_vehicle-cartrailer': 0.0, 'eval_iou_construction-building': 0.6637240016649857, 'eval_iou_construction-door': 0.0, 'eval_iou_construction-wall': 0.3336225132267832, 'eval_iou_construction-fenceguardrail': 0.3131070176565442, 'eval_iou_construction-bridge': 0.0, 'eval_iou_construction-tunnel': nan, 'eval_iou_construction-stairs': 0.0, 'eval_iou_object-pole': 0.17741310577170807, 'eval_iou_object-trafficsign': 0.0, 'eval_iou_object-trafficlight': 0.0, 'eval_iou_nature-vegetation': 0.837720086429597, 'eval_iou_nature-terrain': 0.7272281817316115, 'eval_iou_sky': 0.9005169994943569, 'eval_iou_void-ground': 0.0, 'eval_iou_void-dynamic': 0.0, 'eval_iou_void-static': 0.11979798870649179, 'eval_iou_void-unclear': 0.0, 'eval_runtime': 30.5276, 'eval_samples_per_second': 6.551, 'eval_steps_per_second': 0.819, 'epoch': 20.0}

12. 使用 Inference API 访问模型并可视化结果 🔌

Hugging Face 🤗 提供了一个 Serverless Inference API,允许您通过 API 端点直接免费测试模型。有关使用此 API 的详细指南,请查看此 食谱

我们将使用此 API 来探索其功能,并了解如何利用它来测试我们的模型。

重要提示

在使用 Serverless Inference API 之前,您需要通过创建模型卡来设置模型任务。为微调模型创建模型卡时,请确保您正确指定任务。

image.png

设置好模型任务后,我们可以下载图像并使用 InferenceClient 来测试模型。此客户端将允许我们通过 API 将图像发送到模型并检索结果以进行评估。

>>> url = "https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
>>> image = Image.open(requests.get(url, stream=True).raw)

>>> plt.imshow(image)
>>> plt.axis("off")
>>> plt.show()

我们将使用 InferenceClient 中的 image_segmentation 方法。此方法将模型和图像作为输入,并返回预测的掩码。这将使我们能够测试模型在新图像上的表现如何。

from huggingface_hub import InferenceClient

client = InferenceClient()

response = client.image_segmentation(
    model="sergiopaniego/test-segformer-b0-segments-sidewalk-finetuned",  # Change with your model name
    image="https://images.unsplash.com/photo-1594098742644-314fedf61fb6?q=80&w=2672&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D",
)

print(response)

使用预测的掩码,我们可以显示结果。

>>> image_array = np.array(image)
>>> segmentation_map = np.zeros_like(image_array)

>>> for result in response:
...     mask = np.array(result["mask"])
...     label = result["label"]

...     label_index = list(id2label.values()).index(label)

...     color = sidewalk_palette[label_index]

...     for c in range(3):
...         segmentation_map[:, :, c] = np.where(mask, color[c], segmentation_map[:, :, c])

>>> plt.figure(figsize=(10, 10))
>>> plt.imshow(image_array)
>>> plt.imshow(segmentation_map, alpha=0.5)
>>> plt.axis("off")
>>> plt.show()

也可以使用 JavaScript 的 Inference API。以下是您如何使用 JavaScript 使用 API 的示例

import { HfInference } from "@huggingface/inference";

const inference = new HfInference(HF_TOKEN);
await inference.imageSegmentation({
    data: await (await fetch("https://picsum.photos/300/300")).blob(),
    model: "sergiopaniego/segformer-b0-segments-sidewalk-finetuned",
});

附加分

您还可以使用 Hugging Face Space 部署微调模型。例如,我创建了一个自定义 Space 来展示这一点:在 Segments/Sidewalk 上微调的 SegFormer 进行语义分割

HF Spaces logo
from IPython.display import IFrame

IFrame(src="https://sergiopaniego-segformer-b0-segments-sidewalk-finetuned.hf.space", width=1000, height=800)

结论

在本指南中,我们成功地在自定义数据集上微调了语义分割模型,并利用 Serverless Inference API 对其进行了测试。这演示了您可以多么轻松地将模型集成到各种应用程序中,并利用 Hugging Face 工具进行部署。

我希望本指南为您提供自信地微调和部署您自己模型的工具和知识! 🚀

< > 在 GitHub 上更新