SetFit 文档
知识蒸馏
并获得增强的文档体验
开始使用
知识蒸馏
如果您可以访问未标记的数据,那么您可以使用知识蒸馏来提高您的小型 SetFit 模型的性能。 该方法涉及训练一个更大的模型,并使用未标记的数据将其性能提炼到您较小的 SetFit 模型中。 结果,您的 SetFit 模型将变得更强大。
此外,您还可以使用知识蒸馏来用更高效的模型替换您训练好的 SetFit 模型,同时性能下降较少。
本指南将向您展示如何进行知识蒸馏。
数据准备
让我们考虑一个只有少量标记训练数据(例如 64 个句子)的场景。 我们将使用 ag_news 数据集来模拟本指南中的场景。
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
基线模型
我们可以使用标准的 SetFit 训练方法来准备模型。
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
***** Running training *****
Num examples = 48
Num epochs = 5
Total optimization steps = 240
Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:20<00:00, 11.97it/s]
***** Running evaluation *****
{'accuracy': 0.7818421052631579}
该模型在我们的数据集上达到了 78.18% 的准确率。 考虑到极少量的训练数据,这当然是值得尊敬的,但我们可以使用知识蒸馏来从我们的模型中挤出更多的性能。
未标记数据准备
除了我们标记的训练数据外,我们也可能有很多未标记的训练数据(例如 500 个句子)。 让我们准备它
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
教师模型
然后,我们将准备一个更大的训练好的 SetFit 模型,它将充当我们较小的学生模型的教师。 强大的 sentence-transformers/paraphrase-mpnet-base-v2
Sentence Transformer 模型将用于初始化 SetFit 模型。
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
我们需要首先在标记数据集上训练此模型
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
Num examples = 192
Num epochs = 2
Total optimization steps = 384
Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [01:24<00:00, 4.55it/s]
***** Running evaluation *****
{'accuracy': 0.8378947368421052}
这个大型教师模型达到了 83.79% 的准确率,对于这么少的数据来说已经相当不错了,而且明显强于我们较小(但更高效)的模型的 78.18%。
知识蒸馏
可以使用 DistillationTrainer 将更强大的 teacher_model 的性能提炼到较小的模型中。 它接受教师模型和学生模型,以及未标记的数据集。
请注意,此 trainer 使用句子对作为训练样本,因此训练步骤的数量呈指数级增长到未标记示例的数量。 为了避免过拟合,请考虑将 max_steps
设置得相对较低。
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
Num examples = 7829
Num epochs = 1
Total optimization steps = 7829
Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 22.45it/s]
***** Running evaluation *****
{'accuracy': 0.8084210526315789}
使用知识蒸馏,我们在几分钟的训练中就能够将我们的模型从 78.18% 提高到 80.84%。
端到端
此代码片段显示了端到端示例中的整个知识蒸馏策略
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)