Transformers 文档

Web 服务器推理

Hugging Face's logo
加入 Hugging Face 社区

并获取增强的文档体验

开始使用

Web 服务器推理

Web 服务器是一个系统,它等待请求并在请求到达时提供服务。这意味着你可以在 Web 服务器上使用 Pipeline 作为推理引擎,因为你可以使用迭代器(类似于你如何 迭代数据集)来处理每个传入的请求。

但是,使用 Pipeline 设计 Web 服务器是独特的,因为它们从根本上是不同的。Web 服务器是多路复用的(多线程、异步等),以并发处理多个请求。另一方面,Pipeline 及其底层模型并非为并行性而设计,因为它们占用大量内存。最好在 Pipeline 运行时或执行计算密集型作业时为其提供所有可用资源。

本指南展示了如何通过使用 Web 服务器来处理接收和发送请求的较轻负载,并使用单线程来处理运行 Pipeline 的较重负载,从而解决这种差异。

创建服务器

Starlette 是一个用于构建 Web 服务器的轻量级框架。你可以使用任何你喜欢的其他框架,但你可能需要对下面的代码进行一些更改。

在开始之前,请确保已安装 Starlette 和 uvicorn

!pip install starlette uvicorn

现在你可以在 server.py 文件中创建一个简单的 Web 服务器。关键是仅加载模型一次,以防止不必要的模型副本消耗内存。

创建一个 pipeline 来填充掩码标记 [MASK]

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio

async def homepage(request):
    payload = await request.body()
    string = payload.decode("utf-8")
    response_q = asyncio.Queue()
    await request.app.model_queue.put((string, response_q))
    output = await response_q.get()
    return JSONResponse(output)

async def server_loop(q):
    pipeline = pipeline(task="fill-mask",model="google-bert/bert-base-uncased")
    while True:
        (string, response_q) = await q.get()
        out = pipeline(string)
        await response_q.put(out)

app = Starlette(
    routes=[
        Route("/", homepage, methods=["POST"]),
    ],
)

@app.on_event("startup")
async def startup_event():
    q = asyncio.Queue()
    app.model_queue = q
    asyncio.create_task(server_loop(q))

使用以下命令启动服务器。

uvicorn server:app

使用 POST 请求查询服务器。

curl -X POST -d "Paris is the [MASK] of France." http://localhost:8000/
[{'score': 0.9969332218170166,
  'token': 3007,
  'token_str': 'capital',
  'sequence': 'paris is the capital of france.'},
 {'score': 0.0005914849461987615,
  'token': 2540,
  'token_str': 'heart',
  'sequence': 'paris is the heart of france.'},
 {'score': 0.00043787318281829357,
  'token': 2415,
  'token_str': 'center',
  'sequence': 'paris is the center of france.'},
 {'score': 0.0003378340043127537,
  'token': 2803,
  'token_str': 'centre',
  'sequence': 'paris is the centre of france.'},
 {'score': 0.00026995912776328623,
  'token': 2103,
  'token_str': 'city',
  'sequence': 'paris is the city of france.'}]

请求排队

服务器的排队机制可以用于一些有趣的应用,例如动态批处理。动态批处理首先累积多个请求,然后再使用 Pipeline 处理它们。

下面的示例以伪代码编写,以提高可读性而不是性能,特别是,你会注意到

  1. 没有批量大小限制。

  2. 超时在每次队列获取时都会重置,因此你最终可能会等待比 timeout 值长得多的时间才处理请求。 这也会将第一个推理请求延迟该时间量。即使队列为空,Web 服务器始终等待 1 毫秒,这是效率低下的,因为这段时间可以用来启动推理。但是,如果批处理对你的用例至关重要,这可能是有意义的。

    最好只有一个 1 毫秒的截止时间,而不是在每次获取时都重置它。

(string, rq) = await q.get()
strings = []
queues = []
while True:
    try:
        (string, rq) = await asyncio.wait_for(q.get(), timeout=0.001)
    except asyncio.exceptions.TimeoutError:
        break
    strings.append(string)
    queues.append(rq)
strings
outs = pipeline(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
    await rq.put(out)

错误检查

在生产环境中,很多事情都可能出错。你可能会耗尽内存、空间不足、无法加载模型、模型配置不正确、查询不正确等等。

添加 try...except 语句有助于将这些错误返回给用户以进行调试。请记住,如果你不应该泄露某些信息,这可能存在安全风险。

熔断

当服务器过载时,尝试返回 503 或 504 错误,而不是强迫用户无限期等待。

实现这些错误类型相对简单,因为它只是一个队列。查看队列大小以确定何时开始返回错误,以避免服务器在负载下崩溃。

阻塞主线程

PyTorch 不支持异步感知,因此计算将阻塞主线程的运行。

因此,最好在自己单独的线程或进程中运行 PyTorch。当单个请求的推理时间特别长(超过 1 秒)时,这一点更为重要,因为这意味着推理期间的每个查询都必须等待 1 秒才能收到错误。

动态批处理

动态批处理在正确的设置中使用时可能非常有效,但当你一次只传递 1 个请求时,它不是必需的(有关更多详细信息,请参阅批量推理)。

< > 在 GitHub 上更新