TRL 文档

减少内存使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

减少内存使用

此部分正在建设中。欢迎贡献!

截断

数据集中的序列长度差异可能很大。当数据分批处理时,序列会被填充以匹配批次中最长的序列,这可能会导致高内存使用,即使大多数序列相对较短。

Truncation prompt-completion

为了减少内存使用,将序列截断到合理的长度非常重要。虽然 TRL 训练器默认会截断序列,但您可能需要调整默认的截断长度,以更好地适应您的特定用例。

DPO
SFT

DPO 截断首先通过 max_prompt_lengthmax_completion_length 参数应用于提示和补全。然后使用 max_length 参数来截断最终生成的序列。

Truncation prompt-completion

要设置截断参数,请使用以下代码片段

from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_length=...)

您也可以使用 max_completion_length 参数来截断补全,但这不太常见,因为目标通常是尽可能保留补全的完整长度。

from trl import DPOConfig

training_args = DPOConfig(..., max_completion_length=...)

如何选择 max_length 值?

如果 max_length 太小,大部分词元将被丢弃,无法对训练做出贡献。如果太大,内存使用量可能会激增,可能导致 OOM(内存不足)错误。如果没有打包或无填充,大的 max_length 也可能导致训练效率低下,因为许多词元将是填充词元。

为了帮助您选择一个合适的值,我们提供了一个工具来可视化数据集中序列长度的分布。

打包

此技术仅适用于 SFT。

截断有几个缺点

  1. 信息丢失:序列末尾的关键数据可能会被丢弃。
  2. 选择截断长度:太短会丢失数据;太长会影响效率。

打包(Packing)由 Raffel 等人于 2020 年提出,它通过组合序列而不是截断来解决这些问题。它将数据集序列连接并分割成所需的长度。

Packing

打包通过在可能的情况下将多个序列合并到一行中来减少填充。我们使用一种先进的方法,以近乎最优的方式打包数据集。要启用打包,请在 SFTConfig 中设置 packing=True

在 TRL 0.18 及更早版本中,打包使用了一种更激进的方法,将填充减少到几乎为零,但缺点是会破坏数据集中大部分序列的连续性。要恢复到此策略,请在 `SFTConfig` 中使用 `packing_strategy="wrapped"`。

from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_length=512)

打包可能会导致批次污染,即相邻序列相互影响。这对于某些应用可能是有问题的。更多详情,请参阅 #1230

使用 Liger 减少峰值内存使用

Liger Kernel 是一个专门为 LLM 训练设计的 Triton 内核集合。它可以有效地将多 GPU 训练吞吐量提高 20%,并减少 60% 的内存使用。

更多信息,请参阅 Liger Kernel 集成

DPO
GRPO
KTO

要使用 Liger 减少峰值内存使用,请使用以下代码片段

from trl import DPOConfig

training_args = DPOConfig(..., use_liger_loss=True)

无填充

无填充批处理是另一种减少内存使用的方法。在这种方法中,首先对一个批次进行采样,然后将其展平为单个序列,从而避免了填充。与打包(packing)可能通过组合不同样本的部分而导致序列不完整不同,无填充批处理确保所有序列保持完整。

Padding-free batching

强烈建议将无填充批处理与 FlashAttention 2FlashAttention 3 结合使用。否则,您可能会遇到批次污染问题。

DPO
SFT
from trl import DPOConfig

training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention_2"})

激活卸载

激活卸载是一种内存效率技术,它通过在前向传播期间将激活张量临时移动到 CPU RAM,并在反向传播需要时才将其移回,从而减少 GPU VRAM 的使用。这以略微增加训练时间为代价,显著减少了峰值内存使用。

要在您的 SFT 训练配置中启用激活卸载:

<hfoptions> <hfoption id="SFT">
from trl import SFTConfig

training_args = SFTConfig(..., activation_offloading=True)
</hfoption> </hfoptions>

当将激活卸载与使用 Liger 内核的模型一起使用时,由于兼容性问题,您必须禁用 Liger 交叉熵。该问题特别发生在 use_liger_kernel=True 的情况下,因为 Liger 交叉熵执行原地操作,这与激活卸载冲突。默认设置 (use_liger_kernel=False) 可以正常工作。

# When using activation offloading with a model that uses Liger kernels:
from trl import SFTConfig

training_args = SFTConfig(
    activation_offloading=True,
    use_liger_kernel=False,  # Disable Liger cross entropy
    # Other parameters...
)

在底层,激活卸载实现了 PyTorch 的 saved_tensors_hooks,以在前向传播期间拦截激活。它根据大小和上下文智能地管理要卸载的张量,避免卸载输出张量,因为这样做效率低下。为了优化性能,它可以选择使用 CUDA 流来将计算与 CPU-GPU 传输重叠。

在在线方法中禁用模型聚合以进行生成

当使用 DeepSpeed ZeRO-3 时,模型权重会分片到多个 GPU 上。在线方法涉及在训练过程中从模型生成补全。在此步骤中,模型权重会临时聚合到单个 GPU 上进行生成。对于非常大的模型,这种聚合可能会导致内存不足(OOM)错误,如这个问题所述:#2250

如果遇到此问题,您可以通过设置以下参数来禁用模型权重的生成聚合:

GRPO
在线 DPO
PPO
RLOO
from trl import GRPOConfig

training_args = GRPOConfig(..., ds3_gather_for_generation=False)

此调整可防止模型权重被聚合,从而避免 OOM 错误,但可能会导致生成速度变慢。

< > 在 GitHub 上更新