使用管道进行 Web 服务器推理
需要理解的关键点是,我们可以使用迭代器,就像您在数据集上一样,因为 Web 服务器基本上是一个等待请求并按顺序处理它们的系统。
通常,Web 服务器是多路复用的(多线程、异步等)以并发处理各种请求。另一方面,管道(以及大多数底层模型)并不真正适合并行处理;它们占用大量 RAM,因此在运行时最好为它们提供所有可用资源,或者它是一个计算密集型作业。
我们将通过让 Web 服务器处理接收和发送请求的轻负载,并让单个线程处理实际工作来解决这个问题。此示例将使用 `starlette`。实际的框架并不重要,但如果您使用的是另一个框架,则可能需要调整或更改代码以实现相同的效果。
创建 `server.py`
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from transformers import pipeline
import asyncio
async def homepage(request):
payload = await request.body()
string = payload.decode("utf-8")
response_q = asyncio.Queue()
await request.app.model_queue.put((string, response_q))
output = await response_q.get()
return JSONResponse(output)
async def server_loop(q):
pipe = pipeline(model="google-bert/bert-base-uncased")
while True:
(string, response_q) = await q.get()
out = pipe(string)
await response_q.put(out)
app = Starlette(
routes=[
Route("/", homepage, methods=["POST"]),
],
)
@app.on_event("startup")
async def startup_event():
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))
现在您可以使用以下命令启动它:
uvicorn server:app
您可以查询它:
curl -X POST -d "test [MASK]" https://127.0.0.1:8000/
#[{"score":0.7742936015129089,"token":1012,"token_str":".","sequence":"test."},...]
好了,现在你对如何创建一个 Web 服务器有了一个很好的了解!
真正重要的是我们只加载模型**一次**,这样 Web 服务器上就不会有多个模型副本。这样可以避免不必要的 RAM 使用。然后,排队机制允许你做一些花哨的事情,比如在推断之前累积一些项目以使用动态批处理。
下面的代码示例有意以伪代码的形式编写,以提高可读性。在运行此代码之前,请务必检查它是否适合你的系统资源!
(string, rq) = await q.get()
strings = []
queues = []
while True:
try:
(string, rq) = await asyncio.wait_for(q.get(), timeout=0.001) # 1ms
except asyncio.exceptions.TimeoutError:
break
strings.append(string)
queues.append(rq)
strings
outs = pipe(strings, batch_size=len(strings))
for rq, out in zip(queues, outs):
await rq.put(out)
同样,建议的代码是为了可读性而优化的,而不是为了成为最佳代码。首先,它没有批大小限制,这通常不是一个好主意。其次,超时在每次队列获取时都会重置,这意味着你可能需要等待超过 1 毫秒才能运行推理(将第一个请求延迟这么多)。
最好设置一个单一的 1 毫秒截止时间。
即使队列为空,它也会始终等待 1 毫秒,这可能不是最好的,因为你可能希望在队列中没有任何内容时就开始进行推理。但如果批处理对你用例至关重要,那么它可能是有意义的。同样,真的没有一个最好的解决方案。
一些你可能需要考虑的事情
错误检查
在生产环境中,可能会发生很多错误:内存不足、空间不足、模型加载失败、查询错误、查询正确但由于模型配置错误而无法运行等等。
通常,如果服务器将错误输出给用户,那将是很好的,因此添加大量 try..except
语句来显示这些错误是一个好主意。但请记住,根据你的安全环境,公开所有这些错误也可能存在安全风险。
断路器
Web 服务器在进行断路器操作时通常看起来更好。这意味着当它们过载时,它们会返回适当的错误,而不是无限期地等待查询。返回 503 错误而不是长时间等待,或者在长时间后返回 504 错误。
在建议的代码中,这相对容易实现,因为它只有一个队列。查看队列大小是开始在 Web 服务器在负载下出现故障之前返回错误的基本方法。
阻塞主线程
目前 PyTorch 并没有异步感知能力,计算会在运行时阻塞主线程。这意味着如果 PyTorch 被强制在它自己的线程/进程中运行会更好。这里没有这样做,因为代码会变得更加复杂(主要是因为线程、异步和队列无法很好地协同工作)。但最终它实现了相同的功能。
如果单个项目的推理时间很长(> 1 秒),这将非常重要,因为在这种情况下,这意味着在推理期间的每个查询都必须等待 1 秒才能收到错误。
动态批处理
通常,批处理不一定比一次传递一个项目更好(有关更多信息,请参阅 批处理详细信息)。但在正确的设置中,它可以非常有效。在 API 中,默认情况下没有动态批处理(太多导致速度下降的机会)。但是对于 BLOOM 推理(这是一个非常大的模型),动态批处理对于为每个人提供良好的体验**至关重要**。
< > GitHub 上的更新