深度混合是潮流

这是一篇关于深度混合(MoD)的简短开发博客。它将简要解释这篇论文、代码和一些想法。这篇博客假设您了解专家混合(Mixture of Experts)。
如果您只对代码感兴趣,请点击此链接 链接,这篇博客也在 pythonstuff(我的网站) 上发布。如有任何问题、评论和更正,请通过 Twitter @shxf0072 或 00shxf@gmail.com
联系我
问题所在
现代语言模型建立在 Transformer 块之上,我们堆叠 n 个 X 块来构建大型模型。一个 token 通过这些大型模型并预测下一个 token。现在的问题是,对于每个 token,我们都花费相同的计算量。举两个小的文本补全例子,3+17 = 20
和 3*17 = 51
。两者拥有相同数量的 token,但预测加法比预测乘法简单得多。然而,如果我们使用 Transformer,无论问题的复杂程度如何,我们都会花费相同的计算量。另一个例子是,如果你和朋友玩游戏,一些随意的对话可能会是 你应该买这把步枪
或者 你真是个菜鸟
。在这两句话中,最后一个词需要不同的大脑周期,对于第一句话,你会思考,而对于第二句话,你可能不会。同样,我们希望在某些词上花费不同于其他词的计算量。
深度混合(MoD)
深度混合试图通过在 Transformer 中添加一个简单的路由器来解决这个问题。这个路由器将决定 token 是应该通过这个块(进行计算)还是跳过这个块直接传递。
好的,这很容易,为什么人们以前没有想到呢?
他们想过很多次了,问题在于训练时我们使用批量的样本。如果我们决定某些 token 通过块,而另一些不通过,这会动态改变输入数组的形状,从而使训练在某些情况下变得缓慢甚至不可能。
基于容量的路由
本文采用基于容量的路由策略。容量是我们定义的,假设我们有一个长度为 L=100 的序列,我们将容量因子设置为 0.12,那么只有 12 个 token 会通过 Block (attn+mlp),其余的将直接进入下一个 Block。
这种容量方法来自《Mixture-of-Experts with Expert Choice Routing》。如果有一个专家集合的容量,它将只接受那么多 token。在这里我们做同样的事情,但却是垂直的。我们将每个块的容量设置为最大序列长度的某个百分比,并且只允许那么多的 token。这样做的好处是,如何分割输入数组是静态定义的,并且提前定义形状可以使 GPU 运行得更快。
实现细节
我已对代码进行注释,不再赘述。您应该阅读代码。您应该始终阅读代码
class MoD(nn.Module):
"""
Paper: https://arxiv.org/abs/2404.02258
"""
def __init__(self, cfg: Config) -> None:
super().__init__()
self.seq_len = cfg.seq_len
self.capacity_factor = cfg.capacity_factor
self.dim = cfg.d_model
self.transformer_decoder_block = Block(cfg)
self.router = nn.Linear(self.dim, 1, bias=False)
self.aux_router = nn.Sequential(
nn.Linear(self.dim,self.dim//2),
nn.SiLU(),
nn.Linear(self.dim//2,1),
)
def forward(
self, x: Tensor, mask, freqs_cis, mode="train", auxiliary_loss=False, *args, **kwargs
):
batch_size, seq_len, dim = x.shape
if mode == "inference":
return self.inference(x, *args, **kwargs)
# S = seq_len, C = capacity , C = int(seq_length * capacity_factor)
# page 6 above eq 1 | ( C<S ) | here top_k = beta
top_k = int(seq_len * self.capacity_factor)
# eq1 page 6
# scaler weights for each token
router_logits = self.router(x) # (x) batch,seq_len,dim -> r batch,seq_len,1
# 𝑟𝑙> 𝑃𝛽 (R) ... eqution 1
token_weights, token_index = torch.topk(router_logits, top_k, dim=1, sorted=False)
# now we have idx, we can copy this weights to another tensor and pass them to attn+mlp
# since its auto regressive model we need to keep casual nature of it
# that why we need sort the tokens by idx before we pass it to attn
selected_tokens, index = torch.sort(token_index, dim=1)
# select idx for copying for original tensor
indices_expanded = selected_tokens.expand(-1, -1, dim)
# This are fillted topk tokens with capactiy C
filtered_x = torch.gather(input=x, dim=1, index=indices_expanded) # -> batch, capacity, dim
x_out, _ = self.transformer_decoder_block(filtered_x, mask, freqs_cis)
# softmax router weights, aaah
token_weights = F.softmax(token_weights, dim=1)
# selecting router wight by idx ( in sorted maner)
r_weights = torch.gather(token_weights, dim=1, index=index)
# muliply by router weights, this add router in gradient stream
xw_out = r_weights * x_out
# batch_indices = torch.arange(batch_size).unsqueeze(-1).expand(-1, top_k)
# # # https://discuss.pytorch.org/t/when-inplace-operation-are-allowed-and-when-not/169583/2
# out = x.clone()
# # add back to resuidal strean
# out[batch_indices, selected_tokens.squeeze(-1),: ] += xw_out
# # ^ this can be done with torch.scatter_add
out = torch.scatter_add(input=x, dim=1, index=indices_expanded, src=xw_out)
if auxiliary_loss:
aux_loss = self.aux_loss( , router_logits, selected_tokens)
return out, aux_loss
return out, _
def aux_loss(self, x: Tensor, router_logits: Tensor, selected_tokens: Tensor):
batch_size, seq_len, dim = x.shape
# Page 7, Section 3.5 sampling
router_targets = torch.zeros_like(router_logits).view(
-1
) # i think torch to scatter will work here TODO
router_targets[selected_tokens.view(-1)] = 1.0
aux_router_logits = self.aux_router(x.detach().view(batch_size * seq_len, -1))
# aux_router_logits = F.sigmoid(aux_router_logits) # keep output in range [0,1)
# RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
# so binary_cross_entropy_with_logits == sigmoid + bce_loss
return F.binary_cross_entropy_with_logits(aux_router_logits.view(-1), router_targets)
还有一件事,我们交替使用 [MOD, Block, MOD...] 块,在某些层中,仅使用 MOD 效果不佳。
Softmax、因果关系和辅助闲聊
现在你可能已经注意到我们在整个序列长度上使用了 Softmax。这打破了因果关系,意味着未来的 token 会影响过去的 token。这是错的吗?是的。我们修复它了吗?没有。它只是恰好有效。
这篇论文建议添加另一个路由器来预测 token 是否被选中。
最初,我以为附加的路由器是一个简单的线性路由器,但它不是。我用线性路由器训练了一个模型,并为它编写了推理代码,它一直在闲聊,虽然能工作,但一直在闲聊。关于这个第二个路由器,没有任何信息,比如它有多少层,使用什么激活函数,什么算是“小”,隐藏维度是多少?这让读者自己去猜了 😒。
我原以为我会训练另一个模型,但我没有。所以,为了不让它成为这篇博客中又一个没有推理代码的东西,如果我训练了它,我会更新。
顺便说一句,如果我们将主路由器的 Softmax 替换为 Sigmoid,模型仍然可以工作,并且不会破坏因果关系。我对这方面有改进的想法,但会在有时间的时候去实现。
损失曲线
MoD 是一种训练模型的卓越方式。对于序列长度为 512、模型大小为 300M 的模型,MoD 比基线模型快 30%,同时实现了更低的损失(容量因子为 0.12)。大部分加速来自注意力计算。由于注意力是二次的,对于完全注意力,我们需要一个 n^2 矩阵,而 MoD 的容量因子为 0.12,因此只需 0.0144 n^2。节省量可以通过 n^2 - (0.12n)^2 = n^2 - 0.0144n^2 = n^2(0.9856) 来计算,因此节省量也比基线模型呈二次增长。
几个陷阱
- 使用 MoD,您无法进行批量推理,因为每个批次中的每个 token 都可以绕过块。如果您使用掩码,则可以实现,但此时,它与在带有路由器开销的普通模型上进行推理相同。
- 将整个序列放入模型并不会带来很大的加速,问题和上面一样,一些 token 会通过块,一些不会,而在推理时我们不希望固定容量路由。
- 现有的推测解码加速技术将不起作用,或者不像在普通模型中那样有用。
总的来说,这是一个可靠的架构,即使我们像普通模型一样进行推理,仅仅训练速度的提升就使其值得。谷歌干得好,感谢你们使知识“开放”(咳咳)。