TRL 文档

减少内存使用

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

减少内存使用

本节正在建设中。欢迎贡献!

截断

数据集中序列的长度可能差异很大。当数据被批处理时,序列会被填充以匹配批次中最长的序列,即使大多数序列相对较短,也可能导致高内存使用。

Truncation prompt completion

为了减少内存使用,将序列截断到合理的长度非常重要。虽然 TRL 训练器默认会截断序列,但您可能需要调整默认截断长度,以更好地适应您的特定用例。

DPO
SFT

DPO 截断首先通过 max_prompt_length 和 max_completion_length 参数应用于提示和补全。然后使用 max_length 参数截断生成的序列。

Truncation prompt completion

要设置截断参数,请使用以下代码片段

from trl import DPOConfig

training_args = DPOConfig(..., max_prompt_length=..., max_length=...)

您也可以使用 max_completion_length 参数来截断补全,但这不太常见,因为目标通常是在可能的情况下保留补全的完整长度。

from trl import DPOConfig

training_args = DPOConfig(..., max_completion_length=...)

打包

此技术仅适用于 SFT。

截断 有几个缺点

  1. 信息丢失:序列末尾的关键数据可能会被丢弃。
  2. 选择截断长度:太短会丢失数据;太长会降低效率。

打包,由 Raffel 等人在 2020 年提出,通过对序列进行分组而不是截断来解决这些问题。它将数据集序列连接并拆分为所需的长度。

Packing

打包消除了填充,保留了所有序列信息,并允许灵活的序列长度,使其成为比截断更有效的替代方案。要启用打包,请在 SFTConfig 中使用 packing=True

from trl import SFTConfig

training_args = SFTConfig(..., packing=True, max_length=512)

打包可能会导致批次污染,其中相邻序列相互影响。这对于某些应用程序来说可能存在问题。有关更多详细信息,请参阅 #1230

无填充

无填充批处理是减少内存使用的另一种方法。在此方法中,首先对批次进行采样,然后将其展平为单个序列,从而避免填充。与打包不同,打包可能会通过组合不同样本的部分而导致序列不完整,而无填充批处理可确保所有序列保持完整和完整。

Padding-free batching

强烈建议将无填充批处理与 Flash Attention 2 一起使用。否则,您可能会遇到批次污染问题。

DPO
SFT
from trl import DPOConfig

training_args = DPOConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention2"})

禁用在线方法中用于生成的模型收集

当使用 DeepSpeed ZeRO-3 时,模型权重会跨多个 GPU 分片。在线方法涉及在训练过程中从模型生成补全。在此步骤中,模型权重会临时收集在单个 GPU 上以进行生成。对于非常大的模型,这种收集可能会导致内存不足 (OOM) 错误,如本问题 #2250 中所述。

如果您遇到此问题,可以通过设置以下参数来禁用模型权重收集以进行生成

GRPO
Online DPO
PPO
RLOO
from trl import GRPOConfig

training_args = GRPOConfig(..., ds3_gather_for_generation=False)

此调整可防止模型权重被收集,从而避免 OOM 错误,但可能会导致生成速度变慢。

< > 更新 在 GitHub 上