text-generation-inference 文档

文本生成推理架构

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始

文本生成推理架构

本文档旨在描述文本生成推理 (TGI) 的架构,通过描述各个组件之间的调用流程。

高级架构图可以在这里看到。

TGI architecture

此图清楚地显示了这些独立的组件。

  • 路由器,也称为 webserver,它接收客户端请求、缓冲它们、创建一些批次,并准备 gRPC 调用到模型服务器。
  • 启动器是一个辅助工具,能够启动一个或多个模型服务器(如果模型是分片的),并且它使用兼容的参数启动路由器。
  • 模型服务器,负责接收 gRPC 请求并在模型上处理推理。如果模型跨多个加速器(例如:多个 GPU)分片,则模型服务器分片可以通过 NCCL 或同等方式同步。

请注意,对于其他后端(例如 TRTLLM),模型服务器和启动器特定于后端。

路由器和模型服务器可以是两台不同的机器,它们不需要一起部署。

路由器

此组件是一个 rust Web 服务器二进制文件,它使用自定义 HTTP API 以及 OpenAI 的 Messages API 接受 HTTP 请求。路由器接收 API 调用并处理“批次”逻辑(批处理的介绍可以在这里找到)。它使用不同的策略来减少请求和响应之间的延迟,尤其是面向解码延迟。它将使用队列、调度器和块分配器来实现这一点,并生成批处理请求,然后将其发送到模型服务器。

路由器的命令行

路由器命令行将是向其传递参数的方式(它不依赖于配置文件)。

Text Generation Webserver

Usage: text-generation-router [OPTIONS]

Options:
      --max-concurrent-requests <MAX_CONCURRENT_REQUESTS>
          [env: MAX_CONCURRENT_REQUESTS=] [default: 128]
      --max-best-of <MAX_BEST_OF>
          [env: MAX_BEST_OF=] [default: 2]
      --max-stop-sequences <MAX_STOP_SEQUENCES>
          [env: MAX_STOP_SEQUENCES=] [default: 4]
      --max-top-n-tokens <MAX_TOP_N_TOKENS>
          [env: MAX_TOP_N_TOKENS=] [default: 5]
      --max-input-tokens <MAX_INPUT_TOKENS>
          [env: MAX_INPUT_TOKENS=] [default: 1024]
      --max-total-tokens <MAX_TOTAL_TOKENS>
          [env: MAX_TOTAL_TOKENS=] [default: 2048]
      --waiting-served-ratio <WAITING_SERVED_RATIO>
          [env: WAITING_SERVED_RATIO=] [default: 1.2]
      --max-batch-prefill-tokens <MAX_BATCH_PREFILL_TOKENS>
          [env: MAX_BATCH_PREFILL_TOKENS=] [default: 4096]
      --max-batch-total-tokens <MAX_BATCH_TOTAL_TOKENS>
          [env: MAX_BATCH_TOTAL_TOKENS=]
      --max-waiting-tokens <MAX_WAITING_TOKENS>
          [env: MAX_WAITING_TOKENS=] [default: 20]
      --max-batch-size <MAX_BATCH_SIZE>
          [env: MAX_BATCH_SIZE=]
      --hostname <HOSTNAME>
          [env: HOSTNAME=] [default: 0.0.0.0]
  -p, --port <PORT>
          [env: PORT=] [default: 3000]
      --master-shard-uds-path <MASTER_SHARD_UDS_PATH>
          [env: MASTER_SHARD_UDS_PATH=] [default: /tmp/text-generation-server-0]
      --tokenizer-name <TOKENIZER_NAME>
          [env: TOKENIZER_NAME=] [default: bigscience/bloom]
      --tokenizer-config-path <TOKENIZER_CONFIG_PATH>
          [env: TOKENIZER_CONFIG_PATH=]
      --revision <REVISION>
          [env: REVISION=]
      --validation-workers <VALIDATION_WORKERS>
          [env: VALIDATION_WORKERS=] [default: 2]
      --json-output
          [env: JSON_OUTPUT=]
      --otlp-endpoint <OTLP_ENDPOINT>
          [env: OTLP_ENDPOINT=]
      --otlp-service-name <OTLP_SERVICE_NAME>
          [env: OTLP_SERVICE_NAME=]
      --cors-allow-origin <CORS_ALLOW_ORIGIN>
          [env: CORS_ALLOW_ORIGIN=]
      --ngrok
          [env: NGROK=]
      --ngrok-authtoken <NGROK_AUTHTOKEN>
          [env: NGROK_AUTHTOKEN=]
      --ngrok-edge <NGROK_EDGE>
          [env: NGROK_EDGE=]
      --messages-api-enabled
          [env: MESSAGES_API_ENABLED=]
      --disable-grammar-support
          [env: DISABLE_GRAMMAR_SUPPORT=]
      --max-client-batch-size <MAX_CLIENT_BATCH_SIZE>
          [env: MAX_CLIENT_BATCH_SIZE=] [default: 4]
  -h, --help
          Print help
  -V, --version
          Print version

模型服务器

模型服务器是一个 python 服务器,能够启动一个服务器等待 gRPC 请求、加载给定模型、执行分片以提供张量并行,并在等待新请求时保持活动状态。模型服务器支持使用 Pytorch 实例化的模型,并主要在 CUDA/ROCM 上针对推理进行了优化。

模型服务器变体

存在模型服务器的几种变体,Hugging Face 积极支持这些变体。

  • 默认情况下,模型服务器将尝试构建一个针对带有 CUDA 的 Nvidia GPU 优化的服务器。此版本的代码托管在主 TGI 存储库中。
  • 针对带有 ROCm 的 AMD 优化的版本托管在主 TGI 存储库中。某些模型功能有所不同。
  • 针对 Intel GPU 优化的版本托管在主 TGI 存储库中。某些模型功能有所不同。
  • Intel Gaudi 的版本维护在一个分支存储库上,通常与主 TGI 存储库重新同步。
  • Neuron (AWS Inferentia2) 的版本维护在主 TGI 存储库中。某些模型功能有所不同。
  • Google TPU 的版本作为 Optimum TPU 的一部分进行维护。

并非所有变体都提供相同的功能,因为硬件和中间件功能不提供相同的优化。

命令行界面

服务器的官方命令行界面 (CLI) 支持三个子命令:download-weightsquantizeserve

  • download-weights 将从 hub 下载权重,在某些变体中,它会将权重转换为适应给定实现的格式;
  • quantize 将允许使用 qptq 包量化模型。此功能并非在所有变体上都可用或受支持;
  • serve 将启动服务器,该服务器加载模型(或模型分片)、接收来自路由器的 gRPC 调用、执行推理并为给定请求提供格式化的响应。

TGI 存储库上 Serve 的命令行参数如下:

 Usage: cli.py serve [OPTIONS] MODEL_ID

╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────╮
│ *    model_id      TEXT  [default: None] [required]                                                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
│ --revision                                       TEXT                        [default: None]             │
│ --sharded              --no-sharded                                          [default: no-sharded]       │
│ --quantize                                       [bitsandbytes|bitsandbytes  [default: None]             │
│                                                  -nf4|bitsandbytes-fp4|gptq                              │
│                                                  |awq|eetq|exl2|fp8]                                     │
│ --speculate                                      INTEGER                     [default: None]             │
│ --dtype                                          [float16|bfloat16]          [default: None]             │
│ --trust-remote-code    --no-trust-remote-code                                [default:                   │
│                                                                              no-trust-remote-code]       │
│ --uds-path                                       PATH                        [default:                   │
│                                                                              /tmp/text-generation-serve… │
│ --logger-level                                   TEXT                        [default: INFO]             │
│ --json-output          --no-json-output                                      [default: no-json-output]   │
│ --otlp-endpoint                                  TEXT                        [default: None]             │
│ --otlp-service-name                              TEXT                        [default:                   │
│                                                                              text-generation-inference...│
│ --help                                                                       Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯

请注意,某些变体可能支持不同的参数,并且它们可能会接受更多可以使用环境变量传递的选项。

调用流程

一旦两个组件都初始化、权重下载并且模型服务器已启动并运行,路由器和模型服务器就会通过 gRPC 调用交换数据和信息。目前支持两种模式 v2 和 v3。这两个版本几乎相同,除了:

  • 输入块支持,用于文本和图像数据,
  • 分页注意力支持

这是一个图表,显示了路由器和模型服务器启动后的交换。

sequenceDiagram

    Router->>Model Server: service discovery
    Model Server-->>Router: urls for other shards

    Router->>Model Server: get model info
    Model Server-->>Router: shard info

    Router->>Model Server: health check
    Model Server-->>Router: health OK

    Router->>Model Server: warmup(max_input_tokens, max_batch_prefill_tokens, max_total_tokens, max_batch_size)
    Model Server-->>Router: warmup result

完成这些后,路由器已准备好接收来自多个客户端的生成调用。这是一个例子。

sequenceDiagram
    participant Client 1
    participant Client 2
    participant Client 3
    participant Router
    participant Model Server

    Client 1->>Router: generate_stream
    Router->>Model Server: prefill(batch1)
    Model Server-->>Router: generations, cached_batch1, timings
    Router-->>Client 1: token 1

    Router->>Model Server: decode(cached_batch1)
    Model Server-->>Router: generations, cached_batch1, timings
    Router-->>Client 1: token 2

    Router->>Model Server: decode(cached_batch1)
    Model Server-->>Router: generations, cached_batch1, timings
    Router-->>Client 1: token 3

    Client 2->>Router: generate_stream
    Router->>Model Server: prefill(batch2)
    Note right of Model Server: This stops previous batch, that is restarted
    Model Server-->>Router: generations, cached_batch2, timings
    Router-->>Client 2: token 1'

    Router->>Model Server: decode(cached_batch1, cached_batch2)
    Model Server-->>Router: generations, cached_batch1, timings
    Router-->>Client 1: token 4
    Router-->>Client 2: token 2'

    Note left of Client 1: Client 1 leaves
    Router->>Model Server: filter_batch(cached_batch1, request_ids_to_keep=batch2)
    Model Server-->>Router: filtered batch

    Router->>Model Server: decode(cached_batch2)
    Model Server-->>Router: generations, cached_batch2, timings
    Router-->>Client 2: token 3'

    Client 3->>Router: generate_stream
    Note right of Model Server: This stops previous batch, that is restarted
    Router->>Model Server: prefill(batch3)
    Note left of Client 1: Client 3 leaves without receiving any batch
    Router->>Model Server: clear_cache(batch3)
    Note right of Model Server: This stops previous batch, that is restarted

    Router->>Model Server: decode(cached_batch3)
    Note right of Model Server: Last token (stopping criteria)
    Model Server-->>Router: generations, cached_batch3, timings
    Router-->>Client 2: token 4'

< > 更新 在 GitHub 上