不让任何 GPU 掉队:TRL 中 vLLM 协同部署以释放效率

发布日期:2025 年 6 月 3 日
在 GitHub 上更新

🚀 引言

TRL 支持使用 GRPO 训练大型语言模型(LLM),GRPO 是一种最近在《DeepSeekMath 论文》中引入的在线学习算法。在 GRPO 中,模型从其自身输出中学习:它在训练期间生成响应,接收反馈,并利用该反馈随时间推移改进自身。

这使得生成成为训练循环中的关键步骤,也是一个主要的瓶颈。为了加快生成速度,TRL 与 vLLM 集成。这种组合允许您在 GRPO 设置中更高效地训练强大的模型。然而,这里有一个问题。

🧨 问题所在

在 TRL v0.18.0 之前,vLLM 仅支持**服务器模式**,作为独立进程在与训练作业不同的 GPU 上运行。它通过 HTTP 与训练脚本通信,这使得设置模块化且易于使用——但也引入了 GPU 效率低下问题。

这是发生的情况:

  • 在训练期间,模型需要频繁生成补全。
  • 训练器向 vLLM 服务器发送请求,该服务器运行在自己的 GPU 上。
  • 当 vLLM 生成时,**训练 GPU 处于空闲状态**并等待。
  • 一旦生成完成,**vLLM GPU 变为空闲状态**,训练恢复。

训练和生成之间的这种“乒乓效应”导致:

  • 双方 GPU 时间浪费
  • **额外 GPU** 需求增加,仅用于运行推理
  • 整体**吞吐量降低,成本更高**

在像 GRPO 这样的在线学习方法中——生成持续发生——这种低效率变得更加令人痛苦。您在硬件上花费更多,但却无法获得预期的性能。

**因此,关键问题是:**_我们能否将训练和生成共享同一批 GPU,而不是将它们分开?_

💡 机遇

主要问题在于训练和推理在不同的 GPU 上运行,导致空闲时间和资源利用不足。自然而然的解决方案是:两者都在同一批 GPU 上运行。如果 vLLM 可以与训练代码一起运行,在同一个分布式进程组中,而不是作为独立服务器在自己的进程和设备中运行呢?这将允许我们启动一个单一的分布式作业,其中训练和推理共享相同的设备,在任务之间高效切换而不会浪费资源。

这种方法我们称之为**协同部署**。训练和推理协同部署在同一批 GPU 上,并通过相同的进程组进行协调,允许它们平稳地轮流执行——无需额外的硬件。

以前,这在 TRL 中是不可能的,它依赖于 vLLM 作为外部 HTTP 服务器。通过我们的 PR #3394,这种情况发生了改变,该 PR 添加了对 vLLM 外部启动器和与训练过程的真正集成的支持。

它能实现什么

  • 统一执行:通过将 vLLM 嵌入到同一个进程组中,训练和推理任务可以共享相同的 GPU,轮流执行而不是相互等待。这减少了空闲时间并提高了整体效率。

  • 跳过 HTTP 通信:无需 REST API 调用或网络通信——vLLM 与训练循环内联运行,避免了开销和延迟。

  • Torchrun 兼容性:与 torchrun 无缝协作,因此易于通过最少的配置更改进行跨节点扩展。

  • TP 和 DP 支持:与张量并行 (Tensor Parallelism) 和数据并行 (Data Parallelism) 兼容,使其适用于大规模训练运行。

  • SPMD 执行模式:使用单程序多数据 (SPMD) 模型,其中每个 GPU 同步运行其自身的引擎实例。适用于分布式多 GPU、多节点设置。

  • 简化部署:您不再需要维护单独的服务器脚本——vLLM 直接在您的训练作业中启动和控制。

  • 提高吞吐量:通过避免 GPU 空闲和消除进程间通信,系统提供更快的训练和生成速度,这在 GRPO 等在线学习设置中尤为重要。

  • 健壮的进程间通信:这更健壮,因为它避免了像服务器模式中那样在独立进程之间设置分布式进程组的复杂性。

得益于此功能,协同训练和推理不再是权宜之计——它现在是**一流的、可扩展的、生产就绪的**。

🧩 设计:从独立服务器到共享 GPU

从服务器 TRL 到协同部署 TRL 的转变完全是为了更智能地利用 GPU。下图显示了差异:

gpus-design

服务器 TRL 设置(上排)

在服务器 TRL 设置中,训练和推理在不同的 GPU 上运行。例如:

  • GPU 0 到 2 用于训练。
  • GPU 3 完全用于运行 vLLM 作为独立服务器。

在训练步骤中,**GPU 3 处于空闲状态**。在生成步骤(推理)中,当 GPU 3 生成输出时,**GPU 0-2 处于空闲状态**。

这导致:

  • GPU 使用效率低下,设备经常相互等待
  • 额外配置 GPU 仅用于推理
  • 增加成本和复杂性

协同部署 TRL 设置(下排)

相反,协同部署 TRL 设置在**相同的 GPU** 上运行训练和 vLLM。每个 GPU:

  • 运行训练循环
  • 在**同一个进程**中启动 vLLM 引擎

训练和推理**轮流**使用 GPU 的资源——无需专用设备或独立进程。

此设计:

  • 减少空闲时间
  • 最小化进程间和 HTTP 通信
  • 充分利用可用的 GPU 内存和计算资源
  • 在不增加硬件需求的情况下提供**更快的吞吐量**

🛠️ 实施说明

现在,训练器启动 vLLM **进程内**,使用外部启动器,而不是将 vLLM 作为服务器启动,如下图所示:

self.llm = LLM(
    model=model.name_or_path,
    tensor_parallel_size=args.vllm_tensor_parallel_size,
    gpu_memory_utilization=self.vllm_gpu_memory_utilization,
    max_num_seqs=self.args.per_device_train_batch_size
        * self.vllm_tensor_parallel_size
        * self.args.gradient_accumulation_steps,
    max_model_len=self.max_prompt_length + self.max_completion_length,
    distributed_executor_backend="external_launcher",
    # Feed identical seed for tp groups to ensure sampling results are the same across workers
    seed=self.accelerator.process_index // self.vllm_tensor_parallel_size,
)

协同部署的 vLLM 遵循 `torch.distributed` 进程组和秩结构。这使得 vLLM 可以在训练的同时初始化而不会发生冲突,并使 TP/DP 设置无缝运行。

if self.vllm_tensor_parallel_size > 1:
    # Create subgroups of ranks for TP, each group with `vllm_tensor_parallel_size` ranks.
    self.tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
        [
            list(range(i * self.vllm_tensor_parallel_size, (i + 1) * self.vllm_tensor_parallel_size))
            for i in range(self.accelerator.num_processes // self.vllm_tensor_parallel_size)
        ]
    )

协同部署的 vLLM 不再依赖 REST API——它直接在内存中运行并通过原生 Python 调用进行通信。

if self.vllm_tensor_parallel_size > 1:
    orig_size = len(prompts_text)
    gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)]
    torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group)
    all_prompts_text = [p for sublist in gathered_prompts for p in sublist]
else:
    all_prompts_text = prompts_text

with profiling_context(self, "vLLM.generate"):
    all_outputs = self.llm.generate(all_prompts_text, sampling_params=sampling_params, use_tqdm=False)

completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]

if self.vllm_tensor_parallel_size > 1:
    local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
    tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
    completion_ids = completion_ids[tp_slice]

要使用此设置,只需在 GRPO 配置中将 `vllm_mode="colocate"`。

training_args = GRPOConfig(
    ...,
    use_vllm=True,
    vllm_mode="colocate",
)

注意:根据模型大小和训练所需的总 GPU 内存,您可能需要调整 `GRPOConfig` 中的 `vllm_gpu_memory_utilization` 参数,以避免资源利用不足或内存不足错误。

📊 展示:协同部署与普通 TRL 性能对比

为了衡量协同部署的影响,我们进行了一系列实验,比较了传统的**服务器模式**(vLLM 作为独立服务器在单独的 GPU 上运行)与新的**协同部署模式**(训练和推理共享相同的 GPU)。

在**服务器模式**下,仅使用 7 个 GPU 进行训练,因为 1 个 GPU 完全专用于 vLLM 推理服务器。

在**协同部署模式**下,所有 8 个 GPU 都用于训练——默认情况下增加了有效批量大小。

为了确保公平比较,我们**将服务器模式下的吞吐量标准化为 8/7**。此调整考虑了协同部署模式下更大的训练容量,并允许我们在相同的训练条件下比较两种设置。

实验 1:1.5B 模型 — 不同批量大小

  • 随着批量大小的增加,两种设置的吞吐量都有所改善。
  • **协同部署设置在最大批量大小下达到 1.43 倍加速。**
  • 更大的批量可以更好地利用协同部署模式下共享的 GPU 内存。small-b

实验 2:1.5B 模型 — 不同张量并行度 (TP)

  • 在协同部署设置中,增加 TP 会**降低性能**。
  • 更多的分片会引入更多的通信开销——这**不适合小型模型**。
  • **启示**:对于小型模型,在协同部署模式下避免过度分片。small-tp

实验 3:7B 模型 — 不同批量大小

  • 同样,协同部署模式**随着批量大小的增加而扩展性更好**。
  • 在测试的最大批量下,增益达到**1.35 倍加速**。med-b

实验 4:7B 模型 — 不同张量并行度 (TP)

  • 与 1.5B 模型相反的趋势。
  • 对于 7B 模型,**增加 TP 可以提高吞吐量**,最高可达**1.73 倍加速**。
  • **协同部署设置中,大型模型从分片中受益。** med-tp

📊 扩展到 72B 模型

在训练像 **Qwen2.5-Math-72B** 这样的大型模型时,采用正确的策略以确保在多 GPU 和多节点上实现高效、可扩展和稳定的训练至关重要。在我们的设置中,我们将**协同部署的 vLLM** 与多个关键优化相结合,以实现高效运行。

vLLM 中的休眠模式

在使用协同训练时,管理 GPU 内存至关重要,以便训练和推理都能在同一设备上平稳运行。为支持此功能,我们已将 vLLM 的 `sleep()` API 添加到 GRPO 训练循环中。

`sleep()` 函数暂时暂停 vLLM 引擎并释放 GPU 内存。它支持两个级别:

  • **级别 1**:从 GPU 卸载模型权重(保留在 CPU 内存中)并清除 KV 缓存。当同一模型即将被重复使用时很有用。

  • **级别 2**:完全卸载模型权重和 KV 缓存。最适合模型将更改或不会立即重复使用的情况。

在 GRPO 中,模型在每一步之后都会更新——因此我们使用**级别 2 休眠**。

级别 2 休眠的优势:

  • **最大化训练的空闲 GPU 内存**
  • **避免训练和生成之间的内存争用**
  • 即使对于像 Qwen2.5-72B 这样的大型模型,也能保持协同部署的效率

这个小小的改动在实现平稳、可扩展的协同训练方面发挥了**巨大作用**。

DeepSpeed 优化

为了训练像 Qwen2.5-72B 这样的大型模型,我们依赖于 **DeepSpeed ZeRO Stage 3**,这与普通 TRL 中使用的设置相同。

ZeRO 通过在 GPU 之间分配内存来帮助扩展大型模型。Stage 3 更进一步,通过分区:

  • 模型权重
  • 梯度
  • 优化器状态

这对于无法放入单个 GPU 的模型至关重要。使用 ZeRO Stage 3,每个 GPU 只处理模型的一部分。

我们启用的其他选项:

  • "offload_optimizer": {"device": "cpu"} 将优化器状态移动到 CPU 以释放 GPU 内存——这在协同部署设置中至关重要。

  • "overlap_comm": true 启用通信与计算重叠,加速训练。

  • "contiguous_gradients": true 在单个内存块中分配梯度,改善内存访问并减少碎片化。

这些优化有助于**高效训练 72B 模型**,并确保在严格的内存限制下协同部署保持稳定。

Accelerate 集成

正如 TRL 中推荐的那样,我们使用 **Accelerate**,一个轻量级库,它简化了分布式训练。它处理:

  • 多 GPU 和多节点作业启动
  • 数据并行
  • 梯度累积
  • 分布式数据加载

这使得设置简洁、可扩展且易于维护。

实验 5:Qwen2.5-Math-72B — 吞吐量、准确性和基准测试结果

吞吐量

即使**减少 4 个 GPU**,**协同部署设置仍比普通 TRL 快约 1.26 倍**。这突出了更智能的 GPU 共享和使用 `sleep()` 进行内存清理的有效性。72b-tput

奖励曲线

协同部署和普通设置的训练奖励图**几乎相同**,这表明:

  • 协同部署训练保持了准确性
  • **模型学习性能没有退步** blogpost_72b_rewards

Math500 基准测试

我们评估了三个模型:**基础模型**、**协同训练模型**和**普通训练模型**在 Math500 基准测试中的表现。两个训练模型都**优于基础模型**,并且**协同部署模型与普通训练模型表现相当**——证实了协同部署不会影响下游性能。blogpost_72b_math500

🎓 挑战、经验教训和后续步骤

通过我们利用协同部署 vLLM 扩展 GRPO 训练的工作,我们面临了几个关键挑战,并就大型模型训练的效率、灵活性和系统设计汲取了重要的经验教训。

挑战

  • vLLM ≥ 0.8.0 中的张量并行度 Bug。vLLM 0.8.0 及更高版本中的张量并行度 (TP) 与 external_launcher 停止工作。这在问题 #15895 中进行了跟踪。为了确定破坏点,我们遵循了这篇 vLLM 开发者博客文章中描述的方法,该文章提供了每个提交的轮子。经过二分法查找,我们确定破坏性提交为 cc10281。根本原因是确定性——新版本需要明确设置随机种子。一旦设置了种子,问题就消失了。

  • **二级休眠缓冲区 Bug。**最初,当我们尝试使用 `load_weights` 重新加载权重时,二级休眠无法正常工作。这个问题在 Issue #16564 中进行了跟踪。问题是模型缓冲区(例如 BatchNorm 中的运行均值/方差)在从休眠中唤醒后没有恢复。修复方法是 PR #16889,它添加了在从二级休眠唤醒时明确恢复缓冲区的逻辑。我们现在保留原始缓冲区的副本,并在加载新权重后手动重新应用它们。

  • **退出时发生段错误。**vLLM 休眠在训练结束时关闭进程时仍存在一个未解决的问题,会导致段错误。这在问题 #16993 中报告。此崩溃发生在关机期间,但不会中断训练本身,因此我们能够完成本博客中分享的所有演示和实验。但是,我们正在等待官方修复,然后才能将 sleep() 完全集成到 TRL 上游。

这些挑战并非阻碍,但它们需要仔细的调试、版本控制,以及对 vLLM 如何管理内存和并行性的更深入理解。

经验教训

  • 协同部署推理显著提高了 GPU 利用率。通过允许训练和生成共享相同的 GPU,我们消除了空闲时间并降低了硬件需求——即使使用更少的 GPU 也能实现更高的吞吐量。

  • vLLM 的 `sleep()` 功能对于大规模协同部署至关重要。它实现了对内存使用的细粒度控制,允许训练在生成步骤之间完全回收 GPU 内存——这是像 Qwen2.5-72B 这样的模型实现的关键。

  • DeepSpeed ZeRO Stage 3 对于训练大型模型至关重要。它通过在多个 GPU 上分配模型权重、梯度和优化器状态,使超大型网络能够适应内存。根据我们的经验,启用 `contiguous_gradients` 有助于减少内存碎片,而将优化器卸载到 CPU 则释放了关键的 GPU 内存——这两者在协同部署设置中都特别有用。

  • 协同部署功能强大,但也伴随着权衡。它在仔细管理 GPU 内存时效果最佳,通常需要手动调整内存使用参数,例如 `vllm_gpu_memory_utilization`。虽然它提供了明显的吞吐量优势并减少了 GPU 空闲时间,但协同部署可能不适合内存预算紧张或内存碎片控制不佳的模型。但是,如果做得好,它会带来显著的效率提升。

  • TP/DP 兼容性、Accelerate 和 torchrun 支持使部署无缝。尽管底层架构复杂,但整个系统可以使用标准分布式工具启动和扩展。

  • 协同训练保持模型质量。在多个基准测试(Math500、AIME24)中,协同部署和普通设置产生了可比较的结果,验证了性能不会因效率而牺牲。

✅ 结论

这篇博客文章探讨了将 vLLM 与 GRPO 训练协同部署如何在大语言模型训练(包括 Qwen2.5-72B 等大型模型)中实现显著的效率提升。

传统上,TRL 仅支持服务器模式下的 vLLM,这需要独立的推理进程和 GPU,导致计算浪费和空闲时间。随着 vLLM 外部启动器和 TRL 中协同部署 PR PR #3394 的引入,我们现在可以在同一分布式进程组、同一 GPU 上运行训练和推理,并完全支持 TP、DP 和 Accelerate。

尽管仍存在挑战——例如特定版本 vLLM bug 和 `sleep()` 等边缘情况——但总体结果表明,协同部署 GRPO 是高效训练大型模型的一种实用、可扩展的解决方案。我们很高兴继续完善此设置,集成 FSDP 等功能,并突破大型模型训练的极限——使其更快、更便宜、更易于所有人构建下一代 LLM。

✅ 试一试!

下面是一个尝试使用协同部署 vLLM 进行 GRPO 训练的示例。

📄 train_grpo_colocate.py

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer

# Load dataset
dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

# Define training arguments
training_args = GRPOConfig(
    output_dir="Qwen2-0.5B-GRPO",
    logging_steps=1,
    use_vllm=True,
    vllm_mode="colocate",
    vllm_tensor_parallel_size=1,
    vllm_gpu_memory_utilization=0.3,
    max_prompt_length=512,
    max_completion_length=1024,
    max_steps=2,
    num_generations=4,
    num_train_epochs=1,
    per_device_train_batch_size=4,
    push_to_hub=False,
    report_to=None
)

# Create and run the trainer
trainer = GRPOTrainer(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)

trainer.train()

社区

这项工作很棒,感谢详细的撰写。根据我们的经验,这种方法对于大规模多节点训练非常有效。我们已经看到训练 32B 模型时,训练速度提高了 3 倍。

·
文章作者

太棒了!感谢分享!

示例代码 `train_grpo_colocate.py` 需要使用 accelerate 启动吗?仅仅使用 `python3 train_grpo_colocate.py` 运行会抛出关于缺少环境变量("RANK", "LOCAL_RANK"...)的异常。

·
文章作者

是的!

`vllm_mode="colocate"` 能与 PEFT 配合使用吗?

·
文章作者

@lhkhiem28 实际上我们没有尝试过这个,但是它没有理由不工作,因为 LoRA 与模型训练相关,而我们的更改与生成相关。但是,似乎下面的 @ajinkya-tejankar 已经尝试过并且看起来是可行的。

很棒的文章!协同部署模式是否计划支持数据并行?

·
文章作者

DP是支持的。
例如,如果 GPU 数量 = 8 且 vllm_tensor_parallel_size = 2 → 组:[0,1], [2,3], [4,5], [6,7] -> 使 DP=4

DeepSpeed 是否计划成为未来支持 TRL 多 GPU 和多节点设置的主要引擎?我尝试了 FSDP,但它与许多 DeepSpeed 可用的配置不兼容。例如,我无法让 GRPO + FSDP + LoRA + VLLM colocate 协同工作,但将 FSDP 替换为 DeepSpeed 就可以。DeepSpeed 比 PyTorch 的普通 FSDP 更可靠吗?

附言:很棒的博客!非常感谢您的努力 :)

·
文章作者

@ajinkya-tejankar 在我们的内部实验中,我们尝试将 FSDP2 整合到 accelerate 中,并用 colocate 进行了测试。我认为仍然存在一些问题。1. TRL 的权重加载代码我认为只适用于 FSDP1。2. FSDP1 存在 NaN 问题,我之前提交了一个 bug 报告 https://github.com/vllm-project/vllm/issues/14443

请参阅之前的讨论
https://github.com/huggingface/trl/pull/3317#issuecomment-2842576427

非常感谢这篇精彩的文章。
您的文章对于在协同部署模式下训练 GRPO 帮助巨大。

顺便问一下,您是否曾使用 LoRA 训练过模型?
您提到训练了一个 72B 模型,但我无法访问 32 个 GPU,因此无法进行完全微调。

当使用 `DeepSpeed ZeRO-3` + `vLLM colocate` + `LoRA` + `GRPO` 的组合训练模型,并在 LoRA 配置中配置 `modules_to_save=["embed_tokens", "lm_head"]`(如下所示)时,我遇到了底部的错误。
如果您有任何用于训练 72B 模型的解决方案或技巧,我将不胜感激。

我使用的库版本是:

trl==0.18.2
peft==0.15.2
transformers==4.52.4
deepspeed==0.17.1 

LoRA 配置

lora_config = LoraConfig(
    r=training_config["rank"],
    lora_alpha=training_config["alpha"],
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj", 
        "up_proj",
        "down_proj",
    ],
    lora_dropout=training_config["dropout"],
    bias="none",
    task_type="CAUSAL_LM",
    modules_to_save=["embed_tokens", "lm_head"],
)

错误

AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'

完整的错误日志如下:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/LLMTrainFlow/./src/train/rl_gemma3.py", line 180, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2240, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2555, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3745, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1330, in compute_loss
[rank0]:     return self._compute_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1340, in _compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 852, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/engine.py", line 2087, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1757, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3/modeling_gemma3.py", line 880, in forward
[rank0]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1782, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _post_backward_module_hook
[rank0]:     return apply_to_tensors_only(module.post_bwd_fn.apply,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 133, in apply_to_tensors_only
[rank0]:     touched_output = apply_to_tensors_only(function, elem)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 149, in apply_to_tensors_only
[rank0]:     touched_output = function(value)
[rank0]:                      ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 446, in forward
[rank0]:     module.ds_grads_remaining += 1
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1928, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/LLMTrainFlow/./src/train/rl_gemma3.py", line 180, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2240, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 2555, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/trainer.py", line 3745, in training_step
[rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1330, in compute_loss
[rank0]:     return self._compute_loss(model, inputs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 1340, in _compute_loss
[rank0]:     per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/extras/profiling.py", line 96, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/trl/trainer/grpo_trainer.py", line 852, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:              ^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]:     ret_val = func(*args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/engine.py", line 2087, in forward
[rank0]:     loss = self.module(*inputs, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/peft_model.py", line 1757, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1793, in inner
[rank0]:     result = forward_call(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/peft/tuners/tuners_utils.py", line 193, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/utils/generic.py", line 969, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/transformers/models/gemma3/modeling_gemma3.py", line 880, in forward
[rank0]:     logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1845, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1782, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 378, in _post_backward_module_hook
[rank0]:     return apply_to_tensors_only(module.post_bwd_fn.apply,
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 133, in apply_to_tensors_only
[rank0]:     touched_output = apply_to_tensors_only(function, elem)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/utils.py", line 149, in apply_to_tensors_only
[rank0]:     touched_output = function(value)
[rank0]:                      ^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/autograd/function.py", line 575, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/deepspeed/runtime/zero/parameter_offload.py", line 446, in forward
[rank0]:     module.ds_grads_remaining += 1
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1928, in __getattr__
[rank0]:     raise AttributeError(
[rank0]: AttributeError: 'Linear' object has no attribute 'ds_grads_remaining'

我注意到 vLLM 的休眠功能并未集成到 TRL 中,这是为什么?

·
文章作者

原因在 https://huggingface.co/blog/vllm-colocate#challenges 中的“段错误”讨论中有所说明。基本上,我们正在等待 bug (https://github.com/vllm-project/vllm/issues/16993) 的修复,然后才能将 sleep() 完全集成到 TRL 上游。

你们在 Qwen 72B 实验中是如何分配权重的?是仅在一个节点上以 TP=8 运行,还是每个节点都有自己的 Qwen 72B 副本?

·
文章作者

是的,我们设置了 TP=8,这意味着每个节点都拥有 72B 模型分片的副本。

很棒的文章!
我正在 Slurm 集群中使用 VLLM 协同部署进行 GRPO,我收到一个 TCP 异常:
TCP 客户端连接/验证主机 10.0.1.163:35345 失败
虽然我以为它与训练循环是内联运行的。这正常吗?:D

·

这不正常。请确保您设置了 `vllm_mode="colocate"`。

这很棒;文档应该更新吗?

注册登录以评论