宣布 Hugging Face 和 KerasHub 新集成
Hugging Face Hub 是一个庞大的存储库,目前托管着 75 万多个公共模型,为各种机器学习框架提供了多样化的预训练模型。其中,346,268 个模型(截至撰写本文时)是使用流行的 Transformers 库构建的。KerasHub 库最近新增了一个与 Hub 的集成,首批兼容 33 个模型。
在第一个版本中,KerasHub 用户*仅限于*使用 Hugging Face Hub 上可用的基于 KerasHub 的模型。
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b-keras"
)
他们能够训练/微调模型并将其上传回 Hub(请注意,该模型仍然是 Keras 模型)。
model.save_to_preset("./gemma-2b-finetune")
keras_hub.upload_preset(
"hf://username/gemma-2b-finetune",
"./gemma-2b-finetune"
)
他们错过了使用 transformers 库创建的超过 30 万个模型的庞大集合。图 1 展示了 Hub 中的 4k Gemma 模型。
![]() |
---|
图 1:Hugging Face Hub 中的 Gemma 模型(来源:https://huggingface.co/models?other=gemma) |
然而,如果现在我们告诉您,您可以使用 KerasHub 访问和使用这 30 多万个模型,这将显著扩展您的模型选择和功能,您会作何感想?
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b" # this is not a keras model!
)
我们很高兴地宣布 Hub 社区迈出了重要一步:Transformers 和 KerasHub 现在拥有**共享**的模型保存格式。这意味着 Hugging Face Hub 上的 transformers 库模型现在也可以直接加载到 KerasHub 中——立即为 KerasHub 用户提供了大量微调模型。最初,此集成侧重于启用 Gemma(1 和 2)、Llama 3 和 PaliGemma 模型的使用,并计划在不久的将来将兼容性扩展到更广泛的架构。
使用更广泛的框架
由于 KerasHub 模型可以无缝使用 **TensorFlow**、**JAX** 或 **PyTorch** 后端,这意味着大量的模型检查点现在可以通过一行代码加载到任何这些框架中。在 Hugging Face 上找到了一个很棒的检查点,但您希望将其部署到 TFLite 进行服务或将其移植到 JAX 进行研究?现在您可以了!
如何使用
使用此集成需要更新您的 Keras 版本
$ pip install -U -q keras-hub
$ pip install -U keras>=3.3.3
更新后,尝试集成就像以下代码一样简单:
from keras_hub.models import Llama3CausalLM
# this model was not fine-tuned with Keras but can still be loaded
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
causal_lm.summary()
幕后:工作原理
Transformers 模型以 JSON 格式的配置文件的形式存储,一个分词器(通常也是一个 .JSON 文件),以及一组 safetensors 权重文件。实际的模型代码包含在 Transformers 库本身中。这意味着,只要两个库都有相关架构的模型代码,将 Transformers 检查点交叉加载到 KerasHub 中就相对简单。我们所需要做的就是将配置变量、权重名称和分词器词汇从一种格式映射到另一种格式,然后我们就可以从 Transformers 检查点创建 KerasHub 检查点,反之亦然。
所有这些都在内部为您处理,因此您可以专注于尝试模型,而不是转换它们!
常见用例
生成
语言模型的第一个用例是生成文本。下面是一个示例,演示如何加载 transformer 模型并使用 KerasHub 的 .generate
方法生成新 token。
from keras_hub.models import Llama3CausalLM
# Get the model
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
prompts = [
"""<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]
# Generate from the model
causal_lm.generate(prompts, max_length=200)[0]
更改精度
您可以使用 keras.config
更改模型的精度,如下所示
import keras
keras.config.set_dtype_policy("bfloat16")
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
在 JAX 后端使用检查点
要使用 JAX 试用模型,您可以利用 Keras 在 JAX 后端运行模型。这可以通过简单地将 Keras 的后端切换到 JAX 来实现。以下是您在 JAX 环境中使用模型的方法。
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
Gemma 2
我们很高兴地通知您,Gemma 2 模型也与此集成兼容。
from keras_hub.models import GemmaCausalLM
causal_lm = keras_hub.models.GemmaCausalLM.from_preset(
"hf://google/gemma-2-9b" # This is Gemma 2!
)
PaliGemma
您还可以在 KerasHub 管道中使用任何 PaliGemma safetensor 检查点。
from keras_hub.models import PaliGemmaCausalLM
pali_gemma_lm = PaliGemmaCausalLM.from_preset(
"hf://gokaygokay/sd3-long-captioner" # A finetuned version of PaliGemma
)
接下来呢?
这仅仅是个开始。我们设想将此集成扩展到更广泛的 Hugging Face 模型和架构。请继续关注更新,并务必探索此次合作带来的巨大潜力!
我想借此机会感谢 Matthew Carrigan 和 Matthew Watson 在整个过程中给予的帮助。