Transformers 文档

ViTMatte

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

ViTMatte

PyTorch

概述

ViTMatte 模型在 Boosting Image Matting with Pretrained Plain Vision Transformers 中被提出,作者是 Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang。ViTMatte 利用普通的 Vision Transformers 来完成图像抠图的任务,即准确估计图像和视频中的前景对象的过程。

以下是论文的摘要

最近,普通的 Vision Transformers (ViTs) 在各种计算机视觉任务上表现出令人印象深刻的性能,这归功于它们强大的建模能力和大规模预训练。然而,它们尚未征服图像抠图的问题。我们假设图像抠图也可以通过 ViTs 得到提升,并提出了一个新的高效且鲁棒的基于 ViT 的抠图系统,名为 ViTMatte。我们的方法利用了 (i) 一种混合注意力机制,结合卷积颈部,以帮助 ViTs 在抠图任务中实现出色的性能-计算权衡。(ii) 此外,我们引入了细节捕捉模块,该模块仅由简单的轻量级卷积组成,以补充抠图所需的详细信息。据我们所知,ViTMatte 是第一个通过简洁的适配来释放 ViT 在图像抠图方面的潜力。它从 ViT 继承了许多优越的属性到抠图,包括各种预训练策略、简洁的架构设计和灵活的推理策略。我们在 Composition-1k 和 Distinctions-646(图像抠图最常用的基准)上评估了 ViTMatte,我们的方法实现了最先进的性能,并且大大优于之前的抠图工作。

此模型由 nielsr 贡献。原始代码可以在 这里 找到。

drawing ViTMatte 高级概述。摘自原始论文。

资源

以下是官方 Hugging Face 和社区 (🌎 表示) 资源列表,可帮助您开始使用 ViTMatte。

该模型期望图像和 trimap(串联)作为输入。使用 ViTMatteImageProcessor 以达到此目的。

VitMatteConfig

class transformers.VitMatteConfig

< >

( backbone_config: PretrainedConfig = None backbone = None use_pretrained_backbone = False use_timm_backbone = False backbone_kwargs = None hidden_size: int = 384 batch_norm_eps: float = 1e-05 initializer_range: float = 0.02 convstream_hidden_sizes: typing.List[int] = [48, 96, 192] fusion_hidden_sizes: typing.List[int] = [256, 128, 64, 32] **kwargs )

参数

  • backbone_config (PretrainedConfigdict, 可选, 默认为 VitDetConfig()) — 骨干模型的配置。
  • backbone (str, 可选) — 当 backbone_configNone 时,要使用的骨干网络的名称。如果 use_pretrained_backboneTrue,这将从 timm 或 transformers 库加载相应的预训练权重。如果 use_pretrained_backboneFalse,这将加载骨干网络的配置,并使用它来初始化具有随机权重的骨干网络。
  • use_pretrained_backbone (bool, 可选, 默认为 False) — 是否对骨干网络使用预训练权重。
  • use_timm_backbone (bool, 可选, 默认为 False) — 是否从 timm 库加载 backbone。如果为 False,则从 transformers 库加载骨干网络。
  • backbone_kwargs (dict, 可选) — 从检查点加载时要传递给 AutoBackbone 的关键字参数,例如 {'out_indices': (0, 1, 2, 3)}。如果设置了 backbone_config,则无法指定。
  • hidden_size (int, 可选, 默认为 384) — 解码器的输入通道数。
  • batch_norm_eps (float, 可选, 默认为 1e-05) — 批归一化层使用的 epsilon 值。
  • initializer_range (float, 可选, 默认为 0.02) — 用于初始化所有权重矩阵的截断正态分布初始化器的标准差。
  • convstream_hidden_sizes (List[int], 可选, 默认为 [48, 96, 192]) — ConvStream 模块的输出通道数。
  • fusion_hidden_sizes (List[int], 可选, 默认为 [256, 128, 64, 32]) — Fusion 模块的输出通道数。

这是用于存储 VitMatteForImageMatting 配置的配置类。 它用于根据指定的参数实例化 ViTMatte 模型,定义模型架构。 使用默认值实例化配置将生成与 ViTMatte hustvl/vitmatte-small-composition-1k 架构类似的配置。

配置对象继承自 PretrainedConfig,可用于控制模型输出。 有关更多信息,请阅读 PretrainedConfig 的文档。

示例

>>> from transformers import VitMatteConfig, VitMatteForImageMatting

>>> # Initializing a ViTMatte hustvl/vitmatte-small-composition-1k style configuration
>>> configuration = VitMatteConfig()

>>> # Initializing a model (with random weights) from the hustvl/vitmatte-small-composition-1k style configuration
>>> model = VitMatteForImageMatting(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config

to_dict

< >

( )

将此实例序列化为 Python 字典。 覆盖默认的 to_dict()。 返回值: Dict[str, any]:构成此配置实例的所有属性的字典,

VitMatteImageProcessor

class transformers.VitMatteImageProcessor

< >

( do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_normalize: bool = True image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None do_pad: bool = True size_divisibility: int = 32 **kwargs )

参数

  • do_rescale (bool, 可选, 默认为 True) — 是否按指定的比例 rescale_factor 缩放图像。 可以通过 preprocess 方法中的 do_rescale 参数覆盖。
  • rescale_factor (intfloat, 可选, 默认为 1/255) — 如果缩放图像,则使用的缩放因子。 可以通过 preprocess 方法中的 rescale_factor 参数覆盖。
  • do_normalize (bool, 可选, 默认为 True) — 是否标准化图像。 可以通过 preprocess 方法中的 do_normalize 参数覆盖。
  • image_mean (floatList[float], 可选, 默认为 IMAGENET_STANDARD_MEAN) — 如果标准化图像,则使用的均值。 这是一个浮点数或浮点数列表,其长度为图像中通道的数量。 可以通过 preprocess 方法中的 image_mean 参数覆盖。
  • image_std (floatList[float], 可选, 默认为 IMAGENET_STANDARD_STD) — 如果标准化图像,则使用的标准差。 这是一个浮点数或浮点数列表,其长度为图像中通道的数量。 可以通过 preprocess 方法中的 image_std 参数覆盖。
  • do_pad (bool, 可选, 默认为 True) — 是否填充图像以使宽度和高度可被 size_divisibility 整除。 可以通过 preprocess 方法中的 do_pad 参数覆盖。
  • size_divisibility (int, 可选, 默认为 32) — 图像的宽度和高度将被填充为可被该数字整除的尺寸。

构建 ViTMatte 图像处理器。

preprocess

< >

( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] trimaps: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), list['PIL.Image.Image'], list[numpy.ndarray], list['torch.Tensor']] do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_normalize: typing.Optional[bool] = None image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None do_pad: typing.Optional[bool] = None size_divisibility: typing.Optional[int] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Union[str, transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None )

参数

  • images (ImageInput) — 要预处理的图像。 期望单个或批量的图像,像素值范围为 0 到 255。 如果传入像素值介于 0 和 1 之间的图像,请设置 do_rescale=False
  • trimaps (ImageInput) — 要预处理的 Trimap。
  • do_rescale (bool, 可选, 默认为 self.do_rescale) — 是否将图像值重新缩放到 [0 - 1] 之间。
  • rescale_factor (float, 可选, 默认为 self.rescale_factor) — 如果 do_rescale 设置为 True,则用于重新缩放图像的缩放因子。
  • do_normalize (bool, 可选, 默认为 self.do_normalize) — 是否标准化图像。
  • image_mean (floatList[float], 可选, 默认为 self.image_mean) — 如果 do_normalize 设置为 True,则使用的图像均值。
  • image_std (floatList[float], 可选, 默认为 self.image_std) — 如果 do_normalize 设置为 True,则使用的图像标准差。
  • do_pad (bool, 可选, 默认为 self.do_pad) — 是否填充图像。
  • size_divisibility (int, optional, defaults to self.size_divisibility) — 如果 do_pad 设置为 True,则将图像填充到此尺寸可整除性。
  • return_tensors (strTensorType, optional) — 返回张量的类型。可以是以下之一:
    • Unset: 返回 np.ndarray 列表。
    • TensorType.TENSORFLOW'tf': 返回 tf.Tensor 类型的批次。
    • TensorType.PYTORCH'pt': 返回 torch.Tensor 类型的批次。
    • TensorType.NUMPY'np': 返回 np.ndarray 类型的批次。
    • TensorType.JAX'jax': 返回 jax.numpy.ndarray 类型的批次。
  • data_format (ChannelDimensionstr, optional, defaults to ChannelDimension.FIRST) — 输出图像的通道维度格式。可以是以下之一:
    • "channels_first"ChannelDimension.FIRST: 图像格式为 (num_channels, height, width)。
    • "channels_last"ChannelDimension.LAST: 图像格式为 (height, width, num_channels)。
    • Unset: 使用输入图像的通道维度格式。
  • input_data_format (ChannelDimensionstr, optional) — 输入图像的通道维度格式。如果未设置,则通道维度格式将从输入图像中推断。可以是以下之一:
    • "channels_first"ChannelDimension.FIRST: 图像格式为 (num_channels, height, width)。
    • "channels_last"ChannelDimension.LAST: 图像格式为 (height, width, num_channels)。
    • "none"ChannelDimension.NONE: 图像格式为 (height, width)。

预处理图像或图像批次。

VitMatteForImageMatting

class transformers.VitMatteForImageMatting

< >

( config )

参数

  • This 模型是 PyTorch [torch.nn.Module](https —//pytorch.ac.cn/docs/stable/nn.html#torch.nn.Module) 的子类。使用
  • it 作为一个常规 PyTorch 模块使用,并参考 PyTorch 文档以了解与通用用法和
  • behavior. — config (UperNetConfig): 带有模型所有参数的模型配置类。使用配置文件初始化不会加载与模型关联的权重,仅加载配置。查看 from_pretrained() 方法来加载模型权重。

ViTMatte 框架利用任何视觉骨干网络,例如用于 ADE20k、CityScapes。

forward

< >

( pixel_values: typing.Optional[torch.Tensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None labels: typing.Optional[torch.Tensor] = None return_dict: typing.Optional[bool] = None ) transformers.models.vitmatte.modeling_vitmatte.ImageMattingOutputtuple(torch.FloatTensor)

参数

  • pixel_values (torch.FloatTensor,形状为 (batch_size, num_channels, height, width)) — 像素值。如果您提供填充,默认情况下填充将被忽略。像素值可以使用 AutoImageProcessor 获得。有关详细信息,请参阅 VitMatteImageProcessor.call()
  • output_attentions (bool, optional) — 是否返回骨干网络中所有注意力层的注意力张量(如果骨干网络有)。有关更多详细信息,请参阅返回张量下的 attentions
  • output_hidden_states (bool, optional) — 是否返回骨干网络所有层的隐藏状态。有关更多详细信息,请参阅返回张量下的 hidden_states
  • return_dict (bool, optional) — 是否返回 ModelOutput 而不是普通元组。
  • labels (torch.LongTensor,形状为 (batch_size, height, width), optional) — 用于计算损失的真值图像抠图。

返回

transformers.models.vitmatte.modeling_vitmatte.ImageMattingOutputtuple(torch.FloatTensor)

一个 transformers.models.vitmatte.modeling_vitmatte.ImageMattingOutputtorch.FloatTensor 元组(如果传递 return_dict=False 或当 config.return_dict=False 时),包含各种元素,具体取决于配置 (VitMatteConfig) 和输入。

  • loss (torch.FloatTensor,形状为 (1,), optional, 当提供 labels 时返回) — 损失。

  • alphas (torch.FloatTensor,形状为 (batch_size, num_channels, height, width)) — 估计的 alpha 值。

  • hidden_states (tuple(torch.FloatTensor), optional, 当传递 output_hidden_states=True 或当 config.output_hidden_states=True 时返回) — torch.FloatTensor 元组(如果模型有嵌入层,则为嵌入输出 + 每个阶段的输出),形状为 (batch_size, sequence_length, hidden_size)。模型在每个阶段输出的隐藏状态(也称为特征图)。

  • attentions (tuple(torch.FloatTensor), optional, 当传递 output_attentions=True 或当 config.output_attentions=True 时返回) — torch.FloatTensor 元组(每层一个),形状为 (batch_size, num_heads, patch_size, sequence_length)

    注意力 softmax 之后的注意力权重,用于计算自注意力头中的加权平均值。

VitMatteForImageMatting 的 forward 方法,覆盖了 __call__ 特殊方法。

虽然前向传播的步骤需要在该函数中定义,但应该在之后调用 Module 实例而不是此函数,因为前者负责运行预处理和后处理步骤,而后者会默默地忽略它们。

示例

>>> from transformers import VitMatteImageProcessor, VitMatteForImageMatting
>>> import torch
>>> from PIL import Image
>>> from huggingface_hub import hf_hub_download

>>> processor = VitMatteImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
>>> model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")

>>> filepath = hf_hub_download(
...     repo_id="hf-internal-testing/image-matting-fixtures", filename="image.png", repo_type="dataset"
... )
>>> image = Image.open(filepath).convert("RGB")
>>> filepath = hf_hub_download(
...     repo_id="hf-internal-testing/image-matting-fixtures", filename="trimap.png", repo_type="dataset"
... )
>>> trimap = Image.open(filepath).convert("L")

>>> # prepare image + trimap for the model
>>> inputs = processor(images=image, trimaps=trimap, return_tensors="pt")

>>> with torch.no_grad():
...     alphas = model(**inputs).alphas
>>> print(alphas.shape)
torch.Size([1, 1, 640, 960])
< > 在 GitHub 上更新