将 Pi0-FAST 从 JAX 移植到 LeRobot 的 PyTorch 版本:挑战、修复和未决问题

社区文章 发布于 2025 年 4 月 2 日
Pi0+FAST is in LeRobot

引言

LeRobot 团队最近专注于将 Physical Intelligence 最初实现的 Pi0+FAST 移植到 Lerobot 仓库。本文概述了我们所做的主要修改、遇到的挑战以及与原始实现的关键差异。我们的目标是开启讨论,并为社区贡献提供足够的背景。

背景

论文 | Jax 代码 | 我们在 Lerobot 中的实现

π0-FAST 是 π0 的 自回归版本,引入了 FAST(频域动作序列标记化)——一种新的标记化方案,可提高效率和性能。

π0-FAST 的主要优势:

  • 与基于扩散的 VLA 相比,训练速度快 5 倍
  • 改进的动作表示,减少了动作序列中的冗余。
  • 在未见过的环境和机器人形态之间实现更强的泛化能力

🔗 π0-FAST 分词器 可在此处访问:FAST 分词器

🔗 预训练权重可在此处访问:Pytorch Pi0+FAST

PyTorch 实现中的主要修改

  1. 向量化填充和标记化
  • 与 Hugging Face 的 `transformers` 库中的分词器对齐。
  1. 使用了 Paligemma 内置的 `transformers` 实现
  • 主要区别在于使用了类似于 Pi0 的块因果掩码,但与 Paligemma 不同(Paligemma 对前缀使用完整的双向掩码)。
  1. 添加了自定义的 `prepare_inputs_for_generation`
  • 此项添加是为了正确处理注意力掩码、位置 ID 和其他输入处理细节。
  1. 添加了一些前缀调整
  • 与原始 Pi0+FAST 实现不同的是,原始实现会在输出序列中生成单词 `‘Action: ’`,我们则将 `"Action: "` 添加到前缀并在训练期间传递它。
  1. 没有使用指数移动平均 (EMA)。此实现不使用 EMA。

  2. 动作填充和掩码调整

  • 为动作反标记化添加了填充/截断,以确保稳定的解码,从而解决此讨论中提出的问题。
  • 嵌入了动作损失掩码,而不是显式使用它们并将其传递给模型。
  • 我们使用标记类型 ID 来区分不同的组件(前缀/后缀),确保训练和推理都使用正确的 4D 注意力掩码。

结果

  • 所有这些修改使得 LIBERO 上的成功率达到了 40%,并使用了优化的超参数(现在是 `configuration_pi0fast.py` 中的默认值)。

问题和未决问题?

1. 输出动作标记不一致

  • 将相同的输入传递给 JAX 和 PyTorch 实现并不总能产生相同的标记。

2. 训练稳定性和成功率

  • 为了简化,在没有 EMA 的情况下进行训练——EMA 对于微调有多重要?
  • JAX 的 Pi0 不关注填充图像,而 JAX 的 Pi0+FAST 关注。这种设计选择背后的原理是什么?
  • 保持特定的图像顺序(例如,外部、左手腕、右手腕)可能会影响性能。是否应该保留此顺序?
  • Pi0 使用 块因果掩码,而 Pi0-FAST 对提示使用 双向注意力
  • JAX 模型使用 分位数归一化,而 PyTorch 使用 均值/标准差。分位数归一化对于更好的性能是否必要?

3. 调试微调不稳定性

  • 在 DROID 示例上测试推理显示 MSE 为 0.14(JAX 版本为 0.01)。
  • 有几个生成的标记匹配,但不匹配可能归因于
    • 训练不一致。
    • 实现差异。

呼吁社区贡献

总而言之,尽管多次尝试和修复,SR 仍然低于预期/报告值

  • 基础检查点 微调达到 30% SR
  • LIBERO 检查点 微调在降级前达到 60% SR
  • 早期训练步骤(5k)产生最高的 SR,表明训练配方存在差异。

该模型理想情况下应达到 80% SR,正如原始论文中报告的那样。我们邀请贡献以完善和改进此实现。

资源

额外资源

社区

值得等待。LeRobot 团队加油!

我测试了推理速度,流匹配版本在迭代 50 次时大约需要 1.4 秒,而快速版本需要大约 30 秒……

感谢您的出色工作。这真的帮助了社区!

我想指出当前 Pi0FAST 实现中的一个问题。当前的实现通过调用 normalize_actions 函数将动作空间限制在 [-1,1]:https://github.com/huggingface/lerobot/blob/d2645cb19fc521e5b117fe03d90a84f698d3d3f6/src/lerobot/policies/pi0fast/modeling_pi0fast.py#L594

此归一化在推理期间没有被反归一化。此外,此函数也是不可逆的。它通过使用每个动作块中的最小和最大标量来归一化每个动作块。此信息对于不同的动作块将是不同的,并且在推理时不可用。

注册登录 进行评论