开源 AI 食谱文档
使用向量嵌入和 Qdrant 进行代码搜索
并获得增强的文档体验
开始使用
使用向量嵌入和 Qdrant 进行代码搜索
作者:Qdrant 团队
在本 notebook 中,我们将演示如何使用向量嵌入来浏览代码库,并找到相关的代码片段。我们将使用自然的语义查询来搜索代码库,并基于相似逻辑搜索代码。
您可以查看此方法的在线部署,它通过一个 Web 界面开放了 Qdrant 代码库的搜索功能。
方法
我们需要两个模型来实现我们的目标。
用于自然语言处理 (NLP) 的通用神经编码器,在我们的例子中是 sentence-transformers/all-MiniLM-L6-v2。我们称之为 NLP 模型。
用于代码到代码相似性搜索的专门嵌入。我们将使用 jinaai/jina-embeddings-v2-base-code 模型来完成此任务。它支持英语和 30 种广泛使用的编程语言,序列长度为 8192。我们称之为代码模型。
为了让我们的代码适用于 NLP 模型,我们需要将代码预处理成一种与自然语言非常相似的格式。代码模型支持多种标准编程语言,因此无需对代码片段进行预处理。我们可以直接使用代码。
安装依赖
让我们安装我们将要使用的包。
- inflection - 一个字符串转换库。它可以将英文单词进行单复数转换,并将驼峰式命名转换为下划线命名。
- fastembed - 一个 CPU 优先、轻量级的向量嵌入生成库。支持 GPU。
- qdrant-client - 与 Qdrant 服务器交互的官方 Python 库。
%pip install inflection qdrant-client fastembed
数据准备
将应用程序源码分块成更小的部分是一项不小的任务。一般来说,函数、类方法、结构体、枚举以及所有其他特定于语言的构造都是分块的理想选择。它们足够大,可以包含一些有意义的信息,但又足够小,可以被上下文窗口有限的嵌入模型处理。您还可以使用文档字符串、注释和其他元数据来丰富分块的附加信息。

基于文本的搜索是基于函数签名的,但代码搜索可能会返回更小的片段,例如循环。因此,如果我们从 NLP 模型收到一个特定的函数签名,并从代码模型收到其实现的一部分,我们会将结果合并。
解析代码库
本演示将使用 Qdrant 代码库。虽然此代码库使用 Rust,但您也可以将此方法用于任何其他语言。您可以使用语言服务器协议 (LSP) 工具来构建代码库的图,然后提取分块。我们使用了 rust-analyzer 来完成这项工作。我们将解析后的代码库导出为 LSIF 格式,这是一种代码智能数据的标准。接下来,我们使用 LSIF 数据来浏览代码库并提取分块。
您也可以对其他语言使用相同的方法。有大量的实现可供选择。
然后,我们将分块导出为 JSON 文档,其中不仅包含代码本身,还包含代码在项目中的位置上下文。
您可以在我们的 Google Cloud Storage 存储桶中的 structures.jsonl 文件中查看解析为 JSON 的 Qdrant 结构。下载它并将其用作我们代码搜索的数据源。
!wget https://storage.googleapis.com/tutorial-attachments/code-search/structures.jsonl
接下来,加载文件并将各行解析为一个字典列表。
import json
structures = []
with open("structures.jsonl", "r") as fp:
for i, row in enumerate(fp):
entry = json.loads(row)
structures.append(entry)
让我们看看其中一个条目的样子。
structures[0]
{'name': 'InvertedIndexRam',
'signature': '# [doc = " Inverted flatten index from dimension id to posting list"] # [derive (Debug , Clone , PartialEq)] pub struct InvertedIndexRam { # [doc = " Posting lists for each dimension flattened (dimension id -> posting list)"] # [doc = " Gaps are filled with empty posting lists"] pub postings : Vec < PostingList > , # [doc = " Number of unique indexed vectors"] # [doc = " pre-computed on build and upsert to avoid having to traverse the posting lists."] pub vector_count : usize , }',
'code_type': 'Struct',
'docstring': '= " Inverted flatten index from dimension id to posting list"',
'line': 15,
'line_from': 13,
'line_to': 22,
'context': {'module': 'inverted_index',
'file_path': 'lib/sparse/src/index/inverted_index/inverted_index_ram.rs',
'file_name': 'inverted_index_ram.rs',
'struct_name': None,
'snippet': '/// Inverted flatten index from dimension id to posting list\n#[derive(Debug, Clone, PartialEq)]\npub struct InvertedIndexRam {\n /// Posting lists for each dimension flattened (dimension id -> posting list)\n /// Gaps are filled with empty posting lists\n pub postings: Vec<PostingList>,\n /// Number of unique indexed vectors\n /// pre-computed on build and upsert to avoid having to traverse the posting lists.\n pub vector_count: usize,\n}\n'}}
代码到自然语言的转换
每种编程语言都有其自己的语法,这些语法不属于自然语言。因此,通用模型可能无法直接理解代码。然而,我们可以通过移除代码特有的部分并包含额外的上下文(如模块、类、函数和文件名)来对数据进行归一化。我们采取以下步骤
- 提取函数、方法或其他代码构造的签名。
- 将驼峰式命名和蛇形命名分割成独立的单词。
- 提取文档字符串、注释和其他重要的元数据。
- 使用预定义的模板,根据提取的数据构建一个句子。
- 移除特殊字符,并用空格替换。
现在我们可以定义 textify
函数,该函数使用 inflection
库来执行我们的转换
import inflection
import re
from typing import Dict, Any
def textify(chunk: Dict[str, Any]) -> str:
# Get rid of all the camel case / snake case
# - inflection.underscore changes the camel case to snake case
# - inflection.humanize converts the snake case to human readable form
name = inflection.humanize(inflection.underscore(chunk["name"]))
signature = inflection.humanize(inflection.underscore(chunk["signature"]))
# Check if docstring is provided
docstring = ""
if chunk["docstring"]:
docstring = f"that does {chunk['docstring']} "
# Extract the location of that snippet of code
context = f"module {chunk['context']['module']} " f"file {chunk['context']['file_name']}"
if chunk["context"]["struct_name"]:
struct_name = inflection.humanize(inflection.underscore(chunk["context"]["struct_name"]))
context = f"defined in struct {struct_name} {context}"
# Combine all the bits and pieces together
text_representation = f"{chunk['code_type']} {name} " f"{docstring}" f"defined as {signature} " f"{context}"
# Remove any special characters and concatenate the tokens
tokens = re.split(r"\W", text_representation)
tokens = filter(lambda x: x, tokens)
return " ".join(tokens)
现在我们可以使用 textify
将所有分块转换为文本表示
text_representations = list(map(textify, structures))
让我们看看我们的一个表示形式是怎样的
text_representations[1000]
'Function Hnsw discover precision that does Checks discovery search precision when using hnsw index this is different from the tests in defined as Fn hnsw discover precision module integration file hnsw_discover_test rs'
自然语言嵌入
from fastembed import TextEmbedding
batch_size = 5
nlp_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", threads=0)
nlp_embeddings = nlp_model.embed(text_representations, batch_size=batch_size)
代码嵌入
code_snippets = [structure["context"]["snippet"] for structure in structures]
code_model = TextEmbedding("jinaai/jina-embeddings-v2-base-code")
code_embeddings = code_model.embed(code_snippets, batch_size=batch_size)
构建 Qdrant 集合
Qdrant 支持多种部署模式。包括用于原型设计的内存模式、Docker 和 Qdrant Cloud。您可以参考安装说明获取更多信息。
我们将继续使用内存实例来进行本教程。
内存模式只能用于快速原型设计和测试。它是 Qdrant 服务器方法的 Python 实现。
让我们创建一个集合来存储我们的向量。
from qdrant_client import QdrantClient, models
COLLECTION_NAME = "qdrant-sources"
client = QdrantClient(":memory:") # Use in-memory storage
# client = QdrantClient("http://locahost:6333") # For Qdrant server
client.create_collection(
COLLECTION_NAME,
vectors_config={
"text": models.VectorParams(
size=384,
distance=models.Distance.COSINE,
),
"code": models.VectorParams(
size=768,
distance=models.Distance.COSINE,
),
},
)
我们新创建的集合已经准备好接收数据了。让我们上传嵌入向量吧。
from tqdm import tqdm
points = []
total = len(structures)
print("Number of points to upload: ", total)
for id, (text_embedding, code_embedding, structure) in tqdm(
enumerate(zip(nlp_embeddings, code_embeddings, structures)), total=total
):
# FastEmbed returns generators. Embeddings are computed as consumed.
points.append(
models.PointStruct(
id=id,
vector={
"text": text_embedding,
"code": code_embedding,
},
payload=structure,
)
)
# Upload points in batches
if len(points) >= batch_size:
client.upload_points(COLLECTION_NAME, points=points, wait=True)
points = []
# Ensure any remaining points are uploaded
if points:
client.upload_points(COLLECTION_NAME, points=points)
print(f"Total points in collection: {client.count(COLLECTION_NAME).count}")
上传的点立即可用于搜索。接下来,查询集合以查找相关的代码片段。
查询代码库
我们使用其中一个模型通过 Qdrant 的新查询 API 来搜索集合。从文本嵌入开始。运行以下查询:“如何计算集合中的点数?”。查看结果。
query = "How do I count points in a collection?"
hits = client.query_points(
COLLECTION_NAME,
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=3,
).points
现在,查看结果。下表列出了模块、文件名和得分。每一行都包含一个指向签名的链接。
模块 | 文件名 | 得分 | 签名 |
---|---|---|---|
operations | types.rs | 0.5493385 | pub struct CountRequestInternal |
map_index | types.rs | 0.49973965 | fn get_points_with_value_count |
map_index | mutable_map_index.rs | 0.49941066 | pub fn get_points_with_value_count |
看来我们已经找到了相关的代码结构。让我们用代码嵌入再试一次。
hits = client.query_points(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=3,
).points
输出
模块 | 文件名 | 得分 | 签名 |
---|---|---|---|
field_index | geo_index.rs | 0.7217579 | fn count_indexed_points |
numeric_index | mod.rs | 0.7113214 | fn count_indexed_points |
full_text_index | text_index.rs | 0.6993165 | fn count_indexed_points |
虽然不同模型检索到的分数不可比较,但我们可以看到结果是不同的。代码和文本嵌入可以捕捉到代码库的不同方面。我们可以同时使用两个模型来查询集合,然后结合结果以获得最相关的代码片段。
from qdrant_client import models
hits = client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=5,
),
models.Prefetch(
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=5,
),
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
).points
>>> for hit in hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_path"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| operations | lib/collection/src/operations/types.rs | 0.5 | ` # [doc = " Count Request"] # [doc = " Counts the number of points which satisfy the given filter."] # [doc = " If filter is not provided, the count of all points in the collection will be returned."] # [derive (Debug , Deserialize , Serialize , JsonSchema , Validate)] # [serde (rename_all = "snake_case")] pub struct CountRequestInternal { # [doc = " Look only for points which satisfies this conditions"] # [validate] pub filter : Option < Filter > , # [doc = " If true, count exact number of points. If false, count approximate number of points faster."] # [doc = " Approximate count might be unreliable during the indexing process. Default: true"] # [serde (default = "default_exact_count")] pub exact : bool , } ` | | field_index | lib/segment/src/index/field_index/geo_index.rs | 0.5 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.33333334 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | numeric_index | lib/segment/src/index/field_index/numeric_index/mod.rs | 0.33333334 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | lib/segment/src/fixtures/payload_context_fixture.rs | 0.25 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mutable_map_index.rs | 0.25 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | id_tracker | lib/segment/src/id_tracker/simple_id_tracker.rs | 0.2 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.2 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.16666667 | ` fn count_indexed_points (& self) -> usize ` | | field_index | lib/segment/src/index/field_index/stat_tools.rs | 0.16666667 | ` fn number_of_selected_points (points : usize , values : usize) -> usize ` |
这是如何融合不同模型结果的一个例子。在实际场景中,您可能需要进行一些重排和去重,以及对结果进行额外的处理。
对结果进行分组
您可以通过根据有效载荷属性对搜索结果进行分组来改善搜索结果。在我们的案例中,我们可以按模块对结果进行分组。如果我们使用代码嵌入,我们可以看到来自 map_index
模块的多个结果。让我们对结果进行分组,并假设每个模块只显示一个结果
results = client.query_points_groups(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
group_by="context.module",
limit=5,
group_size=1,
)
>>> for group in results.groups:
... for hit in group.hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_name"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| field_index | geo_index.rs | 0.7217579 | ` fn count_indexed_points (& self) -> usize ` | | numeric_index | mod.rs | 0.7113214 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | payload_context_fixture.rs | 0.6993165 | ` fn total_point_count (& self) -> usize ` | | map_index | mod.rs | 0.68385994 | ` fn count_indexed_points (& self) -> usize ` | | full_text_index | text_index.rs | 0.6660142 | ` fn count_indexed_points (& self) -> usize ` |
我们的教程到此结束。感谢您花时间看到这里。我们才刚刚开始探索向量嵌入的可能性以及如何改进它。欢迎您自由尝试;您可能会创造出非常酷的东西!请与我们分享 🙏 我们在这里。
< > 在 GitHub 上更新