Transformers 文档

使用管道进行 Web 服务器

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

使用管道进行 Web 服务器推理

创建推理引擎是一个复杂的话题,而“最佳”解决方案很可能取决于您的问题领域。您是在 CPU 上还是 GPU 上?您希望获得最低延迟、最高吞吐量、对许多模型的支持,还是只高度优化 1 个特定模型?解决此主题的方法有很多,因此我们将要介绍的是一个良好的默认起点,它不一定是您最优的解决方案。

需要理解的关键点是,我们可以使用迭代器,就像您在数据集上一样,因为 Web 服务器基本上是一个等待请求并按顺序处理它们的系统。

通常,Web 服务器是多路复用的(多线程、异步等)以并发处理各种请求。另一方面,管道(以及大多数底层模型)并不真正适合并行处理;它们占用大量 RAM,因此在运行时最好为它们提供所有可用资源,或者它是一个计算密集型作业。

我们将通过让 Web 服务器处理接收和发送请求的轻负载,并让单个线程处理实际工作来解决这个问题。此示例将使用 `starlette`。实际的框架并不重要,但如果您使用的是另一个框架,则可能需要调整或更改代码以实现相同的效果。

创建 `server.py`

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio


async def homepage(request):
    payload = await request.body()
    string = payload.decode("utf-8")
    response_q = asyncio.Queue()
    await request.app.model_queue.put((string, response_q))
    output = await response_q.get()
    return JSONResponse(output)


async def server_loop(q):
    pipe = pipeline(model="google-bert/bert-base-uncased")
    while True:
        (string, response_q) = await q.get()
        out = pipe(string)
        await response_q.put(out)


app = Starlette(
    routes=[
        Route("/", homepage, methods=["POST"]),
    ],
)


@app.on_event("startup")
async def startup_event():
    q = asyncio.Queue()
    app.model_queue = q
    asyncio.create_task(server_loop(q))

现在您可以使用以下命令启动它:

uvicorn server:app

您可以查询它:

curl -X POST -d "test [MASK]" https://127.0.0.1:8000/
#[{"score":0.7742936015129089,"token":1012,"token_str":".","sequence":"test."},...]

好了,现在你对如何创建一个 Web 服务器有了一个很好的了解!

真正重要的是我们只加载模型**一次**,这样 Web 服务器上就不会有多个模型副本。这样可以避免不必要的 RAM 使用。然后,排队机制允许你做一些花哨的事情,比如在推断之前累积一些项目以使用动态批处理。

下面的代码示例有意以伪代码的形式编写,以提高可读性。在运行此代码之前,请务必检查它是否适合你的系统资源!

(string, rq) = await q.get()
strings = []
queues = []
while True:
    try:
        (string, rq) = await asyncio.wait_for(q.get(), timeout=0.001)  # 1ms
    except asyncio.exceptions.TimeoutError:
        break
    strings.append(string)
    queues.append(rq)
strings
outs = pipe(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
    await rq.put(out)

同样,建议的代码是为了可读性而优化的,而不是为了成为最佳代码。首先,它没有批大小限制,这通常不是一个好主意。其次,超时在每次队列获取时都会重置,这意味着你可能需要等待超过 1 毫秒才能运行推理(将第一个请求延迟这么多)。

最好设置一个单一的 1 毫秒截止时间。

即使队列为空,它也会始终等待 1 毫秒,这可能不是最好的,因为你可能希望在队列中没有任何内容时就开始进行推理。但如果批处理对你用例至关重要,那么它可能是有意义的。同样,真的没有一个最好的解决方案。

一些你可能需要考虑的事情

错误检查

在生产环境中,可能会发生很多错误:内存不足、空间不足、模型加载失败、查询错误、查询正确但由于模型配置错误而无法运行等等。

通常,如果服务器将错误输出给用户,那将是很好的,因此添加大量 try..except 语句来显示这些错误是一个好主意。但请记住,根据你的安全环境,公开所有这些错误也可能存在安全风险。

断路器

Web 服务器在进行断路器操作时通常看起来更好。这意味着当它们过载时,它们会返回适当的错误,而不是无限期地等待查询。返回 503 错误而不是长时间等待,或者在长时间后返回 504 错误。

在建议的代码中,这相对容易实现,因为它只有一个队列。查看队列大小是开始在 Web 服务器在负载下出现故障之前返回错误的基本方法。

阻塞主线程

目前 PyTorch 并没有异步感知能力,计算会在运行时阻塞主线程。这意味着如果 PyTorch 被强制在它自己的线程/进程中运行会更好。这里没有这样做,因为代码会变得更加复杂(主要是因为线程、异步和队列无法很好地协同工作)。但最终它实现了相同的功能。

如果单个项目的推理时间很长(> 1 秒),这将非常重要,因为在这种情况下,这意味着在推理期间的每个查询都必须等待 1 秒才能收到错误。

动态批处理

通常,批处理不一定比一次传递一个项目更好(有关更多信息,请参阅 批处理详细信息)。但在正确的设置中,它可以非常有效。在 API 中,默认情况下没有动态批处理(太多导致速度下降的机会)。但是对于 BLOOM 推理(这是一个非常大的模型),动态批处理对于为每个人提供良好的体验**至关重要**。

< > GitHub 上的更新