如何在JAX中运行Hugging Face模型(第一部分)
(原始内容位于此处:https://github.com/qihqi/learning_machine/blob/main/jax-huggingface/01-run-huggingface-model-in-jax.md)
Hugging Face 最近从其 transformers
库中移除了对 JAX 和 TensorFlow 的原生支持,旨在精简其代码库。这一决定让许多 JAX 用户疑惑,如何在不重新实现所有内容的情况下,继续利用 Hugging Face 庞大的模型集合。
这篇博客文章探讨了一个解决方案:使用 JAX 输入运行基于 PyTorch 的 Hugging Face 模型。这种方法为依赖 Hugging Face 模型的 JAX 用户提供了一条宝贵的“出路”。
背景与方法
作为 torchax(一个旨在实现 JAX 和 PyTorch 无缝互操作性的新兴库)的作者,这次探索将是 torchax
的一次绝佳压力测试。让我们深入了解!
设置
我们将从标准的 Hugging Face 快速启动设置开始。如果您尚未设置环境
# Create venv / conda env; activate etc.
pip install huggingface-cli
huggingface-cli login # Set up your Hugging Face token
pip install -U transformers datasets evaluate accelerate timm flax
接下来,直接从最新开发版本安装 torchax
pip install torchax
pip install jax[tpu] # or jax[cuda12] if you are on GPU
首次尝试:即时执行模式
我们首先实例化一个模型和分词器。我们将创建一个名为 jax_hg_01.py
的脚本,其中包含以下代码:
from transformers import AutoModelForCausalLM, AutoTokenizer
import jax # Import jax here for later use
# Load a PyTorch model and tokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype="bfloat16", device_map="cpu")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Tokenize input, requesting JAX arrays
model_inputs = tokenizer(["The secret to baking a good cake is "], return_tensors="jax")
print(model_inputs)
请注意分词器调用中关键的 return_tensors="jax"
。这指示 Hugging Face 直接返回 JAX 数组,这对于我们使用 JAX 输入和 PyTorch 模型的目标至关重要。运行上述脚本将输出:
{'input_ids': Array([[ 1, 450, 7035, 304, 289, 5086, 263, 1781, 274,
1296, 338, 29871]], dtype=int32), 'attention_mask': Array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)}
现在,让我们使用 torchax
将此 PyTorch 模型转换为 JAX 可调用对象。修改您的脚本如下:
import torchax
# ... (previous code)
weights, func = torchax.extract_jax(model)
torchax.extract_jax
函数将模型的 forward
方法转换为 JAX 兼容的可调用对象。它还会将模型的权重作为 JAX 数组的 Pytree 返回(这本质上是转换为 JAX 数组的 model.state_dict()
)。
有了 func
和 weights
,我们现在可以调用这个 JAX 函数了。约定是首先传递 weights
作为第一个参数,然后是位置参数的元组 (args
),最后是可选的关键字参数字典 (kwargs
)。
让我们将调用添加到脚本中:
# ... (previous code)
print(func(weights, (model_inputs.input_ids, )))
执行此操作将产生以下输出,演示了成功的即时执行模式:
In [2]: import torchax
In [3]: weights, func = torchax.extract_jax(model)
WARNING:root:Duplicate op registration for aten.__and__
In [4]: print(func(weights, (model_inputs.input_ids, )))
CausalLMOutputWithPast(loss=None, logits=Array([[[-12.950611 , -7.4854484 , -0.42371067, ..., -6.819363 ,
-8.073828 , -7.5583534 ],
[-13.508438 , -11.716616 , -6.9578876 , ..., -9.135823 ,
-10.237023 , -8.56888 ],
[-12.8517685 , -11.180469 , -4.0543456 , ..., -7.9564795 ,
-11.546011 , -10.686134 ],
...,
[ -2.983235 , -5.621302 , 11.553352 , ..., -2.6286669 ,
-2.8319468 , -1.9902805 ],
[ -8.674949 , -10.042385 , 3.4400458 , ..., -3.7776647 ,
-8.616567 , -5.7228904 ],
[ -4.0748825 , -4.706395 , 5.117742 , ..., 6.7174563 ,
0.5748794 , 2.506649 ]]], dtype=float32), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
要将关键字参数 (kwargs) 传递给函数,只需将其作为第三个参数添加:
print(func(weights, (model_inputs.input_ids, ), {'use_cache': False}))
虽然这展示了基本功能,但 JAX 的真正强大之处在于其**JIT 编译**。即时 (JIT) 编译可以显著加速计算,尤其是在 GPU 和 TPU 等加速器上。因此,我们的下一个逻辑步骤是对函数进行 jax.jit
。
Jitting - 摆弄 Pytrees
在 JAX 中,JIT 编译就像用 jax.jit
包装函数一样简单。让我们试试:
import jax
# ... (previous code)
func_jit = jax.jit(func)
res = func_jit(weights, (model_inputs.input_ids,))
运行此代码可能会导致 TypeError
:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/hanq_google_com/learning_machine/jax-huggingface/script.py", line 18, in <module>
res = func_jit(weights, (model_inputs.input_ids,))
TypeError: function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 traced for jit returned a value of type <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>, which is not a valid JAX type
错误信息表明 JAX 无法理解 CausalLMOutputWithPast
类型。当您对函数进行 jax.jit
时,JAX 要求所有输入和输出都是“JAX 类型”——这意味着它们可以使用 jax.tree.flatten
扁平化为 JAX 可理解的元素列表。
为了解决这个问题,我们需要在 **JAX 的 Pytree 系统**中注册这些自定义类型。Pytrees 是嵌套数据结构(如元组、列表和字典),JAX 可以遍历并对其应用转换。通过注册自定义类型,我们告诉 JAX 如何将其分解为组成部分(子节点)并进行重构。
将以下内容添加到您的脚本中:
from jax.tree_util import register_pytree_node
from transformers import modeling_outputs
def output_flatten(v):
return v.to_tuple(), None
def output_unflatten(aux, children):
return modeling_outputs.CausalLMOutputWithPast(*children)
register_pytree_node(
modeling_outputs.CausalLMOutputWithPast,
output_flatten,
output_unflatten,
)
此代码片段定义了 CausalLMOutputWithPast
对象应如何扁平化(将其内部组件转换为元组)和反扁平化(从这些组件重构)。现在,JAX 可以正确处理此类型。
但是,再次运行脚本时,您会遇到类似的错误:
Traceback (most recent call last):
File "/home/hanq_google_com/learning_machine/jax-huggingface/script.py", line 33, in <module>
res = func_jit(weights, (model_inputs.input_ids,))
TypeError: function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 traced for jit returned a value of type <class 'transformers.cache_utils.DynamicCache'> at output component [1], which is not a valid JAX type
对于 transformers.cache_utils.DynamicCache
,需要相同的 Pytree 注册技巧:
from transformers import cache_utils
def _flatten_dynamic_cache(dynamic_cache):
return (
dynamic_cache.key_cache,
dynamic_cache.value_cache,
), None
def _unflatten_dynamic_cache(aux, children):
cache = cache_utils.DynamicCache()
cache.key_cache, cache.value_cache = children
return cache
register_pytree_node(
cache_utils.DynamicCache,
_flatten_dynamic_cache,
_unflatten_dynamic_cache,
)
有了这些注册,我们就克服了 Pytree 类型问题。然而,又出现了另一个常见的 JAX 错误:
jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape bool[]
This occurred in the item() method of jax.Array
The error occurred while tracing the function jax_func at /home/hanq_google_com/pytorch/xla/torchax/torchax/__init__.py:52 for jit. This concrete value was not available in Python because it depends on the value of the argument kwargs['use_cache'].
See https://jax.net.cn/en/latest/errors.html#jax.errors.ConcretizationTypeError
静态参数
这个 ConcretizationTypeError
是一个经典的 JAX 问题。当您对函数进行 jax.jit
时,JAX 会追踪其执行以构建计算图。在此追踪过程中,它将所有输入视为*跟踪器*——值的符号表示——而不是它们的具体值。错误产生的原因是 if use_cache and past_key_values is None:
条件尝试读取 use_cache
的实际布尔值,而该值在追踪期间不可用。
有两种主要方法可以解决此问题:
- 在
jax.jit
中使用static_argnums
明确告诉 JAX 哪些参数是编译时常量。 - 使用**闭包**“嵌入”常量值。
对于本例,我们将演示闭包方法。我们将定义一个新函数,该函数封装了常量 use_cache
值,然后对该函数进行 JIT 编译:
import time
# ... (previous code including jax.tree_util imports and pytree registrations)
def func_with_constant(weights, input_ids):
res = func(weights, (input_inputs_ids, ), {'use_cache': False}) # Pass use_cache as a fixed value
return res
jitted_func = jax.jit(func_with_constant)
res = jitted_func(weights, model_inputs.input_ids)
print(res)
运行此更新后的脚本最终会产生预期的输出,与我们的即时执行模式结果一致:
CausalLMOutputWithPast(loss=Array([[[-12.926737 , -7.455758 , -0.42932802, ..., -6.822556 ,
-8.060653 , -7.5620213 ],
[-13.511845 , -11.716769 , -6.9498663 , ..., -9.14628 ,
-10.245605 , -8.572137 ],
[-12.842418 , -11.174898 , -4.0682483 , ..., -7.9594035 ,
-11.54412 , -10.675278 ],
...,
[ -2.9683495 , -5.5914016 , 11.563716 , ..., -2.6254666 ,
-2.8206763 , -1.9780521 ],
[ -8.675585 , -10.044738 , 3.4449315 , ..., -3.7793014 ,
-8.6158495 , -5.729558 ],
[ -4.0751734 , -4.69619 , 5.111123 , ..., 6.733637 ,
0.57132554, 2.524692 ]]], dtype=float32), logits=None, past_key_values=None, hidden_states=None, attentions=None)
我们已成功将 PyTorch 模型转换为 JAX 函数,使其与 jax.jit
兼容,并成功执行!
JIT 编译函数的一个关键特性是其性能表现:首次运行通常由于编译而较慢,但后续运行会显著加快。让我们通过计时几次运行来验证这一点:
for i in range(3):
start = time.time()
res = jitted_func(weights, model_inputs.input_ids)
jax.block_until_ready(res) # Ensure computation is complete
end = time.time()
print(i, end - start, 'seconds')
在 Google Cloud TPU v6e 上,结果清楚地证明了 JIT 的优势:
0 4.365400552749634 seconds
1 0.01341700553894043 seconds
2 0.013022422790527344 seconds
第一次运行花费了超过 4 秒,而后续运行在几毫秒内完成。这就是 JAX 编译的强大之处!
此示例的完整脚本可在随附存储库中的 jax_hg_01.py
中找到。
结论
本次探索表明,在 JAX 中运行 Hugging Face 的 torch.nn.Module
确实可行,尽管这需要解决一些“粗糙的边缘”。主要的挑战包括在 JAX 的 Pytree 系统中注册 Hugging Face 的自定义输出类型,以及管理 JIT 编译的静态参数。
未来,一个适配器库可以预先注册常见的 Hugging Face pytrees,并为 JAX 用户提供更流畅的集成体验。
后续步骤
我们已经奠定了基础!在下一部分中,我们将深入探讨:
- 句子解码: 演示如何在此 JAX-PyTorch 设置中使用
model.generate
进行文本生成。 - 张量并行: 展示如何将此解决方案扩展到多个 TPU(例如 8 个 TPU)上运行,以实现加速推理。
敬请期待!