使用 WebAssembly 解释器反馈训练大型语言模型

社区文章 发布日期:2025年4月3日

使用 WebAssembly 和基于解释器的奖励,为代码训练 LLM 提供一种快速、本地且安全的方法

引言

我们很高兴分享一个用于安全且完全本地训练 Python 代码生成模型的开源工具,该工具基于我们最近在 axolotl 中支持组相对策略优化 (GRPO) 的工作。我们使自托管沙盒代码解释器环境变得容易,因此您可以零设置进行微调!

通过利用 WebAssembly,我们可以在隔离且资源受限的环境中执行不受信任的代码。在此基础上,我们实现了多进程,以最大程度地减少执行代码解释器的开销,从而实现强大且快速的训练过程。要立即开始,请查看我们的 grpo_code 存储库

image/png

使用代码解释器反馈进行训练

为了复制 DeepSeek R1 训练过程最近的成功,人们付出了巨大的努力。DeepSeek R1 放弃了典型的监督微调 (SFT) 阶段,转而采用强化学习,通过在模型输出可自动验证的领域(例如数学问题和编码任务)进行训练。虽然验证数学问题的正确性相对简单,但验证代码的正确性并非易事,需要在一个安全受控的环境中执行代码。

GRPO 的有效性取决于完善的奖励函数,这些函数能够可靠地引导模型朝着我们期望的行为发展。为了在代码解释器的输出上训练模型,我们需要将解释器的输出转换为某种形式的成功标准。幸运的是,现有的一些编码数据集提供了可以在运行时验证代码行为的测试用例。

让我们以 TIGER-Lab/AceCode-87K 为例。该数据集包含一个用 Python 实现函数或类的问题陈述,以及用于验证此功能行为的可执行断言。

{
    "question": "Given a list of strings, implement a function `find_palindromes` that returns a new list containing only the strings that are palindromes. A palindrome is defined as a word that reads the same forward and backward, ignoring case and spaces. For example, if the input list is `['radar', 'hello', 'level', 'world', 'Anna']`, the function should return `['radar', 'level', 'Anna']`. Your function should handle empty strings and variations in case appropriately.",
    "test_cases": [
        "assert find_palindromes(['radar', 'hello', 'level', 'world', 'Anna']) == ['radar', 'level', 'Anna']",
        "assert find_palindromes(['racecar', 'civic', 'deified', 'hello', 'world']) == ['racecar', 'civic', 'deified']",
        "assert find_palindromes(['noon', 'test', 'rotor', 'python']) == ['noon', 'rotor']",
        "assert find_palindromes(['']) == ['']",
        "assert find_palindromes(['Able was I ere I saw Elba', 'Hello']) == ['Able was I ere I saw Elba']",
        "assert find_palindromes(['12321', '12345', '121']) == ['12321', '121']",
        "assert find_palindromes(['Madam', 'Hello World', 'noon']) == ['Madam', 'noon']",
        "assert find_palindromes(['123321', 'abcba', 'xyz']) == ['123321', 'abcba']",
        "assert find_palindromes(['']) == ['']",
        "assert find_palindromes(['level', 'Level', 'LEVEL']) == ['level', 'Level', 'LEVEL']",
        "assert find_palindromes(['abccba', 'abc', 'a']) == ['abccba', 'a']",
        "assert find_palindromes(['step on no pets', 'sample text']) == ['step on no pets']",
        "assert find_palindromes(['no lemon, no melon']) == ['no lemon, no melon']",
        "assert find_palindromes(['racecar', 'not a palindrome']) == ['racecar']",
        "assert find_palindromes(['']) == ['']",
        "assert find_palindromes(['Wow', 'wow', 'WOW']) == ['Wow', 'wow', 'WOW']",
        "assert find_palindromes(['abc', 'def', 'ghi']) == []"
    ]
}

太棒了!只要我们有一种方法可以在训练期间执行模型预测的代码以及相关测试用例列表中的每个断言,我们就可以提供一个奖励信号来优化我们的模型,以最大化测试用例的准确性。然而,我们不希望随意执行 LLM 生成的代码——虽然模型在像 AceCode-87K 这样的竞技编程风格数据集上训练时不太可能生成彻头彻尾的恶意代码,但我们确实希望减轻执行不受信任的代码可能产生的任何意外副作用,这些副作用可能会干扰训练环境或消耗大量系统资源。

现有解决方案可以沙盒化不受信任的代码。像 E2B 这样的云提供商支持在安全的云环境中执行不受信任的代码,但有限使用后需要付费订阅;而容器化解决方案(如 piston)和内核级沙盒(如 isolate)在自托管和配置方面并非易事。

WebAssembly (Wasm) 是一种二进制指令格式,旨在作为编程语言的可移植编译目标。Wasm 代码在一个安全、沙盒化的虚拟环境中运行,该环境与宿主系统隔离,并具有明确定义的资源限制(我们称之为“燃料”)。得益于 VMware Labs 预编译的 Python 3.12.0 Wasm 运行时二进制文件,我们可以在本地训练环境中安全地执行 Python 代码,并且设置和开销极小。我们的 Wasm 运行时实现改编自 Simon Wilson 的博客文章,该文章介绍了如何在 Wasm 沙盒中执行 Python。

奖励函数

让我们通过定义利用我们安全 Wasm 运行时的奖励函数来将所有这些结合起来。

  • grpo_code.code_execution_reward_func 为可以成功执行且没有错误的代码完成提供奖励信号。

  • grpo_code.answer_execution_reward_func 根据代码完成对所提供测试用例的准确性提供奖励信号。我们不只是简单地使用通过的测试用例的百分比,而是应用一个 2 * (准确性)^3 的幂律,以对准确性越来越高的代码完成提供更高的奖励。

  • grpo_code.soft_format_reward_func 强制对代码完成施加格式约束——这对于从模型的输出中正确提取预测代码是必要的。

image/png

多进程

利用多进程异步执行环境更新是强化学习中常见的范式,可以最大程度地减少 GPU 空闲时间。我们提供了一个简单的多进程实现,它使用可重用的进程工作池来异步执行 Wasm 代码——在我们的基准测试中,我们发现这可以使代码执行时间加快近 10 倍。您可以通过设置 MAX_WORKERS 环境变量来配置进程数量——我们建议使用的 worker 数量少于机器上的可用核心数量,并根据您的批处理大小、生成次数和世界大小来扩展 worker。

训练

为了方便大家开始,我们来看看在 AceCode-87K 数据集上训练模型的过程——你可以在这里找到我们训练好的模型

安装

git clone https://github.com/axolotl-ai-cloud/grpo_code
cd grpo_code
pip install -e .
pip install axolotl==0.8.0[vllm]

训练

以下环境变量可用于修改训练期间奖励函数的行为。

WASM_FUEL - 控制分配给 Wasm 环境的燃料量(计算资源)(默认值:10000000000)。WebAssembly 中的燃料机制会计算已执行操作的数量,如果超出配置限制,则返回给调用者。

WASM_PATH - Python Wasm 运行时文件的路径(默认值:“./wasm/python-3.12.0.wasm”)

TIMEOUT - 代码评估的最大执行时间(秒)(默认值:1)

MAX_WORKERS - 多进程奖励函数的并行 worker 数量。默认情况下,多进程是禁用的。当在多个 GPU 上进行训练时,MAX_WORKERS 会除以总世界大小。

我们使用 4 块 A100 GPU 训练了我们的模型约 16 小时。首先,使用我们的四块 GPU 中的两块来设置 vLLM 服务器——通过设置 CUDA_VISIBLE_DEVICES 确保您使用的是最后一块 GPU。

CUDA_VISIBLE_DEVICES=2,3 axolotl vllm-serve r1_acecode.yaml

然后,在另一个终端中运行训练脚本

CUDA_VISIBLE_DEVICES=0,1 MAX_WORKERS=64 axolotl train r1_acecode.yaml --num-processes 2

您应该会看到以下训练图表,显示所有奖励函数的稳定收敛,最终测试用例准确率奖励约为 85%。

image/png

评估

我们已经在我们的代码库的 eval_plus 文件夹中包含了一个示例评估脚本,该脚本改编自 Qwen2.5-Coder 存储库,并更新了依赖项以更好地重现结果。

要运行评估,请首先安装所需的包。在全新的 Python 环境中,运行

cd eval_plus
pip install evalplus --upgrade
pip install -r requirements.txt

然后,运行评估脚本:,确保您的文件系统上有所需评估模型的权重。

bash test.sh {path_to_your_local_model_checkpoint} {tensor_parallel_size} {output_dir}

您应该会在您模型的 output_dir 文件夹中看到 HumanEval 和 MBPP 的评估结果。我们在下面列出了我们训练模型的测试结果,我们将其与 Qwen2.5-Coder-3B-Instruct 以及我们训练运行的基础模型 Qwen2.5-3B-Instruct 进行了基准测试

image/png

我们使用 Qwen2.5-3B-Instruct 和 Qwen2.5-Coder-3B-Instruct 作为评估的基线。Qwen2.5-3B-Instruct 被用作我们微调的基础模型,有助于我们衡量使用解释器反馈进行微调的直接影响。

Qwen2.5-Coder-3B-Instruct 使用 Qwen2.5-3B(非指令)作为基础模型,并经过大量基于代码的持续预训练(CPT)、指令监督微调(SFT)和使用代码执行反馈的直接偏好优化。这些训练步骤使用了大量数据:5.2 万亿个代码专用持续预训练令牌,以及数千万个指令样本。与 Qwen2.5-Coder-3B-Instruct 进行评估有助于我们了解使用 CPT、SFT 和对齐微调进行大量训练与使用相对少量基于强化学习的代码解释器微调的相对性能提升。

我们观察到,在 Mostly Basic Python Programming (MBPP) 和 MBPP+ 基准测试中,我们的模型相对于基础模型 Qwen2.5-3B-Instruct 和 Qwen2.5-Coder-3B-Instruct 有显著改进,这些基准测试包含仅使用标准库函数即可解决的入门级编程问题。我们还发现,在 Human EvalHuman Eval+ 基准测试中,我们的模型相对于基础模型也有所改进。我们认为这些结果显示了使用强化学习和代码解释器反馈进行微调,以生产高性能领域特定模型的巨大潜力。

后续步骤

我们引入了一个轻量级且可扩展的框架,用于通过解释器反馈训练代码生成模型。我们认为这是训练更可靠、更健壮的代码生成模型的一个有前途的方向,并且很高兴看到社区将在此工作基础上进行构建。这可能包括支持更具挑战性的多编码任务数据集、跨语言支持或跨数学和代码数据集的多任务训练。

我们最近还在 axolotl 中添加了对序列并行的支持——我们认为这可以利用 GRPO 解锁长上下文微调,以解决需要更长推理轨迹的复杂编码任务。

社区

非常棒的实现——干得好!您的沙盒是否支持除 Python 之外的其他语言?另外,您是否有关于可以同时执行多少个程序的基准数据?

·

谢谢 Lewis!

目前我们只支持 Python。然而,许多其他语言都有 WebAssembly 运行时,所以我们认为这种方法也适用于其他语言。如果社区对此感兴趣,我肯定可以添加支持。

有关基准测试,请参见此脚本 https://gist.github.com/SalmanMohammadi/ea25d12340851ce033885564dc4bc720
安装 grpo_code 仓库后,您应该可以使用 `MAX_WORKERS={max_workers} test_grpo_mp.py` 运行此脚本。

作为参考,我的笔记本电脑只有 10 个微薄的处理器核心,我得到的结果是

image.png

处理速度的提升与 `num_generations * batch_size * n_test_cases` 成比例——对于这篇博客文章,我们发现在高计算场景中,当处理器数量足够时,`MAX_WORKERS=64` 可以提供约 10 倍的速度提升。

注册登录 发表评论