论文解读:在神经网络中提取知识(知识蒸馏)

摘要

提高几乎所有机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是,使用整个模型集合进行预测是很麻烦的,并且可能在计算上过于昂贵,无法部署到大量用户,特别是如果单个模型是大型神经网络。Caruana和他的合作者[1]已经证明,可以将集成中的知识压缩到一个更容易部署的单一模型中,并且我们使用不同的压缩技术进一步开发了这种方法。我们在MNIST上取得了一些令人惊讶的结果,并且我们表明,通过将模型集合中的知识提取到单个模型中,我们可以显着改善大量使用的商业系统的声学模型。我们还介绍了一种由一个或多个完整模型和许多专业模型组成的新型集成,这些模型学习区分完整模型所混淆的细粒度类。与混合专家不同,这些专家模型可以快速并行地训练。

介绍

许多昆虫都有一个幼虫形态,它最擅长从环境中获取能量和营养,而一个完全不同的成虫形态,它最擅长于旅行和繁殖的不同需求。在大规模机器学习中,我们通常在训练阶段和部署阶段使用非常相似的模型,尽管它们的要求非常不同:对于语音和对象识别等任务,训练必须从非常大的、高度冗余的数据集中提取结构,但它不需要实时操作,它可以使用大量的计算量。

然而,部署到大量用户时,对延迟和计算资源的要求要严格得多。与昆虫的类比表明,如果能更容易地从数据中提取结构,我们应该愿意训练非常繁琐的模型。繁琐的模型可以是单独训练的模型的集合,也可以是使用dropout等非常强的正则化器训练的单个非常大的模型[9]。一旦繁琐的模型得到训练,我们就可以使用另一种训练,我们称之为“蒸馏”,将知识从繁琐的模型转移到更适合部署的小模型中。这种策略的一个版本已经由Rich Caruana和他的合作者开创[1]。在他们的重要论文中,他们令人信服地证明了由大量模型集合获得的知识可以转移到单个小模型中。

一个概念上的障碍可能阻碍了对这种非常有前途的方法进行更多的研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何在保持相同知识的情况下改变模型的形式。一种更抽象的知识观点,将其从任何特定实例中解放出来,即它是一种可学习的从输入向量到输出向量的映射。对于学习区分大量类别的繁琐模型,正常的训练目标是最大化正确答案的平均对数概率,但学习的副作用是训练模型为所有不正确答案分配概率,即使这些概率非常小,其中一些概率也比其他概率大得多。错误答案的相对概率告诉我们很多关于这个繁琐的模型是如何泛化的。例如,宝马的图像可能只有很小的机会被误认为是垃圾车,但这种错误仍然比将其误认为胡萝卜的可能性高很多倍。

人们普遍认为,用于训练的目标函数应该尽可能地反映用户的真实目标。尽管如此,当真正的目标是很好地泛化到新数据时,通常训练模型来优化训练数据的性能。显然,训练模型更好地泛化,但这需要关于正确泛化方法的信息,而这些信息通常是不可用的。然而,当我们将知识从一个大模型提炼成一个小模型时,我们可以训练小模型以与大模型相同的方式进行泛化。如果繁琐的模型泛化得很好,例如,它是不同模型的大集合的平均值,那么以相同方式训练泛化的小模型通常会比在用于训练集合的相同训练集上以正常方式训练的小模型在测试数据上做得更好。

将繁琐模型的泛化能力转移到小模型上的一个显而易见的方法是将繁琐模型产生的类概率作为训练小模型的“软目标”。对于这个迁移阶段,我们可以使用相同的训练集或单独的“迁移”集

当繁琐的模型是由许多简单模型组成的大集合时,我们可以使用单个预测分布的算术或几何平均值作为软目标。当软目标具有高熵时,每个训练案例提供的信息量比硬目标大得多,训练案例之间的梯度方差也小得多,因此小模型通常可以在比原始繁琐模型少得多的数据上进行训练,并使用更高的学习率。

对于像MNIST这样的任务,繁琐的模型几乎总是产生非常高置信度的正确答案,关于学习函数的大部分信息存在于软目标中非常小的概率比率中。例如,一个版本的2可能有10^{-6}的概率是3,
10^{-9}的概率是7,而另一个版本可能是相反的。这是有价值的信息,它定义了数据的丰富相似性结构(即,它表示哪些2看起来像3,哪些看起来像7),但它在传递阶段对交叉熵成本函数的影响很小,因为概率非常接近于零。

Caruana和他的合作者通过使用logits(最终softmax的输入)而不是softmax产生的概率作为学习小模型的目标来规避这个问题,他们最小化了繁琐模型产生的logits和小模型产生的logits之间的平方差。我们更通用的解决方案,称为“蒸馏”,是提高最终软最大值的温度,直到繁琐的模型产生合适的软目标集。然后,我们在训练小模型时使用相同的高温来匹配这些软目标。稍后我们将说明,匹配繁琐模型的对数实际上是蒸馏的一种特殊情况。

用于训练小模型的转移集可以完全由未标记的数据组成[1],或者我们可以使用原始训练集。我们发现,使用原始的训练集效果很好,特别是如果我们在目标函数中添加一个小项,可以鼓励小模型预测真实目标,并匹配繁琐模型提供的软目标。

通常,小模型不能完全匹配软目标,在正确答案的方向上犯错误是有帮助的。

蒸馏

神经网络通常通过使用“softmax”输出层来产生类概率,该输出层通过将z_i与其他logit进行比较,将为每个类计算的logitz_i转换为概率q_i

T是温度,通常设为1。使用更高的T值会产生更柔和的类概率分布。

在最简单的蒸馏形式中,通过在转移集上训练知识,并对转移集中的每个情况使用软目标分布,将知识转移到蒸馏模型中,该转移集中使用具有高温软最大值的繁琐模型产生的软目标分布。在训练蒸馏模型时使用相同的高温,但在训练完成后,它使用的温度为1。

当所有或部分转移集的正确标签已知时,可以通过训练蒸馏模型来产生正确的标签来显著改进该方法。一种方法是使用正确的标签来修改软目标,但我们发现更好的方法是简单地使用两个不同目标函数的加权平均值。第一个目标函数是与软目标的交叉熵,该交叉熵是在蒸馏模型的软最大值中使用与从繁琐模型生成软目标相同的高温来计算的。第二个目标函数是带有正确标签的交叉熵。这是在蒸馏模型的softmax中使用完全相同的logits计算的,但温度为1。我们发现,在第二个目标函数上使用相当低的权重通常可以获得最佳结果。由于软目标产生的梯度大小为\frac{1}{T^2},因此在使用硬目标和软目标时,将它们乘以T^2是很重要的。这确保了在使用元参数进行实验时,如果用于蒸馏的温度发生变化,则硬目标和软目标的相对贡献大致保持不变。

匹配逻辑是蒸馏的一种特殊情况

传递集中的每一种情况都相对于蒸馏模型的每一个logitz_i贡献了一个交叉熵梯度\frac{dC}{dz_i}。如果繁琐模型的logits v_i产生软目标概率p_i,并且迁移训练在温度T下进行,则该梯度为:

如果温度比对数的大小高,我们可以近似:

如果我们现在假设每个转移情况的对数分别为零,那么\sum_j{z_j}=\sum_j{v_j}=0, Eq. 3化简为:

因此,在高温极限下,蒸馏相当于最小化1/2(z_i-v_i)^2,前提是每个转移情况的对数分别为零。在较低的温度下,蒸馏很少注意匹配比平均值负得多的对数。这是潜在的优势,因为这些逻辑几乎完全不受用于训练繁琐模型的成本函数的约束,因此它们可能非常嘈杂。另一方面,非常负的对数可以传达关于繁琐模型所获得的知识的有用信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获繁琐模型中的所有知识时,中间温度效果最好,这强烈表明忽略大的负对数可能是有帮助的。

MNIST的初步实验

为了了解蒸馏的效果如何,我们在所有60,000个训练案例上训练了一个具有两个隐藏层的大型神经网络,包含1200个整流线性隐藏单元。使用dropout和权重约束对网络进行强正则化,如[5]所述。Dropout可以看作是一种训练共享权重的指数级大模型集合的方法。此外,输入图像为在任何方向上抖动最多两个像素。该网络实现了67个测试误差,而具有两个隐藏层(800个校正线性隐藏单元)且未进行正则化的较小网络实现了146个误差。但是,如果仅仅通过增加在20℃温度下匹配大网产生的软目标的额外任务来正则化较小的网,它将获得74个测试误差。这表明软目标可以将大量知识转移到蒸馏模型中,包括从翻译的训练数据中学习到的关于如何泛化的知识,即使转移集不包含任何翻译。

当蒸馏网的两个隐藏层中每层都有300个或更多的单位时,所有温度高于8度的结果都相当相似。但是,当从根本上减少到每层30个单位时,温度范围在2.5到4之间的效果明显优于更高或更低的温度。

然后,我们尝试从转移集中删除数字3的所有示例。因此,从蒸馏模型的角度来看,3是一个从未见过的神话数字。尽管如此,蒸馏模型只产生206个测试错误,其中133个是在测试集中的1010个3上。大多数错误是由于3类的学习偏差太低造成的。如果这个偏差增加3.5(优化测试集上的整体性能),则蒸馏模型产生109个错误,其中14个是3。因此,在正确的偏差下,尽管在训练中从未看到过3,但蒸馏模型的测试3的正确率为98.6%。如果迁移集只包含训练集中的7和8,则蒸馏模型的测试误差为47.3%,但当将7和8的偏差减少7.6以优化测试性能时,测试误差降至13.2%。

语音识别实验

在本节中,我们研究了用于自动语音识别(ASR)的集成深度神经网络(DNN)声学模型的效果。我们表明,我们在本文中提出的蒸馏策略达到了将模型集合提炼成单个模型的预期效果,该模型的工作效果明显优于直接从相同训练数据中学习的相同大小的模型。

最先进的ASR系统目前使用dnn将从波形中导出的特征(短)时间上下文映射到隐马尔可夫模型(HMM)离散状态上的概率分布[4]。更具体地说,DNN每次在三部手机状态的集群上产生一个概率分布,然后解码器在HMM状态中找到一条路径,这是使用高概率状态和产生在语言模型下可能的转录之间的最佳折衷。

虽然有可能(也是可取的)以这样一种方式训练DNN,即通过边缘化所有可能的路径来考虑解码器(以及语言模型),但通常训练DNN通过(局部)最小化网络预测与每个观测值的基本真实状态序列强制对齐所给出的标签之间的交叉熵来执行逐帧分类:

其中θ是声学模型P的参数,该模型将时间t,s_t的声学观测映射到概率P(h_t|s_t;\theta'),“正确”HMM状态h_t,这是通过与正确的单词序列强制对齐来确定的。采用分布式随机梯度下降法对模型进行训练。

我们使用具有8个隐藏层的架构,每个隐藏层包含2560个整流线性单元和一个具有14,000个标签(HMM目标h_t)的最终softmax层。输入是26帧40 mel比例滤波器组系数,每帧提前10ms,我们预测了第21帧的HMM状态。参数总数约为85M。这是Android语音搜索使用的声学模型的一个稍微过时的版本,应该被认为是一个非常强大的基线。为了训练DNN声学模型,我们使用了大约2000小时的英语口语数据,这产生了大约700M个训练样本。该系统在我们的开发集上实现了58.P(h_t|s_t;\theta')9%的帧准确率和10.9%的单词错误率。

表1:框架分类精度和WER,显示蒸馏的单个模型的性能与用于创建软目标的10个模型的平均预测一样好。

结果

我们训练了10个独立的模型来预测P(h_t|s_t;\theta'),使用完全相同的架构和训练程序作为基线。用不同的初始参数值随机初始化模型,我们发现这在训练模型中产生了足够的多样性,使得集合的平均预测明显优于单个模型。我们已经探索了通过改变每个模型看到的数据集来增加模型的多样性,但是我们发现这不会显著改变我们的结果,所以我们选择了更简单的方法。对于蒸馏,我们尝试了[1,2,5,10]的温度,并对硬目标的交叉熵使用了0.5的相对权重,其中粗体表示表1中使用的最佳值。

表1显示,实际上,我们的蒸馏方法能够从训练集中提取更多有用的信息,而不是简单地使用硬标签来训练单个模型。使用10个模型的集合所获得的帧分类精度提高的80%以上被转移到蒸馏模型上,这与我们在MNIST上的初步实验中观察到的改进相似。由于目标函数不匹配,集成对WER的最终目标(在23k个单词的测试集上)给出了较小的改进,但同样,集成实现的WER改进被转移到蒸馏模型上。

我们最近意识到通过匹配已经训练好的大型模型的类概率来学习小型声学模型的相关工作[8]。然而,他们使用大型未标记数据集在温度为1的情况下进行蒸馏,他们的最佳蒸馏模型仅将小模型的错误率降低了28%,当它们都使用硬标签训练时,大小模型的错误率之间的差距。

在非常大的数据集上训练专家团队

训练模型集合是利用并行计算的一种非常简单的方法,并且通常认为在测试时集合需要太多计算的反对意见可以通过使用蒸馏来处理。然而,对于集成还有另一个重要的反对意见:如果单个模型是大型神经网络,并且数据集非常大,那么即使很容易并行化,训练时所需的计算量也会过多。

在本节中,我们给出了这样一个数据集的示例,并展示了如何学习专家模型(每个模型专注于类的不同可混淆子集)可以减少学习集成所需的总计算量。专注于细粒度区分的专家的主要问题是他们很容易过拟合,我们描述了如何通过使用软目标来防止这种过拟合。

JFT数据集

JFT是一个内部的谷歌数据集,它有1亿个带有15000个标签的带标签的图像。当我们做这项工作时,Google的JFT基线模型是一个深度卷积神经网络[7],它已经在大量核心上使用异步随机梯度下降进行了大约六个月的训练。该训练使用了两种类型的并行性[2]。首先,神经网络的许多副本运行在不同的核心集上,处理来自训练集的不同小批量。每个副本计算当前mini-batch上的平均梯度,并将该梯度发送给分片参数服务器,该服务器将返回参数的新值。这些新值反映了自上次向副本发送参数以来参数服务器接收到的所有梯度。其次,通过在每个核心上放置不同的神经元子集,每个副本分布在多个核心上。集成训练是第三种可以包装的并行性其他两种类型,但前提是有更多的内核可用。等待几年时间来训练一个模型集合是不可能的,所以我们需要一种更快的方法来改进基线模型。

表2:由我们的协方差矩阵聚类算法计算的聚类的示例类

专家模式

当类的数量非常大时,将麻烦的模型作为一个集成是有意义的,它包含一个在所有数据上训练的通才模型和许多“专家”模型,每个模型都是在高度丰富的数据上训练的,这些数据来自一个非常容易混淆的类子集(如不同类型的蘑菇)。通过将所有它不关心的类组合到一个垃圾箱类中,可以使这种类型的专家的softmax更小。

为了减少过拟合并分担学习低级特征检测器的工作,每个专家模型都使用通才模型的权重初始化。然后,通过训练专家,将其一半的样本来自其特殊子集,另一半从训练集的剩余部分随机抽样,对这些权重进行稍微修改。训练后,我们可以通过将垃圾箱类的logit增加到专家类的过采样比例的log来纠正有偏差的训练集。

为专家分配课程

为了为专家导出对象类别的分组,我们决定将重点放在我们整个网络经常混淆的类别上。尽管我们可以计算混淆矩阵并将其用作查找此类聚类的方法,但我们选择了一种更简单的方法,该方法不需要真实标签来构建聚类。

特别是,我们将聚类算法应用于通才模型预测的协方差矩阵,以便经常一起预测的一组类S^m

将用作我们的一个专家模型m的目标。我们对协方差矩阵的列应用了K-means算法的在线版本,并获得了合理的聚类(见表2)。我们尝试了几种聚类算法,产生了类似的结果。

与专家团队一起进行推理

在研究当专家模型被提取时会发生什么之前,我们想看看包含专家的集成执行得有多好。除了专家模型,我们总是有一个通才模型,这样我们就可以处理没有专家的类,这样我们就可以决定使用哪个专家。给定一个输入图像x,我们分两步进行top-1分类:

步骤1:对于每个测试用例,我们根据通才模型找到n个最可能的类。

称这组类为k。在我们的实验中,我们使用n = 1。

步骤2:然后我们取所有的专家模型m,其可混淆类的特殊子集S m与k有一个非空相交,并将其称为专家的活动集A_k(注意该集可能是空的)。然后求出所有类的完整概率分布q。

式中KL表示KL散度,p^mp^g表示专家模型或通才全模型的概率分布。分布p^m是m的所有专业类加上一个垃圾桶类的分布,所以当计算它的KL散度时从满q分布我们把满q分布分配给m的垃圾桶中所有类的所有概率相加。

表4:JFT测试集上覆盖正确类别的专家模型的准确率提高排名前1。

Eq. 5没有一般的封闭形式解,尽管当所有模型为每个类别产生单个概率时,解是算术或几何平均值,这取决于我们是使用KL(p, q)还是KL(q, p))。我们参数化q = softmax (z)(其中T = 1),并使用梯度下降来优化logits z w.r.t. eq. 5。请注意,必须对每个图像执行此优化。

结果

从经过训练的基线完整网络开始,专家们的训练速度非常快(几天而不是JFT的几周)。此外,所有的专家都是完全独立训练的。表3显示了基线系统和结合专家模型的基线系统的绝对测试精度。有了61个专业模型,总体测试精度相对提高了4.4%。我们还报告了条件测试精度,这是只考虑属于专家类的示例的精度,并将我们的预测限制在类的子集上。

对于JFT专家实验,我们训练了61个专家模型,每个模型有300个类(加上垃圾箱类)。因为专家的类集合不是不相交的,我们经常有多个专家覆盖一个特定的图像类。表4显示了测试集示例的数量,使用专家时在位置1正确的示例数量的变化,以及JFT数据集的top1精度的相对改进百分比,这些百分比按涵盖类的专家数量进行细分。由于训练独立的专家模型非常容易并行化,所以当我们有更多的专家覆盖一个特定的类时,准确度的提高会更大,这一普遍趋势让我们感到鼓舞。

作为正则化器的软目标

我们关于使用软目标而不是硬目标的一个主要主张是,软目标可以携带许多有用的信息,这些信息不可能被单个硬目标编码。在本节中,我们通过使用更少的数据来拟合前面描述的基线语音模型的85M参数来证明这是一个非常大的影响。表5显示,仅使用3%的数据(约20M例),使用硬目标训练基线模型会导致严重的过拟合(我们提前停止,因为准确率在达到44.5%后急剧下降),而使用软目标训练的相同模型能够恢复整个训练集中几乎所有的信息(约2%)。更值得注意的是,我们不需要提前停止:具有软目标的系统简单地“收敛”到57%。这表明,软目标是一种非常有效的方式,可以将在所有数据上训练的模型发现的规律传递给另一个模型。

表5:软目标允许新模型仅从训练集的3%进行良好的泛化。通过对完整训练集的训练得到软目标。

使用软目标来防止专家过度拟合

我们在JFT数据集上的实验中使用的专家将他们所有的非专业类折叠到一个垃圾箱类中。如果我们允许专家对所有类别都有一个完整的软最大值,可能会有比使用早期停止更好的方法来防止他们过拟合。专家接受的是在其特殊课程中高度丰富的数据培训。这意味着它的训练集的有效大小要小得多,并且在它的特殊类上有很强的过拟合倾向。这个问题不能通过把专家类做得更小来解决,因为那样我们就失去了从对所有非专业类建模中获得的非常有用的转移效应。

我们使用3%的语音数据进行的实验强烈表明,如果用通才的权重初始化一个专家,除了用硬目标训练它之外,我们还可以用软目标训练它,使它保留几乎所有关于非特殊类的知识。软目标可以由通才提供。我们目前正在探索这种方法。

与专家混合的关系

使用在数据子集上训练的专家与使用门控网络来计算将每个示例分配给每个专家的概率的专家混合[6]有一些相似之处。在专家学习处理分配给他们的示例的同时,门控网络正在学习根据专家对该示例的相对判别性能选择将每个示例分配给哪些专家。使用专家的判别性能来确定学习分配比简单地将输入向量聚类并为每个聚类分配专家要好得多,但它使训练难以并行化:首先,每个专家的加权训练集以一种依赖于所有其他专家的方式不断变化;其次,门控网络需要比较不同专家在同一示例上的表现,以知道如何修改其分配概率。这些困难意味着专家组合很少被用于可能最有益的领域:包含明显不同子集的庞大数据集的任务。

对多个专家进行并行培训要容易得多。我们首先训练一个通才模型,然后使用混淆矩阵来定义训练专家的子集。一旦定义了这些子集,专家就可以完全独立地进行训练。在测试时,我们可以使用通才模型的预测来决定哪些专家是相关的,只有这些专家需要运行。

讨论

我们已经证明,对于将知识从集成或从大型高度正则化模型转移到较小的蒸馏模型,蒸馏工作非常有效。在MNIST上,即使用于训练蒸馏模型的转移集缺乏一个或多个类的任何示例,蒸馏也能非常好地工作。对于一个深度声学模型(Android语音搜索使用的模型),我们已经证明,几乎所有通过训练深度神经网络集合实现的改进都可以被提炼成一个相同大小的神经网络,这更容易部署。

对于真正的大神经网络,甚至训练一个完整的集合都是不可行的,但我们已经证明,通过学习大量的专业网络,可以显著提高训练了很长时间的单个真正的大网络的性能,每个专业网络都学会了在高度混淆的集群中区分类别。我们还没有证明,我们可以把专家们的知识提炼成一张大网。

论文下载(NIPS 2014)

https://arxiv.org/abs/1503.02531

📎Distilling the Knowledge in a Neural Network.pdf

相关推荐

  1. 神经网络的先验知识

    2024-02-18 07:28:02       21 阅读
  2. yolov5知识蒸馏

    2024-02-18 07:28:02       34 阅读
  3. 简单的知识蒸馏

    2024-02-18 07:28:02       12 阅读
  4. 知识蒸馏——讨论区

    2024-02-18 07:28:02       7 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-02-18 07:28:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-02-18 07:28:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-02-18 07:28:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-02-18 07:28:02       20 阅读

热门阅读

  1. redis的hash数据结构底层简记

    2024-02-18 07:28:02       30 阅读
  2. 机器人学环境配置(VM-16 + Ubuntu-20.04 + ROS-noetic)

    2024-02-18 07:28:02       35 阅读
  3. Unity-HDRP-Sense-4

    2024-02-18 07:28:02       28 阅读
  4. 力扣代码学习日记三

    2024-02-18 07:28:02       31 阅读
  5. 配置Vite+React+TS项目

    2024-02-18 07:28:02       26 阅读
  6. Docker 第十五章 : Docker 三剑客之 Compose(一)

    2024-02-18 07:28:02       26 阅读