使用 Huggingface Transformers 和 Ray 实现检索增强生成

发布于 2021 年 2 月 10 日
在 GitHub 上更新
来自 Anyscale 团队的 Amog Kamsetty 的客座博文

Huggingface Transformers 最近添加了 检索增强生成 (RAG) 模型,这是一种新的自然语言处理架构,它利用外部文档 (如维基百科) 来增强其知识,并在知识密集型任务上取得了最先进的结果。在这篇博文中,我们介绍了将用于构建可扩展应用程序的库 Ray 集成到 RAG 上下文文档检索机制中。这将检索调用的速度提高了 2 倍,并提高了 RAG 分布式微调的可扩展性。

什么是检索增强生成 (RAG)?

alt_text

RAG 概述。该模型在其执行过程中从外部数据集中检索上下文文档。这些上下文文档与原始输入结合使用以产生输出。该 GIF 取自 Facebook 的原始博文

最近,HuggingfaceFacebook AI 合作,在其 Transformers 库中引入了 RAG 模型。

RAG 的作用与任何其他 seq2seq 模型一样。但是,RAG 有一个中间组件,可以从外部知识库 (如维基百科文本语料库) 检索上下文文档。然后将这些文档与输入序列结合使用,并传递给底层的 seq2seq 生成器

此信息检索步骤允许 RAG 利用多种知识来源 —— 嵌入在模型参数中的知识以及包含在上下文段落中的信息,使其能够在问答等任务中超越其他最先进的模型。您可以使用 Huggingface 提供的此演示亲自尝试!

扩展微调

这种上下文文档的检索对于 RAG 取得最先进的结果至关重要,但也引入了额外的复杂性。当通过数据并行训练例程扩展训练过程时,文档查找的简单实现可能会成为训练的瓶颈。此外,检索组件中使用的**文档索引**通常非常大,这使得每个训练工作进程加载自己的索引副本变得不可行。

之前 RAG 微调的实现利用了 torch.distributed 通信包来进行文档检索。然而,这种实现在灵活性和可扩展性方面有时会受到限制。

因此,需要一种与框架无关且更灵活的实现方式来进行临时的并发编程。Ray 完美地满足了这一要求。Ray 是一个简单而强大的 Python 库,用于通用的分布式和并行编程。使用 Ray 进行分布式文档检索,我们实现了**每次检索调用比 `torch.distributed` 快 2 倍的速度**,并获得了更好的整体微调可扩展性。

使用 Ray 进行文档检索

alt_text 使用 torch.distributed 实现的文档检索

torch.distributed 实现文档检索的主要缺点是它依赖于用于训练的同一个进程组,并且只有 rank 0 的训练工作进程将索引加载到内存中。

因此,这种实现有一些局限性

  1. 同步瓶颈:rank 0 的工作进程必须接收所有工作进程的输入,执行索引查询,然后将结果发送回其他工作进程。这限制了多个训练工作进程的性能。
  2. PyTorch 特定:文档检索进程组必须依赖于用于训练的现有进程组,这意味着训练也必须使用 PyTorch。

alt_text 使用 Ray 实现的文档检索

为了克服这些限制,我们引入了一种基于 Ray 的新颖的分布式检索实现。通过 Ray 的有状态 actor 抽象,使用与训练进程分离的多个进程来加载索引并处理检索查询。有了多个 Ray actor,检索不再是瓶颈,PyTorch 也不再是 RAG 的必需品。

如下所示,使用基于 Ray 的实现可以为多 GPU 微调带来更好的检索性能。以下结果显示了每次检索调用的秒数,我们可以看到,随着我们增加用于训练的 GPU 数量,使用 Ray 的性能比 `torch.distributed` 更好。此外,如果我们增加执行检索的 Ray 进程数量,我们也能在有更多训练工作进程的情况下获得更好的性能,因为单个检索进程不再是瓶颈。

2 个 GPU 3 个 GPU 4 个 GPU
torch.distributed 2.12 秒/检索 2.62 秒/检索 3.438 秒/检索
Ray 2 个检索进程 1.49 秒/检索 1.539 秒/检索 2.029 秒/检索
Ray 4 个检索进程 1.145 秒/检索 1.484 秒/检索 1.66 秒/检索

不同检索实现的性能比较。对于每个文档检索实现,我们使用每个 GPU 批处理大小为 8 运行 500 个训练步骤,并测量 rank 0 训练工作进程上为每个批次检索上下文文档所需的时间。结果表明,使用多个检索进程可以提高性能,尤其是在我们将训练扩展到多个 GPU 时。

如何使用它?

Huggingface 提供了一个基于 PyTorch Lightning微调脚本,我们对其进行了扩展,将 Ray 检索实现作为一个选项添加了进去。

要试用它,请先安装必要的依赖项

pip install ray
pip install transformers
pip install -r transformers/examples/research_projects/rag/requirements.txt

然后,您可以指定数据路径和其他配置,并运行 finetune-rag-ray.sh

# Sample script to finetune RAG using Ray for distributed retrieval.

# Add parent directory to python path to access lightning_base.py
export PYTHONPATH="../":"${PYTHONPATH}"

# Start a single-node Ray cluster.
ray start --head

# A sample finetuning run, you need to specify data_dir, output_dir and model_name_or_path
# run ./examples/rag/finetune_rag_ray.sh --help to see all the possible options

python examples/rag/finetune_rag.py \
    --data_dir $DATA_DIR \
    --output_dir $OUTPUT_DIR \
    --model_name_or_path $MODEL_NAME_OR_PATH \
    --model_type rag_sequence \
    --fp16 \
    --gpus 8 \
    --profile \
    --do_train \
    --do_predict \
    --n_val -1 \
    --train_batch_size 8 \
    --eval_batch_size 1 \
    --max_source_length 128 \
    --max_target_length 25 \
    --val_max_target_length 25 \
    --test_max_target_length 25 \
    --label_smoothing 0.1 \
    --dropout 0.1 \
    --attention_dropout 0.1 \
    --weight_decay 0.001 \
    --adam_epsilon 1e-08 \
    --max_grad_norm 0.1 \
    --lr_scheduler polynomial \
    --learning_rate 3e-05 \
    --num_train_epochs 100 \
    --warmup_steps 500 \
    --gradient_accumulation_steps 1 \
    --distributed_retriever ray \
    --num_retrieval_workers 4

# Stop the Ray cluster.
ray stop

接下来是什么?

使用 Huggingface transformers 中的 RAG 和 Ray 检索实现进行更快的分布式微调,您可以在自己的知识密集型任务上利用 RAG 进行基于检索的生成。

此外,超参数调整是 transformer 微调的另一个方面,并且可以对准确性产生巨大影响。对于可扩展且简便的超参数调整,请查看 Ray Tune 库。通过使用 Ray Tune 与 PyTorch Lightning 的集成,或与 Huggingface transformers 的内置集成,您可以运行实验来为您的 RAG 模型找到完美的超参数。

最后,请继续关注 Huggingface 上可能推出的 RAG 的 Tensorflow 实现!

如果您计划试用 RAG+Ray 集成,请随时在 Ray Discourse 上分享您的经验,或加入 Ray 社区 Slack 进行进一步讨论——我们很乐意听到您的声音!

也发布于 https://medium.com/distributed-computing-with-ray/retrieval-augmented-generation-with-huggingface-transformers-and-ray-b09b56161b1e

社区

注册登录 发表评论