使用Sentence Transformers v4训练和微调Reranker模型

发布于2025年3月26日
在 GitHub 上更新

Sentence Transformers是一个Python库,用于使用和训练嵌入和重排模型,应用于广泛的场景,例如检索增强生成、语义搜索、语义文本相似度、释义挖掘等。其v4.0更新引入了一种新的重排器(也称为交叉编码器模型)训练方法,类似于v3.0更新为嵌入模型引入的方法。在这篇博客文章中,我将向您展示如何使用它来微调一个重排器模型,该模型在您的数据上超越所有现有选项。此方法还可以从头开始训练极其强大的新重排器模型。

微调重排器模型涉及几个组件:数据集、损失函数、训练参数、评估器和训练器类本身。我将逐一探讨这些组件,并提供如何将它们用于微调强大重排器模型的实用示例。

最后,在评估部分,我将向您展示,我与这篇博客文章一起训练的我的小型微调tomaarsen/reranker-ModernBERT-base-gooaq-bce重排器模型,在我的评估数据集上轻松超越了13个最常用的公共重排器模型。它甚至击败了体积大4倍的模型。

使用更大的基础模型重复此方法,结果是tomaarsen/reranker-ModernBERT-large-gooaq-bce,一个在我的数据上超越所有现有通用重排器模型的重排器模型。

Model size vs NDCG for Rerankers on GooAQ

如果您对微调嵌入模型感兴趣,那么也可以阅读我之前的使用Sentence Transformers v3训练和微调嵌入模型博客文章。

目录

什么是Reranker模型?

重排器模型,通常使用交叉编码器架构实现,旨在评估文本对(例如,查询和文档,或两个句子)之间的相关性。与Sentence Transformers(又称双编码器、嵌入模型)不同,后者独立地将每个文本嵌入到向量中并通过距离度量计算相似度,交叉编码器通过共享神经网络同时处理配对文本,从而产生一个输出分数。通过让两个文本相互关注,交叉编码器模型可以胜过嵌入模型。

然而,这种优势也带来了权衡:交叉编码器模型速度较慢,因为它们处理所有可能的文本对(例如,10个查询和500个候选文档需要5,000次计算,而嵌入模型只需要510次)。这使得它们在大规模初始检索中效率较低,但非常适合重排:优化由更快的Sentence Transformer模型首先识别出的前k个结果。最强的搜索系统通常采用这种两阶段的“检索和重排”方法。

Embedding vs Reranker Models

在这篇博客文章中,我将互换使用“重排器模型”和“交叉编码器模型”。

为什么要微调?

重排器模型通常面临一个具有挑战性的问题:

在这些高度相关的文档中,哪一个最能回答问题?

通用重排器模型经过训练,可以在各种领域和主题中充分回答这个问题,这阻碍了它们在您的特定领域中发挥最大潜力。通过微调,模型可以学习专注于对您而言重要的领域和/或语言。

在这篇博客文章的评估部分,我将展示在您的领域中训练的模型可以超越任何通用重排器模型,即使这些基线模型大得多。不要低估在您的领域中进行微调的力量!

训练组件

训练重排器模型涉及以下组件:

  1. 数据集:用于训练和/或评估的数据。
  2. 损失函数:衡量模型性能并指导优化过程的函数。
  3. 训练参数(可选):影响训练性能、跟踪和调试的参数。
  4. 评估器(可选):用于在训练前、训练中或训练后评估模型的类。
  5. 训练器:将所有训练组件整合在一起。

让我们仔细看看每个组件。

数据集

CrossEncoderTrainer使用datasets.Datasetdatasets.DatasetDict实例进行训练和评估。您可以从Hugging Face Datasets Hub加载数据,或者使用您喜欢的任何格式(例如CSV、JSON、Parquet、Arrow或SQL)的本地数据。

注意: 许多可以直接与Sentence Transformers一起使用的公共数据集已在Hugging Face Hub上被标记为sentence-transformers,因此您可以在https://huggingface.co/datasets?other=sentence-transformers上轻松找到它们。考虑浏览这些数据集,以找到可能对您的任务、领域或语言有用的现成数据集。

Hugging Face Hub上的数据

您可以使用load_dataset函数从Hugging Face Hub中的数据集加载数据。

from datasets import load_dataset

train_dataset = load_dataset("sentence-transformers/natural-questions", split="train")

print(train_dataset)
"""
Dataset({
    features: ['query', 'answer'],
    num_rows: 100231
})
"""

一些数据集,如nthakur/swim-ir-monolingual,有多个不同数据格式的子集。您需要指定子集名称和数据集名称,例如dataset = load_dataset("nthakur/swim-ir-monolingual", "de", split="train")

本地数据(CSV、JSON、Parquet、Arrow、SQL)

您也可以使用load_dataset加载某些文件格式的本地数据

from datasets import load_dataset

dataset = load_dataset("csv", data_files="my_file.csv")
# or
dataset = load_dataset("json", data_files="my_file.json")

需要预处理的本地数据

如果您的本地数据需要预处理,您可以使用datasets.Dataset.from_dict。这允许您使用列表字典初始化数据集。

from datasets import Dataset

queries = []
documents = []
# Open a file, perform preprocessing, filtering, cleaning, etc.
# and append to the lists

dataset = Dataset.from_dict({
    "query": queries,
    "document": documents,
})

字典中的每个键都将成为结果数据集中的一列。

数据集格式

重要的是,您的数据集格式必须与您的损失函数匹配(或者您选择与您的数据集格式和模型匹配的损失函数)。验证数据集格式和模型是否与损失函数配合使用涉及三个步骤:

  1. 根据损失概述表,所有未命名为“label”、“labels”、“score”或“scores”的列都被视为*输入*。剩余列的数量必须与您选择的损失函数的有效输入数量匹配。
  2. 如果您的损失函数根据损失概述表需要一个*标签*,那么您的数据集必须有一个名为“label”、“labels”、“score”或“scores”的**列**。此列会自动作为标签。
  3. 模型输出标签的数量与损失概述表根据损失函数的要求匹配。

例如,给定一个包含列["text1", "text2", "label"]的数据集,其中“label”列的浮点相似度分数范围为0到1,以及一个输出1个标签的模型,我们可以将其与BinaryCrossEntropyLoss一起使用,因为:

  1. 数据集具有“label”列,这是此损失函数所必需的。
  2. 数据集有2个非标签列,正好是此损失函数所需的数量。
  3. 模型有1个输出标签,正好是此损失函数所需的数量。

如果您的列顺序不正确,请务必使用Dataset.select_columns重新排序您的数据集列。例如,如果您的数据集有["good_answer", "bad_answer", "question"]作为列,那么此数据集技术上可以与需要(锚点,正例,负例)三元组的损失一起使用,但good_answer列将作为锚点,bad_answer作为正例,question作为负例。

此外,如果您的数据集有多余的列(例如sample_id、metadata、source、type),您应该使用Dataset.remove_columns将其删除,否则它们将被用作输入。您还可以使用Dataset.select_columns仅保留所需列。

难例挖掘

训练重排器模型的成功通常取决于*负例*的质量,即查询-负例得分应低的段落。负例可以分为两种类型:

  • 软负例:完全不相关的段落。也称为简单负例。
  • 难负例:看起来可能与查询相关但实际上不相关的段落。

一个简洁的例子是:

  • 查询:Apple在哪里成立的?
  • 软负例:卡什河大桥是一座帕克桁架桥,横跨阿肯色州核桃岭和帕拉古尔德之间的卡什河。
  • 难负例:富士苹果是一种在20世纪30年代后期开发,并于1962年上市的苹果栽培品种。

最强的交叉编码器模型通常经过训练来识别难负例,因此能够“挖掘”难负例进行训练是很有价值的。Sentence Transformers支持一个强大的mine_hard_negatives函数,可以在给定查询-答案对数据集的情况下提供帮助:

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import mine_hard_negatives

# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
train_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
print(train_dataset)

# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
    train_dataset,
    embedding_model,
    num_negatives=5,  # How many negatives per question-answer pair
    range_min=10,  # Skip the x most similar samples
    range_max=100,  # Consider only the x most similar samples
    max_score=0.8,  # Only consider samples with a similarity score of at most x
    margin=0.1,  # Similarity between query and negative samples should be x lower than query-positive similarity
    sampling_strategy="top",  # Randomly sample negatives from the range
    batch_size=4096,  # Use a batch size of 4096 for the embedding model
    output_format="labeled-pair",  # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
    use_faiss=True,  # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_train_dataset)
print(hard_train_dataset[1])
点击查看此脚本的输出。
Dataset({
    features: ['question', 'answer'],
    num_rows: 100000
})

Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 22/22 [00:01<00:00, 13.74it/s]
Batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:00<00:00, 36.49it/s]
Querying FAISS index: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:19<00:00,  2.80s/it]
Metric       Positive       Negative     Difference
Count         100,000        436,925
Mean           0.5882         0.4040         0.2157
Median         0.5989         0.4024         0.1836
Std            0.1425         0.0905         0.1013
Min           -0.0514         0.1405         0.1014
25%            0.4993         0.3377         0.1352
50%            0.5989         0.4024         0.1836
75%            0.6888         0.4681         0.2699
Max            0.9748         0.7486         0.7545
Skipped 2420871 potential negatives (23.97%) due to the margin of 0.1.
Skipped 43 potential negatives (0.00%) due to the maximum score of 0.8.
Could not find enough negatives for 63075 samples (12.62%). Consider adjusting the range_max, range_min, margin and max_score parameters if you'd like to find more valid negatives.
Dataset({
    features: ['question', 'answer', 'label'],
    num_rows: 536925
})

{
    'question': 'how to transfer bookmarks from one laptop to another?',
    'answer': 'Using an External Drive Just about any external drive, including a USB thumb drive, or an SD card can be used to transfer your files from one laptop to another. Connect the drive to your old laptop; drag your files to the drive, then disconnect it and transfer the drive contents onto your new laptop.',
    'label': 0
}

损失函数

损失函数有助于评估模型在一组数据上的性能并指导训练过程。适用于您任务的正确损失函数取决于您拥有的数据以及您想要实现的目标。您可以在损失概述中找到可用损失函数的完整列表。

大多数损失函数都易于设置——您只需提供您正在训练的CrossEncoder模型:

from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.losses import CachedMultipleNegativesRankingLoss

# Load a model to train/finetune
model = CrossEncoder("xlm-roberta-base", num_labels=1) # num_labels=1 is for rerankers

# Initialize the CachedMultipleNegativesRankingLoss, which requires pairs of
# related texts or triplets
loss = CachedMultipleNegativesRankingLoss(model)

# Load an example training dataset that works with our loss function:
train_dataset = load_dataset("sentence-transformers/gooaq", split="train")

...

训练参数

您可以使用CrossEncoderTrainingArguments类自定义训练过程。此类别允许您调整可能影响训练速度并帮助您了解训练期间发生的事情的参数。

有关最有用的训练参数的更多信息,请查看交叉编码器 > 训练概述 > 训练参数。值得一读,以充分利用您的训练。

以下是如何设置CrossEncoderTrainingArguments的示例:

from sentence_transformers.cross_encoder import CrossEncoderTrainingArguments

args = CrossEncoderTrainingArguments(
    # Required parameter:
    output_dir="models/reranker-MiniLM-msmarco-v1",
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # losses that use "in-batch negatives" benefit from no duplicates
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="reranker-MiniLM-msmarco-v1",  # Will be used in W&B if `wandb` is installed
)

评估器

为了在训练期间跟踪模型的性能,您可以将eval_dataset传递给CrossEncoderTrainer。但是,您可能需要除评估损失之外更详细的指标。这就是评估器可以帮助您在训练的不同阶段使用特定指标评估模型性能的地方。您可以根据需要使用评估数据集、评估器、两者或都不用。评估策略和频率由eval_strategyeval_steps训练参数控制。

Sentence Transformers包含以下内置评估器:

评估器 所需数据
CrossEncoderClassificationEvaluator 带有类标签的对(二分类或多分类)
CrossEncoderCorrelationEvaluator 带有相似度分数的对
CrossEncoderNanoBEIREvaluator 无需数据
CrossEncoderRerankingEvaluator {'query': '...', 'positive': [...], 'negative': [...]}字典列表。负例可以使用mine_hard_negatives挖掘。

您还可以使用SequentialEvaluator将多个评估器组合成一个,然后将其传递给CrossEncoderTrainer。您也可以直接将评估器列表传递给训练器。

有时您没有所需的评估数据来自行准备这些评估器,但您仍然希望跟踪模型在某些常见基准上的表现。在这种情况下,您可以将这些评估器与来自Hugging Face的数据一起使用。

使用STSb的CrossEncoderCorrelationEvaluator

STS基准测试(又称STSb)是一个常用基准数据集,用于衡量模型对“一个人正在给蛇喂老鼠”等短文本语义相似度的理解。

欢迎浏览Hugging Face上的sentence-transformers/stsb数据集。

from datasets import load_dataset
from sentence_transformers import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderCorrelationEvaluator

# Load a model
model = CrossEncoder("cross-encoder/stsb-TinyBERT-L4")

# Load the STSB dataset (https://huggingface.co/datasets/sentence-transformers/stsb)
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
pairs = list(zip(eval_dataset["sentence1"], eval_dataset["sentence2"]))

# Initialize the evaluator
dev_evaluator = CrossEncoderCorrelationEvaluator(
    sentence_pairs=pairs,
    scores=eval_dataset["score"],
    name="sts_dev",
)
# You can run evaluation like so:
# results = dev_evaluator(model)

# Later, you can provide this evaluator to the trainer to get results during training

使用GooAQ挖掘负例的CrossEncoderRerankingEvaluator

CrossEncoderRerankingEvaluator准备数据可能很困难,因为除了查询-正例数据外,您还需要负例。

mine_hard_negatives函数有一个方便的include_positives参数,可以将其设置为True以同时挖掘正例文本。当将其作为documents(必须是1. 已排序且2. 包含正例)提供给CrossEncoderRerankingEvaluator时,评估器不仅会评估交叉编码器的重排性能,还会评估用于挖掘的嵌入模型原始排名。

例如:

CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 67.28
MRR@10:   52.40 -> 66.65
NDCG@10:  59.12 -> 71.35

请注意,默认情况下,如果您使用带有documentsCrossEncoderRerankingEvaluator,评估器将使用*所有*正例进行重排,即使它们不在文档中。这对于从评估器中获得更强的信号很有用,但确实会给出略微不切实际的性能。毕竟,最大性能现在是100,而通常它的上限取决于第一阶段检索器是否实际检索到了正例。

您可以通过在初始化CrossEncoderRerankingEvaluator时设置always_rerank_positives=False来启用真实行为。使用这种真实的两阶段性能重复相同的脚本会得到:

CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 66.12
MRR@10:   52.40 -> 65.61
NDCG@10:  59.12 -> 70.10
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator
from sentence_transformers.util import mine_hard_negatives

# Load a model
model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L6-v2")

# Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
full_dataset = load_dataset("sentence-transformers/gooaq", split=f"train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
print(eval_dataset)
"""
Dataset({
    features: ['question', 'answer'],
    num_rows: 1000
})
"""

# Mine hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_eval_dataset = mine_hard_negatives(
    eval_dataset,
    embedding_model,
    corpus=full_dataset["answer"],  # Use the full dataset as the corpus
    num_negatives=50,  # How many negatives per question-answer pair
    batch_size=4096,  # Use a batch size of 4096 for the embedding model
    output_format="n-tuple",  # The output format is (query, positive, negative1, negative2, ...) for the evaluator
    include_positives=True,  # Key: Include the positive answer in the list of negatives
    use_faiss=True,  # Using FAISS is recommended to keep memory usage low (pip install faiss-gpu or pip install faiss-cpu)
)
print(hard_eval_dataset)
"""
Dataset({
    features: ['question', 'answer', 'negative_1', 'negative_2', 'negative_3', 'negative_4', 'negative_5', 'negative_6', 'negative_7', 'negative_8', 'negative_9', 'negative_10', 'negative_11', 'negative_12', 'negative_13', 'negative_14', 'negative_15', 'negative_16', 'negative_17', 'negative_18', 'negative_19', 'negative_20', 'negative_21', 'negative_22', 'negative_23', 'negative_24', 'negative_25', 'negative_26', 'negative_27', 'negative_28', 'negative_29', 'negative_30', 'negative_31', 'negative_32', 'negative_33', 'negative_34', 'negative_35', 'negative_36', 'negative_37', 'negative_38', 'negative_39', 'negative_40', 'negative_41', 'negative_42', 'negative_43', 'negative_44', 'negative_45', 'negative_46', 'negative_47', 'negative_48', 'negative_49', 'negative_50'],
    num_rows: 1000
})
"""

reranking_evaluator = CrossEncoderRerankingEvaluator(
    samples=[
        {
            "query": sample["question"],
            "positive": [sample["answer"]],
            "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
        }
        for sample in hard_eval_dataset
    ],
    batch_size=32,
    name="gooaq-dev",
)
# You can run evaluation like so
results = reranking_evaluator(model)
"""
CrossEncoderRerankingEvaluator: Evaluating the model on the gooaq-dev dataset:
Queries:  1000     Positives: Min 1.0, Mean 1.0, Max 1.0   Negatives: Min 49.0, Mean 49.1, Max 50.0
          Base  -> Reranked
MAP:      53.28 -> 67.28
MRR@10:   52.40 -> 66.65
NDCG@10:  59.12 -> 71.35
"""
# {'gooaq-dev_map': 0.6728370126462222, 'gooaq-dev_mrr@10': 0.6665190476190477, 'gooaq-dev_ndcg@10': 0.7135068904582963, 'gooaq-dev_base_map': 0.5327714512001362, 'gooaq-dev_base_mrr@10': 0.5239674603174603, 'gooaq-dev_base_ndcg@10': 0.5912299141913905}

训练器

CrossEncoderTrainer是所有先前组件的集合。我们只需指定训练器与模型、训练参数(可选)、训练数据集、评估数据集(可选)、损失函数、评估器(可选),然后就可以开始训练了。让我们看看一个所有这些组件都结合在一起的脚本:

import logging
import traceback

import torch
from datasets import load_dataset

from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
    CrossEncoder,
    CrossEncoderModelCardData,
    CrossEncoderTrainer,
    CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
    CrossEncoderNanoBEIREvaluator,
    CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses.BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
from sentence_transformers.evaluation.SequentialEvaluator import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "answerdotai/ModernBERT-base"

    train_batch_size = 16
    num_epochs = 1
    num_hard_negatives = 5  # How many hard negatives should be mined for each question-answer pair

    # 1a. Load a model to finetune with 1b. (Optional) model card data
    model = CrossEncoder(
        model_name,
        model_card_data=CrossEncoderModelCardData(
            language="en",
            license="apache-2.0",
            model_name="ModernBERT-base trained on GooAQ",
        ),
    )
    print("Model max length:", model.max_length)
    print("Model num labels:", model.num_labels)

    # 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
    logging.info("Read the gooaq training dataset")
    full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
    dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
    train_dataset = dataset_dict["train"]
    eval_dataset = dataset_dict["test"]
    logging.info(train_dataset)
    logging.info(eval_dataset)

    # 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
    embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
    hard_train_dataset = mine_hard_negatives(
        train_dataset,
        embedding_model,
        num_negatives=num_hard_negatives,  # How many negatives per question-answer pair
        margin=0,  # Similarity between query and negative samples should be x lower than query-positive similarity
        range_min=0,  # Skip the x most similar samples
        range_max=100,  # Consider only the x most similar samples
        sampling_strategy="top",  # Sample the top negatives from the range
        batch_size=4096,  # Use a batch size of 4096 for the embedding model
        output_format="labeled-pair",  # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
        use_faiss=True,
    )
    logging.info(hard_train_dataset)

    # 2c. (Optionally) Save the hard training dataset to disk
    # hard_train_dataset.save_to_disk("gooaq-hard-train")
    # Load again with:
    # hard_train_dataset = load_from_disk("gooaq-hard-train")

    # 3. Define our training loss.
    # pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
    loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))

    # 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
    nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
        dataset_names=["msmarco", "nfcorpus", "nq"],
        batch_size=train_batch_size,
    )

    # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
    # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
    # embedding model as a baseline.
    hard_eval_dataset = mine_hard_negatives(
        eval_dataset,
        embedding_model,
        corpus=full_dataset["answer"],  # Use the full dataset as the corpus
        num_negatives=30,  # How many documents to rerank
        batch_size=4096,
        include_positives=True,
        output_format="n-tuple",
        use_faiss=True,
    )
    logging.info(hard_eval_dataset)
    reranking_evaluator = CrossEncoderRerankingEvaluator(
        samples=[
            {
                "query": sample["question"],
                "positive": [sample["answer"]],
                "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
            }
            for sample in hard_eval_dataset
        ],
        batch_size=train_batch_size,
        name="gooaq-dev",
        always_rerank_positives=False,
    )

    # 4c. Combine the evaluators & run the base model on them
    evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
    evaluator(model)

    # 5. Define the training arguments
    short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
    run_name = f"reranker-{short_model_name}-gooaq-bce"
    args = CrossEncoderTrainingArguments(
        # Required parameter:
        output_dir=f"models/{run_name}",
        # Optional training parameters:
        num_train_epochs=num_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=train_batch_size,
        learning_rate=2e-5,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        dataloader_num_workers=4,
        load_best_model_at_end=True,
        metric_for_best_model="eval_gooaq-dev_ndcg@10",
        # Optional tracking/debugging parameters:
        eval_strategy="steps",
        eval_steps=4000,
        save_strategy="steps",
        save_steps=4000,
        save_total_limit=2,
        logging_steps=1000,
        logging_first_step=True,
        run_name=run_name,  # Will be used in W&B if `wandb` is installed
        seed=12,
    )

    # 6. Create the trainer & start training
    trainer = CrossEncoderTrainer(
        model=model,
        args=args,
        train_dataset=hard_train_dataset,
        loss=loss,
        evaluator=evaluator,
    )
    trainer.train()

    # 7. Evaluate the final model, useful to include these in the model card
    evaluator(model)

    # 8. Save the final model
    final_output_dir = f"models/{run_name}/final"
    model.save_pretrained(final_output_dir)

    # 9. (Optional) save the model to the Hugging Face Hub!
    # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
    try:
        model.push_to_hub(run_name)
    except Exception:
        logging.error(
            f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
            f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
            f"and saving it using `model.push_to_hub('{run_name}')`."
        )


if __name__ == "__main__":
    main()

在此示例中,我正在从answerdotai/ModernBERT-base进行微调,这是一个尚未成为交叉编码器模型的基础模型。这通常比微调现有重排器模型(如Alibaba-NLP/gte-multilingual-reranker-base)需要更多的训练数据。我使用了来自GooAQ数据集的99k个查询-答案对,之后我使用sentence-transformers/static-retrieval-mrl-en-v1嵌入模型挖掘难负例。这导致了578k个带标签的对:99k个正例对(即标签=1)和479k个负例对(即标签=0)。

我使用了BinaryCrossEntropyLoss,它非常适合这些带标签的对。我还设置了两种评估形式:CrossEncoderNanoBEIREvaluator用于评估NanoBEIR基准,以及CrossEncoderRerankingEvaluator用于评估上述静态嵌入模型对前30个结果进行重排的性能。之后,我定义了一组相当标准的超参数,包括学习率、预热比率、bf16、最后加载最佳模型以及一些调试参数。最后,我运行了训练器,进行了训练后评估,并将模型保存到本地和Hugging Face Hub。

运行此脚本后,tomaarsen/reranker-ModernBERT-base-gooaq-bce模型已为我上传。请参阅即将到来的评估部分,其中有证据表明该模型优于13种常用开源替代方案,包括更大的模型。我还使用answerdotai/ModernBERT-large作为基础模型运行了该模型,结果是tomaarsen/reranker-ModernBERT-large-gooaq-bce

评估结果会自动存储在生成的模型卡中,模型卡中包含基础模型、语言、许可证、评估结果、训练和评估数据集信息、超参数、训练日志等。无需任何额外工作,您上传的模型将包含您的潜在用户确定模型是否适合他们所需的所有信息。

回调

交叉编码器训练器支持各种transformers.TrainerCallback子类,包括:

  • 如果安装了wandb,则使用WandbCallback将训练指标记录到W&B。
  • 如果可以访问tensorboard,则使用TensorBoardCallback将训练指标记录到TensorBoard。
  • 如果安装了codecarbon,则使用CodeCarbonCallback跟踪训练期间的碳排放。

只要安装了所需的依赖项,这些功能就会自动使用,您无需进行任何指定。

有关这些回调以及如何创建自己的回调的更多信息,请参阅Transformers回调文档

多数据集训练

通常,表现最佳的通用模型是同时在多个数据集上训练的。然而,由于每个数据集的格式不同,这种方法可能具有挑战性。幸运的是,CrossEncoderTrainer允许您在多个数据集上进行训练,而无需统一格式。此外,它提供了为每个数据集应用不同损失函数的灵活性。以下是同时使用多个数据集进行训练的步骤:

  • 使用datasets.Dataset实例字典(或datasets.DatasetDict)作为train_dataset(以及可选的eval_dataset)。
  • (可选)使用损失函数字典,将数据集名称映射到损失。仅当您希望为不同数据集使用不同损失函数时才需要。

每个训练/评估批次将只包含来自一个数据集的样本。从多个数据集中采样批次的顺序由MultiDatasetBatchSamplers枚举定义,该枚举可以通过multi_dataset_batch_sampler传递给CrossEncoderTrainingArguments。有效选项包括:

  • MultiDatasetBatchSamplers.ROUND_ROBIN:从每个数据集循环采样,直到其中一个耗尽。使用此策略,可能不会使用每个数据集中的所有样本,但每个数据集的采样频率相同。
  • MultiDatasetBatchSamplers.PROPORTIONAL(默认):根据每个数据集的大小按比例采样。使用此策略,将使用每个数据集中的所有样本,并且从较大的数据集中采样的频率更高。

训练技巧

交叉编码器模型有其独特的特点,因此这里有一些技巧可以帮助您:

  1. 交叉编码器模型很容易过拟合,因此建议使用像CrossEncoderNanoBEIREvaluatorCrossEncoderRerankingEvaluator这样的评估器,并结合load_best_model_at_endmetric_for_best_model训练参数,以便在训练结束后加载具有最佳评估性能的模型。

  2. 交叉编码器对强硬负例(mine_hard_negatives)特别敏感。它们教导模型非常严格,例如在区分回答问题和与问题相关的段落时很有用。

    1. 请注意,如果您只使用难负例,您的模型在较简单任务上的表现可能会出人意料地变差。这可能意味着,对第一阶段检索系统(例如使用SentenceTransformer模型)检索到的前200个结果进行重排,实际上可能比重排前100个结果得到更差的前10个结果。同时使用随机负例和难负例进行训练可以缓解这种情况。
  3. 不要低估BinaryCrossEntropyLoss的力量,尽管它比学习排序(LambdaLossListNetLoss)或批内负例(CachedMultipleNegativesRankingLossMultipleNegativesRankingLoss)损失更简单,但它仍然是一个非常强大的选择,并且其数据易于准备,尤其是在使用mine_hard_negatives时。

评估

我对我模型在训练器部分中的重排评估,与GooAQ开发集上的几个基线进行了比较,重排评估器中同时使用了always_rerank_positives=Falsealways_rerank_positives=True。这分别代表了真实(仅重排检索器找到的内容)和评估(重排所有正例,即使检索器未找到)格式。

提醒一下,我使用了极其高效的sentence-transformers/static-retrieval-mrl-en-v1静态嵌入模型来检索前30个用于重排。

模型 模型参数 重排前30后GooAQ NDCG@10 重排前30+所有正例后GooAQ NDCG@10
无重排,仅检索器 - 59.12 59.12
cross-encoder/ms-marco-MiniLM-L6-v2 22.7M 69.56 72.09
jinaai/jina-reranker-v1-tiny-en 33M 66.83 69.54
jinaai/jina-reranker-v1-turbo-en 37.8M 72.01 76.10
jinaai/jina-reranker-v2-base-multilingual 278M 74.87 78.88
BAAI/bge-reranker-base 278M 70.98 74.31
BAAI/bge-reranker-large 560M 73.20 77.46
BAAI/bge-reranker-v2-m3 568M 73.56 77.55
mixedbread-ai/mxbai-rerank-xsmall-v1 70.8M 66.63 69.41
mixedbread-ai/mxbai-rerank-base-v1 184M 70.43 74.39
mixedbread-ai/mxbai-rerank-large-v1 435M 74.03 78.66
mixedbread-ai/mxbai-rerank-base-v2 494M 73.03 76.76
mixedbread-ai/mxbai-rerank-large-v2 1.54B 75.40 80.04
Alibaba-NLP/gte-reranker-modernbert-base 150M 73.18 77.49
tomaarsen/reranker-ModernBERT-base-gooaq-bce 150M 77.14 83.51
tomaarsen/reranker-ModernBERT-large-gooaq-bce 396M 79.42 85.81
点击查看评估脚本和数据集

这是评估脚本:

import logging
from pprint import pprint
from datasets import load_dataset

from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CrossEncoderRerankingEvaluator

# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


def main():
    model_name = "tomaarsen/reranker-ModernBERT-base-gooaq-bce"
    eval_batch_size = 64

    # 1. Load a model to evaluate
    model = CrossEncoder(model_name)

    # 2. Load the GooAQ dataset: https://huggingface.co/datasets/tomaarsen/gooaq-reranker-blogpost-datasets
    logging.info("Read the gooaq reranking dataset")
    hard_eval_dataset = load_dataset("tomaarsen/gooaq-reranker-blogpost-datasets", "rerank", split="eval")

    # 4. Create reranking evaluators. We use `always_rerank_positives=False` for a realistic evaluation
    # where only all top 30 documents are reranked, and `always_rerank_positives=True` for an evaluation
    # where the positive answer is always reranked as well.
    samples = [
        {
            "query": sample["question"],
            "positive": [sample["answer"]],
            "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
        }
        for sample in hard_eval_dataset
    ]
    reranking_evaluator = CrossEncoderRerankingEvaluator(
        samples=samples,
        batch_size=eval_batch_size,
        name="gooaq-dev-realistic",
        always_rerank_positives=False,
    )
    realistic_results = reranking_evaluator(model)
    pprint(realistic_results)

    reranking_evaluator = CrossEncoderRerankingEvaluator(
        samples=samples,
        batch_size=eval_batch_size,
        name="gooaq-dev-evaluation",
        always_rerank_positives=True,
    )
    evaluation_results = reranking_evaluator(model)
    pprint(evaluation_results)


if __name__ == "__main__":
    main()

它使用了我的tomaarsen/gooaq-reranker-blogpost-datasets数据集中的rerank子集。该数据集包含:

  • pair子集,train分割:99k个训练样本直接取自GooAQ。这不直接用于训练,而是用于准备hard-labeled-pair子集,后者用于训练。
  • pair子集,eval分割:1k个训练样本直接取自GooAQ,与之前的99k没有重叠。这不直接用于评估,而是用于准备rerank子集,后者用于评估。
  • hard-labeled-pair子集,train分割:578k个带标签的对用于训练,通过使用来自pair子集和train分割的99k个样本与sentence-transformers/static-retrieval-mrl-en-v1进行挖掘。该数据集用于训练。
  • rerank子集,eval分割:1k个样本,包含问题、答案以及由sentence-transformers/static-retrieval-mrl-en-v1使用我的GooAQ子集中完整的100k训练和评估答案检索到的30个文档。该排名已经具有59.12的NDCG@10。

Model size vs NDCG for Rerankers on GooAQ

仅使用gooaq数据集中300万训练对中的9.9万对,并在我的RTX 3090上仅训练30分钟,我的小型1.5亿参数tomaarsen/reranker-ModernBERT-base-gooaq-bce模型就轻松超越了所有小于10亿参数的通用重排器。更大的tomaarsen/reranker-ModernBERT-large-gooaq-bce训练时间不到一小时,并在实际设置中以高达79.42的NDCG@10独占鳌头。GooAQ训练和评估数据集与这些基线模型的训练目标非常吻合,因此在更小众的领域进行训练时,差异应该更大。

请注意,这并不意味着tomaarsen/reranker-ModernBERT-large-gooaq-bce是*所有*领域中最强的模型:它只是*我们*领域中最强的。这完全没有问题,因为我们只需要这个重排器在我们的数据上表现良好。

不要低估在您的领域中微调重排器模型的力量。通过微调(小型)重排器,您可以同时提高搜索性能和搜索堆栈的延迟!

附加资源

训练示例

这些页面包含带解释的训练示例以及训练脚本代码链接。您可以使用它们来熟悉重排器训练循环:

文档

如需进一步学习,您可能还希望探索Sentence Transformers上的以下资源:

这里有一个您可能感兴趣的高级页面:

社区

@tomaarsen ,请告知我们如何进行qlora-peft组合的微调?这将非常有帮助

注册登录发表评论