Transformers 文档
Web服务器推理
并获得增强的文档体验
开始使用
Web服务器推理
Web服务器是一种等待请求并按需提供服务的系统。这意味着你可以将Pipeline用作web服务器上的推理引擎,因为你可以使用迭代器(类似于你遍历数据集的方式)来处理每个传入的请求。
然而,用Pipeline设计web服务器是独特的,因为它们本质上是不同的。Web服务器是多路复用(多线程、异步等)以并发处理多个请求的。Pipeline及其底层模型并非为并行设计,因为它们占用大量内存。最好在Pipeline运行时或进行计算密集型任务时为其提供所有可用资源。
本指南展示了如何通过使用web服务器处理接收和发送请求的较轻负载,并使用单个线程处理运行Pipeline的较重负载来解决这种差异。
创建服务器
Starlette是一个用于构建web服务器的轻量级框架。你可以使用任何其他框架,但可能需要对下面的代码进行一些更改。
在开始之前,请确保已安装Starlette和uvicorn。
!pip install starlette uvicorn
现在你可以在server.py
文件中创建一个简单的web服务器。关键是只加载一次模型,以防止不必要的副本占用内存。
创建一个pipeline来填充掩码标记[MASK]
。
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio
async def homepage(request):
payload = await request.body()
string = payload.decode("utf-8")
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
output = await response_q.get()
return JSONResponse(output)
async def server_loop(q):
pipe = pipeline(task="fill-mask",model="google-bert/bert-base-uncased")
while True:
(string, response_q) = await q.get()
out = pipe(string)
await response_q.put(out)
app = Starlette(
routes=[
Route("/", homepage, methods=["POST"]),
],
)
@app.on_event("startup")
async def startup_event():
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))
使用以下命令启动服务器。
uvicorn server:app
使用POST请求查询服务器。
curl -X POST -d "Paris is the [MASK] of France." https://:8000/
这将返回以下输出。
[{'score': 0.9969332218170166,
'token': 3007,
'token_str': 'capital',
'sequence': 'paris is the capital of france.'},
{'score': 0.0005914849461987615,
'token': 2540,
'token_str': 'heart',
'sequence': 'paris is the heart of france.'},
{'score': 0.00043787318281829357,
'token': 2415,
'token_str': 'center',
'sequence': 'paris is the center of france.'},
{'score': 0.0003378340043127537,
'token': 2803,
'token_str': 'centre',
'sequence': 'paris is the centre of france.'},
{'score': 0.00026995912776328623,
'token': 2103,
'token_str': 'city',
'sequence': 'paris is the city of france.'}]
请求排队
服务器的排队机制可用于一些有趣的应用程序,例如动态批处理。动态批处理首先积累多个请求,然后使用Pipeline进行处理。
下面的示例用伪代码编写,以便于阅读而非性能,特别是,你会注意到
没有批处理大小限制。
每次队列获取时都会重置超时,因此在处理请求之前,你最终可能会等待比
timeout
值长得多的时间。这也会将第一次推理请求延迟相同的时间。即使队列为空,web服务器也会始终等待1毫秒,这是低效的,因为这些时间可以用于开始推理。但如果批处理对你的用例至关重要,那可能就有意义。最好设置一个1毫秒的单一截止时间,而不是每次获取时都重置,如下所示。
async def server_loop(q):
pipe = pipeline(task="fill-mask", model="google-bert/bert-base-uncased")
while True:
(string, rq) = await q.get()
strings = []
queues = []
strings.append(string)
queues.append(rq)
while True:
try:
(string, rq) = await asyncio.wait_for(q.get(), timeout=1)
except asyncio.exceptions.TimeoutError:
break
strings.append(string)
queues.append(rq)
outs = pipe(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
await rq.put(out)
错误检查
在生产环境中可能会出现许多问题。你可能会内存不足、空间不足、模型加载失败、模型配置不正确、查询不正确等等。
添加try...except
语句有助于将这些错误返回给用户进行调试。请记住,如果你不应该泄露某些信息,这可能是一个安全风险。
熔断
当服务器过载时,请尝试返回503或504错误,而不是强制用户无限期等待。
由于只有一个队列,因此实现这些错误类型相对简单。查看队列大小以确定何时在服务器因负载而崩溃之前开始返回错误。
阻塞主线程
PyTorch 不具备异步感知能力,因此计算会阻塞主线程的运行。
因此,最好在自己的独立线程或进程上运行 PyTorch。当单个请求的推理时间特别长(超过1秒)时,这一点更为重要,因为这意味着推理期间的每个查询都必须等待1秒才能收到错误。
动态批处理
动态批处理在正确的设置下非常有效,但当你一次只传递1个请求时,它并不是必需的(更多详细信息请参阅批处理推理)。
< > 在 GitHub 上更新