如何在JAX中运行Hugging Face模型(第一部分)

社区文章 发布于2025年7月20日

(原始内容位于此处:https://github.com/qihqi/learning_machine/blob/main/jax-huggingface/01-run-huggingface-model-in-jax.md

Hugging Face 最近从其 transformers 库中移除了对 JAX 和 TensorFlow 的原生支持,旨在精简其代码库。这一决定让许多 JAX 用户疑惑,如何在不重新实现所有内容的情况下,继续利用 Hugging Face 庞大的模型集合。

image/jpeg

这篇博客文章探讨了一个解决方案:使用 JAX 输入运行基于 PyTorch 的 Hugging Face 模型。这种方法为依赖 Hugging Face 模型的 JAX 用户提供了一条宝贵的“出路”。

背景与方法

作为 torchax(一个旨在实现 JAX 和 PyTorch 无缝互操作性的新兴库)的作者,这次探索将是 torchax 的一次绝佳压力测试。让我们深入了解!


设置

我们将从标准的 Hugging Face 快速启动设置开始。如果您尚未设置环境

# Create venv / conda env; activate etc.
pip install huggingface-cli
huggingface-cli login # Set up your Hugging Face token
pip install -U transformers datasets evaluate accelerate timm flax

接下来,直接从最新开发版本安装 torchax

pip install torchax
pip install jax[tpu]  # or jax[cuda12] if you are on GPU

首次尝试:即时执行模式

我们首先实例化一个模型和分词器。我们将创建一个名为 jax_hg_01.py 的脚本,其中包含以下代码:

from transformers import AutoModelForCausalLM, AutoTokenizer
import jax # Import jax here for later use

# Load a PyTorch model and tokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="bfloat16", device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Tokenize input, requesting JAX arrays
model_inputs = tokenizer(["The secret to baking a good cake is "], return_tensors="jax")
print(model_inputs)

请注意分词器调用中关键的 return_tensors="jax"。这指示 Hugging Face 直接返回 JAX 数组,这对于我们使用 JAX 输入和 PyTorch 模型的目标至关重要。运行上述脚本将输出:

{'input_ids': Array([[    1,   450,  7035,   304,   289,  5086,   263,  1781,   274,
         1296,   338, 29871]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}

现在,让我们使用 torchax 将此 PyTorch 模型转换为 JAX 可调用对象。修改您的脚本如下:

import torchax
# ... (previous code)

weights, func = torchax.extract_jax(model)

torchax.extract_jax 函数将模型的 forward 方法转换为 JAX 兼容的可调用对象。它还会将模型的权重作为 JAX 数组的 Pytree 返回(这本质上是转换为 JAX 数组的 model.state_dict())。

有了 funcweights,我们现在可以调用这个 JAX 函数了。约定是首先传递 weights 作为第一个参数,然后是位置参数的元组 (args),最后是可选的关键字参数字典 (kwargs)。

让我们将调用添加到脚本中:

# ... (previous code)

print(func(weights, (model_inputs.input_ids, )))

执行此操作将产生以下输出,演示了成功的即时执行模式:

In [2]: import torchax

In [3]: weights, func = torchax.extract_jax(model)
WARNING:root:Duplicate op registration for aten.__and__

In [4]: print(func(weights, (model_inputs.input_ids, )))
CausalLMOutputWithPast(loss=None, logits=Array([[[-12.950611  ,  -7.4854484 ,  -0.42371067, ...,  -6.819363  ,
          -8.073828  ,  -7.5583534 ],
        [-13.508438  , -11.716616  ,  -6.9578876 , ...,  -9.135823  ,
         -10.237023  ,  -8.56888   ],
        [-12.8517685 , -11.180469  ,  -4.0543456 , ...,  -7.9564795 ,
         -11.546011  , -10.686134  ],
        ...,
        [ -2.983235  ,  -5.621302  ,  11.553352  , ...,  -2.6286669 ,
          -2.8319468 ,  -1.9902805 ],
        [ -8.674949  , -10.042385  ,   3.4400458 , ...,  -3.7776647 ,
          -8.616567  ,  -5.7228904 ],
        [ -4.0748825 ,  -4.706395  ,   5.117742  , ...,   6.7174563 ,
           0.5748794 ,   2.506649  ]]], dtype=float32), past_key_values=DynamicCache(), hidden_states=None, attentions=None)

要将关键字参数 (kwargs) 传递给函数,只需将其作为第三个参数添加:

print(func(weights, (model_inputs.input_ids, ), {'use_cache': False}))

虽然这展示了基本功能,但 JAX 的真正强大之处在于其**JIT 编译**。即时 (JIT) 编译可以显著加速计算,尤其是在 GPU 和 TPU 等加速器上。因此,我们的下一个逻辑步骤是对函数进行 jax.jit


Jitting - 摆弄 Pytrees

在 JAX 中,JIT 编译就像用 jax.jit 包装函数一样简单。让我们试试:

import jax
# ... (previous code)

func_jit = jax.jit(func)
res = func_jit(weights, (model_inputs.input_ids,))

运行此代码可能会导致 TypeError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/hanq_google_com/learning_machine/jax-huggingface/script.py", line 18, in <module>
    res = func_jit(weights, (model_inputs.input_ids,))
TypeError: function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 traced for jit returned a value of type <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>, which is not a valid JAX type

错误信息表明 JAX 无法理解 CausalLMOutputWithPast 类型。当您对函数进行 jax.jit 时,JAX 要求所有输入和输出都是“JAX 类型”——这意味着它们可以使用 jax.tree.flatten 扁平化为 JAX 可理解的元素列表。

为了解决这个问题,我们需要在 **JAX 的 Pytree 系统**中注册这些自定义类型。Pytrees 是嵌套数据结构(如元组、列表和字典),JAX 可以遍历并对其应用转换。通过注册自定义类型,我们告诉 JAX 如何将其分解为组成部分(子节点)并进行重构。

将以下内容添加到您的脚本中:

from jax.tree_util import register_pytree_node
from transformers import modeling_outputs

def output_flatten(v):
  return v.to_tuple(), None

def output_unflatten(aux, children):
  return modeling_outputs.CausalLMOutputWithPast(*children)

register_pytree_node(
  modeling_outputs.CausalLMOutputWithPast,
  output_flatten,
  output_unflatten,
)

此代码片段定义了 CausalLMOutputWithPast 对象应如何扁平化(将其内部组件转换为元组)和反扁平化(从这些组件重构)。现在,JAX 可以正确处理此类型。

但是,再次运行脚本时,您会遇到类似的错误:

Traceback (most recent call last):
  File "/home/hanq_google_com/learning_machine/jax-huggingface/script.py", line 33, in <module>
    res = func_jit(weights, (model_inputs.input_ids,))
TypeError: function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 traced for jit returned a value of type <class 'transformers.cache_utils.DynamicCache'> at output component [1], which is not a valid JAX type

对于 transformers.cache_utils.DynamicCache,需要相同的 Pytree 注册技巧:

from transformers import cache_utils

def _flatten_dynamic_cache(dynamic_cache):
  return (
      dynamic_cache.key_cache,
      dynamic_cache.value_cache,
  ), None

def _unflatten_dynamic_cache(aux, children):
  cache = cache_utils.DynamicCache()
  cache.key_cache, cache.value_cache = children
  return cache

register_pytree_node(
  cache_utils.DynamicCache,
  _flatten_dynamic_cache,
  _unflatten_dynamic_cache,
)

有了这些注册,我们就克服了 Pytree 类型问题。然而,又出现了另一个常见的 JAX 错误:

jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
The error occurred while tracing the function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['use_cache'].

See https://jax.net.cn/en/latest/errors.html#jax.errors.ConcretizationTypeError

静态参数

这个 ConcretizationTypeError 是一个经典的 JAX 问题。当您对函数进行 jax.jit 时,JAX 会追踪其执行以构建计算图。在此追踪过程中,它将所有输入视为*跟踪器*——值的符号表示——而不是它们的具体值。错误产生的原因是 if use_cache and past_key_values is None: 条件尝试读取 use_cache 的实际布尔值,而该值在追踪期间不可用。

有两种主要方法可以解决此问题:

  1. jax.jit 中使用 static_argnums 明确告诉 JAX 哪些参数是编译时常量。
  2. 使用**闭包**“嵌入”常量值。

对于本例,我们将演示闭包方法。我们将定义一个新函数,该函数封装了常量 use_cache 值,然后对该函数进行 JIT 编译:

import time
# ... (previous code including jax.tree_util imports and pytree registrations)

def func_with_constant(weights, input_ids):
  res = func(weights, (input_inputs_ids, ), {'use_cache': False}) # Pass use_cache as a fixed value
  return res

jitted_func = jax.jit(func_with_constant)
res = jitted_func(weights, model_inputs.input_ids)
print(res)

运行此更新后的脚本最终会产生预期的输出,与我们的即时执行模式结果一致:

CausalLMOutputWithPast(loss=Array([[[-12.926737  ,  -7.455758  ,  -0.42932802, ...,  -6.822556  ,
          -8.060653  ,  -7.5620213 ],
        [-13.511845  , -11.716769  ,  -6.9498663 , ...,  -9.14628   ,
         -10.245605  ,  -8.572137  ],
        [-12.842418  , -11.174898  ,  -4.0682483 , ...,  -7.9594035 ,
         -11.54412   , -10.675278  ],
        ...,
        [ -2.9683495 ,  -5.5914016 ,  11.563716  , ...,  -2.6254666 ,
          -2.8206763 ,  -1.9780521 ],
        [ -8.675585  , -10.044738  ,   3.4449315 , ...,  -3.7793014 ,
          -8.6158495 ,  -5.729558  ],
        [ -4.0751734 ,  -4.69619   ,   5.111123  , ...,   6.733637  ,
           0.57132554,   2.524692  ]]], dtype=float32), logits=None, past_key_values=None, hidden_states=None, attentions=None)

我们已成功将 PyTorch 模型转换为 JAX 函数,使其与 jax.jit 兼容,并成功执行!

JIT 编译函数的一个关键特性是其性能表现:首次运行通常由于编译而较慢,但后续运行会显著加快。让我们通过计时几次运行来验证这一点:

for i in range(3):
  start = time.time()
  res = jitted_func(weights, model_inputs.input_ids)
  jax.block_until_ready(res) # Ensure computation is complete
  end = time.time()
  print(i, end - start, 'seconds')

在 Google Cloud TPU v6e 上,结果清楚地证明了 JIT 的优势:

0 4.365400552749634 seconds
1 0.01341700553894043 seconds
2 0.013022422790527344 seconds

第一次运行花费了超过 4 秒,而后续运行在几毫秒内完成。这就是 JAX 编译的强大之处!

此示例的完整脚本可在随附存储库中的 jax_hg_01.py 中找到。


结论

本次探索表明,在 JAX 中运行 Hugging Face 的 torch.nn.Module 确实可行,尽管这需要解决一些“粗糙的边缘”。主要的挑战包括在 JAX 的 Pytree 系统中注册 Hugging Face 的自定义输出类型,以及管理 JIT 编译的静态参数。

未来,一个适配器库可以预先注册常见的 Hugging Face pytrees,并为 JAX 用户提供更流畅的集成体验。

后续步骤

我们已经奠定了基础!在下一部分中,我们将深入探讨:

  • 句子解码: 演示如何在此 JAX-PyTorch 设置中使用 model.generate 进行文本生成。
  • 张量并行: 展示如何将此解决方案扩展到多个 TPU(例如 8 个 TPU)上运行,以实现加速推理。

敬请期待!

社区

文章作者

太棒了!我很喜欢将自定义类型注册为 Pytrees 的技巧。

如果我理解正确,KV 缓存将不会在 JIT 编译的代码中使用。您是否探索过使用固定大小的 StaticCache,或者这会是您后续序列生成部分的内容?

文章作者

太棒了!我很喜欢将自定义类型注册为 Pytrees 的技巧。

如果我理解正确,KV 缓存将不会在 JIT 编译的代码中使用。您是否探索过使用固定大小的 StaticCache,或者这会是您后续序列生成部分的内容?

嗨,佩德罗,

是的,这是个好主意。我不知道有 StaticCache,我会去探索一下。谢谢!

注册登录 发表评论