使用Sentence Transformers训练速度提高400倍的静态嵌入模型

发布于2025年1月15日
在 GitHub 上更新

总结

本博客文章介绍了一种训练静态嵌入模型的方法,该模型在CPU上的运行速度比最先进的嵌入模型快100到400倍,同时保持了大部分质量。这开启了许多令人兴奋的应用场景,包括设备上和浏览器内执行、边缘计算、低功耗和嵌入式应用。

我们应用此方法训练了两个极其高效的嵌入模型:用于英语检索的sentence-transformers/static-retrieval-mrl-en-v1,以及用于多语言相似度任务的sentence-transformers/static-similarity-mrl-multilingual-v1。这些模型在CPU上比all-mpnet-base-v2multilingual-e5-small等常见模型快100到400倍,同时在各种基准测试中至少达到其性能的85%

今天,我们发布:

  • 上述两个模型(用于英语检索和多语言相似度)。
  • 我们遵循的详细训练策略,从构思到数据集选择,再到实施和评估。
  • 两个基于开源sentence transformers库的训练脚本。
  • 两份包含训练期间收集的训练和评估指标的Weights and Biases报告。
  • 我们使用的详细数据集列表:30个用于训练,13个用于评估。

我们还讨论了潜在的改进,并鼓励社区探索这些改进并在此工作的基础上进行构建!

点击查看已发布模型的使用片段

这些模型的使用非常简单,与常规的Sentence Transformers流程相同

英语检索

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7649, 0.3279]])

多语言相似度

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-similarity-mrl-multilingual-v1", device="cpu")
# Run inference
sentences = [
    'It is known for its dry red chili powder.',
    'It is popular for dried red chili powder.',
    'These monsters will move in large groups.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[ 1.0000,  0.8388, -0.0012],
#         [ 0.8388,  1.0000,  0.0445],
#         [-0.0012,  0.0445,  1.0000]])

NanoBEIR performance vs inference speed

目录

什么是嵌入?

嵌入是自然语言处理中最通用的工具之一,使从业者能够解决各种任务。本质上,嵌入是更复杂对象(如文本、图像、音频等)的数值表示。

embedding model

嵌入模型总是生成相同固定大小的嵌入。然后,您可以通过计算各个嵌入的相似度来计算复杂对象的相似度。

embedding similarity

这有大量的用例,并作为推荐系统、检索、异常检测、单次或少次学习、相似度搜索、聚类、释义检测、分类等的基础。

现代嵌入

当今许多嵌入模型都包含少量转换步骤。遵循这些步骤称为“推理”。

embedding pipeline

`Tokenizer` 和 `Pooler` 分别负责 `Encoder` 的预处理和后处理。前者将文本切分为 `Encoder` 可理解的标记(又称单词或子词),而后者将所有标记的嵌入组合成整个文本的一个嵌入。

在此管道中,`Encoder` 通常是一个带有注意力层的语言模型,它允许每个标记在其他标记的**上下文**中进行计算。例如,`bank` 可能是一个标记,但如果文本指的是“河岸”或金融机构,则该标记的嵌入可能会有所不同。

具有许多注意力层的大型编码器模型将有效地利用上下文来生成有用的嵌入,但这样做的代价是推理速度慢。值得注意的是,在管道中,`Encoder` 步骤通常占据了几乎所有计算时间。

静态嵌入

静态嵌入指不使用大型缓慢的基于注意力的模型,而是依赖预计算标记嵌入的`Encoder`模型组。静态嵌入在 Transformer 架构开发之前就已使用多年。常见示例包括GLoVeword2vec。最近,Model2Vec已被用于将预训练嵌入模型转换为静态嵌入模型。

对于静态嵌入,`Encoder` 步骤就像字典查找一样简单:给定标记,返回预计算的标记嵌入。因此,推理突然不再受 `Encoder` 阶段的瓶颈,从而使速度提高**几个数量级**。这篇博客文章表明,对质量的影响可以非常小!

我们的方法

我们着手使用现代技术重新审视静态嵌入模型并训练它们。我们的大部分收益来自对比学习损失函数的使用,我们将很快解释。此外,通过使用套娃表示学习,我们可以获得额外的速度改进,这使得使用嵌入向量的截断版本成为可能。

我们将使用Sentence Transformers库进行训练。有关此库如何用于训练嵌入模型的更一般概述,请考虑阅读使用Sentence Transformers v3训练和微调嵌入模型博客文章或Sentence Transformers训练概述文档

训练细节

重新构想静态嵌入的目标是,在这些高效嵌入模型上试验现代嵌入模型微调技术。特别是,与 GLoVe 和 word2vec 不同,我们将使用:

  1. **对比学习**:在大多数机器学习中,您输入 $X$ 并期望输出 $Y$,然后训练模型,使通过模型输入的 $X$ 产生接近 $Y$ 的结果。对于嵌入模型,我们没有 $Y$:我们事先不知道好的嵌入是什么。

    相反,在对比学习中,我们有多个输入 $X_1$ 和 $X_2$,以及一个相似度。我们将两个输入都通过模型,之后我们可以**对比**生成的两个嵌入,从而得到预测的相似度。如果真实相似度低,我们可以将嵌入推得更远;如果真实相似度高,则可以将嵌入拉得更近。

  2. **套娃表示学习(MRL)**:套娃嵌入模型(博客文章)是一种巧妙的训练方法,允许用户在性能损失最小的情况下将嵌入模型截断为更小的维度。它不仅使用正常大小的嵌入进行对比损失函数计算,还使用其截断版本。因此,模型学习将信息主要存储在嵌入的开头。

    截断后的嵌入将在下游应用(如检索、分类和聚类)中更快。

对于未来的研究,我们留下了各种其他现代训练方法以提高数据质量。请参阅下一步了解具体想法。

训练要求

如Sentence Transformers中的训练概述文档所示,训练由3到5个组件组成:

  1. 数据集
  2. 损失函数
  3. 训练参数(可选)
  4. 评估器(可选)
  5. 训练器

在以下部分中,我们将详细阐述我们对每个组件的思考过程。

模型灵感

根据我们的经验,嵌入模型要么1) 专门用于检索,要么2) 用于各种任务(分类、聚类、语义文本相似度等)。我们着手训练了这两种模型。

对于检索模型,可用的多语言检索训练数据量有限,因此我们选择仅使用英语模型。相反,我们决定训练一个多语言通用相似度模型,因为对于此任务来说,多语言数据更容易获取。

对于这些模型,我们希望使用StaticEmbedding模块,它实现了高效的tokenize方法以避免填充,以及高效的forward方法来计算和池化嵌入。这就像使用一个torchEmbeddingBag一样简单,它不过是一个高效的Embedding(即嵌入的查找表)加上平均池化。

我们可以通过几种方式初始化它:`StaticEmbedding.from_model2vec`加载Model2Vec 模型`StaticEmbedding.from_distillation`执行Model2Vec风格的蒸馏,或者使用`Tokenizer`和嵌入维度进行初始化以获得随机权重。

根据我们的发现,当使用大量数据进行完全训练时,最后一个选项效果最好。为了匹配all-mpnet-base-v2bge-large-en-v1.5等常见模型,我们选择将嵌入维度设置为1024,即我们的嵌入向量每个包含1024个值。

英语检索

对于英语检索模型,我们依赖google-bert/bert-base-uncased分词器。因此,模型初始化如下所示:

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)

model = SentenceTransformer(modules=[static_embedding])

`modules` 列表中的第一个条目必须实现 `tokenize`,最后一个必须生成池化嵌入。这里两者都符合,所以我们可以开始训练这个模型了。

多语言相似度

对于多语言相似度模型,我们转而依赖`google-bert/bert-base-multilingual-uncased`分词器,这是我们初始化代码中唯一改变的地方

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)

model = SentenceTransformer(modules=[static_embedding])

训练数据集选择

除了数十个Sentence Transformer模型之外,Hugging Face上的Sentence Transformers组织还托管了70多个数据集(截至撰写本文时)

除此之外,许多数据集已标记为`sentence-transformers`,以表明它们对训练嵌入模型有用

英语检索

对于英语检索数据集,我们主要寻找具有以下特征的任何数据集:

  • 问答对,可选地带有负例(即错误答案),以及
  • 与BEIR基准(即MTEB上的检索选项卡)没有重叠。我们的目标是避免在这些数据集上进行训练,以便我们可以将MTEB用作零样本基准。

我们选择了以下数据集:

多语言相似度

对于多语言相似度数据集,我们的目标是选择包含以下特征的数据集:

  • 跨语言的平行句子,即多种语言中的相同文本,或
  • 正例对,即具有高度相似性的对,可选地带有负例(即低相似性)。

我们选择了以下包含平行句子的数据集:

以及以下包含某种正例对的数据集:

代码

加载这些数据集相当简单,例如:

from datasets import load_dataset, Dataset

gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]

print(gooaq_train_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 3002496
})
"""

print(gooaq_eval_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 10000
})
"""

gooaq数据集尚未进行训练-评估划分,因此我们可以使用`train_test_split`创建一个。否则,我们可以直接加载预计算的划分,例如使用`split="eval"`。

请注意,`train_test_split`确实意味着数据集必须加载到内存中,而否则它只保留在磁盘上。这种增加的内存对于训练来说并不理想,因此建议1) 加载数据,2) 分割数据,3) 使用`save_to_disk`将其保存到磁盘。在训练之前,您可以然后使用`load_from_disk`再次加载它。

损失函数选择

在 Sentence Transformers 中,您的损失模型必须与您的训练数据格式匹配。损失概述旨在概述哪些损失与哪些格式兼容。

特别是,我们目前的数据有以下格式:

  • (锚点, 正例) 对,无标签
  • (锚点, 正例, 负例) 三元组,无标签
  • (锚点, 正例, 负例_1, ..., 负例_n) 元组,无标签

对于这些格式,我们有一些极佳的选择:

  1. `MultipleNegativesRankingLoss` (MNRL):也称为批内负样本损失或 InfoNCE 损失,这种损失已用于训练现代嵌入模型数年。简而言之,该损失优化以下目标:

    给定一个锚点(例如一个问题),在批次中的所有正例和负例(例如所有答案)中,将最高相似度分配给对应的正例(即答案)。

    如果您提供可选的负例,它们将仅用作额外选项(也称为批内负例),模型必须从中选择正确的正例。在合理范围内,这种“选择”越困难,模型就会变得越强大。因此,更大的批次大小会产生更多的批内负例,从而提高性能(达到一定程度)。

  2. `CachedMultipleNegativesRankingLoss` (CMNRL):这是 MNRL 的一个扩展,它实现了 GradCache,这种方法允许任意增加批次大小而不增加内存。

    除非您已经可以使用MNRL在内存中容纳足够大的批处理大小,否则建议使用此损失而非MNRL。在这种情况下,您可以使用MNRL来节省CMNRL带来的20%训练速度成本。

  3. `GISTEmbedLoss` (GIST):这也是 MNRL 的一个扩展,它使用一个 `guide` Sentence Transformer 模型从模型必须“选择”正确正例的选项列表中删除潜在的假负例。

    假负例会损害性能,但难的正负例(接近正确但不完全正确的文本)可以帮助提高性能,因此这种过滤需要谨慎权衡。

由于这些静态嵌入模型极其微小,我们可以在我们的硬件(一块24GB显存的RTX 3090)上轻松容纳我们期望的2048样本批次大小,因此我们不需要使用CMNRL。

此外,由于我们正在训练如此快的模型,来自`GISTEmbedLoss`的`guide`会使训练慢很多。因此,我们选择为我们的模型使用`MultipleNegativesRankingLoss`

如果我们要再次尝试这些实验,我们会选择更大的批量大小,例如使用 CMNRL 的 16384。如果您尝试,请告诉我们结果如何!

代码

用法相当简单:

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss

# Prepare a model to train
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])

# Initialize the MNRL loss given the model
loss = MultipleNegativesRankingLoss(model)

套娃表示学习

除了常规损失函数之外,Sentence Transformers 还实现了一些损失修饰符。它们在标准损失函数之上工作,但以不同的方式应用它们,以试图向训练好的嵌入模型注入有用的特性。

一个非常有趣的例子是`MatryoshkaLoss`,它将训练好的模型转换为一个**套娃模型**。这允许用户在性能损失最小的情况下截断输出嵌入到更小的维度,这意味着由于维度更小,检索或聚类可以加速。

代码

`MatryoshkaLoss` 应用于正常的损失之上。建议在 `matryoshka_dims` 列表中也包含正常的嵌入维度。

from sentence_transformers import SentenceTransformer
from sentence_transformers.losses import MultipleNegativesRankingLoss, MatryoshkaLoss

# Prepare a model to train
tokenizer = Tokenizer.from_pretrained("google-bert/bert-base-uncased")
static_embedding = StaticEmbedding(tokenizer, embedding_dim=1024)
model = SentenceTransformer(modules=[static_embedding])

# Initialize the MNRL loss given the model
base_loss = MultipleNegativesRankingLoss(model)
loss = MatryoshkaLoss(model, base_loss, matryoshka_dims=[1024, 768, 512, 256, 128, 64, 32])

训练参数选择

Sentence Transformers 支持大量的训练参数,其中最有价值的参数已在训练概述 > 训练参数文档中列出。

我们使用相同的核心训练参数来训练两个模型:

  • 训练周期数: 1
    • 我们有足够的数据,如果想训练更多,可以添加更多数据,而不是多次训练相同的数据。
  • `per_device_train_batch_size`/`per_device_eval_batch_size`: 2048
    • 2048 维度可以轻松地在我们的 RTX 3090 上运行。多篇论文(Xiao 等Li 等)表明,即使更大的批量大小也能提高性能。对于未来的版本,我们将使用 `CachedMultipleNegativesRankingLoss` 和更大的批量大小,例如 16384。
  • `learning_rate`: 2e-1
    • 注意!这比正常嵌入模型训练的损失(通常约为 2e-5)**大得多**。
  • 预热比率: 0.1
    • 0.1 或 10% 是一个非常标准的预热比率,用于平滑地将高学习率引入模型。
  • `bf16`: True
    • 如果您的 GPU 支持 `bf16`,那么使用它进行训练通常是合理的。否则,如果支持 `fp16`,您可以使用 `fp16=True`。
  • `batch_sampler`: `BatchSamplers.NO_DUPLICATES`
    • 所有具有批内负例的损失(例如 MNRL)都受益于此批采样器,它避免了批内重复。重复通常会导致假负例,从而削弱训练后的模型。
  • `multi_dataset_batch_sampler`: `MultiDatasetBatchSamplers.PROPORTIONAL`
    • 当您使用多个数据集进行训练时,数据集大小通常不相同。发生这种情况时,您可以选择:
      • 循环:从每个数据集抽取相同数量的批次,直到其中一个耗尽。您将获得均匀的数据分布,但并非所有数据都将被使用。
      • 按比例:抽取每个数据集,直到所有数据集都耗尽。您将使用所有数据,但数据分布不均匀。我们选择了这种方式,因为我们不太关心数据不平衡问题。

除了这些核心参数之外,我们还设置了一些用于跟踪和调试的训练参数:`eval_strategy`、`eval_steps`、`save_strategy`、`save_steps`、`save_total_limit`、`logging_steps`、`logging_first_step` 和 `run_name`。

代码

最终,我们为这两个模型使用了这些`SentenceTransformerTrainingArguments`:

run_name = "static-retrieval-mrl-en-v1"
# or 
# run_name = "static-similarity-mrl-multilingual-v1"

args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=f"models/{run_name}",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=2048,
    per_device_eval_batch_size=2048,
    learning_rate=2e-1,
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=True,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,
    logging_steps=1000,
    logging_first_step=True,
    run_name=run_name,  # Used if `wandb`, `tensorboard`, or `neptune`, etc. is installed
)

评估器选择

如果我们向 Sentence Transformer 训练器提供一个评估数据集,那么在评估时我们将得到一个评估损失。这对于跟踪我们是否过拟合很有用,但在实际下游性能方面意义不大。

因此,Sentence Transformers 还支持评估器。与训练损失不同,它们提供定性指标,例如信息检索的 NDCG、MAP、MRR,语义文本相似度的 Spearman 相关系数,或三元组准确率(`similarity(anchor, positive)` > `similarity(anchor, negative)` 的样本数量)。

由于其简单性,我们将为检索模型使用`NanoBEIREvaluator`。该评估器在NanoBEIR数据集集合上运行信息检索基准测试。该数据集是更大的(因此更慢的)BEIR基准的子集,BEIR基准通常用作 MTEB 排行榜中的检索选项卡。

代码

由于所有数据集都已预定义,我们可以无需任何参数加载评估器

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import NanoBEIREvaluator

# Load an example pre-trained model to finetune further
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Initialize the NanoBEIR Evaluator
evaluator = NanoBEIREvaluator()

# Run it on any Sentence Transformer model
evaluator(model)

硬件详情

我们正在消费级硬件上训练这些模型,具体如下:

  • GPU:RTX 3090
  • CPU:i7-13700K
  • 内存:32GB

总体训练脚本

本节包含两个模型的最终训练脚本,其中结合了所有先前描述的组件(数据集、损失函数、训练参数、评估器、训练器)。

英语检索

点击展开
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.evaluation import NanoBEIREvaluator
from sentence_transformers.models.StaticEmbedding import StaticEmbedding

from transformers import AutoTokenizer

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets():
    """
    Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.

    Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
    """
    try:
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
        return train_dataset, eval_dataset
    except FileNotFoundError:
        print("Loading gooaq dataset...")
        gooaq_dataset = load_dataset("sentence-transformers/gooaq", split="train")
        gooaq_dataset_dict = gooaq_dataset.train_test_split(test_size=10_000, seed=12)
        gooaq_train_dataset: Dataset = gooaq_dataset_dict["train"]
        gooaq_eval_dataset: Dataset = gooaq_dataset_dict["test"]
        print("Loaded gooaq dataset.")

        print("Loading msmarco dataset...")
        msmarco_dataset = load_dataset("sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1", "triplet", split="train")
        msmarco_dataset_dict = msmarco_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_train_dataset: Dataset = msmarco_dataset_dict["train"]
        msmarco_eval_dataset: Dataset = msmarco_dataset_dict["test"]
        print("Loaded msmarco dataset.")

        print("Loading squad dataset...")
        squad_dataset = load_dataset("sentence-transformers/squad", split="train")
        squad_dataset_dict = squad_dataset.train_test_split(test_size=10_000, seed=12)
        squad_train_dataset: Dataset = squad_dataset_dict["train"]
        squad_eval_dataset: Dataset = squad_dataset_dict["test"]
        print("Loaded squad dataset.")

        print("Loading s2orc dataset...")
        s2orc_dataset = load_dataset("sentence-transformers/s2orc", "title-abstract-pair", split="train[:100000]")
        s2orc_dataset_dict = s2orc_dataset.train_test_split(test_size=10_000, seed=12)
        s2orc_train_dataset: Dataset = s2orc_dataset_dict["train"]
        s2orc_eval_dataset: Dataset = s2orc_dataset_dict["test"]
        print("Loaded s2orc dataset.")

        print("Loading allnli dataset...")
        allnli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
        allnli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
        print("Loaded allnli dataset.")

        print("Loading paq dataset...")
        paq_dataset = load_dataset("sentence-transformers/paq", split="train")
        paq_dataset_dict = paq_dataset.train_test_split(test_size=10_000, seed=12)
        paq_train_dataset: Dataset = paq_dataset_dict["train"]
        paq_eval_dataset: Dataset = paq_dataset_dict["test"]
        print("Loaded paq dataset.")

        print("Loading trivia_qa dataset...")
        trivia_qa = load_dataset("sentence-transformers/trivia-qa", split="train")
        trivia_qa_dataset_dict = trivia_qa.train_test_split(test_size=5_000, seed=12)
        trivia_qa_train_dataset: Dataset = trivia_qa_dataset_dict["train"]
        trivia_qa_eval_dataset: Dataset = trivia_qa_dataset_dict["test"]
        print("Loaded trivia_qa dataset.")

        print("Loading msmarco_10m dataset...")
        msmarco_10m_dataset = load_dataset("bclavie/msmarco-10m-triplets", split="train")
        msmarco_10m_dataset_dict = msmarco_10m_dataset.train_test_split(test_size=10_000, seed=12)
        msmarco_10m_train_dataset: Dataset = msmarco_10m_dataset_dict["train"]
        msmarco_10m_eval_dataset: Dataset = msmarco_10m_dataset_dict["test"]
        print("Loaded msmarco_10m dataset.")

        print("Loading swim_ir dataset...")
        swim_ir_dataset = load_dataset("nthakur/swim-ir-monolingual", "en", split="train").select_columns(["query", "text"])
        swim_ir_dataset_dict = swim_ir_dataset.train_test_split(test_size=10_000, seed=12)
        swim_ir_train_dataset: Dataset = swim_ir_dataset_dict["train"]
        swim_ir_eval_dataset: Dataset = swim_ir_dataset_dict["test"]
        print("Loaded swim_ir dataset.")

        # NOTE: 20 negatives
        print("Loading pubmedqa dataset...")
        pubmedqa_dataset = load_dataset("sentence-transformers/pubmedqa", "triplet-20", split="train")
        pubmedqa_dataset_dict = pubmedqa_dataset.train_test_split(test_size=100, seed=12)
        pubmedqa_train_dataset: Dataset = pubmedqa_dataset_dict["train"]
        pubmedqa_eval_dataset: Dataset = pubmedqa_dataset_dict["test"]
        print("Loaded pubmedqa dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading miracl dataset...")
        miracl_dataset = load_dataset("sentence-transformers/miracl", "en-triplet-all", split="train")
        miracl_dataset_dict = miracl_dataset.train_test_split(test_size=10_000, seed=12)
        miracl_train_dataset: Dataset = miracl_dataset_dict["train"]
        miracl_eval_dataset: Dataset = miracl_dataset_dict["test"]
        print("Loaded miracl dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mldr dataset...")
        mldr_dataset = load_dataset("sentence-transformers/mldr", "en-triplet-all", split="train")
        mldr_dataset_dict = mldr_dataset.train_test_split(test_size=10_000, seed=12)
        mldr_train_dataset: Dataset = mldr_dataset_dict["train"]
        mldr_eval_dataset: Dataset = mldr_dataset_dict["test"]
        print("Loaded mldr dataset.")

        # NOTE: A lot of overlap with anchor/positives
        print("Loading mr_tydi dataset...")
        mr_tydi_dataset = load_dataset("sentence-transformers/mr-tydi", "en-triplet-all", split="train")
        mr_tydi_dataset_dict = mr_tydi_dataset.train_test_split(test_size=10_000, seed=12)
        mr_tydi_train_dataset: Dataset = mr_tydi_dataset_dict["train"]
        mr_tydi_eval_dataset: Dataset = mr_tydi_dataset_dict["test"]
        print("Loaded mr_tydi dataset.")

        train_dataset = DatasetDict({
            "gooaq": gooaq_train_dataset,
            "msmarco": msmarco_train_dataset,
            "squad": squad_train_dataset,
            "s2orc": s2orc_train_dataset,
            "allnli": allnli_train_dataset,
            "paq": paq_train_dataset,
            "trivia_qa": trivia_qa_train_dataset,
            "msmarco_10m": msmarco_10m_train_dataset,
            "swim_ir": swim_ir_train_dataset,
            "pubmedqa": pubmedqa_train_dataset,
            "miracl": miracl_train_dataset,
            "mldr": mldr_train_dataset,
            "mr_tydi": mr_tydi_train_dataset,
        })
        eval_dataset = DatasetDict({
            "gooaq": gooaq_eval_dataset,
            "msmarco": msmarco_eval_dataset,
            "squad": squad_eval_dataset,
            "s2orc": s2orc_eval_dataset,
            "allnli": allnli_eval_dataset,
            "paq": paq_eval_dataset,
            "trivia_qa": trivia_qa_eval_dataset,
            "msmarco_10m": msmarco_10m_eval_dataset,
            "swim_ir": swim_ir_eval_dataset,
            "pubmedqa": pubmedqa_eval_dataset,
            "miracl": miracl_eval_dataset,
            "mldr": mldr_eval_dataset,
            "mr_tydi": mr_tydi_eval_dataset,
        })

        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")
        
        # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
        quit()
    

def main():
    # 1. Load a model to finetune with 2. (Optional) model card data
    static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-uncased"), embedding_dim=1024)
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            language="en",
            license="apache-2.0",
            model_name="Static Embeddings with BERT uncased tokenizer finetuned on various datasets",
        ),
    )

    # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
    train_dataset, eval_dataset = load_train_eval_datasets()
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])

    # 5. (Optional) Specify training arguments
    run_name = "static-retrieval-mrl-en-v1"
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-1,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=250,
        save_strategy="steps",
        save_steps=250,
        save_total_limit=2,
        logging_steps=250,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. (Optional) Create an evaluator & evaluate the base model
    evaluator = NanoBEIREvaluator()
    evaluator(model)

    # 7. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # (Optional) Evaluate the trained model on the evaluator after training
    evaluator(model)

    # 8. Save the trained model
    model.save_pretrained(f"models/{run_name}/final")

    # 9. (Optional) Push it to the Hugging Face Hub
    model.push_to_hub(run_name, private=True)

if __name__ == "__main__":
    main()

该脚本在训练17.8小时后生成了sentence-transformers/static-retrieval-mrl-en-v1。总共消耗了2.6千瓦时能源,排放了1千克二氧化碳。这大致相当于一个人每天呼出的二氧化碳量。

请参阅我们的Weights and Biases报告,了解训练期间收集的训练和评估指标。

多语言相似度

点击展开
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.models.StaticEmbedding import StaticEmbedding

from transformers import AutoTokenizer

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)


def load_train_eval_datasets():
    """
    Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.

    Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
    """
    try:
        train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
        eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
        return train_dataset, eval_dataset
    except FileNotFoundError:
        print("Loading wikititles dataset...")
        wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
        wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
        wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
        wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
        print("Loaded wikititles dataset.")

        print("Loading tatoeba dataset...")
        tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
        tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
        tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
        tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
        print("Loaded tatoeba dataset.")

        print("Loading talks dataset...")
        talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
        talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
        talks_train_dataset: Dataset = talks_dataset_dict["train"]
        talks_eval_dataset: Dataset = talks_dataset_dict["test"]
        print("Loaded talks dataset.")

        print("Loading europarl dataset...")
        europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
        europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
        europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
        europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
        print("Loaded europarl dataset.")

        print("Loading global voices dataset...")
        global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
        global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
        global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
        global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
        print("Loaded global voices dataset.")

        print("Loading jw300 dataset...")
        jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train")
        jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12)
        jw300_train_dataset: Dataset = jw300_dataset_dict["train"]
        jw300_eval_dataset: Dataset = jw300_dataset_dict["test"]
        print("Loaded jw300 dataset.")

        print("Loading muse dataset...")
        muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
        muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
        muse_train_dataset: Dataset = muse_dataset_dict["train"]
        muse_eval_dataset: Dataset = muse_dataset_dict["test"]
        print("Loaded muse dataset.")

        print("Loading wikimatrix dataset...")
        wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
        wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
        wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
        wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
        print("Loaded wikimatrix dataset.")

        print("Loading opensubtitles dataset...")
        opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
        opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
        opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
        opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
        print("Loaded opensubtitles dataset.")

        print("Loading stackexchange dataset...")
        stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
        stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
        stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
        stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
        print("Loaded stackexchange dataset.")

        print("Loading quora dataset...")
        quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
        quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
        quora_train_dataset: Dataset = quora_dataset_dict["train"]
        quora_eval_dataset: Dataset = quora_dataset_dict["test"]
        print("Loaded quora dataset.")

        print("Loading wikianswers duplicates dataset...")
        wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
        wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
        wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
        wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
        print("Loaded wikianswers duplicates dataset.")

        print("Loading all nli dataset...")
        all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
        all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
        print("Loaded all nli dataset.")

        print("Loading simple wiki dataset...")
        simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
        simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
        simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
        simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
        print("Loaded simple wiki dataset.")

        print("Loading altlex dataset...")
        altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
        altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
        altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
        altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
        print("Loaded altlex dataset.")

        print("Loading flickr30k captions dataset...")
        flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
        flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
        flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
        flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
        print("Loaded flickr30k captions dataset.")

        print("Loading coco captions dataset...")
        coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
        coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
        coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
        coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
        print("Loaded coco captions dataset.")

        print("Loading nli for simcse dataset...")
        nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
        nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
        nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
        nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
        print("Loaded nli for simcse dataset.")

        print("Loading negation dataset...")
        negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
        negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
        negation_train_dataset: Dataset = negation_dataset_dict["train"]
        negation_eval_dataset: Dataset = negation_dataset_dict["test"]
        print("Loaded negation dataset.")

        train_dataset = DatasetDict({
            "wikititles": wikititles_train_dataset,
            "tatoeba": tatoeba_train_dataset,
            "talks": talks_train_dataset,
            "europarl": europarl_train_dataset,
            "global_voices": global_voices_train_dataset,
            "jw300": jw300_train_dataset,
            "muse": muse_train_dataset,
            "wikimatrix": wikimatrix_train_dataset,
            "opensubtitles": opensubtitles_train_dataset,
            "stackexchange": stackexchange_train_dataset,
            "quora": quora_train_dataset,
            "wikianswers_duplicates": wikianswers_duplicates_train_dataset,
            "all_nli": all_nli_train_dataset,
            "simple_wiki": simple_wiki_train_dataset,
            "altlex": altlex_train_dataset,
            "flickr30k_captions": flickr30k_captions_train_dataset,
            "coco_captions": coco_captions_train_dataset,
            "nli_for_simcse": nli_for_simcse_train_dataset,
            "negation": negation_train_dataset,
        })
        eval_dataset = DatasetDict({
            "wikititles": wikititles_eval_dataset,
            "tatoeba": tatoeba_eval_dataset,
            "talks": talks_eval_dataset,
            "europarl": europarl_eval_dataset,
            "global_voices": global_voices_eval_dataset,
            "jw300": jw300_eval_dataset,
            "muse": muse_eval_dataset,
            "wikimatrix": wikimatrix_eval_dataset,
            "opensubtitles": opensubtitles_eval_dataset,
            "stackexchange": stackexchange_eval_dataset,
            "quora": quora_eval_dataset,
            "wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
            "all_nli": all_nli_eval_dataset,
            "simple_wiki": simple_wiki_eval_dataset,
            "altlex": altlex_eval_dataset,
            "flickr30k_captions": flickr30k_captions_eval_dataset,
            "coco_captions": coco_captions_eval_dataset,
            "nli_for_simcse": nli_for_simcse_eval_dataset,
            "negation": negation_eval_dataset,
        })

        train_dataset.save_to_disk("datasets/train_dataset")
        eval_dataset.save_to_disk("datasets/eval_dataset")
        
        # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
        quit()

def main():
    # 1. Load a model to finetune with 2. (Optional) model card data
    static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
    model = SentenceTransformer(
        modules=[static_embedding],
        model_card_data=SentenceTransformerModelCardData(
            license="apache-2.0",
            model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets",
        ),
    )

    # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
    train_dataset, eval_dataset = load_train_eval_datasets()
    print(train_dataset)

    # 4. Define a loss function
    loss = MultipleNegativesRankingLoss(model)
    loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])

    # 5. (Optional) Specify training arguments
    run_name = "static-similarity-mrl-multilingual-v1"
    args = SentenceTransformerTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=1,
        per_device_train_batch_size=2048,
        per_device_eval_batch_size=2048,
        learning_rate=2e-1,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=1000,
        save_strategy="steps",
        save_steps=1000,
        save_total_limit=2,
        logging_steps=1000,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
    )

    # 6. Create a trainer & train
    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        loss=loss,
    )
    trainer.train()

    # 7. Save the trained model
    model.save_pretrained(f"models/{run_name}/final")

    # 8. (Optional) Push it to the Hugging Face Hub
    model.push_to_hub(run_name, private=True)

if __name__ == "__main__":
    main()

这个模型只比流行但慢得多的multilingual-e5-small模型损失了大约8%的性能,正如即将到来的性能 > 多语言相似度部分所示。

请参阅我们的Weights and Biases 报告,了解训练期间收集的训练和评估损失。

用法

这些模型的使用非常简单,与常规的Sentence Transformers流程相同

英语检索

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7649, 0.3279]])

即将到来的性能 > 英语检索部分将显示,这些结果非常可靠,与常用的基于 Transformer 的编码器模型(如all-mpnet-base-v2)相差不到 15%。

多语言相似度

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer("sentence-transformers/static-similarity-mrl-multilingual-v1", device="cpu")
# Run inference
sentences = [
    'It is known for its dry red chili powder.',
    'It is popular for dried red chili powder.',
    'These monsters will move in large groups.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 1024]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings, embeddings)
print(similarities)
# tensor([[ 1.0000,  0.8388, -0.0012],
#         [ 0.8388,  1.0000,  0.0445],
#         [-0.0012,  0.0445,  1.0000]])

与流行的但速度慢得多的multilingual-e5-small相比,该模型仅损失约8%的性能,如即将到来的性能 > 多语言相似度部分所示。

套娃降维截断

要降低计算出的嵌入的维度,您只需传递 `truncate_dim` 参数即可。这适用于所有 Sentence Transformer 模型。

from sentence_transformers import SentenceTransformer

# Download from the 🤗 Hub
model = SentenceTransformer(
    "sentence-transformers/static-retrieval-mrl-en-v1",
    device="cpu",
    truncate_dim=256,
)
# Run inference
sentences = [
    'Gadofosveset-enhanced MR angiography of carotid arteries: does steady-state imaging improve accuracy of first-pass imaging?',
    'To evaluate the diagnostic accuracy of gadofosveset-enhanced magnetic resonance (MR) angiography in the assessment of carotid artery stenosis, with digital subtraction angiography (DSA) as the reference standard, and to determine the value of reading first-pass, steady-state, and "combined" (first-pass plus steady-state) MR angiograms.',
    'In a longitudinal study we investigated in vivo alterations of CVO during neuroinflammation, applying Gadofluorine M- (Gf) enhanced magnetic resonance imaging (MRI) in experimental autoimmune encephalomyelitis, an animal model of multiple sclerosis. SJL/J mice were monitored by Gadopentate dimeglumine- (Gd-DTPA) and Gf-enhanced MRI after adoptive transfer of proteolipid-protein-specific T cells. Mean Gf intensity ratios were calculated individually for different CVO and correlated to the clinical disease course. Subsequently, the tissue distribution of fluorescence-labeled Gf as well as the extent of cellular inflammation was assessed in corresponding histological slices.',
]
embeddings = model.encode(sentences)
print(embeddings.shape)
# [3, 256]

# Get the similarity scores for the embeddings
similarities = model.similarity(embeddings[0], embeddings[1:])
print(similarities)
# tensor([[0.7844, 0.3561]])

第三方库

该模型还可与各种第三方库开箱即用,例如LangChainLlamaIndexHaystacktxtai

LangChain

# pip install langchain langchain_huggingface
from langchain_huggingface import HuggingFaceEmbeddings

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
model_kwargs = {'device': 'cpu'} # you can use 'truncate_dim' here
model = HuggingFaceEmbeddings(
    model_name=model_name,
    model_kwargs=model_kwargs,
)

LlamaIndex

# pip install llama-index llama-index-embeddings-huggingface
from llama_index.core import Settings
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

# Set up the HuggingFaceEmbedding class with the required model to use with llamaindex core.
model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
device = "cpu"
embed_model = HuggingFaceEmbedding(
    model_name=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)
Settings.embed_model = embed_model

Haystack

# pip install haystack sentence-transformers
from haystack.components.embedders import (
    SentenceTransformersDocumentEmbedder,
    SentenceTransformersTextEmbedder,
)

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
device = "cpu"
document_embedder = SentenceTransformersDocumentEmbedder(
    model=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)
text_embedder = SentenceTransformersTextEmbedder(
    model=model_name,
    device=device,
    # truncate_dim=256, # you can use 'truncate_dim' here
)

txtai

# pip install txtai sentence-transformers
from txtai import Embeddings

model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
embeddings = Embeddings(path=model_name)

性能

英语检索

训练完成后,我们评估了最终模型sentence-transformers/static-retrieval-mrl-en-v1在NanoBEIR(正常维度和Matryoshka维度)以及BEIR上的性能。

NanoBEIR

我们评估了sentence-transformers/static-retrieval-mrl-en-v1在NanoBEIR上的性能,并将其与我们在硬件上计算的推理速度进行了对比。对于推理速度测试,我们计算了每秒在CPU或GPU上GooAQ数据集计算的查询嵌入数量。

我们评估了三种类型的模型:

  1. 基于注意力的密集嵌入模型,例如传统的 Sentence Transformer 模型,如`all-mpnet-base-v2``bge-base-en-v1.5``gte-large-en-v1.5`

  2. 基于静态嵌入的模型,例如 `static-retrieval-mrl-en-v1`, `potion-base-8M`, `M2V_base_output`, 和 `glove.6B.300d`

  3. 稀疏词袋模型,BM25,通常是一个强大的基线。

    点击展开BM25实现细节

    我们依赖于高效的bm25s实现,在标记化和使用英文`PyStemmer`进行词干提取后,对标记使用`model.get_scores()`。

**注意:**许多基于注意力的密集嵌入模型在(Nano)BEIR 评估数据集的训练集上进行了微调。这使得模型在此基准测试中具有不公平的优势,并可能导致实际检索任务的下游性能降低。

static-retrieval-mrl-en-v1 有意未在这些数据集上进行训练。

点击查看下一页两张图表的所有数值
模型 NanoBEIR NDCG@10 CPU (每秒句子数) GPU (每秒句子数)
zeta-alpha-ai/Zeta-Alpha-E5-Mistral 0.6860 0.00* 0.00*
Alibaba-NLP/gte-large-en-v1.5 0.6808 56.01 965.95
Salesforce/SFR-Embedding-Mistral 0.6800 0.00* 0.00*
mixedbread-ai/mxbai-embed-large-v1 0.6567 79.83 1376.80
BAAI/bge-large-en-v1.5 0.6592 80.94 1315.03
intfloat/e5-mistral-7b-instruct 0.6530 0.00* 0.00*
Alibaba-NLP/gte-base-en-v1.5 0.6411 197.85 3142.94
BAAI/bge-base-en-v1.5 0.6376 264.83 4363.04
BAAI/bge-small-en-v1.5 0.6267 888.46 10159.97
nomic-ai/nomic-embed-text-v1.5 0.6179 86.86 2843.03
jinaai/jina-embeddings-v3 0.6174 0.55 3377.56
BAAI/bge-m3 0.6054 80.63 1434.82
sentence-transformers/all-mpnet-base-v2 0.5757 270.40 4043.13
TaylorAI/gte-tiny 0.5692 1752.26 17215.15
sentence-transformers/all-MiniLM-L6-v2 0.5623 1739.31 16942.46
mixedbread-ai/mxbai-embed-xsmall-v1 0.5557 1749.42 16773.76
sentence-transformers/all-MiniLM-L12-v2 0.5533 909.72 9915.69
sentence-transformers/static-retrieval-mrl-en-v1 0.5032 107419.51 97171.47
bm25 0.4518 49706.77 49706.77
minishlab/potion-base-8M 0.4421 124029.91 122384.10
minishlab/potion-base-4M 0.4225 123082.88 123612.54
minishlab/M2V_base_glove 0.4077 142173.77 146154.73
minishlab/M2V_base_glove_subword 0.3914 127426.83 131412.56
minishlab/M2V_base_output 0.3851 84191.93 85738.36
minishlab/potion-base-2M 0.3666 128994.27 122358.16
sentence-transformers/glove.6B.300d 0.3293 76519.74 62782.23
sentence-transformers/glove.840B.300d 0.2899 86348.98 75350.36
  • *:对于 7B LLM,我们没有进行推理实验,因为它们的推理速度在图中将无法区分。
  • 我们进行了实验以确定每个模型的最佳批处理大小。
GPU

NanoBEIR performance vs inference speed

CPU

NanoBEIR performance vs inference speed

我们可以从这些数据中得出一些显著结论:

  1. `static-retrieval-mrl-en-v1` 的性能优于所有其他静态嵌入模型,如 GloVe 或 Model2Vec。
  2. `static-retrieval-mrl-en-v1` 是唯一优于 BM25 的静态嵌入模型。
  3. `static-retrieval-mrl-en-v1` 的性能:
    • 与常用模型`all-mpnet-base-v2`相比,性能达到**87.4%**,
    • 在GPU上快**24倍**,
    • 在CPU上快**397倍**。
  4. `static-retrieval-mrl-en-v1` 在 CPU 上比在 GPU 上更快:此模型可以在任何地方以极快的速度运行,包括消费级 PC、小型服务器、手机或浏览器中。

Matryoshka 评估

此外,我们通过将输出嵌入截断到较低维度,进行了 Matryoshka 式降维,并对 NanoBEIR 性能结果进行了实验。

NanoBEIR performance vs Matryoshka dimensionality reduction

这些发现表明,例如将维度减少 2 倍,性能仅下降 1.47%(0.5031 NDCG@10 对 0.4957 NDCG@10),而实际上检索速度却提高了 2 倍。

多语言相似度

我们还评估了最终的 sentence-transformers/static-similarity-mrl-multilingual-v1 模型在 5 种语言上的表现,这些语言在 MTEB 上有大量基准测试。

我们希望重申,此模型不适用于检索用例。相反,我们评估的是语义文本相似度 (STS)、分类和对分类。我们与出色的轻量级 multilingual-e5-small 模型进行了比较。

STS, Classification, Pair Classification on MTEB

在所有测试语言中,static-similarity-mrl-multilingual-v1 相对于 multilingual-e5-small 在 STS 上平均达到 92.3%,在对分类上达到 95.52%,在分类上达到 86.52%

Texts per second processed

为了弥补这种性能下降,static-similarity-mrl-multilingual-v1 在 CPU 设备上比 multilingual-e5-small 快约 125 倍,在 GPU 设备上快约 10 倍。由于注意力模型的超线性性质,与静态嵌入模型的线性性质相比,编码令牌数量的增加将使加速效果更大。

Matryoshka 评估

最后,我们通过将输出嵌入截断到较低维度,进行了 Matryoshka 式降维,并对英语 STS 在 MTEB 性能上的影响进行了实验。

English STS MTEB performance vs Matryoshka dimensionality reduction

如您所见,您可以轻松地将维度减少 2 倍或 4 倍,而性能损失很小(0.15% 或 0.56%)。如果您的下游任务的速度或存储成本是瓶颈,这应该可以帮助您缓解一些担忧。

结论

这篇博客文章描述了我们从构思到完成模型的所有步骤,以及关于两个结果模型(static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1)的使用和评估的详细信息。

评估结果表明:

  • 基于静态嵌入的模型可以超过常见基于注意力稠密模型性能的 85%
  • 基于静态嵌入的模型在 GPU 上比常见的有效替代方案(如 all-mpnet-base-v2multilingual-e5-small)快 10 倍到 25 倍,在 CPU 上快 100 倍到 400 倍。文本越长,这种加速效果就越大。
  • 使用 Matryoshka 损失进行训练可以显著保持下游性能

如果您需要一个高效的仅支持 CPU 的稠密嵌入模型来执行检索或相似性任务,那么 static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1 将是以最小成本提供极其高效的解决方案,并且其性能出人意料地接近基于注意力的稠密模型。

后续步骤

试一试!如果您已经在某个地方使用了 Sentence Transformer 模型,请随意将其替换为 static-retrieval-mrl-en-v1static-similarity-mrl-multilingual-v1。或者,更好的是:根据您感兴趣的任务和语言的代表性数据训练您自己的模型。

此外,关于训练好的模型,还有一些问题有待解决。

  1. 由于基于静态嵌入的模型不受位置嵌入或超线性时间复杂度的瓶颈,因此它们可以具有任意高的最大序列长度。然而,在某个时刻,大数定律可能会“规范化”所有真正长文档的嵌入,使其不再有用。

    需要进行更多的实验来确定一个好的截断点。目前,我们将最大序列长度、分块等留给用户。

此外,还有一些可能的扩展,可能会提高此模型的性能,我们很高兴将其留给其他模型作者。我们也欢迎合作。

  1. 困难负样本挖掘:搜索相似但不相关的文本以提高训练数据难度。
  2. 模型聚合:结合以相同方式训练的多个模型(使用不同种子或数据分布)的权重。
  3. 课程学习:从难度逐渐增加的示例中进行训练。
  4. 引导式批内假负样本过滤:通过高效的预训练嵌入模型排除假负样本。
  5. 随机权重初始化的种子优化:用各种种子训练最初的步骤,以找到一个有用的权重初始化。
  6. 分词器再训练:使用现代文本和学习成果对分词器进行再训练。
  7. 梯度缓存:通过 CachedMultipleNegativesRankingLoss 应用 GradCache 可以实现更大的批量,这通常会带来更好的性能。
  8. 模型蒸馏:除了仅使用有监督的训练数据进行训练外,我们还可以通过更大的嵌入模型输入无监督数据,并将这些嵌入蒸馏到基于静态嵌入的学生模型中。

致谢

我要感谢 Stéphan TulkensThomas van DongenThe Minish Lab 团队,他们通过 Model2Vec 工作让我关注到了静态嵌入模型。此外,我还要感谢 Vaibhav SrivastavPedro Cuenca 在这篇博客文章中的帮助,以及 Antoine Chaffin 提出的发布检查点。

最后,非常感谢所有致力于嵌入模型、数据集和开源 Python 包的研究人员。你们为行业增添了力量,我将站在你们的肩膀上。希望有一天,你们也能站在我的肩膀上。

社区

英伟达:你需要购买 GPU 和机器。
开发者:不,我们只需调整算法即可。
埃隆·马斯克:看看我买的 GPU,你们这些穷鬼。

非常感谢您的帖子,工作很棒,

我已经训练了一些英语和西班牙语模型

  • NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-en-es
  • NickyNicky/StaticEmbedding-MatryoshkaLoss-gemma-2-2b-gooaq-en

我想知道如何增加或减少
“max_length 示例 371”

当我检查“print(model.max_seq_length) # -> Inf”时。

这可能吗,怎么做?我找不到相关文档

非常感谢

·
文章作者

你好!

这些模型做得真好!我是否理解正确,其中一个模型在所有数据集上达到了 NanoBEIR 的 0.5623 NDCG@10?这比 static-retrieval-mrl-en-v1 的 0.5032 NDCG@10 提升了很大。

我想知道如何增加或减少
“max_length 示例 371”

你指的是模型卡 这里 的“max”吗?
image.png

那只是关于训练数据的一些近似统计;取自前 1000 个样本。尽管不建议使用(远)大于训练数据的序列长度的文本,但实际的最大序列长度确实是无限的。它在这里定义:https://github.com/UKPLab/sentence-transformers/blob/cccab8303aaf6e18f069b0da578b3d162bf8442a/sentence_transformers/models/StaticEmbedding.py#L106-L108

简而言之:模型永远不会截断序列,因为该方法

  1. 具有线性复杂度(数据量增加 2 倍 -> 速度慢 2 倍),这与 Transformer 模型(数据量增加 2 倍 -> 速度慢(远)超过 2 倍)不同。
  2. 不受可能对最大序列长度施加限制的位置嵌入的影响。

所以,静态模型没有最大序列长度。它们只是要求用户注意不要输入过大的文档,因为所有文档如果足够长,最终都会嵌入得非常相似。

  • Tom Aarsen

这太酷了!我很惊讶你做得比 model2vec 更好——区别真的只是使用了(更好的)对比损失预训练公式吗?

·
文章作者

是的!架构是相同的。事实上,这篇博客文章中描述的模型所使用的 StaticEmbedding 模块与在 Sentence Transformers 中加载 Model2Vec 模型时使用的模块实际上是相同的。

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding
from tokenizers import Tokenizer

# Pre-distilled embeddings:
static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")

model = SentenceTransformer(modules=[static_embedding])

embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."])
similarity = model.similarity(embeddings[0], embeddings[1])
# tensor([[0.9177]]) (If you use the distilled bge-base)

惊人的工作和出色的写作!

我尝试在部分数据(AllNLI, GooAQ, MSMacro, PAQ, S2ORC)上进行此训练,批处理大小为 16384。耗时 5 小时。

w&b
https://api.wandb.ai/links/arunarumugam411-sui/dkcwm6gs

很棒的工作。看起来很酷!

看起来很棒,但是
1
你能阐明其背后的想法吗?你是为每个词元计算嵌入然后取平均值吗?
2
你能分享 NanoBEIR 的链接吗?

·

你有像论文一样的详细描述吗,拜托

你能解释一下吗
来自
https://huggingface.co/blog/Pringled/model2vec
齐普夫
由于我们对空间中的词元进行了简单平均,因此正确加权向量非常重要。通常,一个句子转换器会在给定上下文的情况下为我们正确加权所有词元,但我们不再拥有这种奢侈。直观地,我们希望使用类似逆文档频率 (IDF) 的方法来降低非常频繁或无趣的词的权重。但是我们无法访问一个语料库来计算文档频率。

为了克服这个问题,我们选择使用语言科学中一个众所周知的原理,即,给定一个按频率排序的列表,列表中项目的频率遵循幂律分布。这被称为齐普夫定律。因此,如果我们假设词汇表是按频率排序的,我们就可以准确地降低非常频繁项目的权重,而无需访问实际频率。由于分词器词汇表是按频率排序的,我们已经可以访问一个排序列表,因此无需任何额外工作即可应用此优化。

所以对于假设的齐普夫输入
[ [ 0.2,0.5,0.7] , [1.2, 0.9,0.2], [0.4, 0.3, 0.2] ,[1.3, 2.4, 3.2]]

1
根据每个向量范数对输入进行排序
所以你得到
[ [0.4, 0.3, 0.2] , [ 0.2,0.5,0.7] , [1.2, 0.9,0.2],[1.3, 2.4, 3.2] ]
2
你将每个向量除以它的范数

[ [0.4, 0.3, 0.2]/n1 , [ 0.2,0.5,0.7]/n2 , [1.2, 0.9,0.2] /n3 ,[1.3, 2.4, 3.2]/n4 ]

3
那么最终的嵌入是这些降权向量的平均值吗?
( [0.4, 0.3, 0.2]/n1 + [ 0.2,0.5,0.7]/n2 + [1.2, 0.9,0.2] /n3 + [1.3, 2.4, 3.2]/n4) / 4

这是正确的算法吗?

太棒了!这项技术真的令人大开眼界^^

不过,对帖子标题提个小建议,一开始,我以为这篇帖子是关于更快地训练句子嵌入模型,而不是关于训练推理时间更快的句子嵌入模型。只是让你们知道。

这是一种很棒的方法!

我通过整合大量日语数据集,训练了一个静态嵌入日语模型(static-embedding-japanese),当我们在日语多语言文本嵌入基准(JMTEB)上进行比较时,我能够获得仅略低于 mE5-small 的分数。

JMTEB 结果

模型 平均(微观) 检索 STS 分类 重排 聚类 对分类
文本嵌入-3-小型 69.18 66.39 79.46 73.06 92.92 51.06 62.27
多语言-e5-小型 67.71 67.27 80.07 67.62 93.03 46.91 62.19
静态嵌入-日语 67.17 67.92 80.16 67.96 91.87 40.39 62.37

感谢您发表如此优秀的文章。

·
文章作者

这表现太棒了,工作出色!我也很感谢你非常详细的模型卡——我现在就用翻译器读一下!

随着一些研究方向转向无分词器模型,不知道字符级相似度训练模型能达到多远。

·

还在想,考虑到我们现在拥有的额外计算和时间,与其他静态嵌入模型进行集成是否会带来额外的质量改进。但是,最好的集成方式可能是什么,也许特定领域数据集会有帮助,不同损失训练的模型呢?

有没有计划让模型可用于文本嵌入推理?

https://github.com/huggingface/text-embeddings-inference

注册登录 以发表评论