在服务器上运行推理
推理是使用训练好的模型对新数据进行预测的过程。由于此过程可能计算密集型,因此在专用服务器上运行可能是一个不错的选择。该 huggingface_hub
库提供了一种简单的方法来调用运行托管模型推理的服务。您可以连接到多种服务
- 推理 API:一项允许您在 Hugging Face 的基础设施上免费运行加速推理的服务。此服务是快速入门、测试不同模型和原型设计 AI 产品的快速方法。
- 推理端点:一个轻松将模型部署到生产环境的产品。推理由 Hugging Face 在您选择的云提供商的专用、完全托管的基础设施上运行。
可以使用 InferenceClient 对象调用这些服务。它充当旧版 InferenceApi 客户端的替代品,增加了对任务的特定支持,并在 推理 API 和 推理端点 上处理推理。了解如何在 旧版 InferenceAPI 客户端 部分中迁移到新客户端。
InferenceClient 是一个 Python 客户端,它对我们的 API 发出 HTTP 调用。如果您想使用您首选的工具(curl、postman 等)直接发出 HTTP 调用,请参阅 推理 API 或 推理端点 文档页面。
开始使用
让我们从一个文本到图像的任务开始
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> image = client.text_to_image("An astronaut riding a horse on the moon.")
>>> image.save("astronaut.png") # 'image' is a PIL.Image object
在上面的示例中,我们使用默认参数初始化了一个InferenceClient。您只需要知道要执行的任务。默认情况下,客户端将连接到推理 API 并选择一个模型来完成任务。在我们的示例中,我们根据文本提示生成了一个图像。返回值是一个PIL.Image
对象,可以保存到文件中。有关更多详细信息,请查看text_to_image()文档。
现在让我们看一个使用[~InferenceClient.chat_completion
] API 的示例。此任务使用 LLM 根据消息列表生成回复
>>> from huggingface_hub import InferenceClient
>>> messages = [{"role": "user", "content": "What is the capital of France?"}]
>>> client = InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
>>> client.chat_completion(messages, max_tokens=100)
ChatCompletionOutput(
choices=[
ChatCompletionOutputComplete(
finish_reason='eos_token',
index=0,
message=ChatCompletionOutputMessage(
role='assistant',
content='The capital of France is Paris.',
name=None,
tool_calls=None
),
logprobs=None
)
],
created=1719907176,
id='',
model='meta-llama/Meta-Llama-3-8B-Instruct',
object='text_completion',
system_fingerprint='2.0.4-sha-f426a33',
usage=ChatCompletionOutputUsage(
completion_tokens=8,
prompt_tokens=17,
total_tokens=25
)
)
在此示例中,我们指定了要使用的模型("meta-llama/Meta-Llama-3-8B-Instruct"
)。您可以在此页面上找到兼容模型的列表。然后,我们提供了一个要完成的消息列表(此处为单个问题)并向 API 传递了一个额外的参数(max_token=100
)。输出是一个遵循 OpenAI 规范的ChatCompletionOutput
对象。生成的內容可以通过output.choices[0].message.content
访问。有关更多详细信息,请查看chat_completion()文档。
API 的设计非常简单。并非所有参数和选项都可供最终用户使用或描述。如果您有兴趣了解每个任务可用的所有参数,请查看此页面。
使用特定模型
如果您想使用特定模型怎么办?您可以将其作为参数指定,也可以在实例级别直接指定。
>>> from huggingface_hub import InferenceClient
# Initialize client for a specific model
>>> client = InferenceClient(model="prompthero/openjourney-v4")
>>> client.text_to_image(...)
# Or use a generic client but pass your model as an argument
>>> client = InferenceClient()
>>> client.text_to_image(..., model="prompthero/openjourney-v4")
Hugging Face Hub 上有超过 200k 个模型!InferenceClient 中的每个任务都带有一个推荐模型。请注意,HF 推荐可能会随时更改,恕不另行通知。因此,最好在您做出决定后明确设置模型。此外,在大多数情况下,您会感兴趣的是找到满足您特定需求的模型。访问 Hub 上的模型页面以探索您的可能性。
使用特定 URL
我们上面看到的示例使用了无服务器推理 API。这对于快速原型设计和测试非常有用。一旦您准备好将模型部署到生产环境中,您将需要使用专用的基础设施。这就是推理端点发挥作用的地方。它允许您部署任何模型并将其公开为私有 API。部署后,您将获得一个 URL,您可以使用与之前完全相同的代码连接到该 URL,只需更改model
参数即可
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if")
# or
>>> client = InferenceClient()
>>> client.text_to_image(..., model="https://uu149rez6gw9ehej.eu-west-1.aws.endpoints.huggingface.cloud/deepfloyd-if")
身份验证
使用InferenceClient 进行的调用可以使用用户访问令牌进行身份验证。默认情况下,如果您已登录,它将使用保存在您计算机上的令牌(请查看如何进行身份验证)。如果您未登录,则可以将您的令牌作为实例参数传递。
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient(token="hf_***")
使用推理 API 时,身份验证不是强制性的。但是,经过身份验证的用户可以获得更高的免费套餐来使用该服务。如果您想对私有模型或私有端点运行推理,则令牌也是强制性的。
OpenAI 兼容性
chat_completion
任务遵循OpenAI 的 Python 客户端语法。这对您意味着什么?这意味着如果您习惯于使用OpenAI
的 API,则只需更新两行代码即可切换到huggingface_hub.InferenceClient
来使用开源模型!
- from openai import OpenAI
+ from huggingface_hub import InferenceClient
- client = OpenAI(
+ client = InferenceClient(
base_url=...,
api_key=...,
)
output = client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Count to 10"},
],
stream=True,
max_tokens=1024,
)
for chunk in output:
print(chunk.choices[0].delta.content)
就是这样!唯一需要的更改是将from openai import OpenAI
替换为from huggingface_hub import InferenceClient
,并将client = OpenAI(...)
替换为client = InferenceClient(...)
。您可以通过将模型 ID 作为model
参数传递来从 Hugging Face Hub 中选择任何 LLM 模型。这是一个受支持模型的列表。对于身份验证,您应该将有效的用户访问令牌作为api_key
传递,或使用huggingface_hub
进行身份验证(请参阅身份验证指南)。
所有输入参数和输出格式都严格相同。特别是,您可以传递stream=True
以在生成时接收令牌。您还可以使用AsyncInferenceClient使用asyncio
运行推理。
import asyncio
- from openai import AsyncOpenAI
+ from huggingface_hub import AsyncInferenceClient
- client = AsyncOpenAI()
+ client = AsyncInferenceClient()
async def main():
stream = await client.chat.completions.create(
model="meta-llama/Meta-Llama-3-8B-Instruct",
messages=[{"role": "user", "content": "Say this is a test"}],
stream=True,
)
async for chunk in stream:
print(chunk.choices[0].delta.content or "", end="")
asyncio.run(main())
您可能想知道为什么使用InferenceClient而不是 OpenAI 的客户端?这有几个原因
- InferenceClient配置用于 Hugging Face 服务。您无需提供
base_url
即可在无服务器推理 API 上运行模型。如果您的计算机已正确登录,您也不需要提供token
或api_key
。 - InferenceClient专为文本生成推理 (TGI) 和
transformers
框架而设计,这意味着您可以确保它始终与最新更新保持一致。 - InferenceClient与我们的推理端点服务集成,使启动推理端点、检查其状态并在其上运行推理变得更加容易。有关更多详细信息,请查看推理端点指南。
InferenceClient.chat.completions.create
只是InferenceClient.chat_completion
的别名。有关更多详细信息,请查看chat_completion()的包引用。在实例化客户端时,base_url
和api_key
参数也是model
和token
的别名。这些别名已定义为减少从OpenAI
切换到InferenceClient
时的摩擦。
支持的任务
InferenceClient 的目标是提供最简单的接口来在 Hugging Face 模型上运行推理。它拥有一个简单的 API,支持最常见的任务。以下是当前支持的任务列表
查看任务页面,以了解有关每个任务、如何使用它们以及每个任务最流行的模型的更多信息。
自定义请求
但是,并非总是能够涵盖所有用例。对于自定义请求,InferenceClient.post() 方法为您提供了灵活性,可以向推理 API 发送任何请求。例如,您可以指定如何解析输入和输出。在下面的示例中,生成的图像作为原始字节返回,而不是将其解析为PIL Image
。如果您在设置中没有安装Pillow
并且只关心图像的二进制内容,这将非常有用。InferenceClient.post() 也可用于处理尚未正式支持的任务。
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> response = client.post(json={"inputs": "An astronaut riding a horse on the moon."}, model="stabilityai/stable-diffusion-2-1")
>>> response.content # raw bytes
b'...'
异步客户端
还提供了基于asyncio
和aiohttp
的客户端的异步版本。您可以直接安装aiohttp
或使用[inference]
额外功能
pip install --upgrade huggingface_hub[inference]
# or
# pip install aiohttp
安装后,所有异步 API 端点都可通过AsyncInferenceClient访问。它的初始化和 API 与仅同步版本完全相同。
# Code must be run in a asyncio concurrent context.
# $ python -m asyncio
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()
>>> image = await client.text_to_image("An astronaut riding a horse on the moon.")
>>> image.save("astronaut.png")
>>> async for token in await client.text_generation("The Huggingface Hub is", stream=True):
... print(token, end="")
a platform for sharing and discussing ML-related content.
有关asyncio
模块的更多信息,请参阅官方文档。
高级提示
在上一节中,我们看到了InferenceClient的主要方面。让我们深入了解一些更高级的提示。
超时
在进行推理时,超时主要有两个原因
- 推理过程需要很长时间才能完成。
- 模型不可用,例如当推理 API 第一次加载它时。
InferenceClient 有一个全局timeout
参数来处理这两个方面。默认情况下,它设置为None
,这意味着客户端将无限期地等待推理完成。如果您希望在工作流程中获得更多控制,可以将其设置为以秒为单位的特定值。如果超时延迟到期,则会引发InferenceTimeoutError。您可以捕获它并在代码中进行处理
>>> from huggingface_hub import InferenceClient, InferenceTimeoutError
>>> client = InferenceClient(timeout=30)
>>> try:
... client.text_to_image(...)
... except InferenceTimeoutError:
... print("Inference timed out after 30s.")
二进制输入
某些任务需要二进制输入,例如,在处理图像或音频文件时。InferenceClient 尝试尽可能地宽容并接受不同的类型
- 原始
bytes
- 文件类对象,以二进制方式打开(
with open("audio.flac", "rb") as f: ...
) - 指向本地文件的路径(
str
或Path
) - 指向远程文件的 URL(
str
)(例如https://...
)。在这种情况下,文件将在发送到推理 API 之前在本地下载。
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> client.image_classification("https://upload.wikimedia.org/wikipedia/commons/thumb/4/43/Cute_dog.jpg/320px-Cute_dog.jpg")
[{'score': 0.9779096841812134, 'label': 'Blenheim spaniel'}, ...]
旧版 InferenceAPI 客户端
InferenceClient 充当旧版InferenceApi客户端的替代品。它为任务添加了特定支持,并处理推理 API和推理端点上的推理。
这是一个简短指南,可帮助您从InferenceApi迁移到InferenceClient。
初始化
从
>>> from huggingface_hub import InferenceApi
>>> inference = InferenceApi(repo_id="bert-base-uncased", token=API_TOKEN)
到
>>> from huggingface_hub import InferenceClient
>>> inference = InferenceClient(model="bert-base-uncased", token=API_TOKEN)
在特定任务上运行
从
>>> from huggingface_hub import InferenceApi
>>> inference = InferenceApi(repo_id="paraphrase-xlm-r-multilingual-v1", task="feature-extraction")
>>> inference(...)
到
>>> from huggingface_hub import InferenceClient
>>> inference = InferenceClient()
>>> inference.feature_extraction(..., model="paraphrase-xlm-r-multilingual-v1")
这是将您的代码适配到InferenceClient的推荐方法。它允许您受益于特定于任务的方法,例如feature_extraction
。
运行自定义请求
从
>>> from huggingface_hub import InferenceApi
>>> inference = InferenceApi(repo_id="bert-base-uncased")
>>> inference(inputs="The goal of life is [MASK].")
[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
到
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> response = client.post(json={"inputs": "The goal of life is [MASK]."}, model="bert-base-uncased")
>>> response.json()
[{'sequence': 'the goal of life is life.', 'score': 0.10933292657136917, 'token': 2166, 'token_str': 'life'}]
使用参数运行
从
>>> from huggingface_hub import InferenceApi
>>> inference = InferenceApi(repo_id="typeform/distilbert-base-uncased-mnli")
>>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
>>> params = {"candidate_labels":["refund", "legal", "faq"]}
>>> inference(inputs, params)
{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}
到
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()
>>> inputs = "Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!"
>>> params = {"candidate_labels":["refund", "legal", "faq"]}
>>> response = client.post(json={"inputs": inputs, "parameters": params}, model="typeform/distilbert-base-uncased-mnli")
>>> response.json()
{'sequence': 'Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!', 'labels': ['refund', 'faq', 'legal'], 'scores': [0.9378499388694763, 0.04914155602455139, 0.013008488342165947]}