高效的LLM预训练:打包序列和掩码注意力

社区文章 发布于 2024年10月7日
Description of the image
图片来源:https://huggingface.co/blog/poedator/4d-masks

训练大型语言模型 (LLM) 是一项计算密集型任务。它需要大量数据、强大的硬件和巧妙的优化技术。其中一项不常被提及的技术是使用打包序列,以充分利用每个训练步骤中选择的上下文长度。

想象一下,向 Transformer 模型输入一批长度不一的文本序列。为了保持一致的输入维度,较短的序列会用特殊标记进行填充。虽然这看起来无害,但它会通过关注无意义的填充标记来浪费宝贵的 GPU 内存。

解决方案:序列打包

打包序列提供了一个优雅的解决方案。我们不是进行填充,而是将多个较短的序列连接成一个更长的序列。这最大限度地减少了计算浪费(通过填充标记)。它还允许我们在每个批次中处理更多的标记,从而缩短训练时间。然而,这里有一个问题:我们需要确保模型不会跨序列边界进行关注。让我们看一个简单的例子。我们将以下三个句子打包成一个由EOS标记分隔的单个序列。

# Setup
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"

sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
tokenizer.decode(tokenized_sentences)

如果再次解码打包后的序列,它会是这样的

猫坐在垫子上<|endoftext|>狗吃了我的家庭作业<|endoftext|>我的阿姨是老师<|endoftext|>

打包序列的因果语言建模标准注意力掩码将如下所示。

tokenized_sentences = torch.tensor(tokenized_sentences)
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.int).tril()
attn_mask

image/png

然而,使用这个掩码,在处理第二个句子时,模型仍然可以关注第一个句子中的标记,这不理想,因为这两个例子是独立的。为了解决这个问题,我们可以以某种方式截断注意力掩码。当批次中只有一个样本时,在 PyTorch 中这样做相对容易。

def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    # store sequence length in variable for easier readability
    T = tokenized_sentences.size(0)
    # get indices of all EOS tokens
    eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()
    # from indices, get length of each sequence
    reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])
    # repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
    repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)
    # create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
    mask_indices = torch.arange(T).view(-1,1).expand(-1, T)
    # create causal mask and additionally mask out all tokens from preceeding sequences
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)
    mask.masked_fill_(mask_indices > repeated_idx, False)
    return mask

get_attention_mask_for_packed_sequence(tokenized_sentences, tokenizer.eos_token_id)

image/png

如您所见,标准因果掩码被截断以掩盖前一个句子中的标记。

相应调整位置ID

将序列打包在一起时,相应调整用于创建位置嵌入的位置ID非常重要。序列中的每个标记通常都有一个关联的位置ID,可以帮助模型理解标记的相对位置。当我们打包多个序列时,我们需要确保每个序列的位置ID从头开始(通常是0或1),而不是从上一个序列结束的地方继续。

通过调整位置ID,我们还清晰地标记了序列边界。这对于模型区分不同序列,而不将打包数据视为一个连续序列至关重要。

我们可以利用上述函数的代码来生成带位置ID的张量

pos_ids = torch.arange(T) - torch.repeat_interleave(torch.cat([torch.tensor([0]), eos_indices+1], dim=0)[:-1], reps)
pos_ids

用于批量序列打包的注意力掩码

通常在训练期间,我们希望处理整个序列批次。对于上述示例代码,我们将不得不使用循环实现来获取批次的截断注意力掩码。在不使用循环的情况下完成它会更具挑战性,因为有额外的批次维度。为了向您展示如何做到这一点,我们首先创建第二个打包序列项目,以获得大小为2的批次。

sentence4 = "Rome wasn't built in a day"
sentence5 = "My hovercraft is full of eels"

sentences = [sentence4, sentence5]
tokenized_sentences2 = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences2 = torch.tensor([t for s in tokenized_sentences2 for t in s + [tokenizer.eos_token_id]])

batch = torch.nn.utils.rnn.pad_sequence(
  [tokenized_sentences, tokenized_sentences2],
  batch_first=True, padding_value=tokenizer.eos_token_id
)

我们将批次的形状分配给两个变量 B 和 T。这使得以下代码更具可读性。

B, T = batch.shape

批处理实现的主要挑战是构建与上述示例中相同的“repeated_index”张量。首先,我们需要 EOS 标记的全局索引。

eos_idx = (batch.view(-1) == tokenizer.eos_token_id) \
  .nonzero(as_tuple=True)[0] + 1

image/png

我们将 0 索引和每个批次项的最后一个标记索引添加到这个索引向量中。这样做是为了稍后能够再次分离批次项。然后我们删除重复项(如果第一个或最后一个索引对于批次项已经存在),然后进行排序。

eos_idx_expanded = torch.cat(
  [eos_idx, torch.arange(0,B*T+1,T)]
).unique().sort()[0]

image/png

接下来,由于我们的索引向量包含批次中 EOS 标记的全局索引(例如,第二个批次项的第一个索引 = T),我们需要按序列长度对索引进行归一化。对于归一化的索引,我们将零替换为 T。这在下一步中是必需的。

normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)

image/png

通过归一化索引,我们可以检查每个 EOS 标记索引需要重复多少次才能获得正确的序列长度。为了实现这一点,我们需要存在每个序列的最后一个索引。如果我们在上一步中没有将 0 替换为 T,则每个批次中最后一个 EOS 索引的重复次数将是错误的。

reps = normalized_idx[1:] - normalized_idx[:-1]
reps = torch.where(reps < 1, normalized_idx[1:], reps)

image/png

现在我们可以创建批量重复索引张量

repeated_idx = torch.repeat_interleave(
  normalized_idx[1:], reps
).view(B,1,T).expand(-1,T,-1)

image/png

其余部分与批次大小为 1 的示例类似。我们构造一个张量,其中包含沿维度 1 重复 T 次的从 0 到 T-1 的索引,并创建因果掩码。然后我们掩盖前一个序列中的所有标记。

mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
# create mask
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
mask = mask.masked_fill(mask_indices >= repeated_idx, False)

这是完整的函数。我添加了选择检查 EOS 标记或 BOS 标记的功能。

def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
    B, T = x.shape
    eos_idx = (x.view(-1) == token_id).nonzero(as_tuple=True)[0] + eos
    eos_idx_expanded = torch.cat([eos_idx, torch.arange(0,B*T+1,T)]).unique().sort()[0]
    normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
    normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
    reps = normalized_idx[1:] - normalized_idx[:-1]
    reps = torch.where(reps < 1, normalized_idx[1:], reps)
    repeated_idx = torch.repeat_interleave(normalized_idx[1:], reps).view(B,1,T).expand(-1,T,-1)
    mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
    mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
    mask = mask.masked_fill(mask_indices >= repeated_idx, False)
    return mask

与上述不带批处理维度的示例类似,您可以重用创建注意力掩码的代码来获取正确的位置 ID。

pos_ids = (torch.arange(B*T) - torch.repeat_interleave(eos_idx_expanded[:-1], reps)).view(B,T)
pos_ids

您还可以在此 notebook 中找到所有代码片段。

社区

@sirluk ,感谢您的精彩文章。您知道上述掩码技术是否适用于某些注意力实现,并且与某些其他实现不兼容吗?

例如,上述掩码是否适用于 SDPA/flash_attention_2 和 eager(例如,这些实现在 https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L666 中处理方式略有不同)?

·
文章作者

@shantanuagarwal ,很高兴您喜欢这篇文章!尽管我没有亲自尝试过,但您应该能够利用 PyTorch FlexAttention API 来实现这一点。请查看此处的教程 https://pytorch.ac.cn/blog/flexattention/。“文档掩码/锯齿序列”部分讨论了这些打包序列掩码。

谢谢 @sirluk 的精彩文章!
有一点我不清楚

  • TRL 中的 SFTTrainer 包含一个“packing”选项。它是否处理您上面提到的掩码注意力(masked attention)和位置 ID(position ids)?
  • 如果使用 SFTTrainer,上述评论中 @shantanuagarwal 提出的担忧是否仍然相关?
·

几天后,我将根据我在代码中看到的内容回答我自己的问题。如果我错过了什么,请随时补充。

据我所知,SFTTrainer 当前的代码中

  • SFTTrainer 的当前选项(packing=True)不处理注意力掩码。
  • 它也不处理位置编码。

请参见此处 https://github.com/huggingface/trl/blob/64aa06499b2e71537a8e701fad076873b0f3603f/trl/trainer/sft_trainer.py#L351:使用 packing 选项准备数据集
此处 https://github.com/huggingface/trl/blob/64aa06499b2e71537a8e701fad076873b0f3603f/trl/trainer/sft_trainer.py#L663:仅对 input_ids 进行实际打包,使用 pack_dataset 函数。
此处 https://github.com/huggingface/trl/blob/e0dd5250217305f7f8c2f4a153a6939a2f16e2bf/trl/data_utils.py#L475:pack_dataset 函数本身。

据我理解,“不关注当前句子以外的标记”仅从分隔每个句子的 eos 标记推断。SFTTrainer 在这方面遵循了 GPT3 文章(“语言模型是少样本学习者”)作者选择的方法。请看以下摘录

Capture d’écran 2025-05-20 à 12.46.43.png

我运行了本文提供的代码。它一直正常运行,直到我得到

eos_idx:  tensor([ 7, 13, 19, 28, 37, 38])
eos_idx_expanded:  tensor([ 0,  7, 13, 19, 28, 37, 38])

这与您的不符

eos_idx: 7,13,23,35,44,45,46

注册登录 发表评论