文本生成推理架构
本文档旨在通过描述各个组件之间的调用流程来描述文本生成推理 (TGI) 的架构。
此处可以查看高级架构图
此图很好地展示了这些单独的组件
- 路由器,也称为
webserver
,它接收客户端请求,缓冲它们,创建一些批次,并准备对模型服务器的 gRPC 调用。 - 模型服务器,负责接收 gRPC 请求并在模型上处理推理。如果模型跨多个加速器(例如:多个 GPU)分片,则模型服务器分片可能会通过 NCCL 或等效方式同步。
- 启动器是一个助手,能够启动一个或多个模型服务器(如果模型已分片),并且它会使用兼容的参数启动路由器。
路由器和模型服务器可以是两台不同的机器,它们不需要一起部署。
路由器
此组件是一个 Rust Web 服务器二进制文件,它使用自定义HTTP API以及 OpenAI 的消息 API接受 HTTP 请求。路由器接收 API 调用并处理“批次”逻辑(可以在此处找到关于批处理的介绍)。它使用不同的策略来减少请求和响应之间的延迟,特别是面向解码延迟。它将使用队列、调度程序和块分配器来实现此目标,并生成批量请求,然后将其发送到模型服务器。
路由器的命令行
路由器命令行将是向其传递参数的方式(它不依赖于配置文件)
Text Generation Webserver
Usage: text-generation-router [OPTIONS]
Options:
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
[env: MAX_CONCURRENT_REQUESTS=] [default: 128]
--max-best-of <MAX_BEST_OF>
[env: MAX_BEST_OF=] [default: 2]
--max-stop-sequences <MAX_STOP_SEQUENCES>
[env: MAX_STOP_SEQUENCES=] [default: 4]
--max-top-n-tokens <MAX_TOP_N_TOKENS>
[env: MAX_TOP_N_TOKENS=] [default: 5]
--max-input-tokens <MAX_INPUT_TOKENS>
[env: MAX_INPUT_TOKENS=] [default: 1024]
--max-total-tokens <MAX_TOTAL_TOKENS>
[env: MAX_TOTAL_TOKENS=] [default: 2048]
--waiting-served-ratio <WAITING_SERVED_RATIO>
[env: WAITING_SERVED_RATIO=] [default: 1.2]
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
[env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
--max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
[env: MAX_BATCH_TOTAL_TOKENS=]
--max-waiting-tokens <MAX_WAITING_TOKENS>
[env: MAX_WAITING_TOKENS=] [default: 20]
--max-batch-size <MAX_BATCH_SIZE>
[env: MAX_BATCH_SIZE=]
--hostname <HOSTNAME>
[env: HOSTNAME=] [default: 0.0.0.0]
-p, --port <PORT>
[env: PORT=] [default: 3000]
--master-shard-uds-path <MASTER_SHARD_UDS_PATH>
[env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
--tokenizer-name <TOKENIZER_NAME>
[env: TOKENIZER_NAME=] [default: bigscience/bloom]
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
[env: TOKENIZER_CONFIG_PATH=]
--revision <REVISION>
[env: REVISION=]
--validation-workers <VALIDATION_WORKERS>
[env: VALIDATION_WORKERS=] [default: 2]
--json-output
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=]
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=]
--ngrok
[env: NGROK=]
--ngrok-authtoken <NGROK_AUTHTOKEN>
[env: NGROK_AUTHTOKEN=]
--ngrok-edge <NGROK_EDGE>
[env: NGROK_EDGE=]
--messages-api-enabled
[env: MESSAGES_API_ENABLED=]
--disable-grammar-support
[env: DISABLE_GRAMMAR_SUPPORT=]
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
[env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
-h, --help
Print help
-V, --version
Print version
模型服务器
模型服务器是一个 Python 服务器,能够启动一个等待 gRPC 请求的服务器,加载给定的模型,执行分片以提供张量并行,并在等待新请求时保持活动状态。模型服务器支持使用 Pytorch 实例化并在 CUDA/ROCM 上主要针对推理优化的模型。
模型服务器变体
Hugging Face 积极支持模型服务器的几个变体。
- 默认情况下,模型服务器将尝试构建针对 Nvidia GPU 和 CUDA 优化的服务器。此版本的代码托管在主 TGI 存储库中。
- 针对 AMD 和 ROCm 优化的版本托管在主 TGI 存储库中。一些模型功能有所不同。
- 针对英特尔 GPU 优化的版本托管在主 TGI 存储库中。一些模型功能有所不同。
- 针对英特尔 Gaudi 的版本在分叉的存储库中维护,通常与主TGI 存储库同步。
- 针对 Neuron(AWS Inferentia2)的版本作为Optimum Neuron的一部分进行维护。
- 针对 Google TPU 的版本作为Optimum TPU的一部分进行维护。
并非所有变体都提供相同的功能,因为硬件和中间件功能无法提供相同的优化。
命令行界面
服务器的官方命令行界面 (CLI) 支持三个子命令:download-weights
、quantize
和 serve
。
download-weights
将从 Hub 下载权重,并且在某些变体中,它会将权重转换为适合给定实现的格式;quantize
将允许使用qptq
包量化模型。此功能并非在所有变体中都可用或受支持;serve
将启动服务器,该服务器加载模型(或模型分片),接收来自路由器的 gRPC 调用,执行推理并向给定请求提供格式化的响应。
TGI 存储库上的 Serve 命令行参数如下所示
Usage: cli.py serve [OPTIONS] MODEL_ID
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
│ * model_id TEXT [default: None] [required] │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
│ --revision TEXT [default: None] │
│ --sharded --no-sharded [default: no-sharded] │
│ --quantize [bitsandbytes|bitsandbytes [default: None] │
│ -nf4|bitsandbytes-fp4|gptq │
│ |awq|eetq|exl2|fp8] │
│ --speculate INTEGER [default: None] │
│ --dtype [float16|bfloat16] [default: None] │
│ --trust-remote-code --no-trust-remote-code [default: │
│ no-trust-remote-code] │
│ --uds-path PATH [default: │
│ /tmp/text-generation-serve… │
│ --logger-level TEXT [default: INFO] │
│ --json-output --no-json-output [default: no-json-output] │
│ --otlp-endpoint TEXT [default: None] │
│ --otlp-service-name TEXT [default: │
│ text-generation-inference...│
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
请注意,某些变体可能支持不同的参数,并且它们可能接受更多可以使用环境变量传递的选项。
调用流程
一旦两个组件都初始化、权重下载且模型服务器正在运行,路由器和模型服务器就会通过 gRPC 调用交换数据和信息。目前支持两种架构,v2 和v3。这两个版本几乎相同,除了
- 文本和图像数据的输入分块支持,
- 分页注意力支持。
以下图表显示了路由器和模型服务器启动后发生的交换。
sequenceDiagram
Router->>Model Server: service discovery
Model Server-->>Router: urls for other shards
Router->>Model Server: get model info
Model Server-->>Router: shard info
Router->>Model Server: health check
Model Server-->>Router: health OK
Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
Model Server-->>Router: warmup result
完成这些操作后,路由器即可准备接收来自多个客户端的生成调用。以下是一个示例。
sequenceDiagram
participant Client 1
participant Client 2
participant Client 3
participant Router
participant Model Server
Client 1->>Router: generate_stream
Router->>Model Server: prefill(batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 1
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 2
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 3
Client 2->>Router: generate_stream
Router->>Model Server: prefill(batch2)
Note right of Model Server: This stops previous batch, that is restarted
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 1'
Router->>Model Server: decode(cached_batch1, cached_batch2)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 4
Router-->>Client 2: token 2'
Note left of Client 1: Client 1 leaves
Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
Model Server-->>Router: filtered batch
Router->>Model Server: decode(cached_batch2)
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 3'
Client 3->>Router: generate_stream
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: prefill(batch3)
Note left of Client 1: Client 3 leaves without receiving any batch
Router->>Model Server: clear_cache(batch3)
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: decode(cached_batch3)
Note right of Model Server: Last token (stopping criteria)
Model Server-->>Router: generations, cached_batch3, timings
Router-->>Client 2: token 4'