Mamba出击

社区文章 发布于2024年10月18日

MambaOut 结构上类似于状态空间模型(SSM),但它褪去了那层外皮,露出了更精简、更凶猛的内核。它是机器学习模型中的新娘,是一位受过数据功夫训练的战士,随时准备对任何横亘在其路径上的数据集造成破坏。正如科比会后仰跳投命中那些不可能的投篮一样,MambaOut 会淘汰不必要的计算,只留下纯粹、致命的性能。

</LLM 关闭>

抱歉,我本想写一个与流行文化中“黑曼巴”相关的引人入胜的开场白,结果却太搞笑了。说正经的,这个模型现已作为 timm 1.0.11 版的一部分(https://github.com/huggingface/pytorch-image-models)被包含在内,值得一试。我花了一些时间对其进行实验,并从头训练了一些模型,在原始权重的基础上增加了我自己的几个调整。

那么这个模型到底是什么呢?正如其俏皮的名称所示,其核心模块的结构与其他 Mamba (SSM) 视觉模型类似,但并未包含 SSM 本身。没有 SSM 意味着没有自定义编译内核或额外的依赖项,太棒了!它只是一个 7x7 DW 卷积,在两个 FC 层之间夹着一个门控机制。请参阅下面论文中的图表(https://arxiv.org/abs/2405.07992

image/png

在我看来,这个模型与 ConvNeXt 系列有很多共同点——没有 BatchNorm,纯卷积,使用 7x7 DW 卷积进行空间混合,一些 1x1 PW / FC 进行通道混合。事实上,在运行时性能和训练行为方面,它也感觉像是 ConvNeXt 系列的延伸。比较它们的每个核心构建块,我们就能明白为什么了

image/png

与 ConvNeXt 相比,MambaOut 的额外拆分、拼接等操作会增加一些开销,但通过使用 torch.compile(),这些开销被抵消了,并且它们具有相当的竞争力。在将模型引入 `timm` 时,我尝试了对模型结构和训练配方进行了一些小调整。你可以在下面带有 `_rw` 后缀的模型中看到它们。在基础尺寸范围内,这些更改产生了一些稍微更快的变体,参数更多,并且能够挤出更高的准确度。在 `tall`、`wide` 和 `short` 基础变体中,我认为 `tall`(稍微更深、稍微更窄)是最有价值的。

不过,我最有趣的补充是 `base_plus`。它在深度和宽度上都有所增加,并且在 ImageNet-12k 上进行了预训练,使其与 `timm` 中最好的 ImageNet-22k 预训练模型不相上下。在查看原始预训练权重时,我首先想到的问题之一是,较小的模型在其尺寸下表现相当不错,但 `base` 模型却几乎没有进步,这是怎么回事?是缩放出了问题吗?

不。在 1.02 亿参数下,`base_plus` 的准确率与 ImageNet-22k 预训练的 ConvNeXt-Large(约 2 亿参数)持平或更高,与最佳的 22k 训练的 ViT-Large(DeiT-III,约 3 亿参数)相差不远,并且在考虑运行时性能的情况下,它明显优于 Swin/SwinV2-Large。

模型 图像尺寸 top1 top5 参数数量
mambaout_base_plus_rw.sw_e150_r384_in12k_ft_in1k 384 87.506 98.428 101.66
mambaout_base_plus_rw.sw_e150_in12k_ft_in1k 288 86.912 98.236 101.66
mambaout_base_plus_rw.sw_e150_in12k_ft_in1k 224 86.632 98.156 101.66
mambaout_base_tall_rw.sw_e500_in1k 288 84.974 97.332 86.48
mambaout_base_wide_rw.sw_e500_in1k 288 84.962 97.208 94.45
mambaout_base_short_rw.sw_e500_in1k 288 84.832 97.27 88.83
mambaout_base.in1k 288 84.72 96.93 84.81
mambaout_small_rw.sw_e450_in1k 288 84.598 97.098 48.5
mambaout_small.in1k 288 84.5 96.974 48.49
mambaout_base_wide_rw.sw_e500_in1k 224 84.454 96.864 94.45
mambaout_base_tall_rw.sw_e500_in1k 224 84.434 96.958 86.48
mambaout_base_short_rw.sw_e500_in1k 224 84.362 96.952 88.83
mambaout_base.in1k 224 84.168 96.68 84.81
mambaout_small.in1k 224 84.086 96.63 48.49
mambaout_small_rw.sw_e450_in1k 224 84.024 96.752 48.5
mambaout_tiny.in1k 288 83.448 96.538 26.55
mambaout_tiny.in1k 224 82.736 96.1 26.55
mambaout_kobe.in1k 288 81.054 95.718 9.14
mambaout_kobe.in1k 224 79.986 94.986 9.14
mambaout_femto.in1k 288 79.848 95.14 7.3
mambaout_femto.in1k 224 78.87 94.408 7.3

所以,`pip install --upgrade timm`,然后试试看 :)

论文作者的原始研究代码库:https://github.com/yuweihao/MambaOut

@article{yu2024mambaout,
  title={MambaOut: Do We Really Need Mamba for Vision?},
  author={Yu, Weihao and Wang, Xinchao},
  journal={arXiv preprint arXiv:2405.07992},
  year={2024}
}

社区

注册登录以发表评论