text-generation-inference 文档
文本生成推理架构
并获得增强的文档体验
开始使用
文本生成推理架构
本文档旨在通过描述独立组件之间的调用流程来阐述文本生成推理(TGI)的架构。
以下是一个高级架构图
此图很好地展示了这些独立的组件
- 路由器,也称为
webserver
,它接收客户端请求,缓冲它们,创建批次,并准备对模型服务器的 gRPC 调用。 - 启动器是一个辅助程序,能够启动一个或多个模型服务器(如果模型是分片的),并使用兼容的参数启动路由器。
- 模型服务器,负责接收 gRPC 请求并对模型执行推理。如果模型分片到多个加速器(例如:多个 GPU),模型服务器分片可以通过 NCCL 或等效方式进行同步。
请注意,对于其他后端(例如 TRTLLM),模型服务器和启动器是特定于后端的。
路由器和模型服务器可以是两台不同的机器,它们无需部署在一起。
路由器
此组件是一个 Rust Web 服务器二进制文件,它接受使用自定义 HTTP API 以及 OpenAI 的 Messages API 的 HTTP 请求。路由器接收 API 调用并处理“批处理”逻辑(有关批处理的介绍可在此处找到)。它使用不同的策略来减少请求和响应之间的延迟,特别是针对解码延迟。它将使用队列、调度器和块分配器来实现这一点,并生成批处理请求,然后将其发送到模型服务器。
路由器的命令行
路由器的命令行将是向其传递参数的方式(它不依赖于配置文件)
Text Generation Webserver
Usage: text-generation-router [OPTIONS]
Options:
--max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
[env: MAX_CONCURRENT_REQUESTS=] [default: 128]
--max-best-of <MAX_BEST_OF>
[env: MAX_BEST_OF=] [default: 2]
--max-stop-sequences <MAX_STOP_SEQUENCES>
[env: MAX_STOP_SEQUENCES=] [default: 4]
--max-top-n-tokens <MAX_TOP_N_TOKENS>
[env: MAX_TOP_N_TOKENS=] [default: 5]
--max-input-tokens <MAX_INPUT_TOKENS>
[env: MAX_INPUT_TOKENS=] [default: 1024]
--max-total-tokens <MAX_TOTAL_TOKENS>
[env: MAX_TOTAL_TOKENS=] [default: 2048]
--waiting-served-ratio <WAITING_SERVED_RATIO>
[env: WAITING_SERVED_RATIO=] [default: 1.2]
--max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
[env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
--max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
[env: MAX_BATCH_TOTAL_TOKENS=]
--max-waiting-tokens <MAX_WAITING_TOKENS>
[env: MAX_WAITING_TOKENS=] [default: 20]
--max-batch-size <MAX_BATCH_SIZE>
[env: MAX_BATCH_SIZE=]
--hostname <HOSTNAME>
[env: HOSTNAME=] [default: 0.0.0.0]
-p, --port <PORT>
[env: PORT=] [default: 3000]
--master-shard-uds-path <MASTER_SHARD_UDS_PATH>
[env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
--tokenizer-name <TOKENIZER_NAME>
[env: TOKENIZER_NAME=] [default: bigscience/bloom]
--tokenizer-config-path <TOKENIZER_CONFIG_PATH>
[env: TOKENIZER_CONFIG_PATH=]
--revision <REVISION>
[env: REVISION=]
--validation-workers <VALIDATION_WORKERS>
[env: VALIDATION_WORKERS=] [default: 2]
--json-output
[env: JSON_OUTPUT=]
--otlp-endpoint <OTLP_ENDPOINT>
[env: OTLP_ENDPOINT=]
--otlp-service-name <OTLP_SERVICE_NAME>
[env: OTLP_SERVICE_NAME=]
--cors-allow-origin <CORS_ALLOW_ORIGIN>
[env: CORS_ALLOW_ORIGIN=]
--ngrok
[env: NGROK=]
--ngrok-authtoken <NGROK_AUTHTOKEN>
[env: NGROK_AUTHTOKEN=]
--ngrok-edge <NGROK_EDGE>
[env: NGROK_EDGE=]
--messages-api-enabled
[env: MESSAGES_API_ENABLED=]
--disable-grammar-support
[env: DISABLE_GRAMMAR_SUPPORT=]
--max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
[env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
-h, --help
Print help
-V, --version
Print version
模型服务器
模型服务器是一个 Python 服务器,能够启动一个等待 gRPC 请求的服务器,加载给定模型,执行分片以提供张量并行,并在等待新请求时保持活动状态。模型服务器支持使用 Pytorch 实例化并主要为 CUDA/ROCM 推理优化的模型。
模型服务器变体
Hugging Face 积极支持模型服务器的几种变体
- 默认情况下,模型服务器将尝试构建针对 Nvidia GPU 和 CUDA 优化的服务器。此版本的代码托管在主 TGI 仓库中。
- 一个针对 AMD 和 ROCm 优化的版本托管在主 TGI 仓库中。某些模型功能有所不同。
- 一个针对 Intel GPU 优化的版本托管在主 TGI 仓库中。某些模型功能有所不同。
- Intel Gaudi 版本在分叉仓库中维护,并经常与主TGI 仓库同步。
- Neuron (AWS Inferentia2) 版本在主 TGI 仓库中维护。某些模型功能有所不同。
- Google TPU 版本作为 Optimum TPU 的一部分进行维护。
并非所有变体都提供相同的功能,因为硬件和中间件功能不提供相同的优化。
命令行界面
服务器的官方命令行界面 (CLI) 支持三个子命令:download-weights
、quantize
和 serve
download-weights
将从 Hub 下载权重,在某些变体中,它会将权重转换为适用于给定实现的格式;quantize
将允许使用qptq
包量化模型。此功能并非在所有变体上都可用或受支持;serve
将启动服务器,该服务器加载模型(或模型分片),接收来自路由器的 gRPC 调用,执行推理并为给定请求提供格式化响应。
TGI 仓库中 Serve 的命令行参数如下:
Usage: cli.py serve [OPTIONS] MODEL_ID
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
│ * model_id TEXT [default: None] [required] │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
│ --revision TEXT [default: None] │
│ --sharded --no-sharded [default: no-sharded] │
│ --quantize [bitsandbytes|bitsandbytes [default: None] │
│ -nf4|bitsandbytes-fp4|gptq │
│ |awq|eetq|exl2|fp8] │
│ --speculate INTEGER [default: None] │
│ --dtype [float16|bfloat16] [default: None] │
│ --trust-remote-code --no-trust-remote-code [default: │
│ no-trust-remote-code] │
│ --uds-path PATH [default: │
│ /tmp/text-generation-serve… │
│ --logger-level TEXT [default: INFO] │
│ --json-output --no-json-output [default: no-json-output] │
│ --otlp-endpoint TEXT [default: None] │
│ --otlp-service-name TEXT [default: │
│ text-generation-inference...│
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
请注意,某些变体可能支持不同的参数,并且它们可能接受更多可以通过环境变量传递的选项。
调用流程
一旦这两个组件初始化完毕,权重下载完成,模型服务器启动并运行,路由器和模型服务器通过 gRPC 调用交换数据和信息。目前支持两种模式,v2 和 v3。这两个版本几乎相同,除了:
- 支持文本和图像数据的输入块,
- 支持分页注意力。
以下图表显示了路由器和模型服务器启动后的交换过程。
sequenceDiagram
Router->>Model Server: service discovery
Model Server-->>Router: urls for other shards
Router->>Model Server: get model info
Model Server-->>Router: shard info
Router->>Model Server: health check
Model Server-->>Router: health OK
Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
Model Server-->>Router: warmup result
完成这些后,路由器即可接收来自多个客户端的生成调用。以下是一个示例。
sequenceDiagram
participant Client 1
participant Client 2
participant Client 3
participant Router
participant Model Server
Client 1->>Router: generate_stream
Router->>Model Server: prefill(batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 1
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 2
Router->>Model Server: decode(cached_batch1)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 3
Client 2->>Router: generate_stream
Router->>Model Server: prefill(batch2)
Note right of Model Server: This stops previous batch, that is restarted
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 1'
Router->>Model Server: decode(cached_batch1, cached_batch2)
Model Server-->>Router: generations, cached_batch1, timings
Router-->>Client 1: token 4
Router-->>Client 2: token 2'
Note left of Client 1: Client 1 leaves
Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
Model Server-->>Router: filtered batch
Router->>Model Server: decode(cached_batch2)
Model Server-->>Router: generations, cached_batch2, timings
Router-->>Client 2: token 3'
Client 3->>Router: generate_stream
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: prefill(batch3)
Note left of Client 1: Client 3 leaves without receiving any batch
Router->>Model Server: clear_cache(batch3)
Note right of Model Server: This stops previous batch, that is restarted
Router->>Model Server: decode(cached_batch3)
Note right of Model Server: Last token (stopping criteria)
Model Server-->>Router: generations, cached_batch3, timings
Router-->>Client 2: token 4'