如何在 JAX 中运行 Hugging Face 模型(第 2 部分)
在上一篇文章中,我们探讨了如何使用 Hugging Face 和 JAX 对 Llama 模型进行前向传播。这次,我们将实现相同的目标,但同时利用八个设备。
张量并行性入门
我们将采用的并行化方案称为张量并行性,有时也称为 NeMo-Megatron 分片。
Lightning AI 的这份文档对此进行了出色的解释。要点是我们可以执行两次矩阵乘法(matmuls)——一次按列分片,另一次按行分片——仅需要一次集体操作(all-reduce)。
因此,我们可以按照以下方案对权重进行分片
对于注意力块
- Q、K 和 V 投影按列分片(因为它们代表第一次矩阵乘法)。
- O 投影按行分片(因为它是第二次矩阵乘法)。
- 注意力机制本身不需要通信,因为它纯粹是数据并行的(头数被分片)。
对于 FFN(前馈网络)
- Up 和 Gate 投影按列分片。
- Down 投影按行分片。
JAX 并行性支持入门
与 PyTorch 不同,JAX 的并行性支持使用 gSPMD(广义单程序多数据)模式。这意味着我们不需要为每个设备配备一个进程并手动管理集合操作,我们只需要指定一个 mesh
以及每个数组如何分片(通过分片约束)。然后 XLA 编译器会自动确定在哪里插入必要的集合操作。
这个过程在此处有详细描述:JAX 分布式数组和自动并行化。
本质上,要并行运行我们的模型,我们需要两个关键点
- 定义一个
mesh
:在我们的例子中,它只是jax.make_mesh((jax.device_count(), ), ('axis', ))
。请注意,我们给轴起的名称对功能没有显著影响。 - 知道模型的每个权重如何分片.
为了弄清第二点,我们打印出模型的权重并决定如何分片。
权重切分
让我们打印权重,以了解我们正在处理什么。在 weights, func = torchax.extract_jax(model)
之后添加以下代码:
for name, w in weights.items():
print(name, w.shape)
我们将得到类似这样的输出
model.rotary_emb.inv_freq (64,)
model.embed_tokens.weight (32000, 4096)
model.layers.0.self_attn.q_proj.weight (4096, 4096)
model.layers.0.self_attn.k_proj.weight (4096, 4096)
model.layers.0.self_attn.v_proj.weight (4096, 4096)
model.layers.0.self_attn.o_proj.weight (4096, 4096)
model.layers.0.mlp.gate_proj.weight (11008, 4096)
model.layers.0.mlp.up_proj.weight (11008, 4096)
model.layers.0.mlp.down_proj.weight (4096, 11008)
model.layers.0.input_layernorm.weight (4096,)
model.layers.0.post_attention_layernorm.weight (4096,)
model.layers.1.self_attn.q_proj.weight (4096, 4096)
...
权重跨越 32 层。根据我们之前的讨论,我们需要按如下方式对其进行分片
model.layers.0.self_attn.q_proj.weight (4096, 4096) -> ('axis', None)
model.layers.0.self_attn.k_proj.weight (4096, 4096) -> ('axis', None)
model.layers.0.self_attn.v_proj.weight (4096, 4096)-> ('axis', None)
model.layers.0.self_attn.o_proj.weight (4096, 4096)-> (None, 'axis')
model.layers.0.mlp.gate_proj.weight (11008, 4096)-> ('axis', None)
model.layers.0.mlp.up_proj.weight (11008, 4096)-> ('axis', None)
model.layers.0.mlp.down_proj.weight (4096, 11008)-> (None, 'axis')
除了讨论的权重之外,还有一个用于嵌入的权重和另一个用于最终输出投影的权重。对于这些,我们在分片方面有更大的灵活性。
现在,我们可以这样编写分片函数
def shard_weights_llama(mesh, weights):
result = {}
for k, v in weights.items():
if (('q_proj' in k) or
('k_proj' in k) or
('v_proj' in k) or
('gate_proj' in k) or
('up_proj' in k)):
sharding = P('axis', None)
elif(('o_proj' in k) or
('down_proj' in k) or
('lm_head.weight' in k) or
('embed_tokens' in k)):
sharding = P(None, 'axis')
else:
sharding = P() # replicated
result[k] = jax.device_put(v, NamedSharding(mesh, sharding))
return result
然后,我们可以使用 weights = shard_weights_llama(mesh, weights)
对权重进行分片。
再次运行
现在权重已分片,我们几乎可以分布式运行推理了!还有一个步骤:输入也需要在每个设备上可用,以便所有设备都可以用它进行计算。我们可以通过复制输入来完成此操作
model_inputs.input_ids = jax.device_put(
model_inputs.input_ids, NamedSharding(mesh, P())) # replicate
再次运行脚本会得到
0 5.062012195587158 seconds
1 0.0038039684295654297 seconds
2 0.0034346580505371094 seconds
这比单设备版本大约快 4.3 倍。🚀
我们如何确保它确实在 8 个设备上运行?
虽然我们已经看到了推理速度的提升,但它并没有达到完全 8 倍的加速。为了确认它确实利用了所有 8 个设备并理解为什么加速不是线性的,我们可以使用 JAX profiler。
要捕获配置文件,只需使用标准 JAX API 包装相关代码段即可
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=False):
# Your inference code here
我使用了 xprof 插件和 TensorBoard 而不是 Perfetto,因为我在远程机器上。无论如何,结果是一个如下所示的可视化表示
从这个输出中,您可以验证所有 8 个设备的活动,并识别每个设备上正在运行的操作。这有助于找出瓶颈并理解整体并行执行。
要重现此文章的内容,请运行
python jax_hg_02.py
来自 https://github.com/qihqi/learning_machine/blob/main/jax-huggingface/jax_hg_02.py
结论
我们已成功演示了如何在不更改模型核心代码的情况下,以分布式方式运行 Llama 模型的前向传播。关键只是指定权重应如何分片。我们还展示了标准 JAX 分析工具如何确认分布式执行并帮助进行性能分析。
在下一篇文章中,我们将对 HuggingFace diffusers 库中的模型做同样的事情。