李宏毅机器学习课程笔记-13.2模型压缩之知识蒸馏

知识蒸馏就是Knowledge Distillation。

Knowledge Distillation:https://arxiv.org/abs/1503.02531

Do Deep Nets Really Need to be Deep?:https://arxiv.org/abs/1312.6184

熟悉YOLO的读者,可以根据这个仓库感受一下剪枝和知识蒸馏:https://github.com/tanluren/yolov3-channel-and-layer-pruning

Student and Teacher

什么是Knowledge Distillation?

我们可以让一个较小的Student Net向较大的Teacher Net学习,使得Student Net的输出尽可能接近Teacher Net的输出。

普通的训练方式为仅在数据集上训练Student Net,而Knowledge Distillation的思路是:即使Teacher Net的输出并不一定是正确的,但Teacher Net可以提供一些数据集无法提供的信息,比如手写数字图片分类模型Teacher Net的输出为“1:0.7,7:0.2,9:0.1”,这不仅说明这张图片像1,还可以说明1和7、9很相似

或者可以这么理解:学生自己直接做题目太难了,让学生学习下老师是怎么想的可能会更好。

Ensemble

Knowledge Distillation有什么用呢?

打Kaggle比赛时很多人的做法是ensemble(将多个model的结果进行平均)。Ensemble通常可以得到更好的精度,但现实中设备上不可能放这么多个model,这时就可以利用Knowledge Distillation让Student Net向Teacher Net学习,最终设备上只运行Student Net就可以。

Temperature

在分类任务中,网络的最后一般有个softmax函数:$y_i=\frac{e^{x_i}}{\sum_je^{x^j}}$,其中$y_i$是输入属于类别$i$的置信度。

在Knowledge Distillation中,我们需要对softmax函数进行调整:$y_i=\frac{e^{\frac{x_i}{T}}}{\sum_je^{\frac{x^j}{T}}}$,其中$T$为Temperature,一般是一个大于1的数,它的作用是使得Teacher Net输出的属于各个类别的置信度更加接近,如下图所示。

img


Github(github.com):@chouxianyu

Github Pages(github.io):@臭咸鱼

知乎(zhihu.com):@臭咸鱼

博客园(cnblogs.com):@臭咸鱼

B站(bilibili.com):@绝版臭咸鱼

微信公众号:@臭咸鱼

转载请注明出处,欢迎讨论和交流!