通过使用 Flash Attention 提高 Hugging Face 训练效率

发布于 2024 年 8 月 21 日
在 GitHub 上更新

简而言之

通过最近的 PR 和新的 DataCollatorWithFlattening,现在 Hugging Face 中使用打包指令调优示例(无填充)的训练与 Flash Attention 2 兼容。

它可以在保持收敛质量的同时,将训练吞吐量提高多达 2 倍。请继续阅读详细信息!

引言

在训练期间,通常会采用在迷你批次中填充输入序列的方法来整理输入。然而,这会导致效率低下,因为存在不相关的填充标记。在不填充的情况下打包示例,并利用标记位置信息,是一种更有效的替代方案。但是,以前的打包实现在使用 Flash Attention 2 时没有考虑示例边界,导致出现不必要的跨示例注意力,从而降低质量和收敛性。

Hugging Face Transformers 现在通过一项新功能解决了这个问题,该功能在打包过程中保持了边界感知,同时引入了新的数据整理器 DataCollatorWithFlattening

通过选择 DataCollatorWithFlattening,Hugging Face Trainer 用户现在可以无缝地将序列连接成单个张量,同时在 Flash Attention 2 计算期间考虑序列边界。这是通过 flash_attn_varlen_func 实现的,它计算每个迷你批次中累积的序列长度(cu_seqlens)。

Hugging Face SFTTrainer 用户在调用数据整理器 DataCollatorForCompletionOnlyLM 时,通过设置一个新标志 padding_free=True,可以在 TRL 库中使用相同的功能。

吞吐量提升高达 2 倍

我们发现,使用此功能和新的 DataCollatorWithFlattening 后,训练吞吐量显著提高。下图显示了训练期间以 tokens/second 衡量的吞吐量。在此示例中,吞吐量是 8 个 A100-80 GPU 在一个 epoch 中,从两个不同的指令调优数据集(FLAN 和 OrcaMath)中随机选择 2 万个样本的每 GPU 平均值。

throughput

FLAN 的序列平均较短,但序列长度的方差较大,因此每个批次中示例的长度可能差异很大。这意味着填充的 FLAN 批次可能会因未使用的填充标记而产生显著的开销。使用新的 DataCollatorWithFlattening 在 FLAN 数据集上进行训练,在提高吞吐量方面显示出显著的优势。我们看到此处显示的模型:llama2-7B、mistral-7B 和 granite-8B-code 的吞吐量提高了 2 倍。

OrcaMath 具有更长的示例和更低的示例长度方差。因此,打包带来的改进较低。我们的实验表明,在 OrcaMath 数据集上使用这种形式的打包训练,这三种模型的吞吐量提高了 1.4 倍。

memory

使用新的 DataCollatorWithFlattening 进行打包也改善了内存使用。下图显示了相同三个模型在相同两个数据集上训练的峰值内存使用情况。FLAN 数据集上的峰值内存减少了 20%,这得益于打包的显著优势。

OrcaMath 数据集因其更同质的示例长度,峰值内存减少了 6%。

当减少优化步骤时,打包示例可能会损害训练收敛。然而,新功能保留了迷你批次,因此优化步骤的数量与使用填充示例时相同。因此,对训练收敛没有影响,正如我们在下图中所看到的,该图显示了相同三个模型在相同两个数据集上训练时,使用新的 DataCollatorWithFlattening 进行打包或使用填充训练的模型具有相同的验证损失。

ValLoss

工作原理

考虑一个批次大小为 4 的数据批次,其中四个序列如下:

batch

连接示例后,无填充整理器会返回每个示例的 input_idslabelsposition_ids。因此,对于此数据批次,整理器提供:

example

所需的修改很轻量,仅限于为 Flash Attention 2 提供 position_ids

然而,这依赖于模型公开 position_ids。截至本文撰写之时,有 14 个模型公开了它们并受此解决方案支持。具体来说,Llama 2 和 3、Mistral、Mixtral、Granite、DBRX、Falcon、Gemma、OLMo、Phi 1、2 和 3、phi3、Qwen 2 和 2 MoE、StableLM 和 StarCoder 2 都受此解决方案支持。

入门

利用 position_ids 打包的好处很容易实现。

如果您使用的是 Hugging Face Transformers 中的 Trainer,只需两个步骤:

  1. 使用 Flash Attention 2 实例化模型
  2. 使用新的 DataCollatorWithFlattening

如果您使用的是 TRL 中的 Hugging Face SFTTrainerDataCollatorForCompletionOnlyLM,则需要两个步骤:

  1. 使用 Flash Attention 2 实例化模型
  2. 调用 DataCollatorForCompletionOnlyLM 时,将 padding_free 设置为 True,如下所示:collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, padding_free=True)

如何使用

对于 Trainer 用户,以下示例说明了如何使用新功能。

# Example using DataCollatorWithFlattening
 
import torch

# load model as usual
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    "instructlab/merlinite-7b-lab",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2"
)

# read dataset as usual
from datasets import load_dataset
train_dataset = load_dataset("json", data_files="path/to/my/dataset")["train"]

# use DataCollatorWithFlattening
from transformers import DataCollatorWithFlattening
data_collator = DataCollatorWithFlattening()

# train
from transformers import TrainingArguments, Trainer
train_args = TrainingArguments(output_dir="/save/path")
trainer = Trainer(
    args=train_args,
    model=model,
    train_dataset=train_dataset,
    data_collator=data_collator
)
trainer.train()

对于 TRL 用户,以下示例展示了如何将新功能与 SFTTrainer 结合使用。

# SFTTrainer example using DataCollatorForCompletionOnlyLM

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained(
    "instructlab/merlinite-7b-lab",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained("instructlab/merlinite-7b-lab")
tokenizer.pad_token = tokenizer.eos_token 

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

response_template = " ### Answer:"
response_template_ids = tokenizer.encode(response_template, add_special_tokens=False)[2:]
collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer, padding_free=True)

trainer = SFTTrainer(
    model,
    train_dataset=dataset,
    args=SFTConfig(
        output_dir="./tmp",
        gradient_checkpointing=True,
        per_device_train_batch_size=8
    ),
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

trainer.train()

结论

得益于最近的 PR 和新的 DataCollatorWithFlattening,打包指令调优示例(而不是填充)现在与 Flash Attention 2 完全兼容。该方法与使用 position_ids 的模型兼容。在训练过程中,吞吐量和峰值内存使用方面可以看到好处,而训练收敛性没有下降。实际的吞吐量和内存改进取决于模型和训练数据中示例长度的分布。通过使用 DataCollatorWithFlattening,训练数据示例长度变化较大的模型将获得相对于填充的最大收益。TRL 库中的 SFTTrainer 用户可以通过在调用 DataCollatorForCompletionOnlyLM 时设置新标志 padding_free=True 来使用相同的功能。

有关更详细的分析,请参阅论文:https://huggingface.co/papers/2407.09105

社区

data_collator = DataCollatorWithFlattening() 这种方法无法从根本上防止注意力隔离。但它可以实现位置隔离。您意识到这个问题了吗? @RQlee @ArthurZ

注册登录 发表评论