SetFit 文档

知识蒸馏

Hugging Face's logo
加入 Hugging Face 社区

并获得增强的文档体验

开始使用

知识蒸馏

如果您可以访问未标记的数据,那么您可以使用知识蒸馏来提高您的小型 SetFit 模型的性能。 该方法涉及训练一个更大的模型,并使用未标记的数据将其性能提炼到您较小的 SetFit 模型中。 结果,您的 SetFit 模型将变得更强大。

此外,您还可以使用知识蒸馏来用更高效的模型替换您训练好的 SetFit 模型,同时性能下降较少。

本指南将向您展示如何进行知识蒸馏。

数据准备

让我们考虑一个只有少量标记训练数据(例如 64 个句子)的场景。 我们将使用 ag_news 数据集来模拟本指南中的场景。

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

基线模型

我们可以使用标准的 SetFit 训练方法来准备模型。

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)
***** Running training *****
  Num examples = 48
  Num epochs = 5
  Total optimization steps = 240
  Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}                                                                                  
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}                                                                                  
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}                                                                                 
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}                                                     
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:20<00:00, 11.97it/s] 
***** Running evaluation *****
{'accuracy': 0.7818421052631579}

该模型在我们的数据集上达到了 78.18% 的准确率。 考虑到极少量的训练数据,这当然是值得尊敬的,但我们可以使用知识蒸馏来从我们的模型中挤出更多的性能。

未标记数据准备

除了我们标记的训练数据外,我们也可能有很多未标记的训练数据(例如 500 个句子)。 让我们准备它

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

教师模型

然后,我们将准备一个更大的训练好的 SetFit 模型,它将充当我们较小的学生模型的教师。 强大的 sentence-transformers/paraphrase-mpnet-base-v2 Sentence Transformer 模型将用于初始化 SetFit 模型。

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

我们需要首先在标记数据集上训练此模型

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
  Num examples = 192
  Num epochs = 2
  Total optimization steps = 384
  Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}                                                                                 
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}                                                                                  
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}                                                                                 
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}                                                                                   
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}                                                                                  
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}                                                                                 
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}                                                      
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [01:24<00:00,  4.55it/s] 
***** Running evaluation *****
{'accuracy': 0.8378947368421052}

这个大型教师模型达到了 83.79% 的准确率,对于这么少的数据来说已经相当不错了,而且明显强于我们较小(但更高效)的模型的 78.18%。

知识蒸馏

可以使用 DistillationTrainer 将更强大的 teacher_model 的性能提炼到较小的模型中。 它接受教师模型和学生模型,以及未标记的数据集。

请注意,此 trainer 使用句子对作为训练样本,因此训练步骤的数量呈指数级增长到未标记示例的数量。 为了避免过拟合,请考虑将 max_steps 设置得相对较低。

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
  Num examples = 7829
  Num epochs = 1
  Total optimization steps = 7829
  Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}                                                                                   
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}                                                                                    
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}                                                                                  
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}                                                                                   
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}                                                                                  
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}                                                                                 
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}                                                                                  
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}                                                   
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 22.45it/s] 
***** Running evaluation *****
{'accuracy': 0.8084210526315789}

使用知识蒸馏,我们在几分钟的训练中就能够将我们的模型从 78.18% 提高到 80.84%。

端到端

此代码片段显示了端到端示例中的整个知识蒸馏策略

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
< > 在 GitHub 上更新