借助 Hugging Face 推理端点实现强大的 ASR + 对话分离 + 推测解码

发布于 2024 年 5 月 1 日
在 GitHub 上更新

Whisper 是最优秀的开源语音识别模型之一,也无疑是使用最广泛的模型。Hugging Face 推理端点 让部署任何 Whisper 模型都变得非常简单。然而,如果你想引入额外功能,比如用于识别说话者的对话分离(diarization)流水线,或者用于推测解码的辅助生成,事情就会变得复杂。原因是,你需要将 Whisper 与其他模型结合起来,同时仍然只暴露一个 API 端点。

我们将使用自定义推理处理程序来解决这个挑战,它将在推理端点上实现自动语音识别(ASR)和对话分离流水线,并支持推测解码。对话分离流水线的实现受到了著名的 Insanely Fast Whisper 的启发,并使用 Pyannote 模型进行对话分离。

这也将展示推理端点的灵活性,以及你几乎可以在上面托管任何东西。这里是可供参考的代码。请注意,在端点初始化期间,整个代码仓库都会被挂载,所以如果不想把所有逻辑都放在一个文件中,你的 `handler.py` 可以引用仓库中的其他文件。在这种情况下,我们决定将内容分成几个文件以保持整洁。

  • handler.py 包含初始化和推理代码
  • diarization_utils.py 包含所有与对话分离相关的预处理和后处理
  • config.py 包含 ModelSettingsInferenceConfigModelSettings 定义了流水线中将使用哪些模型(不必全部使用),而 InferenceConfig 定义了默认的推理参数

Pytorch 2.2 开始,SDPA 原生支持 Flash Attention 2,所以我们将使用该版本以加快推理速度。

主要模块

这是端点内部结构的高层示意图

pipeline_schema

ASR 和对话分离流水线的实现是模块化的,以适应更广泛的用例——对话分离流水线在 ASR 输出的基础上运行,如果不需要对话分离,你可以只使用 ASR 部分。对于对话分离,我们建议使用 Pyannote 模型,这是目前最先进的开源实现。

我们还将添加推测解码作为加速推理的一种方式。加速是通过使用一个更小、更快的模型来提议生成,然后由较大的模型进行验证来实现的。要了解更多关于它如何与 Whisper 协同工作的信息,请参阅这篇精彩的博客文章

推测解码有一些限制

  • 辅助模型的解码器部分至少应与主模型的架构相同
  • 批量大小必须为 1

请务必考虑以上因素。根据你的生产用例,支持更大的批量可能比推测解码更快。如果你不想使用辅助模型,只需在配置中将 `assistant_model` 保持为 `None`。

如果你确实使用辅助模型,对于 Whisper 来说,一个不错的选择是蒸馏版本

设置您自己的端点

最简单的开始方式是使用仓库复制器克隆自定义处理程序仓库。

这是来自 `handler.py` 的模型加载部分

from pyannote.audio import Pipeline
from transformers import pipeline, AutoModelForCausalLM

...

self.asr_pipeline = pipeline(
      "automatic-speech-recognition",
      model=model_settings.asr_model,
      torch_dtype=torch_dtype,
      device=device
  )

  self.assistant_model = AutoModelForCausalLM.from_pretrained(
      model_settings.assistant_model,
      torch_dtype=torch_dtype,
      low_cpu_mem_usage=True,
      use_safetensors=True
  ) 
  
  ...

  self.diarization_pipeline = Pipeline.from_pretrained(
      checkpoint_path=model_settings.diarization_model,
      use_auth_token=model_settings.hf_token,
  ) 
  
  ...

你可以根据需求定制流水线。位于 `config.py` 文件中的 `ModelSettings` 保存了用于初始化的参数,定义了在推理过程中使用的模型

class ModelSettings(BaseSettings):
    asr_model: str
    assistant_model: Optional[str] = None
    diarization_model: Optional[str] = None
    hf_token: Optional[str] = None

这些参数可以通过传递相应名称的环境变量进行调整——这对于自定义容器和推理处理程序都适用。这是 Pydantic 的一个特性。要在构建时将环境变量传递给容器,你需要通过 API 调用(而不是通过界面)创建端点。

你可以硬编码模型名称,而不是通过环境变量传递,但*请注意,对话分离流水线需要明确传递一个令牌(`hf_token`)。*出于安全原因,你不允许硬编码你的令牌,这意味着为了使用对话分离模型,你将需要通过 API 调用来创建端点。

提醒一下,所有与对话分离相关的预处理和后处理工具都在 `diarization_utils.py` 中。

唯一必需的组件是 ASR 模型。可选地,可以指定一个辅助模型用于推测解码,以及一个对话分离模型用于按说话者划分转录文本。

在推理端点上部署

如果你只需要 ASR 部分,你可以在 `config.py` 中指定 `asr_model`/`assistant_model`,然后一键部署

deploy_oneclick

要将环境变量传递给托管在推理端点上的容器,你需要使用提供的 API 以编程方式创建一个端点。以下是一个示例调用

body = {
    "compute": {
        "accelerator": "gpu",
        "instanceSize": "medium",
        "instanceType": "g5.2xlarge",
        "scaling": {
            "maxReplica": 1,
            "minReplica": 0
        }
    },
    "model": {
        "framework": "pytorch",
        "image": {
            # a default container
            "huggingface": {
                "env": {
            # this is where a Hub model gets mounted
                    "HF_MODEL_DIR": "/repository", 
                    "DIARIZATION_MODEL": "pyannote/speaker-diarization-3.1",
                    "HF_TOKEN": "<your_token>",
                    "ASR_MODEL": "openai/whisper-large-v3",
                    "ASSISTANT_MODEL": "distil-whisper/distil-large-v3"
                }
            }
        },
        # a model repository on the Hub
        "repository": "sergeipetrov/asrdiarization-handler",
        "task": "custom"
    },
    # the endpoint name
    "name": "asr-diarization-1",
    "provider": {
        "region": "us-east-1",
        "vendor": "aws"
    },
    "type": "private"
}

何时使用辅助模型

为了更好地说明何时使用辅助模型是有益的,这里有一个使用 k6 进行的基准测试

# Setup:
# GPU: A10
ASR_MODEL=openai/whisper-large-v3
ASSISTANT_MODEL=distil-whisper/distil-large-v3

# long: 60s audio; short: 8s audio
long_assisted..................: avg=4.15s    min=3.84s    med=3.95s    max=6.88s    p(90)=4.03s    p(95)=4.89s   
long_not_assisted..............: avg=3.48s    min=3.42s    med=3.46s    max=3.71s    p(90)=3.56s    p(95)=3.61s   
short_assisted.................: avg=326.96ms min=313.01ms med=319.41ms max=960.75ms p(90)=325.55ms p(95)=326.07ms
short_not_assisted.............: avg=784.35ms min=736.55ms med=747.67ms max=2s       p(90)=772.9ms  p(95)=774.1ms

如你所见,当音频较短(批量大小为1)时,辅助生成能带来显著的性能提升。如果音频较长,推理会自动将其分块成批,由于我们之前讨论过的限制,推测解码可能会损害推理时间。

推理参数

所有的推理参数都在 `config.py` 中

class InferenceConfig(BaseModel):
    task: Literal["transcribe", "translate"] = "transcribe"
    batch_size: int = 24
    assisted: bool = False
    chunk_length_s: int = 30
    sampling_rate: int = 16000
    language: Optional[str] = None
    num_speakers: Optional[int] = None
    min_speakers: Optional[int] = None
    max_speakers: Optional[int] = None

当然,你可以根据需要添加或删除参数。与说话者数量相关的参数会传递给对话分离流水线,而其他参数主要用于 ASR 流水线。`sampling_rate` 指示待处理音频的采样率,用于预处理;`assisted` 标志告诉流水线是否使用推测解码。请记住,对于辅助生成,`batch_size` 必须设置为 1。

有效负载

部署后,将您的音频连同推理参数发送到您的推理端点,如下所示(使用 Python):

import base64
import requests

API_URL = "<your endpoint URL>"
filepath = "/path/to/audio"

with open(filepath, "rb") as f:
    audio_encoded = base64.b64encode(f.read()).decode("utf-8")

data = {
    "inputs": audio_encoded,
    "parameters": {
        "batch_size": 24
    }
}

resp = requests.post(API_URL, json=data, headers={"Authorization": "Bearer <your token>"})
print(resp.json())

这里的 **"parameters"** 字段是一个字典,包含了你想从 `InferenceConfig` 中调整的所有参数。请注意,未在 `InferenceConfig` 中指定的参数将被忽略。

或者使用 InferenceClient(还有一个异步版本

from huggingface_hub import InferenceClient

client = InferenceClient(model = "<your endpoint URL>", token="<your token>")

with open("/path/to/audio", "rb") as f:
    audio_encoded = base64.b64encode(f.read()).decode("utf-8")
data = {
    "inputs": audio_encoded,
    "parameters": {
        "batch_size": 24
    }
}

res = client.post(json=data)

回顾

在这篇博客中,我们讨论了如何使用 Hugging Face 推理端点设置一个模块化的 ASR + 对话分离 + 推测解码流水线。我们尽力使其易于根据需要配置和调整,并且使用推理端点进行部署总是轻而易举!我们很幸运能够拥有社区公开提供的优秀模型和工具,并在实现中使用了它们

有一个仓库实现了相同的流水线以及服务器部分(FastAPI+Uvicorn)。如果你想进一步定制或在其他地方托管,它可能会派上用场。

社区

你好 naveen 你好吗

你好

Tffg

注册登录 以发表评论