将 Pi0-FAST 从 JAX 移植到 LeRobot 的 PyTorch 版本:挑战、修复和未决问题

引言
LeRobot 团队最近专注于将 Physical Intelligence 最初实现的 Pi0+FAST 移植到 Lerobot 仓库。本文概述了我们所做的主要修改、遇到的挑战以及与原始实现的关键差异。我们的目标是开启讨论,并为社区贡献提供足够的背景。
背景
论文 | Jax 代码 | 我们在 Lerobot 中的实现
π0-FAST 是 π0 的 自回归版本,引入了 FAST(频域动作序列标记化)——一种新的标记化方案,可提高效率和性能。
π0-FAST 的主要优势:
- 与基于扩散的 VLA 相比,训练速度快 5 倍。
- 改进的动作表示,减少了动作序列中的冗余。
- 在未见过的环境和机器人形态之间实现更强的泛化能力。
🔗 π0-FAST 分词器 可在此处访问:FAST 分词器
🔗 预训练权重可在此处访问:Pytorch Pi0+FAST
PyTorch 实现中的主要修改
- 向量化填充和标记化
- 与 Hugging Face 的 `transformers` 库中的分词器对齐。
- 使用了 Paligemma 内置的 `transformers` 实现
- 主要区别在于使用了类似于 Pi0 的块因果掩码,但与 Paligemma 不同(Paligemma 对前缀使用完整的双向掩码)。
- 添加了自定义的 `prepare_inputs_for_generation`
- 此项添加是为了正确处理注意力掩码、位置 ID 和其他输入处理细节。
- 添加了一些前缀调整
- 与原始 Pi0+FAST 实现不同的是,原始实现会在输出序列中生成单词 `‘Action: ’`,我们则将 `"Action: "` 添加到前缀并在训练期间传递它。
没有使用指数移动平均 (EMA)。此实现不使用 EMA。
动作填充和掩码调整
- 为动作反标记化添加了填充/截断,以确保稳定的解码,从而解决此讨论中提出的问题。
- 嵌入了动作损失掩码,而不是显式使用它们并将其传递给模型。
- 我们使用标记类型 ID 来区分不同的组件(前缀/后缀),确保训练和推理都使用正确的 4D 注意力掩码。
结果
- 所有这些修改使得 LIBERO 上的成功率达到了 40%,并使用了优化的超参数(现在是 `configuration_pi0fast.py` 中的默认值)。
问题和未决问题?
1. 输出动作标记不一致
- 将相同的输入传递给 JAX 和 PyTorch 实现并不总能产生相同的标记。
2. 训练稳定性和成功率
- 为了简化,在没有 EMA 的情况下进行训练——EMA 对于微调有多重要?
- JAX 的 Pi0 不关注填充图像,而 JAX 的 Pi0+FAST 关注。这种设计选择背后的原理是什么?
- 保持特定的图像顺序(例如,外部、左手腕、右手腕)可能会影响性能。是否应该保留此顺序?
- Pi0 使用 块因果掩码,而 Pi0-FAST 对提示使用 双向注意力。
- JAX 模型使用 分位数归一化,而 PyTorch 使用 均值/标准差。分位数归一化对于更好的性能是否必要?
3. 调试微调不稳定性
- 在 DROID 示例上测试推理显示 MSE 为 0.14(JAX 版本为 0.01)。
- 有几个生成的标记匹配,但不匹配可能归因于
- 训练不一致。
- 实现差异。
呼吁社区贡献
总而言之,尽管多次尝试和修复,SR 仍然低于预期/报告值
- 从 基础检查点 微调达到 30% SR。
- 从 LIBERO 检查点 微调在降级前达到 60% SR。
- 早期训练步骤(5k)产生最高的 SR,表明训练配方存在差异。
该模型理想情况下应达到 80% SR,正如原始论文中报告的那样。我们邀请贡献以完善和改进此实现。
资源
- 预训练基础检查点:Hugging Face: lerobot/pi0fast_base
- GitHub PR:支持 Pi0+FAST