🐯 Liger GRPO 与 TRL 的邂逅

发布于 2025 年 5 月 25 日
在 GitHub 上更新

摘要:LigerTRL 的组相对策略优化 GRPO Trainer 注入了强大动力,它在不降低模型质量的前提下,将内存使用量减少了 40%。我们还增加了对 FSDPPEFT 的支持,使得在多个 GPU 上扩展 GRPO 变得前所未有地简单。

动机

使用强化学习 (RL) 对语言模型进行微调是模型训练生命周期中的关键一步,它能引导模型产生更复杂的、符合期望的行为,而这是传统的监督式微调难以实现的。传统上,RL 通常通过近端策略优化 (PPO) 算法来优化大型语言模型 (LLM)。这种方法,通常与基于人类反馈的强化学习 (RLHF) 相关联,利用一个单独训练的奖励模型来指导主模型的微调。

然而,使用 PPO 的 RLHF 是一种非常消耗资源的方法——PPO 需要在内存中加载多个模型(策略模型、价值模型、奖励模型和参考模型),并且还需要对奖励模型和基础模型进行多次迭代微调才能达到预期效果。RLHF 的成功还取决于奖励模型有效区分模型期望行为和非期望行为的能力。

随着 DeepSeek 的 R1 模型的推出,组相对策略优化 (GRPO) 近期受到了广泛关注。GRPO 摒弃了 RLHF 中使用的预训练奖励模型和价值模型,转而依赖于 *可验证的奖励函数*,这些函数能够以封闭形式检查模型输出的正确性,而无需外部奖励模型。这使得在那些易于验证的领域(如教模型推理、在数学和编码任务上表现良好)使用 GRPO 进行微调时,相较于 PPO 取得了巨大改进。

下图展示了 GRPO 与 PPO 的训练流程对比 (参考:DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models 论文图 4)

PPO-vs-GRPO

话虽如此,RL 训练仍然会占用大量 GPU 内存,因此这里仍有很大的优化空间。在这篇博文中,我们将讨论我们最近添加到 TRL 的一项优化,该优化在 GRPO 训练期间将峰值内存使用量减少了 40%,并且我们还将深入探讨如何在不损失性能或正确性的情况下将 GRPO 扩展到多个 GPU 和节点。

Liger Kernel 如何为 GRPO 大幅削减内存

我们将 Liger 的分块损失 (Chunked Loss) 方法扩展到了 GRPO 损失计算中,这让我们在每个训练步骤中都无需将完整的 logits 存储在内存里。logits 的计算涉及模型的输出头 (output head),是峰值内存使用的主要来源,尤其是在处理大词汇表、长序列或大批量数据时。我们通过将输入到 lm_head 的数据按批次 (batch) 分块,并逐块运行前向传播来解决这个问题。

但如果你只是直接实现它,实际上并不能减少峰值内存,因为你仍然需要为反向传播将所有 logits 保留在 GPU 内存中。为了解决这个问题,我们在前向传播过程中计算每个损失块(相对于 input 块和 lm_head 权重)的梯度,然后在处理每个块时累积这些梯度。

以下是该优化的可视化图示 (来源: Byron Hsu)

liger-chunked-loss

与 TRL 的即插即用式集成

我们最近在 PR #3184 中将 Liger GRPO 与 TRL 进行了集成,所以现在你只需在你的 GRPOConfig 中将 use_liger_loss 设置为 True,就可以使用 Liger GRPO 损失并享受内存节省带来的好处!

请注意:这些功能尚未包含在最新的 TRL 版本中,因此你目前需要从源代码安装 TRL

pip install "trl[liger] @ git+https://github.com/huggingface/trl.git"

然后你就可以这样使用它

from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset


train_dataset = load_dataset("trl-lib/tldr", split="train")
training_args = GRPOConfig(output_dir="Qwen3-0.6B-GRPO", use_liger_loss=True)

def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

trainer = GRPOTrainer(
    model="Qwen/Qwen3-0.6B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()

基准测试

我们进行了一系列使用和不使用 Liger GRPO 损失的 GRPO 实验,以比较两者之间的差异。对于策略模型,我们使用了 Qwen3-0.6B,并尝试了不同的批量大小。所有实验都在 gsm8k 数据集上进行,并使用其奖励函数。

这是在 FP32 和 BF16 训练中,峰值内存使用量与批量大小关系的图表。正如预期的那样,随着批量大小的增加,内存节省效果会更好,因为我们是沿着批量维度进行分块的。所以当批量大小增加时,Liger 分块损失最终使用的内存比常规(非 Liger)版本少得多,最多可节省 40%。

简要说明:目前,我们只支持 FP32,但我们正在努力将 Liger GRPO 的 BF16 支持开源到 TRL 中。此处显示的 BF16 结果来自我们一直在测试的内部补丁。

Mem-vs-batch-size-fp32

Mem-vs-batch-size-bf16

我们还证明了 Liger 损失在效果上是精确的。如图所示,训练过程中的奖励变化与使用标准 TRL 实现所观察到的结果基本保持一致。

reward-vs-step

通过 FSDP 和 PEFT 进一步扩展

我们还在 PR #3260 和 PR #3355 中分别为 Liger GRPO 损失添加了 FSDP 和 PEFT 支持,让用户可以轻松地将实验扩展到多个 GPU 或节点。像 LoRA 和 QLoRA 这样的 PEFT 技术通过只调整原始模型之上的较小适配器权重来减少可训练参数的数量,从而显著降低了内存压力,因为不需要在内存中保留整个模型的梯度、激活和优化器状态。此外,在 GRPO 中使用 PEFT 可以在训练期间省去加载单独的参考模型,因为我们只需禁用 LoRA 适配器就可以在训练中获得原始的、未经修改的模型。

这里,我们展示了一个使用 FSDP 和 PEFT 的多 GPU GRPO 训练图,其中我们比较了在不同 Qwen3 模型尺寸下,使用和不使用 Liger 损失时可能的最大训练批量大小。我们发现,使用 Liger,我们能够将批量大小提高约 1.5 到 1.8 倍

peft-batch-size-vs-model-size

通过 vLLM 实现更大规模的扩展

为了加速训练过程中的文本生成,Liger 损失可以与 TRL 集成的 vLLM 服务器有效结合。这能以最小的开销显著加快 rollout 数据的收集,并提供无缝的集成体验。

以下是如何进行设置

  1. 启动 vLLM 服务器: 首先,启动 vLLM 服务器。此服务器将处理来自您训练脚本的生成请求。打开一个终端并运行

    CUDA_VISIBLE_DEVICES=1 trl vllm-serve --model "Qwen/Qwen3-0.6B"
    

    注意:我们分配 CUDA_VISIBLE_DEVICES=1 以在特定 GPU(本例中为 GPU 1)上运行 vLLM 服务器,从而让其他 GPU 可用于训练。

  2. 配置并运行您的训练脚本: 接下来,修改您的训练脚本以使用 vLLM 服务器。关键的改动是在您的 GRPOConfig 中设置 use_vllm=True

    from trl import GRPOConfig, GRPOTrainer
    from datasets import load_dataset
    
    
    def reward_len(completions, **kwargs):
        return [-abs(20 - len(completion)) for completion in completions]
    
    dataset = load_dataset("trl-lib/tldr", split="train[:1%]")
    training_args = GRPOConfig(
        output_dir="Qwen3-0.6B-GRPO", 
        use_liger_loss=True, 
        use_vllm=True, # Enable vLLM integration
        logging_steps=10
    )
    trainer = GRPOTrainer(
        model="Qwen/Qwen3-0.6B", # Ensure this matches the model served by vLLM
        reward_funcs=reward_len,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()
    
  3. 启动训练: 最后,使用 accelerate launch(或者如果不使用 Accelerate 进行多 GPU/分布式训练,则使用 python)运行您的训练脚本。如果您的 vLLM 服务器正在占用一个 GPU,请确保将训练目标设置为另一个不同的 GPU。

    CUDA_VISIBLE_DEVICES=0 accelerate launch train.py 
    

    (假设您的脚本名为 train.py 并且您想在 GPU 0 上运行训练).

通过遵循这些步骤,您可以在使用 Liger 损失进行 GRPO 训练时,利用 vLLM 实现更快的生成周转。

结论

随着 Liger-GRPO 集成到 TRL 中,并支持 FSDP 和 PEFT,使用 GRPO 微调语言模型现在比以往任何时候都更加节省内存且可扩展。我们鼓励社区尝试这些新功能并分享他们的反馈,以帮助我们进一步改进 LLM 的 RL 训练。

社区

Liger Kernel 会影响训练速度吗?与常规 GRPO 相比,是更快、更慢还是没有区别?

·

这通常取决于设置,可能会有加速效果,也可能没有!

文章作者

在我们的实验中,我们观察到使用和不使用 Liger 的训练速度没有显著差异

·

听起来太完美了!

感谢你们出色的工作。

顺便问一下,我用 Qwen/Qwen2.5-0.5B-Instructbf16 模式下,结合 deepspeed zero3 测试了 liger loss。
我遇到了如下所述的形状不匹配问题


[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/temp.py", line 22, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2238, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2553, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3730, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1187, in compute_loss
[rank0]:     return self.compute_liger_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1160, in compute_liger_loss
[rank0]:     loss, metrics = self.liger_grpo_loss(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 249, in forward
[rank0]:     return LigerFusedLinearGRPOFunction.apply(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/grpo_loss.py", line 142, in forward
[rank0]:     return super().forward(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 219, in forward
[rank0]:     accumulate_chunk(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/liger_kernel/chunked_loss/fused_linear_ppo.py", line 132, in accumulate_chunk
[rank0]:     (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
[rank0]:                                                                                            ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
[rank0]:     return self._torchdynamo_orig_callable(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 1164, in __call__
[rank0]:     result = self._inner_convert(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
[rank0]:     return _compile(
[rank0]:            ^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
[rank0]:     guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
[rank0]:     return _compile_inner(code, one_graph, hooks, transform)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_utils_internal.py", line 95, in wrapper_function
[rank0]:     return function(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
[rank0]:     out_code = transform_code_object(code, transform)
[rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
[rank0]:     transformations(instructions, code_options)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/convert_frame.py", line 662, in transform
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
[rank0]:     super().run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1736, in CALL_FUNCTION_EX
[rank0]:     self.call_function(fn, argsvars.items, kwargsvars)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 858, in call_function
[rank0]:     return self.func.call_function(tx, merged_args, merged_kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 1022, in call_function
[rank0]:     return self.obj.call_method(tx, self.name, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/misc.py", line 778, in call_method
[rank0]:     .call_function(tx, args, kwargs)
[rank0]:      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 317, in call_function
[rank0]:     return super().call_function(tx, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/functions.py", line 118, in call_function
[rank0]:     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 903, in inline_user_function_return
[rank0]:     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3072, in inline_call
[rank0]:     return cls.inline_call_(parent, func, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 3198, in inline_call_
[rank0]:     tracer.run()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
[rank0]:     while self.step():
[rank0]:           ^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
[rank0]:     self.dispatch_table[inst.opcode](self, inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 659, in wrapper
[rank0]:     return inner_fn(self, inst)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2341, in CALL
[rank0]:     self._call(inst)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 2335, in _call
[rank0]:     self.call_function(fn, args, kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/symbolic_convert.py", line 897, in call_function
[rank0]:     self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/torch.py", line 953, in call_function
[rank0]:     tensor_variable = wrap_fx_proxy(
[rank0]:                       ^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2153, in wrap_fx_proxy
[rank0]:     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2219, in wrap_fx_proxy_cls
[rank0]:     return _wrap_fx_proxy(
[rank0]:            ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/variables/builder.py", line 2315, in _wrap_fx_proxy
[rank0]:     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2536, in get_fake_value
[rank0]:     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2471, in get_fake_value
[rank0]:     ret_val = wrap_fake_exception(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2017, in wrap_fake_exception
[rank0]:     return fn()
[rank0]:            ^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2472, in <lambda>
[rank0]:     lambda: run_node(tx.output, node, args, kwargs, nnmodule)
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2604, in run_node
[rank0]:     raise RuntimeError(make_error_message(e)).with_traceback(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/utils.py", line 2586, in run_node
[rank0]:     return node.target(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 289, in _fn
[rank0]:     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4444, in matmul
[rank0]:     return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
[rank0]:                                        ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/utils/_stats.py", line 21, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1276, in __torch_dispatch__
[rank0]:     return self.dispatch(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1816, in dispatch
[rank0]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 1377, in _cached_dispatch_impl
[rank0]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_subclasses/fake_tensor.py", line 2290, in _dispatch_impl
[rank0]:     decomposition_table[func](*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_prims_common/wrappers.py", line 291, in _fn
[rank0]:     result = fn(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 83, in inner
[rank0]:     r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
[rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_decomp/decompositions.py", line 4336, in mv
[rank0]:     torch._check(
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1656, in _check
[rank0]:     _check_with(RuntimeError, cond, message)
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/__init__.py", line 1638, in _check_with
[rank0]:     raise error_type(message_evaluated)
[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in method matmul of type object at 0x7f2e2a41ff00>(*(GradTrackingTensor(lvl=1, value=
[rank0]:     FakeTensor(..., device='cuda:0', size=(1, s0, 896), dtype=torch.bfloat16,
[rank0]:                requires_grad=True)
[rank0]: ), GradTrackingTensor(lvl=1, value=
[rank0]:     FakeTensor(..., device='cuda:0', size=(0,), dtype=torch.bfloat16,
[rank0]:                requires_grad=True)
[rank0]: )), **{}):
[rank0]: size mismatch, got input (s0x896), vec (0)

Liger GRPO 是否支持使用 deepspeed zero3 进行多 GPU 训练?

注册登录 以发表评论