概念

将大的教师模型的知识萃取出来,浓缩到一个小的学生模型,教师网络一般比较臃肿,学生网络较小,可以实现轻量化应用

轻量化方式

  1. 模型压缩: 知识蒸馏,权值量化、权重剪枝、通道剪枝、注意力迁移
  2. 轻量化模型: SqueezeNet、MobileNetv1v2v3、MnasNet、SHhffleNet、Xception、EfficientNet、EfficieentDet
  3. 加速卷积: im2col+GEMM、Wiongrad、低秩分解
  4. 硬件部署: TensorRT、Jetson、TensorFlow-Slim、openvino、FPGA集成电路

作用

精度与时延

通过教师、学生模型的训练,可以得到精度更高、时延更低的模型

标签域迁移

训练多个不同域的teacher模型,可以同时蒸馏得到一个多个域的student模型。实现不同域的标签、数据集成和迁移

分类

基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏) 和 基于特征蒸馏

目标蒸馏

Student学习Teacher模型的泛化能力,可以直接通过softmax层的输出类别概率来作为“Soft-target”

soft target“软标签”指的是大网络在每一层卷积后输出的特征映射。

Hard-target 和 Soft-target区别

Hard-target 求的是ground truth的极大似然(正标签为1,其他为0)

Soft-target 求的是类别的概率,正标签概率最高

软标签的方法可以让模型学习到概率本身携带的信息,例如类别之间有相似程度,分类结果有相互关系等,可以让结果更精确

即方差小,可以很快学习到Teacher模型的推理过程,信息量大,样本更少,学习率更大,泛化能力好

因为只学习结果(logits),所以称为logits蒸馏

特征蒸馏

不学习结果,而是学习中间层的特征,要求将Teacher的特征级知识迁移给Student

升温与损失

蒸馏过程存在一个关键参数,温度参数(Temperature)

教师的预测输出除以温度参数之后,做softmax可以得到软标签,温度参数越大,分布越缓和,越小错误分类的概率就越容易放大,引入噪声,对于困难的分类预测,通常取1

损失设计为软目标与硬目标对应的交叉熵加权平均(KD loss / CE loss)

软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。

  • 当想从负标签中学到一些信息量的时候,温度应调高一些;
  • 当想减少负标签的干扰的时候,温度 应调低一些;

在分类网络中知识蒸馏的 Loss 计算

上部分教师网络,它进行预测的时候, softmax要进行升温,升温后的预测结果我们称为软标签(soft label) 学生网络一个分支softmax的时候也进行升温,在预测的时候得到软预测(soft predictions),然后对soft label和soft predictions 计算损失函数,称为distillation loss ,让学生网络的预测结果接近教师网络; 学生网络的另一个分支,在softmax的时候不进行升温T =1,此时预测的结果叫做hard prediction 。然后和hard label也就是 ground truth直接计算损失,称为student loss 。 总的损失结合了distilation loss和student loss ,并通过系数a加权,来平衡这两种Loss ,比如与教师网络通过MSE损失,学生网络与ground truth通过cross entropy损失, Loss的公式可表示如下:

参考自:一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理_知识蒸馏算法-CSDN博客