如何在 JAX 中运行 Hugging Face 模型(第 2 部分)

社区文章 发布于 2025 年 7 月 29 日

上一篇文章中,我们探讨了如何使用 Hugging Face 和 JAX 对 Llama 模型进行前向传播。这次,我们将实现相同的目标,但同时利用八个设备


张量并行性入门

我们将采用的并行化方案称为张量并行性,有时也称为 NeMo-Megatron 分片

tensor parallelism

Lightning AI 的这份文档对此进行了出色的解释。要点是我们可以执行两次矩阵乘法(matmuls)——一次按列分片,另一次按行分片——仅需要一次集体操作(all-reduce)

因此,我们可以按照以下方案对权重进行分片

  • 对于注意力块

    1. Q、K 和 V 投影分片(因为它们代表第一次矩阵乘法)。
    2. O 投影分片(因为它是第二次矩阵乘法)。
    3. 注意力机制本身不需要通信,因为它纯粹是数据并行的(头数被分片)。
  • 对于 FFN(前馈网络)

    1. Up 和 Gate 投影分片。
    2. Down 投影分片。

JAX 并行性支持入门

与 PyTorch 不同,JAX 的并行性支持使用 gSPMD(广义单程序多数据)模式。这意味着我们不需要为每个设备配备一个进程并手动管理集合操作,我们只需要指定一个 mesh 以及每个数组如何分片(通过分片约束)。然后 XLA 编译器会自动确定在哪里插入必要的集合操作。

这个过程在此处有详细描述:JAX 分布式数组和自动并行化

本质上,要并行运行我们的模型,我们需要两个关键点

  1. 定义一个 mesh:在我们的例子中,它只是 jax.make_mesh((jax.device_count(), ), ('axis', ))。请注意,我们给轴起的名称对功能没有显著影响。
  2. 知道模型的每个权重如何分片.

为了弄清第二点,我们打印出模型的权重并决定如何分片。


权重切分

让我们打印权重,以了解我们正在处理什么。在 weights, func = torchax.extract_jax(model) 之后添加以下代码:

for name, w in weights.items():
  print(name, w.shape)

我们将得到类似这样的输出

model.rotary_emb.inv_freq (64,)
model.embed_tokens.weight (32000, 4096)
model.layers.0.self_attn.q_proj.weight (4096, 4096)
model.layers.0.self_attn.k_proj.weight (4096, 4096)
model.layers.0.self_attn.v_proj.weight (4096, 4096)
model.layers.0.self_attn.o_proj.weight (4096, 4096)
model.layers.0.mlp.gate_proj.weight (11008, 4096)
model.layers.0.mlp.up_proj.weight (11008, 4096)
model.layers.0.mlp.down_proj.weight (4096, 11008)
model.layers.0.input_layernorm.weight (4096,)
model.layers.0.post_attention_layernorm.weight (4096,)
model.layers.1.self_attn.q_proj.weight (4096, 4096)
...

权重跨越 32 层。根据我们之前的讨论,我们需要按如下方式对其进行分片

  model.layers.0.self_attn.q_proj.weight (4096, 4096) -> ('axis', None)
  model.layers.0.self_attn.k_proj.weight (4096, 4096) -> ('axis', None)
  model.layers.0.self_attn.v_proj.weight (4096, 4096)-> ('axis', None)
  model.layers.0.self_attn.o_proj.weight (4096, 4096)-> (None, 'axis')
  model.layers.0.mlp.gate_proj.weight (11008, 4096)-> ('axis', None)
  model.layers.0.mlp.up_proj.weight (11008, 4096)-> ('axis', None)
  model.layers.0.mlp.down_proj.weight (4096, 11008)-> (None, 'axis')

除了讨论的权重之外,还有一个用于嵌入的权重和另一个用于最终输出投影的权重。对于这些,我们在分片方面有更大的灵活性。

现在,我们可以这样编写分片函数

def shard_weights_llama(mesh, weights):
  result = {}
  for k, v in weights.items():
    if (('q_proj' in k) or
        ('k_proj' in k) or
        ('v_proj' in k) or
        ('gate_proj' in k) or
        ('up_proj' in k)):
      sharding = P('axis', None)
    elif(('o_proj' in k) or
        ('down_proj' in k) or
        ('lm_head.weight' in k) or
        ('embed_tokens' in k)):
      sharding = P(None, 'axis')
    else:
      sharding = P() # replicated
    result[k] = jax.device_put(v, NamedSharding(mesh, sharding))
  return result

然后,我们可以使用 weights = shard_weights_llama(mesh, weights) 对权重进行分片。


再次运行

现在权重已分片,我们几乎可以分布式运行推理了!还有一个步骤:输入也需要在每个设备上可用,以便所有设备都可以用它进行计算。我们可以通过复制输入来完成此操作

model_inputs.input_ids = jax.device_put(
  model_inputs.input_ids, NamedSharding(mesh, P())) # replicate

再次运行脚本会得到

0 5.062012195587158 seconds
1 0.0038039684295654297 seconds
2 0.0034346580505371094 seconds

这比单设备版本大约快 4.3 倍。🚀


我们如何确保它确实在 8 个设备上运行?

虽然我们已经看到了推理速度的提升,但它并没有达到完全 8 倍的加速。为了确认它确实利用了所有 8 个设备并理解为什么加速不是线性的,我们可以使用 JAX profiler

要捕获配置文件,只需使用标准 JAX API 包装相关代码段即可

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=False):
  # Your inference code here

我使用了 xprof 插件和 TensorBoard 而不是 Perfetto,因为我在远程机器上。无论如何,结果是一个如下所示的可视化表示

image.png

从这个输出中,您可以验证所有 8 个设备的活动,并识别每个设备上正在运行的操作。这有助于找出瓶颈并理解整体并行执行。

要重现此文章的内容,请运行

python jax_hg_02.py

来自 https://github.com/qihqi/learning_machine/blob/main/jax-huggingface/jax_hg_02.py


结论

我们已成功演示了如何在不更改模型核心代码的情况下,以分布式方式运行 Llama 模型的前向传播。关键只是指定权重应如何分片。我们还展示了标准 JAX 分析工具如何确认分布式执行并帮助进行性能分析。

在下一篇文章中,我们将对 HuggingFace diffusers 库中的模型做同样的事情。

社区

注册登录以评论