社区计算机视觉课程文档
视觉Transformer的知识蒸馏
并获得增强的文档体验
开始使用
视觉Transformer的知识蒸馏
我们将学习知识蒸馏,这是distilGPT和distilbert这两种Hugging Face Hub上下载量最高的模型背后的方法!
大概我们都遇到过这样的老师,他们“教书”的方式就是直接告诉我们正确答案,然后用我们没见过的问题来考我们,这类似于机器学习模型的监督学习,我们提供一个带标签的数据集进行训练。然而,我们不让模型在标签上训练,而是可以采用知识蒸馏作为替代方案,以获得一个性能与大型模型相当且速度快得多的更小型模型。
知识蒸馏背后的直觉
想象一下,你被问到这个多项选择题
如果有人直接告诉你:“答案是德拉科·马尔福,”这并不能让你对每个角色与哈利·波特的关系了解很多。
另一方面,如果有人告诉你:“我非常确定不是罗恩·韦斯莱,我有点确定不是纳威·隆巴顿,我非常确定是德拉科·马尔福,”这会给你一些关于这些角色与哈利·波特关系的信息!这正是知识蒸馏范式下传递给学生模型的信息类型。
神经网络中的知识蒸馏
在论文Distilling the Knowledge in a Neural Network中,Hinton等人受昆虫的启发,引入了知识蒸馏的训练方法。正如昆虫从幼虫形态过渡到为不同任务优化的成年形态一样,大型机器学习模型最初可能像幼虫一样笨重,难以从数据中提取结构,但它们可以将知识蒸馏到更小、更高效的模型中进行部署。
知识蒸馏的精髓是使用教师网络的预测逻辑来将信息传递给更小、更高效的学生模型。我们通过重新编写损失函数来实现这一点,使其包含一个蒸馏损失,该损失鼓励学生模型在输出空间上的分布近似于教师模型。
蒸馏损失的公式为
KL损失指的是教师和学生输出分布之间的Kullback-Leibler散度。学生模型的总损失被公式化为这个蒸馏损失与地面真实标签上的标准交叉熵损失之和。
要查看此损失函数在Python中的实现以及Python中的完整示例,请查看本节的notebook。
为边缘设备利用知识蒸馏
知识蒸馏在AI模型部署到边缘设备上时变得越来越重要。部署一个大型模型(例如大小为1GB、延迟为1秒的模型)对于实时应用来说是不切实际的,因为它需要大量的计算和内存。这些限制主要归因于模型的尺寸。因此,该领域已经接受了知识蒸馏,这项技术可以将模型参数减少90%以上,同时性能下降最小。
知识蒸馏的后果(好与坏)
1. 熵增益
在信息论中,熵类似于物理学中的熵,它衡量系统内部的“混沌”或无序程度。在我们的场景中,它量化了分布所包含的信息量。考虑以下示例
- 哪一个更难记住:
[0, 1, 0, 0]
还是[0.2, 0.5, 0.2, 0.1]
?
第一个向量[0, 1, 0, 0]
更容易记住和压缩,因为它包含的信息更少。这可以表示为第二个位置的“1”。另一方面,[0.2, 0.5, 0.2, 0.1]
包含更多的信息。在此基础上,例如,我们用ImageNet训练了一个80M参数的网络,然后将其蒸馏(如前所述)成一个5M参数的学生模型。我们会发现教师模型输出中包含的熵远低于学生模型。这意味着即使学生模型的输出是正确的,它也比教师模型的输出更混乱。这归结为一个简单的事实:教师模型的额外参数有助于它更容易区分类别,因为它提取了更多的特征。这种关于知识蒸馏的观点非常有趣,并且正在积极研究中,以通过将其用作损失函数或应用受物理学启发的类似度量(例如能量)来减少学生的熵。
2. 连贯的梯度更新
模型通过最小化损失函数并通过梯度下降更新其参数来迭代学习。考虑一组参数P = {w1, w2, w3, ..., wn}
,它们在教师模型中的作用是在检测到A类样本时激活。如果一个模糊的样本类似于A类但属于B类,模型在错误分类后会进行激进的梯度更新,导致不稳定。相比之下,蒸馏过程(使用教师模型的软目标)促进了训练过程中更稳定和连贯的梯度更新,从而使学生模型的学习过程更平滑。
3. 在无标签数据上训练的能力
教师模型使得学生模型能够在无标签数据上进行训练。教师模型可以为这些无标签样本生成伪标签,学生模型随后可以使用这些伪标签进行训练。这种方法显著增加了可用的训练数据量。
4. 视角转变
深度学习模型通常是基于这样的假设进行训练的:提供足够的数据将使它们能够近似一个准确表示底层现象的函数F
。然而,在许多情况下,数据稀缺使得这一假设不切实际。传统方法涉及构建更大的模型并迭代微调以实现最佳结果。相比之下,知识蒸馏改变了这种视角:鉴于我们已经有了一个训练良好的教师模型F
,目标是使用一个更小的模型f
来近似F
。