开源 AI 食谱文档
使用向量嵌入和 Qdrant 进行代码搜索
并获得增强的文档体验
开始使用
使用向量嵌入和 Qdrant 进行代码搜索
作者:Qdrant 团队
在这篇 notebook 中,我们将演示如何使用向量嵌入来浏览代码库,并找到相关的代码片段。我们将使用自然的语义查询来搜索代码库,并基于类似的逻辑搜索代码。
您可以查看此方法的在线部署,它通过 Web 界面公开了 Qdrant 代码库以供搜索。
方法
我们需要两个模型来实现我们的目标。
用于自然语言处理 (NLP) 的通用神经编码器,在我们的例子中是 sentence-transformers/all-MiniLM-L6-v2。我们将其称为 NLP 模型。
用于代码到代码相似性搜索的专用嵌入。我们将使用 jinaai/jina-embeddings-v2-base-code 模型来完成此任务。它支持英语和 30 种广泛使用的编程语言,序列长度为 8192。我们将其称为代码模型。
为了准备我们的代码以供 NLP 模型使用,我们需要将代码预处理成一种与自然语言非常相似的格式。代码模型支持各种标准编程语言,因此无需预处理代码片段。我们可以按原样使用代码。
安装依赖
让我们安装我们将要使用的软件包。
- inflection - 一个字符串转换库。它可以将英语单词单数和复数化,并将 CamelCase 转换为下划线字符串。
- fastembed - 一个 CPU 优先的轻量级库,用于生成向量嵌入。 GPU 支持可用。
- qdrant-client - 用于与 Qdrant 服务器交互的官方 Python 库。
%pip install inflection qdrant-client fastembed
数据准备
将应用程序源代码分块成更小的部分是一项重要的任务。一般来说,函数、类方法、结构体、枚举以及所有其他特定于语言的构造都是分块的良好候选对象。它们足够大,可以包含一些有意义的信息,但又足够小,可以被具有有限上下文窗口的嵌入模型处理。您还可以使用文档字符串、注释和其他元数据来丰富块,以提供更多信息。

基于文本的搜索基于函数签名,但代码搜索可能会返回更小的片段,例如循环。因此,如果我们从 NLP 模型收到特定的函数签名,并从代码模型收到其部分实现,我们将合并结果。
解析代码库
在此演示中,我们将使用 Qdrant 代码库。虽然此代码库使用 Rust,但您可以将此方法用于任何其他语言。您可以使用语言服务器协议 (LSP) 工具来构建代码库的图,然后提取块。我们使用 rust-analyzer 完成了这项工作。我们将解析的代码库导出为 LSIF 格式,这是代码智能数据的标准。接下来,我们使用 LSIF 数据来浏览代码库并提取块。
您可以将相同的方法用于其他语言。有大量的实现可用。
然后,我们将块导出到 JSON 文档中,其中不仅包含代码本身,还包含代码在项目中的位置的上下文。
您可以在我们的 Google Cloud Storage 存储桶中的 structures.jsonl 文件中查看以 JSON 格式解析的 Qdrant 结构。下载它并将其用作我们代码搜索的数据源。
!wget https://storage.googleapis.com/tutorial-attachments/code-search/structures.jsonl
接下来,加载文件并将行解析为字典列表
import json
structures = []
with open("structures.jsonl", "r") as fp:
for i, row in enumerate(fp):
entry = json.loads(row)
structures.append(entry)
让我们看看一个条目是什么样的。
structures[0]
{'name': 'InvertedIndexRam',
'signature': '# [doc = " Inverted flatten index from dimension id to posting list"] # [derive (Debug , Clone , PartialEq)] pub struct InvertedIndexRam { # [doc = " Posting lists for each dimension flattened (dimension id -> posting list)"] # [doc = " Gaps are filled with empty posting lists"] pub postings : Vec < PostingList > , # [doc = " Number of unique indexed vectors"] # [doc = " pre-computed on build and upsert to avoid having to traverse the posting lists."] pub vector_count : usize , }',
'code_type': 'Struct',
'docstring': '= " Inverted flatten index from dimension id to posting list"',
'line': 15,
'line_from': 13,
'line_to': 22,
'context': {'module': 'inverted_index',
'file_path': 'lib/sparse/src/index/inverted_index/inverted_index_ram.rs',
'file_name': 'inverted_index_ram.rs',
'struct_name': None,
'snippet': '/// Inverted flatten index from dimension id to posting list\n#[derive(Debug, Clone, PartialEq)]\npub struct InvertedIndexRam {\n /// Posting lists for each dimension flattened (dimension id -> posting list)\n /// Gaps are filled with empty posting lists\n pub postings: Vec<PostingList>,\n /// Number of unique indexed vectors\n /// pre-computed on build and upsert to avoid having to traverse the posting lists.\n pub vector_count: usize,\n}\n'}}
代码到自然语言的转换
每种编程语言都有自己的语法,这不是自然语言的一部分。因此,通用模型可能无法按原样理解代码。但是,我们可以通过删除代码的特定细节并包含其他上下文(例如模块、类、函数和文件名)来规范化数据。我们采取以下步骤
- 提取函数、方法或其他代码构造的签名。
- 将驼峰式和蛇形命名法名称分成单独的单词。
- 获取文档字符串、注释和其他重要的元数据。
- 使用预定义的模板从提取的数据构建句子。
- 删除特殊字符并将其替换为空格。
现在我们可以定义 textify
函数,该函数使用 inflection
库来执行我们的转换
import inflection
import re
from typing import Dict, Any
def textify(chunk: Dict[str, Any]) -> str:
# Get rid of all the camel case / snake case
# - inflection.underscore changes the camel case to snake case
# - inflection.humanize converts the snake case to human readable form
name = inflection.humanize(inflection.underscore(chunk["name"]))
signature = inflection.humanize(inflection.underscore(chunk["signature"]))
# Check if docstring is provided
docstring = ""
if chunk["docstring"]:
docstring = f"that does {chunk['docstring']} "
# Extract the location of that snippet of code
context = f"module {chunk['context']['module']} " f"file {chunk['context']['file_name']}"
if chunk["context"]["struct_name"]:
struct_name = inflection.humanize(inflection.underscore(chunk["context"]["struct_name"]))
context = f"defined in struct {struct_name} {context}"
# Combine all the bits and pieces together
text_representation = f"{chunk['code_type']} {name} " f"{docstring}" f"defined as {signature} " f"{context}"
# Remove any special characters and concatenate the tokens
tokens = re.split(r"\W", text_representation)
tokens = filter(lambda x: x, tokens)
return " ".join(tokens)
现在我们可以使用 textify
将所有块转换为文本表示形式
text_representations = list(map(textify, structures))
让我们看看我们的一个表示形式是什么样的
text_representations[1000]
'Function Hnsw discover precision that does Checks discovery search precision when using hnsw index this is different from the tests in defined as Fn hnsw discover precision module integration file hnsw_discover_test rs'
自然语言嵌入
from fastembed import TextEmbedding
batch_size = 5
nlp_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", threads=0)
nlp_embeddings = nlp_model.embed(text_representations, batch_size=batch_size)
代码嵌入
code_snippets = [structure["context"]["snippet"] for structure in structures]
code_model = TextEmbedding("jinaai/jina-embeddings-v2-base-code")
code_embeddings = code_model.embed(code_snippets, batch_size=batch_size)
构建 Qdrant 集合
Qdrant 支持多种部署模式。包括用于原型设计的内存模式、Docker 和 Qdrant Cloud。您可以参考安装说明以获取更多信息。
我们将继续使用内存实例进行本教程。
内存模式只能用于快速原型设计和测试。它是 Qdrant 服务器方法的 Python 实现。
让我们创建一个集合来存储我们的向量。
from qdrant_client import QdrantClient, models
COLLECTION_NAME = "qdrant-sources"
client = QdrantClient(":memory:") # Use in-memory storage
# client = QdrantClient("http://locahost:6333") # For Qdrant server
client.create_collection(
COLLECTION_NAME,
vectors_config={
"text": models.VectorParams(
size=384,
distance=models.Distance.COSINE,
),
"code": models.VectorParams(
size=768,
distance=models.Distance.COSINE,
),
},
)
我们新创建的集合已准备好接收数据。让我们上传嵌入
from tqdm import tqdm
points = []
total = len(structures)
print("Number of points to upload: ", total)
for id, (text_embedding, code_embedding, structure) in tqdm(
enumerate(zip(nlp_embeddings, code_embeddings, structures)), total=total
):
# FastEmbed returns generators. Embeddings are computed as consumed.
points.append(
models.PointStruct(
id=id,
vector={
"text": text_embedding,
"code": code_embedding,
},
payload=structure,
)
)
# Upload points in batches
if len(points) >= batch_size:
client.upload_points(COLLECTION_NAME, points=points, wait=True)
points = []
# Ensure any remaining points are uploaded
if points:
client.upload_points(COLLECTION_NAME, points=points)
print(f"Total points in collection: {client.count(COLLECTION_NAME).count}")
上传的点立即可用于搜索。接下来,查询集合以查找相关的代码片段。
查询代码库
我们使用其中一个模型通过 Qdrant 的新 Query API 搜索集合。从文本嵌入开始。运行以下查询“How do I count points in a collection?”。查看结果。
query = "How do I count points in a collection?"
hits = client.query_points(
COLLECTION_NAME,
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=3,
).points
现在,查看结果。下表列出了模块、文件名和分数。每行都包含指向签名的链接。
模块 | 文件名 | 分数 | 签名 |
---|---|---|---|
operations | types.rs | 0.5493385 | pub struct CountRequestInternal |
map_index | types.rs | 0.49973965 | fn get_points_with_value_count |
map_index | mutable_map_index.rs | 0.49941066 | pub fn get_points_with_value_count |
看来我们能够找到一些相关的代码结构。让我们用代码嵌入尝试相同的操作
hits = client.query_points(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=3,
).points
输出
模块 | 文件名 | 分数 | 签名 |
---|---|---|---|
field_index | geo_index.rs | 0.7217579 | fn count_indexed_points |
numeric_index | mod.rs | 0.7113214 | fn count_indexed_points |
full_text_index | text_index.rs | 0.6993165 | fn count_indexed_points |
虽然不同模型检索到的分数不可比较,但我们可以看到结果是不同的。代码和文本嵌入可以捕获代码库的不同方面。我们可以使用这两个模型来查询集合,然后将结果组合起来以获得最相关的代码片段。
from qdrant_client import models
hits = client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=next(nlp_model.query_embed(query)).tolist(),
using="text",
limit=5,
),
models.Prefetch(
query=next(code_model.query_embed(query)).tolist(),
using="code",
limit=5,
),
],
query=models.FusionQuery(fusion=models.Fusion.RRF),
).points
>>> for hit in hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_path"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| operations | lib/collection/src/operations/types.rs | 0.5 | ` # [doc = " Count Request"] # [doc = " Counts the number of points which satisfy the given filter."] # [doc = " If filter is not provided, the count of all points in the collection will be returned."] # [derive (Debug , Deserialize , Serialize , JsonSchema , Validate)] # [serde (rename_all = "snake_case")] pub struct CountRequestInternal { # [doc = " Look only for points which satisfies this conditions"] # [validate] pub filter : Option < Filter > , # [doc = " If true, count exact number of points. If false, count approximate number of points faster."] # [doc = " Approximate count might be unreliable during the indexing process. Default: true"] # [serde (default = "default_exact_count")] pub exact : bool , } ` | | field_index | lib/segment/src/index/field_index/geo_index.rs | 0.5 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.33333334 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | numeric_index | lib/segment/src/index/field_index/numeric_index/mod.rs | 0.33333334 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | lib/segment/src/fixtures/payload_context_fixture.rs | 0.25 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mutable_map_index.rs | 0.25 | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` | | id_tracker | lib/segment/src/id_tracker/simple_id_tracker.rs | 0.2 | ` fn total_point_count (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.2 | ` fn count_indexed_points (& self) -> usize ` | | map_index | lib/segment/src/index/field_index/map_index/mod.rs | 0.16666667 | ` fn count_indexed_points (& self) -> usize ` | | field_index | lib/segment/src/index/field_index/stat_tools.rs | 0.16666667 | ` fn number_of_selected_points (points : usize , values : usize) -> usize ` |
这是如何融合来自不同模型的结果的一个示例。在实际场景中,您可能会运行一些重新排序和去重,以及对结果进行额外的处理。
结果分组
您可以通过按有效负载属性对搜索结果进行分组来改进搜索结果。在我们的例子中,我们可以按模块对结果进行分组。如果我们使用代码嵌入,我们可以看到来自 map_index 模块的多个结果。让我们对结果进行分组,并假设每个模块只有一个结果
results = client.query_points_groups(
COLLECTION_NAME,
query=next(code_model.query_embed(query)).tolist(),
using="code",
group_by="context.module",
limit=5,
group_size=1,
)
>>> for group in results.groups:
... for hit in group.hits:
... print(
... "| ",
... hit.payload["context"]["module"],
... " | ",
... hit.payload["context"]["file_name"],
... " | ",
... hit.score,
... " | `",
... hit.payload["signature"],
... "` |",
... )
| field_index | geo_index.rs | 0.7217579 | ` fn count_indexed_points (& self) -> usize ` | | numeric_index | mod.rs | 0.7113214 | ` fn count_indexed_points (& self) -> usize ` | | fixtures | payload_context_fixture.rs | 0.6993165 | ` fn total_point_count (& self) -> usize ` | | map_index | mod.rs | 0.68385994 | ` fn count_indexed_points (& self) -> usize ` | | full_text_index | text_index.rs | 0.6660142 | ` fn count_indexed_points (& self) -> usize ` |
我们的教程到此结束。感谢您抽出时间来到这里。我们才刚刚开始探索向量嵌入的可能性以及如何改进它。随时以您的方式进行实验;您可以构建一些非常酷的东西!请与我们分享 🙏 我们在这里。
< > 在 GitHub 上更新