Transformers 文档

Web服务器推理

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

Web服务器推理

Web服务器是一种等待请求并按需提供服务的系统。这意味着你可以将Pipeline用作web服务器上的推理引擎,因为你可以使用迭代器(类似于你遍历数据集的方式)来处理每个传入的请求。

然而,用Pipeline设计web服务器是独特的,因为它们本质上是不同的。Web服务器是多路复用(多线程、异步等)以并发处理多个请求的。Pipeline及其底层模型并非为并行设计,因为它们占用大量内存。最好在Pipeline运行时或进行计算密集型任务时为其提供所有可用资源。

本指南展示了如何通过使用web服务器处理接收和发送请求的较轻负载,并使用单个线程处理运行Pipeline的较重负载来解决这种差异。

创建服务器

Starlette是一个用于构建web服务器的轻量级框架。你可以使用任何其他框架,但可能需要对下面的代码进行一些更改。

在开始之前,请确保已安装Starlette和uvicorn

!pip install starlette uvicorn

现在你可以在server.py文件中创建一个简单的web服务器。关键是只加载一次模型,以防止不必要的副本占用内存。

创建一个pipeline来填充掩码标记[MASK]

from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio

async def homepage(request):
    payload = await request.body()
    string = payload.decode("utf-8")
    response_q = asyncio.Queue()
    await request.app.model_queue.put((string, response_q))
    output = await response_q.get()
    return JSONResponse(output)

async def server_loop(q):
    pipe = pipeline(task="fill-mask",model="google-bert/bert-base-uncased")
    while True:
        (string, response_q) = await q.get()
        out = pipe(string)
        await response_q.put(out)

app = Starlette(
    routes=[
        Route("/", homepage, methods=["POST"]),
    ],
)

@app.on_event("startup")
async def startup_event():
    q = asyncio.Queue()
    app.model_queue = q
    asyncio.create_task(server_loop(q))

使用以下命令启动服务器。

uvicorn server:app

使用POST请求查询服务器。

curl -X POST -d "Paris is the [MASK] of France." https://:8000/

这将返回以下输出。

[{'score': 0.9969332218170166,
  'token': 3007,
  'token_str': 'capital',
  'sequence': 'paris is the capital of france.'},
 {'score': 0.0005914849461987615,
  'token': 2540,
  'token_str': 'heart',
  'sequence': 'paris is the heart of france.'},
 {'score': 0.00043787318281829357,
  'token': 2415,
  'token_str': 'center',
  'sequence': 'paris is the center of france.'},
 {'score': 0.0003378340043127537,
  'token': 2803,
  'token_str': 'centre',
  'sequence': 'paris is the centre of france.'},
 {'score': 0.00026995912776328623,
  'token': 2103,
  'token_str': 'city',
  'sequence': 'paris is the city of france.'}]

请求排队

服务器的排队机制可用于一些有趣的应用程序,例如动态批处理。动态批处理首先积累多个请求,然后使用Pipeline进行处理。

下面的示例用伪代码编写,以便于阅读而非性能,特别是,你会注意到

  1. 没有批处理大小限制。

  2. 每次队列获取时都会重置超时,因此在处理请求之前,你最终可能会等待比timeout值长得多的时间。这也会将第一次推理请求延迟相同的时间。即使队列为空,web服务器也会始终等待1毫秒,这是低效的,因为这些时间可以用于开始推理。但如果批处理对你的用例至关重要,那可能就有意义。

    最好设置一个1毫秒的单一截止时间,而不是每次获取时都重置,如下所示。

async def server_loop(q):
    pipe = pipeline(task="fill-mask", model="google-bert/bert-base-uncased")
    while True:
        (string, rq) = await q.get()
        strings = []
        queues = []
        strings.append(string)
        queues.append(rq)
        while True:
            try:
                (string, rq) = await asyncio.wait_for(q.get(), timeout=1)
            except asyncio.exceptions.TimeoutError:
                break
            strings.append(string)
            queues.append(rq)
        outs = pipe(strings, batch_size=len(strings))
        for rq, out in zip(queues, outs):
            await rq.put(out)

错误检查

在生产环境中可能会出现许多问题。你可能会内存不足、空间不足、模型加载失败、模型配置不正确、查询不正确等等。

添加try...except语句有助于将这些错误返回给用户进行调试。请记住,如果你不应该泄露某些信息,这可能是一个安全风险。

熔断

当服务器过载时,请尝试返回503或504错误,而不是强制用户无限期等待。

由于只有一个队列,因此实现这些错误类型相对简单。查看队列大小以确定何时在服务器因负载而崩溃之前开始返回错误。

阻塞主线程

PyTorch 不具备异步感知能力,因此计算会阻塞主线程的运行。

因此,最好在自己的独立线程或进程上运行 PyTorch。当单个请求的推理时间特别长(超过1秒)时,这一点更为重要,因为这意味着推理期间的每个查询都必须等待1秒才能收到错误。

动态批处理

动态批处理在正确的设置下非常有效,但当你一次只传递1个请求时,它并不是必需的(更多详细信息请参阅批处理推理)。

< > 在 GitHub 上更新