使用 LoRAX 在一个 GPU 上部署数百个开源模型

社区文章 发布于2024年7月18日

目录

  1. 引言
  2. 先决条件
  3. 启动服务器
  4. 在您的服务器上执行推理
  5. 创建一个简单界面使其动态化
  6. 使其真实化(某种程度上)
  7. 成本分析
  8. 结论
  9. 资源
  10. 引用

引言

什么是 LoRA?

LoRA(低秩适应)是一种通过向现有权重添加小的、可训练的秩分解矩阵来实现大型语言模型高效适应的技术。这种方法显著减少了可训练参数的数量,从而可以用最少的计算资源对模型进行特定任务的微调。

image/png

LoRAX 如何利用 LoRA?

LoRAX 是一个基于 text-generation inference (v0.9.4) 构建的生产就绪推理服务器,旨在为许多 LoRA 适配器提供一个基础模型服务。它利用 LoRA 的效率来处理具有不同 LoRA 适配器的多个用户,动态加载每个请求的相应适配器。这种方法大大提高了吞吐量和 GPU 利用率。

image/webp

此可视化展示了 LoRAX 如何处理具有不同 LoRA 适配器的多个用户,动态加载每个请求的相应适配器。这大大提高了吞吐量和 GPU 利用率。

为了进一步优化推理速度,LoRAX 整合了预填充和 KV(键值)缓存技术。预填充阶段处理初始输入 token,计算它们的注意力模式并将结果存储在 KV 缓存中。然后,这些缓存信息可以在后续推理步骤中重复使用,从而无需重新计算已见 token 的注意力。

因此,模型只需要处理新的 token,大大减少了计算负载。这种优化在为具有不同 LoRA 适配器的多个用户提供服务时特别有效,因为它允许高效处理增量请求和长序列。

为什么要使用 LoRAX?

在您可能需要几个模型来处理聊天应用程序不同方面的情况下,LoRAX 尤其有价值。也许您想在另一个模型响应之前,根据 OpenAI 的内容审核数据对传入的聊天消息进行分类,或者您想拥有几种不同的工具,例如法院案件摘要器以及将文档分类为特定类型法律文档的分类模型。也许您想要所有这些,但托管所有这些模型的成本是一个问题。如果您正在生产应用程序中提供模型服务,LoRAX 可能会为您省钱。

KV 缓存和预填充解码

使用 KV 缓存对于提高您的输出生成时间至关重要。这是一个如何使用 KV 缓存来提高推理速度的代码示例。

import matplotlib.pyplot as plt
import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "./models/gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

prompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors="pt")

def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

generated_tokens = []
next_inputs = inputs
durations_cached_s = []
for _ in range(10):
    t0 = time.time()
    next_token_id, past_key_values = \
        generate_token_with_past(next_inputs)
    durations_cached_s += [time.time() - t0]
    
    next_inputs = {
        "input_ids": next_token_id.reshape((1, 1)),
        "attention_mask": torch.cat(
            [next_inputs["attention_mask"], torch.tensor([[1]])],
            dim=1),
        "past_key_values": past_key_values,
    }
    
    next_token = tokenizer.decode(next_token_id)
    generated_tokens.append(next_token)

print(f"{sum(durations_cached_s)} s")
print(generated_tokens)

这些原则对于提高推理速度非常重要。本文末尾提供的演示包括预填充和 KV 缓存 token 计数,因此您可以跟踪每个属性的益处。

先决条件

本指南将涵盖本地免费部署门控模型的端到端过程(假设您拥有一台符合先决条件的 GPU)。如果您愿意,也可以使用云配置。

  • 使用的硬件:Nvidia 4090,i9,128GB 内存
  • 使用的软件:dockerLoRAX
pip install lorax-client transformers

启动服务器

本地

要启动 LoRAX 服务器,您有几种选择。首先,您可以在本地启动服务器,这也是我将在本指南中进行的操作。第二个选择是通过 AWS sagemaker 部署以实现生产就绪的解决方案。本节之后的所有内容适用于任何一种部署选项。

如果您计划使用门控模型或适配器,请确保细粒度访问令牌具有访问正确存储库/组织的权限。在此处查看完整的参数列表:here。将此脚本写入 launch_lorax.sh

# Define variables
MODEL="google/gemma-2b"
VOLUME="$PWD/data"
HUGGING_FACE_HUB_TOKEN="your_fine_grained_access_token"

# Export the HuggingFace token
export HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN

# Run the Docker container with the HuggingFace token
docker run --gpus all --shm-size 1g -p 8080:80 -v $VOLUME:/data \
    -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
    ghcr.io/predibase/lorax:main --model-id $MODEL \
    --max-concurrent-requests 128 \ # defaults at 128, requests that exceed this limit will fail without retry
    --max-input-length 1024 \ # this is how large the user message can be 
    --max-batch-prefill-tokens 2048 \ # defaults at 2048, this is the most important option for your memory usage

    ### optional args (model must be quantized before deployment)
    # --quantize eetq
    # --quantize hqq-2bit # 2,3, 4 available
    # --quantize awq

如果您想使用Prompt Lookup Decoding启动,这是一种简单的方法,通过将输入与之前生成的标记进行字符串匹配来查找可能的 N-gram。这在 RAG 用例中特别有用。存储库创建者提供的最小 colab 实现可在此处获得。

docker run --gpus all --shm-size 1g -p 8080:80 -v $PWD:/data \
    ghcr.io/predibase/lorax:main \
    --model-id $MODEL \
    --speculative-tokens 3

然后我们使文件可执行并启动

chmod +x launch_lorax.sh
./launch_lorax.sh

您的服务器现在应该可以在以下地址访问

http://127.0.0.1:8080

SkyPilot

如果您想部署到各种云提供商,可以使用 SkyPilot

首先安装 SkyPilot 并检查您的云凭据是否已正确设置。这将使用您所需平台的默认凭据

pip install skypilot
sky check

创建名为 lorax.yaml 的 YAML 配置文件

resources:
  cloud: aws # gcp
  accelerators: A10G:1 # {T4:2} is $0.7 for 32GB VRAM rather than $1.20 for 24GB
  memory: 32+ # system memory
  ports: 
    - 8080

envs:
  MODEL_ID: google/gemma-2b
  HUGGING_FACE_HUB_TOKEN: your_fine_grained_token

run: |
  docker run --gpus all --shm-size 1g -p 8080:80 -v /data \
      -e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
      ghcr.io/predibase/lorax:main --model-id $MODEL_ID

在上述示例中,我们请求 SkyPilot 配置一个配备 1 个 Nvidia A10G GPU 和至少 32GB RAM 的 AWS 实例。在尝试此操作之前,请确保您的服务配额已满足。更多信息可在此处找到。

让我们启动 LoRAX 任务

sky launch -c lorax-cluster lorax.yaml
Expected  output:
  I 06-27 14:19:04 optimizer.py:695] == Optimizer ==
  I 06-27 14:19:04 optimizer.py:706] Target: minimizing cost
  I 06-27 14:19:04 optimizer.py:718] Estimated cost: $1.2 / hour
  I 06-27 14:19:04 optimizer.py:718] 
  I 06-27 14:19:04 optimizer.py:843] Considered resources (1 node):
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913]  CLOUD   INSTANCE     vCPUs   Mem(GB)   ACCELERATORS   REGION/ZONE   COST ($)   CHOSEN   
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913]  AWS     g5.2xlarge   8       32        A10G:1         us-east-1     1.21          ✔     
  I 06-27 14:19:04 optimizer.py:913] -----------------------------------------------------------------------------------------
  I 06-27 14:19:04 optimizer.py:913] 
  I 06-27 14:19:04 optimizer.py:931] Multiple AWS instances satisfy A10G:1. The cheapest AWS(g5.2xlarge, {'A10G': 1}, ports=['8080']) is considered among:
  I 06-27 14:19:04 optimizer.py:931] ['g5.2xlarge', 'g5.4xlarge', 'g5.8xlarge', 'g5.16xlarge'].
  I 06-27 14:19:04 optimizer.py:931] 
  I 06-27 14:19:04 optimizer.py:937] To list more details, run 'sky show-gpus A10G'.
  Launching a new cluster 'lorax-cluster'. Proceed? [Y/n]: 

 (Y) ---> 

提示 LoRAX

在单独的窗口中,获取新创建实例的 IP 地址

sky status --ip lorax-cluster

现在我们可以像往常一样提示 LoRAX 部署

IP=$(sky status --ip lorax-cluster)

TEMPLATE = """
<|im_start|>system
You are a medical classification assistant<|im_end|>
<|im_start|>user
{medical document content}<|im_end|>
<|im_start|>assistant
"""

ADAPTER_ID="macadeliccc/gemma-2b-pubmed-classifier"
curl http://$IP:8080/generate \
    -X POST \
    -d '{"inputs": $TEMPLATE, "parameters": {"max_new_tokens": 64, "adapter_id": $ADAPTER_ID}}' \
    -H 'Content-Type: application/json'

AWS SageMaker

关于如何部署 SageMaker 所需所有组件的详细示例,请参见 此处

此部署方法与 SkyPilot 方法类似,但更为详细。本文的聊天演示与笔记本中的所有内容兼容,因此如果您使用此方法,在演示中使用您的 API URL 应该没有问题。

在您的服务器上执行推理

现在我们已经部署了容器,我们可以开始进行预测。

from lorax import Client

endpoint_url = "http://127.0.0.1:8080"

template = """
<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{ctx}<|im_end|>
<|im_start|>assistant
"""
system = """
You are a helpful assistant.
"""

query = "What is the capital of France?"
prompt = template.format(ctx=query, system=system)

client = Client(endpoint_url)

# Token Streaming
text = ""
for response in client.generate_stream(
    prompt, 
    adapter_id="macadeliccc/gemma-2b-pubmed-classifier",
    adapter_source="hub",
    api_token="your_fine_grained_access_token"

    ):
    if not response.token.special:
        text += response.token.text
print(text)

请特别注意提示模板和格式。您必须提供模型所需的模板,包括 bos 和 eos token。对于本指南,它将是 chatml。

本地推理

在本地路径中指定适配器时,adapter_id 应对应于包含以下文件的适配器根目录

root_adapter_path/
    adapter_config.json
    adapter_model.bin
    adapter_model.safetensors

用法:

text = ""
for response in client.generate_stream(
    prompt, 
    adapter_id="path/to/your/adapter/bin",
    adapter_source="local",
    ):
    if not response.token.special:
        text += response.token.text

print(text)

创建一个简单界面使其动态化

此界面与代码功能相同,只是元素设计为动态的,因此您可以根据用例热插拔适配器和源。

以下是演示代码的一个略微修改版本,它不需要任何额外的文件

pip install streamlit
import streamlit as st
from lorax import Client
import logging
from typing import Dict, Generator, Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_ENDPOINT = "http://127.0.0.1:8080"
DEFAULT_ADAPTER_SOURCE = "hub"
DEFAULT_SYSTEM_PROMPT = "You are a helpful AI assistant"
DEFAULT_MAX_TOKENS = 315

# Template options
TEMPLATE_OPTIONS = {
    "Base Model (Completion)": "{ctx}",
    "ChatML": """
        <|im_start|>system
        {system}<|im_end|>
        <|im_start|>user
        {ctx}<|im_end|>
        <|im_start|>assistant
        """,
}

def generate_response(client: Client, prompt: str, **kwargs) -> Generator[str, None, None]:
    """Generate response from the Lorax client."""
    try:
        for response in client.generate_stream(prompt, **kwargs):
            if not response.token.special:
                yield response.token.text
    except Exception as e:
        logger.error(f"Error generating response: {e}")
        yield f"An error occurred: {str(e)}"

def fetch_metrics(endpoint_url: str) -> tuple[Optional[int], Optional[int]]:
    """Fetch metrics from the Lorax endpoint."""
    try:
        # Implement metric fetching logic here
        # This is a placeholder as the original function wasn't provided
        return 100, 200  # Example values
    except Exception as e:
        logger.error(f"Error fetching metrics: {e}")
        return None, None

def setup_sidebar() -> Dict[str, any]:
    """Setup and return sidebar configuration."""
    st.sidebar.title("Lorax Chat Demo")
    st.sidebar.header("Configuration")
    
    config = {
        "endpoint_url": st.sidebar.text_input("Endpoint URL", value=DEFAULT_ENDPOINT),
        "adapter_source": st.sidebar.text_input("Adapter Source", value=DEFAULT_ADAPTER_SOURCE),
        "adapter_id": st.sidebar.text_input("Adapter ID", value=""),
        "api_token": st.sidebar.text_input("API Token", value="", type="password"),
        "system_prompt": st.sidebar.text_area("System Prompt", value=DEFAULT_SYSTEM_PROMPT, height=3),
        "max_new_tokens": st.sidebar.number_input("Max New Tokens", value=DEFAULT_MAX_TOKENS, min_value=1, max_value=1024),
        "selected_template": st.sidebar.selectbox("Select Template", list(TEMPLATE_OPTIONS.keys())),
    }
    
    with st.sidebar.expander("Advanced Settings"):
        config.update({
            "temperature": st.sidebar.slider("Temperature", 0.0, 1.0, 0.7),
            "top_p": st.sidebar.slider("Top-p", 0.0, 1.0, 0.95),
            "top_k": st.sidebar.slider("Top-k", 1, 10, 10),
            "typical_p": st.sidebar.slider("Typical-p", 0.0, 1.0, 0.95),
        })
    
    return config

def main():
    st.set_page_config(page_title="Lorax Chat Demo", page_icon="🦁", layout="wide")
    
    config = setup_sidebar()
    
    if "last_message" not in st.session_state:
        st.session_state.last_message = None

    if st.session_state.last_message:
        with st.chat_message(st.session_state.last_message["role"]):
            st.markdown(st.session_state.last_message["content"])

    if prompt := st.chat_input("What's your question?"):
        with st.chat_message("user"):
            st.markdown(prompt)

        with st.chat_message("assistant"):
            try:
                client = Client(config["endpoint_url"])
                template = TEMPLATE_OPTIONS[config["selected_template"]]
                full_prompt = template.format(ctx=prompt, system=config["system_prompt"])
                response_container = st.empty()
                full_response = ""

                kwargs = {
                    "adapter_source": config["adapter_source"],
                    "api_token": config["api_token"],
                    "max_new_tokens": config["max_new_tokens"],
                    "temperature": config["temperature"],
                    "top_k": config["top_k"],
                    "top_p": config["top_p"],
                    "typical_p": config["typical_p"],
                    "stop_sequences": ["<|im_end|>"]
                }

                if config["adapter_id"]:
                    kwargs["adapter_id"] = config["adapter_id"]

                for response_chunk in generate_response(client, full_prompt, **kwargs):
                    full_response += response_chunk
                    response_container.markdown(full_response + "▌")
                response_container.markdown(full_response)

                st.session_state.last_message = {"role": "assistant", "content": full_response}
            
            except Exception as e:
                st.error(f"An error occurred: {str(e)}")
                logger.error(f"Error in chat response generation: {e}")

    decode_success, prefill_success = fetch_metrics(config["endpoint_url"])
    if decode_success is not None and prefill_success is not None:
        metrics_info = f"""
        Inference Metrics:
        - Decode Success: {decode_success}
        - Prefill Success: {prefill_success}
        """
        st.sidebar.info(metrics_info)
    else:
        st.sidebar.warning("Unable to fetch metrics. Please check the endpoint URL.")

if __name__ == "__main__":
    main()

该演示跟踪服务器的提示查找令牌和 KV 缓存计数,因此您可以准确地知道它为您节省了多少令牌。这很有用,因为随着服务器的增长,您可以直观地看到已保存的令牌预测数量。这是 LoRAX 能够以如此高的速度和吞吐量提供模型服务的主要原因。

使其真实化(某种程度上)

既然我们有了这个推理服务器。我们需要人们在其上执行推理。

假设您以聊天机器人应用程序的形式向用户提供模型服务,您可能希望模拟用户。为此,您可以使用另一个开源项目,名为 locust

pip install locust

touch locustfile.py

使用 locustfile,您可以将 `@task` 装饰器放置在您想要监视的请求上。在此示例中,我们使用以下设置进行评估:

import time
from locust import HttpUser, task, between
from lorax import AsyncClient
import os
from dotenv import load_dotenv
from datasets import load_dataset
from itertools import islice

load_dotenv(override=True)
hf_token = os.getenv("HF_TOKEN")

dataset = load_dataset('sentence-transformers/natural-questions', split='train', streaming=True)
question_stream = (item['query'] for item in dataset)  # Adjust 'questions' to match your dataset's field name
questions = list(islice(question_stream, 1000))  # Stream 1000 questions

class ChatStressTest(HttpUser):
    host = "http://127.0.0.1:8080" # this can also be your sagemaker URL if you deployed for prodcution
    wait_time = between(1, 5)  # Random wait between tasks to simulate real user behavior

    def run(self, prompt):
        output_text = ""
        start_time = time.time()
        async_client = AsyncClient(self.host)
        async for resp in async_client.generate_stream(
            prompt, 
            adapter_id="macadeliccc/gemma-2b-pubmed-classifier",
            adapter_source="hub",
            max_new_tokens=512,
            api_token=hf_token
        ):
            if not resp.token.special:
                output_text += resp.token.text
        end_time = time.time()
        duration = end_time - start_time
        self.environment.events.request.fire(
            request_type="HTTP",
            name="/generate-stream",
            response_time=duration * 1000,  # in milliseconds
            response_length=len(output_text),
            context={},
            exception=None
        )
        return output_text, duration

    @task
    def chat_task(self):
        query = questions.pop(0) if questions else "Default question if list is empty"
         # Load the prompt from a file or just use it directly and comment this out
        with open("prompts/your_prompt.txt", 'r') as file:
            system_prompt = file.read() 
        
        result, duration = self.run(llama_obs.format(ctx=query))
        print(result)
        print(f"\nTime taken: {duration:.2f} seconds")
        print("done")

locust 代码中的提示文件期望您的提示文件采用提示模板结构。

<|im_start|>system
{system}<|im_end|>
<|im_start|>user
{ctx}<|im_end|>
<|im_start|>assistant

设置好 locustfile 后,您可以使用以下命令运行它

locust -f locustfile.py

测试服务器将在 https://:8089 上可用。在此实验中,我将用户设置为 500,每秒增加 10 个。该服务器处理了所有请求,没有失败,并且延迟在生产环境中非常可接受。这是我的运行结果 locustfile 报告。

image/png

成本分析

在此成本分析中,我们可以看到,在一年内,LoRAX 容器比将每个模型托管在自己的硬件上部署 5 个模型更具成本效益。

image/png

如果您受限于 AWS 等单一提供商,那么成本可能会高一点。您可以根据 1 个 A10G 执行相同的计算。对于大多数应用程序来说,2 个 T4 的成本更低,性能也优于 1 个 A10G。

这份成本分析由 LoRAX 提供,表示与 gpt-3.5-turbo 相比,每百万 token 的成本。

image/png

结论

一旦您完成了 locustfile 中的压力测试,您就可以将其部署到您的生产应用程序中 🤗。LoRAX 具有 CORS 设置以进一步提高安全性。这在生产环境中推荐使用,但为了让更多人能够使用本指南,此处未包含。鉴于服务器托管在云环境中,要实现 CORS,您需要一个域。

我与 LoRAX 无关

资源

  1. LoRAX 文档
  2. SkyPilot
  3. AWS SageMaker
  4. AWS 部署笔记本
  5. Locust

参考

@misc{hu2021loralowrankadaptationlarge,
      title={LoRA: Low-Rank Adaptation of Large Language Models}, 
      author={Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Lu Wang and Weizhu Chen},
      year={2021},
      eprint={2106.09685},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2106.09685}, 
}
@misc{zhao2024loraland310finetuned,
      title={LoRA Land: 310 Fine-tuned LLMs that Rival GPT-4, A Technical Report}, 
      author={Justin Zhao and Timothy Wang and Wael Abid and Geoffrey Angus and Arnav Garg and Jeffery Kinnison and Alex Sherstinsky and Piero Molino and Travis Addair and Devvret Rishi},
      year={2024},
      eprint={2405.00732},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2405.00732}, 
}

社区

注册登录 发表评论