开源 AI 食谱文档

使用向量嵌入和 Qdrant 进行代码搜索

Hugging Face's logo
加入 Hugging Face 社区

并获得增强文档体验

开始使用

Open In Colab

使用向量嵌入和 Qdrant 进行代码搜索

作者:Qdrant 团队

在本笔记本中,我们演示了如何使用向量嵌入来浏览代码库并查找相关的代码片段。我们将使用自然语义查询搜索代码库,并根据类似的逻辑搜索代码。

您可以查看此方法的实时部署,它通过 Web 界面公开 Qdrant 代码库进行搜索。

方法

我们需要两个模型来实现我们的目标。

  • 用于自然语言处理 (NLP) 的通用用途神经编码器,在本例中为sentence-transformers/all-MiniLM-L6-v2。我们将此模型称为 NLP 模型。

  • 用于代码到代码相似性搜索的专用嵌入。我们将使用jinaai/jina-embeddings-v2-base-code模型来完成此任务。它支持英语和 30 种广泛使用的编程语言,序列长度为 8192。我们将其称为代码模型。

为了为 NLP 模型准备代码,我们需要将代码预处理成更接近自然语言的格式。代码模型支持各种标准编程语言,因此无需预处理代码片段。我们可以按原样使用代码。

安装依赖项

让我们安装我们将要使用的包。

  • inflection - 一个字符串转换库。它可以将英文单词进行单复数转换,并将 CamelCase 转换为下划线分隔的字符串。
  • fastembed - 一个以 CPU 为先的轻量级向量嵌入生成库。 支持 GPU
  • qdrant-client - 与 Qdrant 服务器交互的官方 Python 库。
%pip install inflection qdrant-client fastembed

数据准备

将应用程序源代码分割成更小的部分是一项非平凡的任务。通常,函数、类方法、结构体、枚举以及所有其他特定于语言的构造都是良好的代码块候选。它们足够大,包含一些有意义的信息,但又足够小,可以由具有有限上下文窗口的嵌入模型进行处理。您还可以使用文档字符串、注释和其他元数据来为代码块提供更多信息。

基于文本的搜索依赖于函数签名,但代码搜索可能会返回更小的代码片段,例如循环。因此,如果我们从 NLP 模型中接收到特定的函数签名,以及从代码模型中接收到其部分实现,我们将合并结果。

代码库解析

我们将使用 Qdrant 代码库 进行此演示。虽然此代码库使用 Rust,但您可以将此方法应用于任何其他语言。您可以使用 语言服务器协议 (LSP) 工具构建代码库的图,然后提取代码块。我们使用 rust-analyzer 完成了这项工作。我们将解析后的代码库导出为 LSIF 格式,这是代码智能数据的标准格式。接下来,我们使用 LSIF 数据遍历代码库并提取代码块。

您可以对其他语言使用相同的方法。有 大量实现 可供使用。

然后,我们将代码块导出到 JSON 文档中,这些文档不仅包含代码本身,还包含代码在项目中的位置上下文。

您可以在我们的 Google Cloud Storage 存储桶中的 structures.jsonl 文件 中查看以 JSON 格式解析的 Qdrant 结构。下载它并将其用作我们代码搜索的数据源。

!wget https://storage.googleapis.com/tutorial-attachments/code-search/structures.jsonl

接下来,加载文件并将行解析为字典列表。

import json

structures = []
with open("structures.jsonl", "r") as fp:
    for i, row in enumerate(fp):
        entry = json.loads(row)
        structures.append(entry)

让我们看看一个条目是什么样的。

structures[0]
{'name': 'InvertedIndexRam',
 'signature': '# [doc = " Inverted flatten index from dimension id to posting list"] # [derive (Debug , Clone , PartialEq)] pub struct InvertedIndexRam { # [doc = " Posting lists for each dimension flattened (dimension id -> posting list)"] # [doc = " Gaps are filled with empty posting lists"] pub postings : Vec < PostingList > , # [doc = " Number of unique indexed vectors"] # [doc = " pre-computed on build and upsert to avoid having to traverse the posting lists."] pub vector_count : usize , }',
 'code_type': 'Struct',
 'docstring': '= " Inverted flatten index from dimension id to posting list"',
 'line': 15,
 'line_from': 13,
 'line_to': 22,
 'context': {'module': 'inverted_index',
  'file_path': 'lib/sparse/src/index/inverted_index/inverted_index_ram.rs',
  'file_name': 'inverted_index_ram.rs',
  'struct_name': None,
  'snippet': '/// Inverted flatten index from dimension id to posting list\n#[derive(Debug, Clone, PartialEq)]\npub struct InvertedIndexRam {\n    /// Posting lists for each dimension flattened (dimension id -> posting list)\n    /// Gaps are filled with empty posting lists\n    pub postings: Vec<PostingList>,\n    /// Number of unique indexed vectors\n    /// pre-computed on build and upsert to avoid having to traverse the posting lists.\n    pub vector_count: usize,\n}\n'}}

代码到自然语言转换

每种编程语言都有其自身的语法,这并不是自然语言的一部分。因此,通用模型可能无法理解代码本身。但是,我们可以通过删除代码细节并包含其他上下文(例如模块、类、函数和文件名)来规范化数据。我们将采取以下步骤:

  1. 提取函数、方法或其他代码构造的签名。
  2. 将驼峰命名法和蛇形命名法名称分解成单独的单词。
  3. 获取文档字符串、注释和其他重要的元数据。
  4. 使用预定义的模板从提取的数据构建句子。
  5. 删除特殊字符并将其替换为空格。

现在,我们可以定义使用 `inflection` 库执行转换的 `textify` 函数。

import inflection
import re

from typing import Dict, Any


def textify(chunk: Dict[str, Any]) -> str:
    # Get rid of all the camel case / snake case
    # - inflection.underscore changes the camel case to snake case
    # - inflection.humanize converts the snake case to human readable form
    name = inflection.humanize(inflection.underscore(chunk["name"]))
    signature = inflection.humanize(inflection.underscore(chunk["signature"]))

    # Check if docstring is provided
    docstring = ""
    if chunk["docstring"]:
        docstring = f"that does {chunk['docstring']} "

    # Extract the location of that snippet of code
    context = f"module {chunk['context']['module']} " f"file {chunk['context']['file_name']}"
    if chunk["context"]["struct_name"]:
        struct_name = inflection.humanize(inflection.underscore(chunk["context"]["struct_name"]))
        context = f"defined in struct {struct_name} {context}"

    # Combine all the bits and pieces together
    text_representation = f"{chunk['code_type']} {name} " f"{docstring}" f"defined as {signature} " f"{context}"

    # Remove any special characters and concatenate the tokens
    tokens = re.split(r"\W", text_representation)
    tokens = filter(lambda x: x, tokens)
    return " ".join(tokens)

现在我们可以使用 `textify` 将所有代码块转换为文本表示。

text_representations = list(map(textify, structures))

让我们看看其中一个表示是什么样的。

text_representations[1000]
'Function Hnsw discover precision that does Checks discovery search precision when using hnsw index this is different from the tests in defined as Fn hnsw discover precision module integration file hnsw_discover_test rs'

自然语言嵌入

from fastembed import TextEmbedding

batch_size = 5

nlp_model = TextEmbedding("sentence-transformers/all-MiniLM-L6-v2", threads=0)
nlp_embeddings = nlp_model.embed(text_representations, batch_size=batch_size)

代码嵌入

code_snippets = [structure["context"]["snippet"] for structure in structures]

code_model = TextEmbedding("jinaai/jina-embeddings-v2-base-code")

code_embeddings = code_model.embed(code_snippets, batch_size=batch_size)

构建 Qdrant 集合

Qdrant 支持多种部署模式,包括用于原型设计的内存模式、Docker 和 Qdrant 云。您可以参考 安装说明 获取更多信息。

我们将继续使用内存实例进行本教程。

内存模式仅可用于快速原型设计和测试。它是 Qdrant 服务器方法的 Python 实现。

让我们创建一个集合来存储我们的向量。

from qdrant_client import QdrantClient, models

COLLECTION_NAME = "qdrant-sources"

client = QdrantClient(":memory:")  # Use in-memory storage
# client = QdrantClient("http://locahost:6333")  # For Qdrant server

client.create_collection(
    COLLECTION_NAME,
    vectors_config={
        "text": models.VectorParams(
            size=384,
            distance=models.Distance.COSINE,
        ),
        "code": models.VectorParams(
            size=768,
            distance=models.Distance.COSINE,
        ),
    },
)

我们新创建的集合已准备好接收数据。让我们上传嵌入。

from tqdm import tqdm

points = []
total = len(structures)
print("Number of points to upload: ", total)

for id, (text_embedding, code_embedding, structure) in tqdm(
    enumerate(zip(nlp_embeddings, code_embeddings, structures)), total=total
):
    # FastEmbed returns generators. Embeddings are computed as consumed.
    points.append(
        models.PointStruct(
            id=id,
            vector={
                "text": text_embedding,
                "code": code_embedding,
            },
            payload=structure,
        )
    )

    # Upload points in batches
    if len(points) >= batch_size:
        client.upload_points(COLLECTION_NAME, points=points, wait=True)
        points = []

# Ensure any remaining points are uploaded
if points:
    client.upload_points(COLLECTION_NAME, points=points)

print(f"Total points in collection: {client.count(COLLECTION_NAME).count}")

上传的点可立即用于搜索。接下来,查询集合以查找相关的代码片段。

查询代码库

我们使用其中一个模型通过 Qdrant 的新 查询 API 搜索集合。首先使用文本嵌入。运行以下查询:“如何计算集合中的点数?”。查看结果。

query = "How do I count points in a collection?"

hits = client.query_points(
    COLLECTION_NAME,
    query=next(nlp_model.query_embed(query)).tolist(),
    using="text",
    limit=3,
).points

现在,查看结果。下表列出了模块、文件名和分数。每行都包含指向签名的链接。

模块 文件名 分数 签名
operations types.rs 0.5493385 pub struct CountRequestInternal
map_index types.rs 0.49973965 fn get_points_with_value_count
map_index mutable_map_index.rs 0.49941066 pub fn get_points_with_value_count

看起来我们能够找到一些相关的代码结构。让我们尝试使用代码嵌入进行同样的操作。

hits = client.query_points(
    COLLECTION_NAME,
    query=next(code_model.query_embed(query)).tolist(),
    using="code",
    limit=3,
).points

输出

模块 文件名 分数 签名
field_index geo_index.rs 0.7217579 fn count_indexed_points
numeric_index mod.rs 0.7113214 fn count_indexed_points
full_text_index text_index.rs 0.6993165 fn count_indexed_points

虽然不同模型检索的分数不可比,但我们可以看到结果是不同的。代码和文本嵌入可以捕捉代码库的不同方面。我们可以使用这两种模型来查询集合,然后组合结果以获得最相关的代码片段。

from qdrant_client import models

hits = client.query_points(
    collection_name=COLLECTION_NAME,
    prefetch=[
        models.Prefetch(
            query=next(nlp_model.query_embed(query)).tolist(),
            using="text",
            limit=5,
        ),
        models.Prefetch(
            query=next(code_model.query_embed(query)).tolist(),
            using="code",
            limit=5,
        ),
    ],
    query=models.FusionQuery(fusion=models.Fusion.RRF),
).points
>>> for hit in hits:
...     print(
...         "| ",
...         hit.payload["context"]["module"],
...         " | ",
...         hit.payload["context"]["file_path"],
...         " | ",
...         hit.score,
...         " | `",
...         hit.payload["signature"],
...         "` |",
...     )
|  operations  |  lib/collection/src/operations/types.rs  |  0.5  | ` # [doc = " Count Request"] # [doc = " Counts the number of points which satisfy the given filter."] # [doc = " If filter is not provided, the count of all points in the collection will be returned."] # [derive (Debug , Deserialize , Serialize , JsonSchema , Validate)] # [serde (rename_all = "snake_case")] pub struct CountRequestInternal &#123; # [doc = " Look only for points which satisfies this conditions"] # [validate] pub filter : Option < Filter > , # [doc = " If true, count exact number of points. If false, count approximate number of points faster."] # [doc = " Approximate count might be unreliable during the indexing process. Default: true"] # [serde (default = "default_exact_count")] pub exact : bool , } ` |
|  field_index  |  lib/segment/src/index/field_index/geo_index.rs  |  0.5  | ` fn count_indexed_points (& self) -> usize ` |
|  map_index  |  lib/segment/src/index/field_index/map_index/mod.rs  |  0.33333334  | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` |
|  numeric_index  |  lib/segment/src/index/field_index/numeric_index/mod.rs  |  0.33333334  | ` fn count_indexed_points (& self) -> usize ` |
|  fixtures  |  lib/segment/src/fixtures/payload_context_fixture.rs  |  0.25  | ` fn total_point_count (& self) -> usize ` |
|  map_index  |  lib/segment/src/index/field_index/map_index/mutable_map_index.rs  |  0.25  | ` fn get_points_with_value_count < Q > (& self , value : & Q) -> Option < usize > where Q : ? Sized , N : std :: borrow :: Borrow < Q > , Q : Hash + Eq , ` |
|  id_tracker  |  lib/segment/src/id_tracker/simple_id_tracker.rs  |  0.2  | ` fn total_point_count (& self) -> usize ` |
|  map_index  |  lib/segment/src/index/field_index/map_index/mod.rs  |  0.2  | ` fn count_indexed_points (& self) -> usize ` |
|  map_index  |  lib/segment/src/index/field_index/map_index/mod.rs  |  0.16666667  | ` fn count_indexed_points (& self) -> usize ` |
|  field_index  |  lib/segment/src/index/field_index/stat_tools.rs  |  0.16666667  | ` fn number_of_selected_points (points : usize , values : usize) -> usize ` |

这是一个关于如何融合来自不同模型的结果的示例。在现实场景中,您可能需要进行一些重新排序和去重,以及对结果进行其他处理。

结果分组

您可以通过根据有效负载属性对结果进行分组来改进搜索结果。在我们的例子中,我们可以根据模块对结果进行分组。如果我们使用代码嵌入,我们可以看到来自 `map_index` 模块的多个结果。让我们对结果进行分组,并假设每个模块只有一个结果。

results = client.query_points_groups(
    COLLECTION_NAME,
    query=next(code_model.query_embed(query)).tolist(),
    using="code",
    group_by="context.module",
    limit=5,
    group_size=1,
)
>>> for group in results.groups:
...     for hit in group.hits:
...         print(
...             "| ",
...             hit.payload["context"]["module"],
...             " | ",
...             hit.payload["context"]["file_name"],
...             " | ",
...             hit.score,
...             " | `",
...             hit.payload["signature"],
...             "` |",
...         )
|  field_index  |  geo_index.rs  |  0.7217579  | ` fn count_indexed_points (& self) -> usize ` |
|  numeric_index  |  mod.rs  |  0.7113214  | ` fn count_indexed_points (& self) -> usize ` |
|  fixtures  |  payload_context_fixture.rs  |  0.6993165  | ` fn total_point_count (& self) -> usize ` |
|  map_index  |  mod.rs  |  0.68385994  | ` fn count_indexed_points (& self) -> usize ` |
|  full_text_index  |  text_index.rs  |  0.6660142  | ` fn count_indexed_points (& self) -> usize ` |

我们的教程到此结束。感谢您抽出时间学习到这里。我们刚刚开始探索向量嵌入的可能性以及如何改进它。请随意尝试;您可以构建非常酷的东西!请与我们分享🙏 我们在这里

< > 在 GitHub 上更新