概念
将大的教师模型的知识萃取出来,浓缩到一个小的学生模型,教师网络一般比较臃肿,学生网络较小,可以实现轻量化应用
轻量化方式
- 模型压缩: 知识蒸馏,权值量化、权重剪枝、通道剪枝、注意力迁移
- 轻量化模型: SqueezeNet、MobileNetv1v2v3、MnasNet、SHhffleNet、Xception、EfficientNet、EfficieentDet
- 加速卷积: im2col+GEMM、Wiongrad、低秩分解
- 硬件部署: TensorRT、Jetson、TensorFlow-Slim、openvino、FPGA集成电路
作用
精度与时延
通过教师、学生模型的训练,可以得到精度更高、时延更低的模型
标签域迁移
训练多个不同域的teacher模型,可以同时蒸馏得到一个多个域的student模型。实现不同域的标签、数据集成和迁移
分类
基于目标蒸馏(也称为Soft-target蒸馏或Logits方法蒸馏) 和 基于特征蒸馏
目标蒸馏
Student学习Teacher模型的泛化能力,可以直接通过softmax层的输出类别概率来作为“Soft-target”
soft target
“软标签”指的是大网络在每一层卷积后输出的特征映射。
Hard-target 和 Soft-target区别
Hard-target 求的是ground truth的极大似然(正标签为1,其他为0)
Soft-target 求的是类别的概率,正标签概率最高
软标签的方法可以让模型学习到概率本身携带的信息,例如类别之间有相似程度,分类结果有相互关系等,可以让结果更精确
即方差小,可以很快学习到Teacher模型的推理过程,信息量大,样本更少,学习率更大,泛化能力好
因为只学习结果(logits),所以称为logits蒸馏
特征蒸馏
不学习结果,而是学习中间层的特征,要求将Teacher的特征级知识迁移给Student
升温与损失
蒸馏过程存在一个关键参数,温度参数(Temperature)
教师的预测输出除以温度参数之后,做softmax可以得到软标签,温度参数越大,分布越缓和,越小错误分类的概率就越容易放大,引入噪声,对于困难的分类预测,通常取1
损失设计为软目标与硬目标对应的交叉熵加权平均(KD loss / CE loss)
软目标交叉熵的加权系数越大,表明迁移诱导越依赖教师网络的贡献,这对训练初期阶段是很有必要的,有助于让学生网络更轻松的鉴别简单样本,但训练后期需要适当减小软目标的比重,让真实标注帮助鉴别困难样本。另外,教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。
- 当想从负标签中学到一些信息量的时候,温度应调高一些;
- 当想减少负标签的干扰的时候,温度 应调低一些;
在分类网络中知识蒸馏的 Loss 计算
上部分教师网络,它进行预测的时候, softmax要进行升温,升温后的预测结果我们称为软标签(soft label) 学生网络一个分支softmax的时候也进行升温,在预测的时候得到软预测(soft predictions),然后对soft label和soft predictions 计算损失函数,称为distillation loss ,让学生网络的预测结果接近教师网络; 学生网络的另一个分支,在softmax的时候不进行升温T =1,此时预测的结果叫做hard prediction 。然后和hard label也就是 ground truth直接计算损失,称为student loss 。 总的损失结合了distilation loss和student loss ,并通过系数a加权,来平衡这两种Loss ,比如与教师网络通过MSE损失,学生网络与ground truth通过cross entropy损失, Loss的公式可表示如下: