SetFit 文档

知识蒸馏

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

知识蒸馏

如果您有未标记的数据,则可以使用知识蒸馏来提高小型 SetFit 模型的性能。该方法涉及训练一个更大的模型,并使用未标记的数据将其性能蒸馏到您的小型 SetFit 模型中。因此,您的 SetFit 模型将变得更强大。

此外,您还可以使用知识蒸馏来用更高效的模型替换训练过的 SetFit 模型,同时减少性能下降。

本指南将向您展示如何进行知识蒸馏。

数据准备

让我们考虑一个场景,只有少量标记训练数据(例如 64 个句子)。本指南将使用 ag_news 数据集来模拟此场景。

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

基线模型

我们可以使用标准的 SetFit 训练方法来准备模型。

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)
***** Running training *****
  Num examples = 48
  Num epochs = 5
  Total optimization steps = 240
  Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}                                                                                  
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}                                                                                  
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}                                                                                 
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}                                                     
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:20<00:00, 11.97it/s] 
***** Running evaluation *****
{'accuracy': 0.7818421052631579}

该模型在我们的数据集上达到了 78.18% 的准确率。考虑到微小的训练数据量,这当然是值得称赞的,但我们可以使用知识蒸馏从我们的模型中榨取更多的性能。

未标记数据准备

除了标记训练数据之外,我们可能还有大量未标记的训练数据(例如 500 个句子)。让我们准备一下

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

教师模型

然后,我们将准备一个更大的训练好的 SetFit 模型,作为我们较小学生模型的教师。强大的 sentence-transformers/paraphrase-mpnet-base-v2 句子转换器模型将用于初始化 SetFit 模型。

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

我们首先需要在标记数据集上训练此模型

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
  Num examples = 192
  Num epochs = 2
  Total optimization steps = 384
  Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}                                                                                 
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}                                                                                  
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}                                                                                 
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}                                                                                   
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}                                                                                  
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}                                                                                 
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}                                                      
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [01:24<00:00,  4.55it/s] 
***** Running evaluation *****
{'accuracy': 0.8378947368421052}

这个大型教师模型达到了 83.79% 的准确率,对于这么少的数据来说,这相当强大,并且明显强于我们更小(但更高效)的模型获得的 78.18%。

知识蒸馏

可以使用 DistillationTrainer 将更强大的 teacher_model 的性能蒸馏到较小的模型中。它接受一个教师模型和一个学生模型,以及一个未标记数据集。

请注意,此训练器使用句子之间的配对作为训练样本,因此训练步骤的数量会随着未标记样本的数量呈指数级增长。为避免过拟合,请考虑将 max_steps 设置得相对较低。

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
  Num examples = 7829
  Num epochs = 1
  Total optimization steps = 7829
  Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}                                                                                   
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}                                                                                    
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}                                                                                  
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}                                                                                   
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}                                                                                  
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}                                                                                 
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}                                                                                  
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}                                                   
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 22.45it/s] 
***** Running evaluation *****
{'accuracy': 0.8084210526315789}

通过知识蒸馏,我们能够在几分钟的训练时间内将模型的性能从 78.18% 提高到 80.84%。

端到端

此代码片段展示了端到端知识蒸馏策略的完整示例

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
< > 在 GitHub 上更新