开源 AI 食谱文档
如何使用推理端点嵌入文档
并获得增强的文档体验
开始使用
如何使用推理端点嵌入文档
作者:Derek Thomas
目标
我有一个想要嵌入用于语义搜索(或问答,或 RAG)的数据集,我希望以最简单的方式嵌入它并将其放入新的数据集中。
方法
我正在使用我最喜欢的 subreddit r/bestofredditorupdates 中的数据集。由于它有很长的条目,我将使用新的 jinaai/jina-embeddings-v2-base-en,因为它具有 8k 的上下文长度。我将使用 推理端点 进行部署以节省时间和金钱。要遵循本教程,您需要已经添加了付款方式。如果您还没有,可以在 账单 中添加。为了让它更容易,我将使其完全基于 API。
为了更快地完成这项工作,我将使用 文本嵌入推理 镜像。这有许多好处,例如:
- 无模型图编译步骤
- 小巧的 Docker 镜像和快速启动时间。准备好迎接真正的无服务器!
- 基于 token 的动态批处理
- 使用 Flash Attention、Candle 和 cuBLASLt 优化了 Transformers 推理代码
- Safetensors 权重加载
- 生产就绪(通过 Open Telemetry 进行分布式追踪,Prometheus 指标)
要求
!pip install -q aiohttp==3.8.3 datasets==2.14.6 pandas==1.5.3 requests==2.31.0 tqdm==4.66.1 huggingface-hub>=0.20
导入
import asyncio
from getpass import getpass
import json
from pathlib import Path
import time
from typing import Optional
from aiohttp import ClientSession, ClientTimeout
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import notebook_login, create_inference_endpoint, list_inference_endpoints, whoami
import numpy as np
import pandas as pd
import requests
from tqdm.auto import tqdm
配置
DATASET_IN
是您的文本数据所在的位置,DATASET_OUT
是您的嵌入将存储的位置。
请注意,我将 MAX_WORKERS
设置为 5,因为 jina-embeddings-v2
非常占用内存。
DATASET_IN = "derek-thomas/dataset-creator-reddit-bestofredditorupdates"
DATASET_OUT = "processed-subset-bestofredditorupdates"
ENDPOINT_NAME = "boru-jina-embeddings-demo-ie"
MAX_WORKERS = 5 # This is for how many async workers you want. Choose based on the model and hardware
ROW_COUNT = 100 # Choose None to use all rows, Im using 100 just for a demo
推理端点提供多种 GPU 供您选择。请查阅文档以获取有关 GPU 和其他加速器的信息。
您可能需要给我们发送电子邮件以获取某些架构的访问权限。
提供商 | 实例类型 | 实例大小 | 每小时费率 | GPU | 内存 | 架构 |
---|---|---|---|---|---|---|
AWS | nvidia-a10g | x1 | \$1 | 1 | 24GB | NVIDIA A10G |
AWS | nvidia-t4 | x1 | \$0.5 | 1 | 14GB | NVIDIA T4 |
AWS | nvidia-t4 | x4 | \$3 | 4 | 56GB | NVIDIA T4 |
GCP | nvidia-l4 | x1 | \$0.8 | 1 | 24GB | NVIDIA L4 |
GCP | nvidia-l4 | x4 | \$3.8 | 4 | 96GB | NVIDIA L4 |
AWS | nvidia-a100 | x1 | \$4 | 1 | 80GB | NVIDIA A100 |
AWS | nvidia-a10g | x4 | \$5 | 4 | 96GB | NVIDIA A10G |
AWS | nvidia-a100 | x2 | \$8 | 2 | 160GB | NVIDIA A100 |
AWS | nvidia-a100 | x4 | \$16 | 4 | 320GB | NVIDIA A100 |
AWS | nvidia-a100 | x8 | \$32 | 8 | 640GB | NVIDIA A100 |
GCP | nvidia-t4 | x1 | \$0.5 | 1 | 16GB | NVIDIA T4 |
GCP | nvidia-l4 | x1 | \$1 | 1 | 24GB | NVIDIA L4 |
GCP | nvidia-l4 | x4 | \$5 | 4 | 96GB | NVIDIA L4 |
GCP | nvidia-a100 | x1 | \$6 | 1 | 80 GB | NVIDIA A100 |
GCP | nvidia-a100 | x2 | \$12 | 2 | 160 GB | NVIDIA A100 |
GCP | nvidia-a100 | x4 | \$24 | 4 | 320 GB | NVIDIA A100 |
GCP | nvidia-a100 | x8 | \$48 | 8 | 640 GB | NVIDIA A100 |
GCP | nvidia-h100 | x1 | \$12.5 | 1 | 80 GB | NVIDIA H100 |
GCP | nvidia-h100 | x2 | \$25 | 2 | 160 GB | NVIDIA H100 |
GCP | nvidia-h100 | x4 | \$50 | 4 | 320 GB | NVIDIA H100 |
GCP | nvidia-h100 | x8 | \$100 | 8 | 640 GB | NVIDIA H100 |
AWS | inf2 | x1 | \$0.75 | 1 | 32GB | AWS Inferentia2 |
AWS | inf2 | x12 | \$12 | 12 | 384GB | AWS Inferentia2 |
# GPU Choice
VENDOR = "aws"
REGION = "us-east-1"
INSTANCE_SIZE = "x1"
INSTANCE_TYPE = "nvidia-a10g"
notebook_login()
一些用户可能在组织中注册了付款。这允许您使用付款方式连接到您所属的组织。
如果您想使用您的用户名,请将其留空。
>>> who = whoami()
>>> organization = getpass(
... prompt="What is your Hugging Face 🤗 username or organization? (with an added payment method)"
... )
>>> namespace = organization or who["name"]
What is your Hugging Face 🤗 username or organization? (with an added payment method) ········
获取数据集
dataset = load_dataset(DATASET_IN)
dataset["train"]
documents = dataset["train"].to_pandas().to_dict("records")[:ROW_COUNT]
len(documents), documents[0]
推理端点
创建推理端点
- 方便(无需点击)
- 可重复(我们有代码可以轻松运行它)
- 更便宜(无需等待加载,并自动关闭它)
try:
endpoint = create_inference_endpoint(
ENDPOINT_NAME,
repository="jinaai/jina-embeddings-v2-base-en",
revision="7302ac470bed880590f9344bfeee32ff8722d0e5",
task="sentence-embeddings",
framework="pytorch",
accelerator="gpu",
instance_size=INSTANCE_SIZE,
instance_type=INSTANCE_TYPE,
region=REGION,
vendor=VENDOR,
namespace=namespace,
custom_image={
"health_route": "/health",
"env": {
"MAX_BATCH_TOKENS": str(MAX_WORKERS * 2048),
"MAX_CONCURRENT_REQUESTS": "512",
"MODEL_ID": "/repository",
},
"url": "ghcr.io/huggingface/text-embeddings-inference:0.5.0",
},
type="protected",
)
except:
endpoint = [ie for ie in list_inference_endpoints(namespace=namespace) if ie.name == ENDPOINT_NAME][0]
print("Loaded endpoint")
这里有几个设计选择
- 如前所述,我们使用
jinaai/jina-embeddings-v2-base-en
作为我们的模型。- 为了可重现性,我们将其固定到特定的修订版。
- 如果您对更多模型感兴趣,请在此处查看支持列表:here。
- 请注意,大多数嵌入模型都基于 BERT 架构。
MAX_BATCH_TOKENS
的选择基于我们的 worker 数量和嵌入模型的上下文窗口。type="protected"
利用了此处详细说明的推理端点安全性。- 我使用的是 1x Nvidia A10,因为
jina-embeddings-v2
内存占用大(请记住 8k 上下文长度)。 - 如果您的工作负载较高,您应该考虑进一步调整
MAX_BATCH_TOKENS
和MAX_CONCURRENT_REQUESTS
等待它运行
>>> %%time
>>> endpoint.wait()
CPU times: user 48.1 ms, sys: 15.7 ms, total: 63.8 ms Wall time: 52.6 s
当我们使用 endpoint.client.post
时,我们会得到一个字节字符串。这有点麻烦,因为我们需要将其转换为 np.array
,但这在 Python 中只需要几行代码。
response = endpoint.client.post(
json={
"inputs": "This sound track was beautiful! It paints the senery in your mind so well I would recomend it even to people who hate vid. game music!",
"truncate": True,
},
task="feature-extraction",
)
response = np.array(json.loads(response.decode()))
response[0][:20]
您输入的文本可能超过了上下文长度。在这种情况下,如何处理这些文本取决于您。在我的例子中,我宁愿截断它们而不是报错。让我们测试一下它是否有效。
>>> embedding_input = "This input will get multiplied" * 10000
>>> print(f"The length of the embedding_input is: {len(embedding_input)}")
>>> response = endpoint.client.post(json={"inputs": embedding_input, "truncate": True}, task="feature-extraction")
>>> response = np.array(json.loads(response.decode()))
>>> response[0][:20]
The length of the embedding_input is: 300000
获取嵌入
在这里,我发送一个文档,用嵌入更新它,然后返回它。这与 MAX_WORKERS
并行进行。
async def request(document, semaphore):
# Semaphore guard
async with semaphore:
result = await endpoint.async_client.post(
json={"inputs": document["content"], "truncate": True}, task="feature-extraction"
)
result = np.array(json.loads(result.decode()))
document["embedding"] = result[0] # Assuming the API's output can be directly assigned
return document
async def main(documents):
# Semaphore to limit concurrent requests. Adjust the number as needed.
semaphore = asyncio.BoundedSemaphore(MAX_WORKERS)
# Creating a list of tasks
tasks = [request(document, semaphore) for document in documents]
# Using tqdm to show progress. It's been integrated into the async loop.
for f in tqdm(asyncio.as_completed(tasks), total=len(documents)):
await f
>>> start = time.perf_counter()
>>> # Get embeddings
>>> await main(documents)
>>> # Make sure we got it all
>>> count = 0
>>> for document in documents:
... if "embedding" in document.keys() and len(document["embedding"]) == 768:
... count += 1
>>> print(f"Embeddings = {count} documents = {len(documents)}")
>>> # Print elapsed time
>>> elapsed_time = time.perf_counter() - start
>>> minutes, seconds = divmod(elapsed_time, 60)
>>> print(f"{int(minutes)} min {seconds:.2f} sec")
Embeddings = 100 documents = 100 0 min 21.33 sec
暂停推理端点
现在我们已经完成,我们可以暂停端点,这样就不会产生额外的费用,这也能让我们分析成本。
>>> endpoint = endpoint.pause()
>>> print(f"Endpoint Status: {endpoint.status}")
Endpoint Status: paused
将更新后的数据集推送到 Hub
现在我们已经用我们想要的嵌入更新了我们的文档。首先我们需要将其转换回 Dataset
格式。我发现最简单的方法是从字典列表 -> pd.DataFrame
-> Dataset
。
df = pd.DataFrame(documents)
dd = DatasetDict({"train": Dataset.from_pandas(df)})
我默认将其上传到用户帐户(而不是上传到组织),但您可以随意通过在 repo_id
中设置用户或在配置中设置 DATASET_OUT
来推送到任何您想要的位置。
dd.push_to_hub(repo_id=DATASET_OUT)
>>> print(f'Dataset is at https://huggingface.co/datasets/{who["name"]}/{DATASET_OUT}')
Dataset is at https://huggingface.co/datasets/derek-thomas/processed-subset-bestofredditorupdates
分析使用情况
- 转到下方打印的
dashboard_url
- 点击“使用情况和费用”选项卡
- 查看您的花费
>>> dashboard_url = f"https://ui.endpoints.huggingface.co/{namespace}/endpoints/{ENDPOINT_NAME}"
>>> print(dashboard_url)
https://ui.endpoints.huggingface.co/HF-test-lab/endpoints/boru-jina-embeddings-demo-ie
>>> input("Hit enter to continue with the notebook")
Hit enter to continue with the notebook
我们可以看到这只花了 $0.04
!
删除端点
现在我们已经完成,我们不再需要我们的端点。我们可以通过编程方式删除我们的端点。
>>> endpoint = endpoint.delete()
>>> if not endpoint:
... print("Endpoint deleted successfully")
>>> else:
... print("Delete Endpoint in manually")
Endpoint deleted successfully< > 在 GitHub 上更新