高效的LLM预训练:打包序列和掩码注意力

训练大型语言模型 (LLM) 是一项计算密集型任务。它需要大量数据、强大的硬件和巧妙的优化技术。其中一项不常被提及的技术是使用打包序列,以充分利用每个训练步骤中选择的上下文长度。
想象一下,向 Transformer 模型输入一批长度不一的文本序列。为了保持一致的输入维度,较短的序列会用特殊标记进行填充。虽然这看起来无害,但它会通过关注无意义的填充标记来浪费宝贵的 GPU 内存。
解决方案:序列打包
打包序列提供了一个优雅的解决方案。我们不是进行填充,而是将多个较短的序列连接成一个更长的序列。这最大限度地减少了计算浪费(通过填充标记)。它还允许我们在每个批次中处理更多的标记,从而缩短训练时间。然而,这里有一个问题:我们需要确保模型不会跨序列边界进行关注。让我们看一个简单的例子。我们将以下三个句子打包成一个由EOS标记分隔的单个序列。
# Setup
import torch; torch.set_printoptions(linewidth=200)
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("gpt2")
config = AutoConfig.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_config(config)
sentence1 = "The cat sat on the mat"
sentence2 = "The dog ate my homework"
sentence3 = "My aunt is a teacher"
sentences = [sentence1, sentence2, sentence3]
tokenized_sentences = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences = [t for s in tokenized_sentences for t in s + [tokenizer.eos_token_id]]
tokenizer.decode(tokenized_sentences)
如果再次解码打包后的序列,它会是这样的
猫坐在垫子上<|endoftext|>狗吃了我的家庭作业<|endoftext|>我的阿姨是老师<|endoftext|>
打包序列的因果语言建模标准注意力掩码将如下所示。
tokenized_sentences = torch.tensor(tokenized_sentences)
attn_mask = torch.ones(tokenized_sentences.size(0), tokenized_sentences.size(0), dtype=torch.int).tril()
attn_mask
然而,使用这个掩码,在处理第二个句子时,模型仍然可以关注第一个句子中的标记,这不理想,因为这两个例子是独立的。为了解决这个问题,我们可以以某种方式截断注意力掩码。当批次中只有一个样本时,在 PyTorch 中这样做相对容易。
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
# store sequence length in variable for easier readability
T = tokenized_sentences.size(0)
# get indices of all EOS tokens
eos_indices = (tokenized_sentences == tokenizer.eos_token_id).nonzero().squeeze()
# from indices, get length of each sequence
reps = torch.cat([eos_indices[[0]]+1, eos_indices[1:] - eos_indices[:-1]])
# repeat each eos index n times along dimension 1 (n is the number of tokens in the sequence)
repeated_idx = torch.repeat_interleave(eos_indices, reps).view(1,-1).expand(T, -1)
# create tensor with all indices from 0 to T-1 repeated T times along dimesion 1
mask_indices = torch.arange(T).view(-1,1).expand(-1, T)
# create causal mask and additionally mask out all tokens from preceeding sequences
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(-1, -1)
mask.masked_fill_(mask_indices > repeated_idx, False)
return mask
get_attention_mask_for_packed_sequence(tokenized_sentences, tokenizer.eos_token_id)
如您所见,标准因果掩码被截断以掩盖前一个句子中的标记。
相应调整位置ID
将序列打包在一起时,相应调整用于创建位置嵌入的位置ID非常重要。序列中的每个标记通常都有一个关联的位置ID,可以帮助模型理解标记的相对位置。当我们打包多个序列时,我们需要确保每个序列的位置ID从头开始(通常是0或1),而不是从上一个序列结束的地方继续。
通过调整位置ID,我们还清晰地标记了序列边界。这对于模型区分不同序列,而不将打包数据视为一个连续序列至关重要。
我们可以利用上述函数的代码来生成带位置ID的张量
pos_ids = torch.arange(T) - torch.repeat_interleave(torch.cat([torch.tensor([0]), eos_indices+1], dim=0)[:-1], reps)
pos_ids
用于批量序列打包的注意力掩码
通常在训练期间,我们希望处理整个序列批次。对于上述示例代码,我们将不得不使用循环实现来获取批次的截断注意力掩码。在不使用循环的情况下完成它会更具挑战性,因为有额外的批次维度。为了向您展示如何做到这一点,我们首先创建第二个打包序列项目,以获得大小为2的批次。
sentence4 = "Rome wasn't built in a day"
sentence5 = "My hovercraft is full of eels"
sentences = [sentence4, sentence5]
tokenized_sentences2 = tokenizer(sentences, return_attention_mask=False, add_special_tokens=False)["input_ids"]
tokenized_sentences2 = torch.tensor([t for s in tokenized_sentences2 for t in s + [tokenizer.eos_token_id]])
batch = torch.nn.utils.rnn.pad_sequence(
[tokenized_sentences, tokenized_sentences2],
batch_first=True, padding_value=tokenizer.eos_token_id
)
我们将批次的形状分配给两个变量 B 和 T。这使得以下代码更具可读性。
B, T = batch.shape
批处理实现的主要挑战是构建与上述示例中相同的“repeated_index”张量。首先,我们需要 EOS 标记的全局索引。
eos_idx = (batch.view(-1) == tokenizer.eos_token_id) \
.nonzero(as_tuple=True)[0] + 1
我们将 0 索引和每个批次项的最后一个标记索引添加到这个索引向量中。这样做是为了稍后能够再次分离批次项。然后我们删除重复项(如果第一个或最后一个索引对于批次项已经存在),然后进行排序。
eos_idx_expanded = torch.cat(
[eos_idx, torch.arange(0,B*T+1,T)]
).unique().sort()[0]
接下来,由于我们的索引向量包含批次中 EOS 标记的全局索引(例如,第二个批次项的第一个索引 = T),我们需要按序列长度对索引进行归一化。对于归一化的索引,我们将零替换为 T。这在下一步中是必需的。
normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
通过归一化索引,我们可以检查每个 EOS 标记索引需要重复多少次才能获得正确的序列长度。为了实现这一点,我们需要存在每个序列的最后一个索引。如果我们在上一步中没有将 0 替换为 T,则每个批次中最后一个 EOS 索引的重复次数将是错误的。
reps = normalized_idx[1:] - normalized_idx[:-1]
reps = torch.where(reps < 1, normalized_idx[1:], reps)
现在我们可以创建批量重复索引张量
repeated_idx = torch.repeat_interleave(
normalized_idx[1:], reps
).view(B,1,T).expand(-1,T,-1)
其余部分与批次大小为 1 的示例类似。我们构造一个张量,其中包含沿维度 1 重复 T 次的从 0 到 T-1 的索引,并创建因果掩码。然后我们掩盖前一个序列中的所有标记。
mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
# create mask
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
mask = mask.masked_fill(mask_indices >= repeated_idx, False)
这是完整的函数。我添加了选择检查 EOS 标记或 BOS 标记的功能。
def get_attention_mask_for_packed_sequence(x, token_id, eos: bool = True):
B, T = x.shape
eos_idx = (x.view(-1) == token_id).nonzero(as_tuple=True)[0] + eos
eos_idx_expanded = torch.cat([eos_idx, torch.arange(0,B*T+1,T)]).unique().sort()[0]
normalized_idx = eos_idx_expanded - (eos_idx_expanded // T) * T
normalized_idx = torch.where(normalized_idx == 0, T, normalized_idx)
reps = normalized_idx[1:] - normalized_idx[:-1]
reps = torch.where(reps < 1, normalized_idx[1:], reps)
repeated_idx = torch.repeat_interleave(normalized_idx[1:], reps).view(B,1,T).expand(-1,T,-1)
mask_indices = torch.arange(T).view(1,-1,1).expand(B, -1, T)
mask = torch.ones(T, T, dtype=torch.bool).tril().expand(B, -1, -1)
mask = mask.masked_fill(mask_indices >= repeated_idx, False)
return mask
与上述不带批处理维度的示例类似,您可以重用创建注意力掩码的代码来获取正确的位置 ID。
pos_ids = (torch.arange(B*T) - torch.repeat_interleave(eos_idx_expanded[:-1], reps)).view(B,T)
pos_ids
您还可以在此 notebook 中找到所有代码片段。