使用视觉Transformer进行知识蒸馏
我们将学习知识蒸馏,这是 distilGPT 和 distilbert 背后的方法,这两个模型是Hugging Face Hub 上下载次数最多的模型!
可能我们都遇到过这样的老师,他们“教学”的方式仅仅是提供正确的答案,然后测试我们以前从未见过的题目,这类似于机器学习模型的监督学习,我们提供标记数据集进行训练。然而,与其让模型在标签上进行训练,我们可以采用知识蒸馏作为一种替代方法,从而得到一个更小的模型,其性能可以与更大的模型相媲美,并且速度更快。
知识蒸馏的直觉
想象一下,你被问到这样一个选择题
如果有人只是告诉你,“答案是德拉科·马尔福”,那并不能让你对每个角色与哈利·波特的相对关系有太多了解。
另一方面,如果有人告诉你,“我非常确定不是罗恩·韦斯莱,我有点确定不是纳威·隆巴顿,我非常确定是德拉科·马尔福”,这给了你一些关于这些角色与哈利·波特关系的信息!这正是知识蒸馏范式下传递给我们的学生模型的信息类型。
在神经网络中蒸馏知识
在论文在神经网络中蒸馏知识中,Hinton 等人介绍了一种名为知识蒸馏的训练方法,其灵感来自昆虫。就像昆虫从幼虫过渡到成虫,成虫针对不同的任务进行了优化一样,大规模机器学习模型最初可能很笨拙,就像幼虫一样,用于从数据中提取结构,但可以将其知识蒸馏到更小、更高效的模型中进行部署。
知识蒸馏的本质是使用教师网络的预测 logits 将信息传递给更小、更高效的学生模型。我们通过重新编写损失函数使其包含蒸馏损失来实现这一点,这鼓励学生模型在输出空间上的分布近似于教师模型的分布。
蒸馏损失的公式为
KL 损失指的是教师和学生输出分布之间的Kullback-Leibler 散度。然后,学生模型的总体损失被公式化为该蒸馏损失与地面实况标签的标准交叉熵损失的总和。
要查看此损失函数在 Python 中的实现以及 Python 中的完整示例,让我们查看本节的笔记本。
< > 在 GitHub 上更新