Swin Transformer
Swin Transformer 架构在 2021 年的论文 Swin Transformer: Hierarchical Vision Transformer using Shifted Windows 中提出,它使用移位窗口(而不是滑动窗口)方法优化延迟和性能,从而减少了所需的运算次数。Swin 被认为是计算机视觉的**分层主干**。Swin 可用于图像分类等任务。
在深度学习中,主干是指神经网络中进行特征提取的部分。可以在主干上添加额外的层来执行各种视觉任务。分层主干具有分层结构,有时具有不同的分辨率。这与 VitDet 模型中非分层的**普通主干**形成对比。
主要亮点
移位窗口
在原始 ViT 中,注意力是在每个 patch 和所有其他 patch 之间进行的,这在计算上非常密集。Swin 通过将 ViT 通常的二次复杂度降低到线性复杂度(相对于图像大小)来优化此过程。Swin 使用类似于 CNN 的技术实现这一点,其中 patch 只关注同一窗口中的其他 patch,而不是所有其他 patch,然后逐渐与相邻 patch 合并。这就是使 Swin 成为分层模型的原因。
图片取自 Swin Transformer 论文
优势
计算效率
Swin 的性能优于完全基于 patch 的方法,例如 ViT。
大型数据集
SwinV2 是首批 30 亿参数模型之一。随着训练规模的增加,Swin 的性能优于 CNN。大量的参数使学习能力和更复杂的表示能力增强。
Swin Transformer V2(论文)
Swin Transformer V2 是一种大型视觉模型,最多可以支持 30 亿个参数,并且能够使用高分辨率图像进行训练。它通过稳定训练、将使用低分辨率图像预训练的迁移模型应用于高分辨率任务以及使用 SimMIM(一种减少训练所需标记图像数量的自监督训练方法)改进了原始的 Swin Transformer。
图像恢复中的应用
SwinIR (论文)
SwinIR 是一种基于 Swin Transformer 的模型,用于将低分辨率图像转换为高分辨率图像。
Swin2SR (论文)
Swin2SR 是另一个图像恢复模型。它是 SwinIR 的改进版本,通过整合 Swin Transformer V2,利用 Swin V2 的优势,例如训练稳定性和更高的图像分辨率能力。
Swin 的 PyTorch 实现概述
下面概述了原始论文中 Swin 实现的关键部分
Swin Transformer 类
初始化参数。除了各种其他 dropout 和归一化参数外,这些参数包括
window_size
:局部自注意的窗口大小。ape (bool)
:如果为 True,则将绝对位置嵌入添加到 patch 嵌入中。fused_window_process
:可选的硬件优化
应用 Patch 嵌入:与 ViT 类似,图像被分割成不重叠的 patch,并使用
Conv2D
进行线性嵌入。应用位置嵌入:
SwinTransformer
可选择使用绝对位置嵌入 (ape
),将其添加到 patch 嵌入中。绝对位置嵌入通常有助于模型学习使用每个 patch 的位置信息来做出更明智的预测。应用深度衰减:深度衰减有助于正则化并防止过拟合。深度衰减通常通过在训练期间跳过层来完成。在这个 Swin 实现中,使用了**随机**深度衰减,这意味着层越深,跳过的可能性越高。
层构建:
- 该模型由多层 (
BasicLayer
) 的SwinTransformerBlock
组成,每个SwinTransformerBlock
使用PatchMerging
对特征图进行下采样以进行分层处理。 - 特征的维度和特征图的分辨率在各层之间发生变化。
- 该模型由多层 (
分类头:与 ViT 类似,它使用多层感知器 (MLP) 头进行分类任务,如
self.head
中定义的那样,作为最后一步
class SwinTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=4,
in_chans=3,
num_classes=1000,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
fused_window_process=False,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None,
)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2**i_layer),
input_resolution=(
patches_resolution[0] // (2**i_layer),
patches_resolution[1] // (2**i_layer),
),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
fused_window_process=fused_window_process,
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {"absolute_pos_embed"}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"relative_position_bias_table"}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
x = self.avgpool(x.transpose(1, 2)) # B C 1
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
Swin Transformer 块
SwinTransformerBlock
封装了 Swin Transformer 的核心操作:局部窗口化注意力和随后的 MLP 处理。它在使 Swin Transformer能够有效地处理大型图像方面发挥着关键作用,因为它专注于局部 patch,同时保持学习全局表示的能力。
层组件:
- 归一化层 1 (
self.norm1
):在注意力机制之前应用。 - 窗口注意力 (
self.attn
):在局部窗口内计算自注意力。 - Drop Path (
self.drop_path
):实现随机深度以进行正则化。 - 归一化层 2 (
self.norm2
):在 MLP 层之前应用。 - MLP (
mlp
):用于处理注意力后特征的多层感知器。 - 注意力掩码 (
self.register_buffer
):注意力掩码在自注意力计算期间用于控制窗口化输入中的哪些元素允许交互(即相互关注)。移位窗口方法通过允许一些跨窗口交互来帮助模型捕获更广泛的上下文信息。
Swin Transformer 块的初始化
class SwinTransformerBlock(nn.Module):
r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
fused_window_process=False,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
w_slices = (
slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None),
)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size
) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
self.fused_window_process = fused_window_process
### New cell ###
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
)
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(
x, B, H, W, C, -self.shift_size, self.window_size
)
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size
) # nW*B, window_size, window_size, C
x_windows = x_windows.view(
-1, self.window_size * self.window_size, C
) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(
attn_windows, self.window_size, H, W
) # B H' W' C
x = torch.roll(
shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
)
else:
x = WindowProcessReverse.apply(
attn_windows, B, H, W, C, self.shift_size, self.window_size
)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# Feed-forward network (FFN)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
Swin Transformer 块的前向传递
有 4 个关键步骤
- 循环移位:特征图通过
window_partition
被划分为窗口。然后对分区应用循环移位。循环移位是通过将序列中的元素(在本例中为分区)向左或向右移动,并将超出某一端的部分环绕到另一端来完成。此过程更改了元素彼此之间的相对位置,但保持序列本身完整。例如,如果您将序列A, B, C, D
向右循环移位一个位置,它将变为D, A, B, C
。
循环移位允许模型捕获相邻窗口之间的关系,增强其学习超出单个窗口局部范围的空间上下文的能力。
窗口化注意力:使用基于窗口的多头自注意力 (W-MSA) 模块执行注意力
合并 Patch:Patch 通过
PatchMerging
合并反向循环移位:注意力完成后,窗口分区通过
reverse_window
取消,并且循环移位操作被反转,以便特征图保留其原始形式。
class WindowAttention(nn.Module):
"""
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = (
coords_flatten[:, :, None] - coords_flatten[:, None, :]
) # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
1
).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
窗口注意力
WindowAttention
是一个具有相对位置偏差的基于窗口的多头自注意力 (W-MSA) 模块。这可用于移位和非移位窗口。
class PatchMerging(nn.Module):
r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x