timm 中的 NaFlex

社区文章 发布于 2025 年 4 月 9 日

在诸多干扰之下,我一直在努力编写一些新的代码——最终通过 SigLIP2 NaFlex 所采用的方法,解决了 timm 中可变图像尺寸/宽高比的问题。一年多前,我曾致力于 NaViT 方法,但被数据管道所阻碍。这在当前也是最大的挑战,但没那么棘手,不需要序列打包。那么,为什么数据加载会很痛苦呢?因为经典的 PyTorch 数据管道。

通常情况下,“数据集”、“转换”、“采样器”、“数据加载器”和“整理”的组织方式,导致数据处理管道的起始部分(数据集/转换)与采样器和整理函数脱节,而且在其中一些实体之间存在数据加载器工作进程边界。此外,这些实体之间内置的同步机制也有限……所以,如果你想改变序列长度,从而改变每个批次的批大小,那会很麻烦。我尝试了几次,但都失败了,然后我决定,算了,所有东西都放到数据集中去。

什么?是的,数据加载器中的批处理和采样功能被禁用,并且有一个数据集包装器(本身也是一个 IterableDataset),它确定序列长度和批次大小,采样索引(支持分布式),运行转换,然后批处理并自行完成大部分的切块和整理。

我目前在 timm 中有一个我认为是“alpha”版本的 PR。还有很多需要测试和完善的地方,尤其是分布式支持。已实现的功能:

  • 一种 NaFlex 风格的 ViT,应该能够在 NaFlex 序列模式和固定模式之间切换。

  • 模型中 NaFlex 位置嵌入插值的优化,至少比现有 PyTorch 实现中的有所改进。

  • 一组面向“序列”的转换,将图像大小限制到目标序列长度。

  • 一个用于映射式数据集的数据集包装器,迭代器包装器正在开发中。

  • 数据集包装器控制序列长度与批处理大小的权衡,并尝试最大化 GPU 利用率,以便随着序列长度从最大值减小,批处理大小增加。这在 PyTorch 和 GPU 上似乎运行良好,包括使用 torch compile,因为组合集是有限的。我想 XLA 也能应对?

  • 修改训练脚本以处理每个步骤的可变批处理大小,包括缩放梯度和处理每个分布式等级不同批处理大小的选项。

参考:论文中讨论的 SigLIP-2 NaFlex 详细信息:https://arxiv.org/html/2502.14786

社区

注册登录 发表评论