Mamba出击
MambaOut 结构上类似于状态空间模型(SSM),但它褪去了那层外皮,露出了更精简、更凶猛的内核。它是机器学习模型中的新娘,是一位受过数据功夫训练的战士,随时准备对任何横亘在其路径上的数据集造成破坏。正如科比会后仰跳投命中那些不可能的投篮一样,MambaOut 会淘汰不必要的计算,只留下纯粹、致命的性能。
</LLM 关闭>
抱歉,我本想写一个与流行文化中“黑曼巴”相关的引人入胜的开场白,结果却太搞笑了。说正经的,这个模型现已作为 timm 1.0.11 版的一部分(https://github.com/huggingface/pytorch-image-models)被包含在内,值得一试。我花了一些时间对其进行实验,并从头训练了一些模型,在原始权重的基础上增加了我自己的几个调整。
那么这个模型到底是什么呢?正如其俏皮的名称所示,其核心模块的结构与其他 Mamba (SSM) 视觉模型类似,但并未包含 SSM 本身。没有 SSM 意味着没有自定义编译内核或额外的依赖项,太棒了!它只是一个 7x7 DW 卷积,在两个 FC 层之间夹着一个门控机制。请参阅下面论文中的图表(https://arxiv.org/abs/2405.07992)
在我看来,这个模型与 ConvNeXt 系列有很多共同点——没有 BatchNorm,纯卷积,使用 7x7 DW 卷积进行空间混合,一些 1x1 PW / FC 进行通道混合。事实上,在运行时性能和训练行为方面,它也感觉像是 ConvNeXt 系列的延伸。比较它们的每个核心构建块,我们就能明白为什么了
与 ConvNeXt 相比,MambaOut 的额外拆分、拼接等操作会增加一些开销,但通过使用 torch.compile(),这些开销被抵消了,并且它们具有相当的竞争力。在将模型引入 `timm` 时,我尝试了对模型结构和训练配方进行了一些小调整。你可以在下面带有 `_rw` 后缀的模型中看到它们。在基础尺寸范围内,这些更改产生了一些稍微更快的变体,参数更多,并且能够挤出更高的准确度。在 `tall`、`wide` 和 `short` 基础变体中,我认为 `tall`(稍微更深、稍微更窄)是最有价值的。
不过,我最有趣的补充是 `base_plus`。它在深度和宽度上都有所增加,并且在 ImageNet-12k 上进行了预训练,使其与 `timm` 中最好的 ImageNet-22k 预训练模型不相上下。在查看原始预训练权重时,我首先想到的问题之一是,较小的模型在其尺寸下表现相当不错,但 `base` 模型却几乎没有进步,这是怎么回事?是缩放出了问题吗?
不。在 1.02 亿参数下,`base_plus` 的准确率与 ImageNet-22k 预训练的 ConvNeXt-Large(约 2 亿参数)持平或更高,与最佳的 22k 训练的 ViT-Large(DeiT-III,约 3 亿参数)相差不远,并且在考虑运行时性能的情况下,它明显优于 Swin/SwinV2-Large。
所以,`pip install --upgrade timm`,然后试试看 :)
论文作者的原始研究代码库:https://github.com/yuweihao/MambaOut
@article{yu2024mambaout,
title={MambaOut: Do We Really Need Mamba for Vision?},
author={Yu, Weihao and Wang, Xinchao},
journal={arXiv preprint arXiv:2405.07992},
year={2024}
}