NBDT: Neural-Backed Decision Trees
简介
论文标题
- NBDT: Neural-Backed Decision Trees
- NBDT:神经支持决策树
- 2020.1
贡献
- 我们提出了一种将任何分类神经网络作为决策树运行的方法,方法是定义一组嵌入的决策规则,这些规则可以从完全连通层构造出来。我们还设计了易于神经网络学习的诱导层次结构。
- 我们提出了树监督损失,它使神经网络的准确率提高了0.5个百分点,并产生了高精度的NBDT。我们在小型、中型和大型图像分类数据集上证明了我们的NBDT达到了与神经网络相当的精度。
- 我们为我们的模型决策提供了语义解释的定性和定量证据。
该工具可以直接在以下地址在线使用:
- Demo:http://nbdt.alvinwan.com/demo/
- Colab:http://nbdt.alvinwan.com/notebook/
- git : https://github.com/alvinwan/neural-backed-decision-trees
- 论文:https://arxiv.org/abs/2004.00221
论文中先讲的推理后讲的训练
摘要
深度学习正被用于需要准确和合理的预测的环境中,从金融到医学成像。虽然最近有为模型预测提供事后解释的工作,但探索更直接可解释的模型以匹配最先进的准确性的工作相对较少。从历史上看,决策树一直是平衡可解释性和准确性的黄金标准。然而,最近将决策树与深度学习相结合的尝试导致了以下模型:(1)即使在较小的数据集(例如MNIST)上,实现的精度也远远低于现代神经网络(例如ResNet),以及(2)需要显著不同的体系结构,迫使实践者在准确性和可解释性之间做出选择。我们通过创建神经支持的决策树(NBDs)来摆脱这一困境,它(1)实现了神经网络的准确性,(2)不需要对神经网络的体系结构进行任何改变。使用最新的WideResNet,NBDT在CIFAR10、CIFAR100、TinyImageNet上的基本神经网络的精确度在1%以内;在ImageNet上,NBDT在EfficientNet上的精确度在2%以内。这在ImageNet上产生了最先进的可解释模型,NBDT将基准∼提高了14%到75.30%TOP-1准确率。此外,我们还通过半自动过程展示了我们模型决策的定性和定量可解释性。代码和预先培训的NBDT可以在github.com/alvinwan/neuralbacked-decision-trees.上找到。
初步
在这项工作中,我们提出了神经支持决策树(NBDT)来使最先进的计算机视觉模型具有可解释性。这些NBDT不需要特殊的体系结构:任何用于图像分类的神经网络都可以通过微调和自定义损失转换成NBDT。此外,NBDT通过将图像分类分解成中间决策序列来执行推理。然后,这个决策序列可以映射到更多可解释的概念,并在底层类中揭示可感知的信息层次结构。关键的是,与计算机视觉中关于决策树的先前工作相比,NBDT在CIFAR10[18]、CIFAR100[18]、TinyImageNet[19]和ImageNet[8]上的最新结果具有竞争力,并且比基于可比决策树的方法精确度大大提高(高达18%),同时也更易于解释。
我们介绍了一种针对NBDT的两阶段训练过程。首先,我们计算一个层次,称为诱导层次(图1,步骤1)。该层次结构是从已经在目标数据集上训练的神经网络的权重导出的。其次,我们使用专门为该树设计的自定义损失来微调网络,称为树监督损失(图1,步骤2)。这种损失迫使模型在给定固定的树层次结构的情况下最大化决策树的准确性。然后,我们分两步进行推理:(1)使用网络主干(图1,步骤3)为每幅训练图像构造特征.然后,对于每个节点,我们在给定决策树层次结构的情况下,在网络的权重空间中计算最能代表其子树中的叶子的向量-我们将该向量称为代表性向量。(2)从根节点开始,将每个样本发送给与样本具有最相似代表向量的子节点。我们继续采摘和遍历这棵树,直到我们到达一片树叶。与此叶相关的类是我们的预测(图1,步骤4)。这与引入可解释性障碍的相关工作形成对比,例如不纯树叶[16]或模型集合[1,16]
图1:神经支持诊断树。在步骤1中,我们使用预先训练的网络的完全连接层权重来构建层次结构(第3.2节)。在步骤2中,我们使用自定义损耗(秒)微调网络。3.3)。在步骤3中,我们使用神经网络主干对样本进行特征化。在步骤4中,我们使用完全连接层的权重来运行决策规则(SEC。3.1)。如上所示,步骤3中的橙色箭头与步骤4中的树的橙色节点相关联。同样,绿色箭头映射到绿色节点。该树获取传入样本与橙色w1和绿色w2矢量中的每一个之间的内积;预测具有较高内积的叶子。
如何构建树,如何训练树,如何选择树分类
相关工作
从决策树到神经网络。最近的工作还用决策树[13]提供的权重播种神经网络,重新引起了人们对基于梯度的方法[29]的兴趣。这些方法在UCI数据集[9]上显示了非常稀疏特征和稀疏样本的经验证据。
神经网络到决策树 最近的工作[10]使用蒸馏,训练决策树来模拟神经网络的输入输出函数。所有这些工作都是在简单的数据集(如UCI[9]或MNIST[20])上进行评估,而我们的方法是在更复杂的数据集(如CIF
AR10[18]、CIF AR100[18]和TinyImageNet[19])上进行评估。
将神经网络与决策树相结合。最近的工作是将神经网络与决策树相结合,将推理扩展到有很多高维样本数据集。深度神经决策森林[16]的性能与ImageNet上的神经网络相匹配。然而,这发生在剩余网络开始之前,通过使用不纯的树叶和需要森林来牺牲模型的可解释性。Murthy等人。[23]提出为决策树中的每个节点建立一个新的神经网络,并给出可解释的输出。艾哈迈德等人。1通过在所有节点之间共享主干来修改这一点,但仅支持深度-2树;NofE认为ImageNet的性能与ResNet之前的架构相媲美。我们的方法进一步建立在此基础上,不仅共享主干,而且共享完全连接层;此外,我们在保持可解释性的同时,还显示了与最先进的神经网络(包括残差网络)的竞争性能。
一些工作没有明确地将神经网络和决策树相结合,而是从决策树中借鉴了神经网络的思想,反之亦然。特别地,几种重新设计的神经网络结构利用决策树分支结构[35,21,34]。虽然精确度提高了,但这种方法牺牲了决策树的可解释性。其他人使用决策树来分析神经网络权重[39,24]。这会带来相反的后果,要么牺牲准确性,要么不支持预测机制。正如我们假设和展示的那样,高精度的决策树对于解释和解释高精度的模型是必要的。此外,我们具有竞争力的性能表明,不需要牺牲准确性和可解释性。
视觉解释。一个正交但占主导地位的可解释性方向包括生成显著图,该图突出神经网络决策所使用的空间证据[30,37,28,38,27,26,25,31]。诸如引导反向传播[30]、去卷积[37,28]、GradCAM[27]和积分梯度[31]之类的白盒技术使用网络的梯度来确定图像中最显著的区域,而诸如LIME[26]和RISE[25]之类的黑盒技术通过扰动输入并测量预测中的变化来确定像素重要性。显著图只解释单个图像,当网络出于错误的原因(例如,一只鸟被错误地归类为飞机)查看正确的东西时,它是没有帮助的。另一方面,我们的方法在整个数据集上表示模型的先验,并显式地将每个分类分解为一系列中间决策。
方法
在本节中,我们描述了将任何分类神经网络转换为决策树的建议步骤,如图1所示:
(1)建立诱导层次结构(SEC。3.2)、
(2)使用树监督损失(SEC)对模型进行微调。3.3)。
对于推理,(3)使用神经网络主干对样本进行特征化,
以及(4)运行嵌入在完全连接层(SEC)中的决策规则。3.1)。
使用嵌入式决策规则进行推理
首先,我们的NBDT方法使用神经网络主干对每个样本进行特征化;主干由最终完全连接层之前的所有神经网络层组成。
其次,在每个节点,我们取特征化样本x∈Rd与每个子节点的代表向量ri之间的内积。请注意,所有代表向量都是从神经网络的完全连通层权重计算出来的。因此,这些决策规则被“嵌入”在神经网络中。
第三,我们使用这些内积来做出硬决策或软决策,如下所述。
为了激励我们为什么使用内积,我们将首先构建一个等价于完全连接层的退化决策树
完全连接层
全连接层的权重矩阵为 。用特征化样运行推理是矩阵向量的乘积:
其中, ,最大的就是预测的
决策树
考虑一棵最小树,它有一个根节点和k个子节点。每个子节点是叶,并且每个子节点具有代表向量,即行向量 。用特征化样本x运行推断意味着取x和每个子节点的代表向量ri之间的内积,其被写为 。与全连接层一样,最大乘积 的指标也是我们的类预测。图2(B.)说明了这一点。
尽管这两个计算的表示方式不同,但都是通过取最大内积 的索引来预测类。我们将决策树推理称为运行嵌入式决策规则。
接下来,我们将朴素决策树扩展到退化情况之外。我们的判决规则要求每个子节点具有代表性向量ri。因此,如果我们将一个非叶子子代添加到根,那么这个非叶子子代将需要一个代表性向量。我们天真地认为非叶的代表向量是所有子树的叶的代表向量的平均值。对于包含中间节点的更复杂的树结构,现在有两种方式来运行推理:
- 硬决策树。计算所有子节点上每个节点的argmax。对于每个节点,获取与最大内积对应的子节点,并遍历该子节点。这个过程选择一片叶子(图2,A.硬)。
- 软决策树。在每个节点上计算所有子节点上的Softmax,以获得每个节点上每个子节点的概率。对于每个叶,获取从其父级遍历该叶的概率。然后取遍历树叶的父代与其祖辈的概率。继续直到到达根部。这个乘积是那片叶子及其到根部的路径的概率。树遍历将为每个叶生成一个概率。在这个叶子分布上计算argmax,以选择一个叶子(图2,C.Soft)。
这允许我们将任何分类神经网络作为嵌入的决策规则序列来运行。然而,以这种方式简单地运行标准问题的预先训练的神经网络将导致较差的精度。在下一节中,我们将讨论如何通过微调神经网络使其在确定层次结构后执行良好,从而最大限度地提高精度。
笔记:
主干部分: 完全连接层之前的所有神经网络层
完全连接层用于拆成决策树
每个节点都有对应的一行向量,其中
叶子节点的特征向量对应原有的权重矩阵中的一行向量
非叶子节对应其子树叶子的所有特征向量的平均值,
树的结构是通过层次聚类或者是wordnet预定义层次结构而来的
如何选择分支是通过取最大内积 的方式来的,称为嵌入式决策规则
主要是讲如何通过内积选择分支
构建诱导层次结构
使用上述内积决策规则,网络可以直观地更容易地学习决策树层次结构。这些更容易的层次结构可以更准确地反映网络是如何达到高精度的。为此,我们对从完全连接的层权重W提取的类代表W运行分层聚集聚类,如上一节所述。3.1,每个叶是一个Wi(图3,步骤B),并且每个中间节点的代表向量是其子树叶子的所有代表的平均值(图3,步骤C)。我们把这个层次称为诱导层次(图3)。
此外,我们还使用另一种基于WordNet的层次结构进行了实验。Wordnet[22]提供了一个现有的名词层次结构,我们利用它在语言上将每个数据集中的类联系起来。我们找到了WordNet层次结构的最小子集,其中包括所有类作为叶子,修剪冗余的叶子和单子中间节点。因此,WordNet关系为该候选决策树提供了“自由”和可解释的标签,例如将一只猫也归类为哺乳动物和生物。为了利用这个“自由”的标签源,我们通过找到每个子树叶子的最早祖先,为诱导层次结构中的每个中间节点自动生成假设。
图3:构建诱导层次结构。步骤A,将预先训练好的神经网络最终的全连通层的权值加载到权重矩阵 ;步骤B,以W为代表的每一列作为每个叶节点的代表向量。例如,A中的红色w1被指定给B中的红色叶子。步骤C使用每对叶子的平均值作为父代的代表向量。例如,B中的w1和w2(红色和紫色)平均为C中的w5(蓝色)。步骤D。对于每个祖先,取其根所在的子树。子树中所有树叶的平均表示向量。这个平均值是祖先的代表性矢量。在这个图中,祖先是根,所以它的代表向量是所有叶子w1、w2、w3、w4的平均值。
这幅图说明了如何构建一棵树,注意构建树之前一般需要预训练的权重
用树木监督损失进行训练
上面提出的所有决策树都有一个主要问题:即使鼓励原始神经网络为每一类分离代表向量,但它没有被训练为为每个内部节点分离代表向量。图4说明了这一点。为了解决这个问题,我们添加了损失项,鼓励神经网络在训练期间分离内部节点的代表。现在我们依次解释硬决策规则和软决策规则的附加损失条款(图5)。
图4:病理性诊断树。在地块中,一簇点用绿色圆圈标记,另一簇用黄色标记。每个圆的中心由它的两个灰点的平均值给出。在每个绘图的右侧绘制相应的决策树。答:一旦给出一个点,决策树的根将计算具有最接近代表向量(绿色或黄色的点)的子节点。请注意,类4(红色)的所有样本将比正确的父级(绿色)更接近错误的父级(黄色)。这是因为A试图用4聚类2,用3聚1。因此,神经网络很难获得高精度,因为它需要大幅移动所有的点来区分黄色和绿色的点。B:对于相同的点,这棵树将1与2聚为一簇,而将3与4聚为一簇,从而产生更多可分离的簇。请注意,B中的决策边界(虚线)相对于绿点和黄点的边距要大得多。因此,对于神经网络来说,这棵树更容易对点进行正确分类。
图A:红色点是绿色类,但他更接近黄色点的圆心而不是绿色
图B:这种情况下划分好,不容易出错
由于直接使用平均值作为特征向量有误差,所以这里添加了损失函数
对于硬决策规则,我们使用硬树监督损失。原始神经网络的损失
最小化了跨类的交叉熵。对于k类数据集,这是k路交叉熵损失。每个内部节点的目标是相似的,最小化跨子节点的交叉熵损失。对于具有c个子节点的节点i,这是预测概率
和标签
之间的c路交叉熵损失。我们将这组新的损失术语称为硬树监督损失(等式2)。默认情况下,每个节点的单个交叉熵损失被缩放,使得原始交叉熵损失和树监督损失被相等地加权。我们在SEC中测试了各种权重方案。4.2.。如果我们假设树中有N个节点(不包括树叶),那么我们将有N+1个不同的交叉熵损失项-原始的交叉熵损失项和N个硬树监督损失项。这是
,其中:
对于软决策规则,我们使用软树监督损失。在3.1节中,我们描述了软决策树如何在树叶上提供单一分布Dpred。我们在这个分布上增加了交叉熵损失。总共有2个不同的交叉熵损失项-原始交叉熵损失项和软树监督损失项。这是
,其中:
硬决策树因为每个节点都是一个概率分布,所以每个节点都算一次交叉熵
软决策树因为整体服从一个概率分布,所以只在叶子上算一次交叉熵
图5:树监督损失有两种变体:硬树监督损失(A)定义了每个节点的交叉熵项。蓝色节点的蓝色框和橙色节点的橙色框说明了这一点。交叉熵被取而代之的是子节点概率。绿色节点是标签叶。虚线节点不包括在从标签到根的路径中,因此没有定义的损耗。软树监督损失(B)定义了所有叶概率上的交叉熵损失。绿叶的概率是通向根部的概率的乘积(在本例中, )。其他树叶的概率被类似地定义。每个叶概率用一个彩色方框表示。然后在该叶概率分布上计算交叉熵,该分布由坐在彼此直接相邻的彩色框表示。
实验结果分析
精度表格略
节点语义的可解释性
由于诱导层次是使用模型权重构建的,因此不会强制对特定属性进行拆分。虽然像WordNet这样的层次结构为节点的含义提供了假设,但图6显示WordNet是不够的,因为树可能会在上下文属性(如水下和陆地)上分裂。为了诊断节点含义,我们执行以下4步测试:
图6:使用(A)WordNet层次结构和(B)来自训练有素的ResNet10模型的诱导树,对TinyImageNet中的10个类进行树可视化。
- 假设节点的含义(例如,动物与车辆)。这一假设可以从给定的分类法(如WordNet)自动计算出来,也可以从手动检查每个孩子的叶子中推导出来(图7)。
- 收集一个数据集,其中包含测试步骤1中节点的假设含义的新的、看不见的类(例如,Elephant是一种看不见的动物)。此数据集中的样本称为分布外样本,因为它们是从单独标记的数据集中提取的。
- 将此数据集中的样本传递给相关节点。对于每个样本,检查所选子节点是否与假设一致。
- 假设的准确性是传递给正确孩子的样本的百分比。如果精确度较低,请使用不同的假设重复。
这个过程自动验证WordNet假设,但WordNet之外的假设需要人工干预。图7a描述了由在CIFAR10上训练的WideResNet28x10模型诱导的CIFAR10树。我们的假设是,根节点在Animal和Vehicle上分裂。我们从CIF
AR100收集在培训时间看不到的动物和车辆类的分发外图像。然后我们计算假设的准确性。图7b显示了我们的假设准确地预测了每个看不见的类的样本遍历的是哪个子类。
注意,诱导出来的树形结构需要检查或自动推导每个非叶子节点的含义
避免准确性与可解释性的权衡
入的分层结构在权重空间中对群集向量进行分类,但是在权重空间中接近的类可能不具有类似的语义含义:图8分别描绘了由WideResNet20x10和ResNet10诱导的树。虽然WideResNet诱导的层次结构(图8A)对语义相似的类进行分组,但ResNet(图8B)诱导的层次结构不是这样,将青蛙、猫和飞机等类分组。WideResNet的准确率提高了4%,这解释了语义意义上的差异:我们认为,准确度越高的模型在语义上表现出更多的声音权重空间。因此,与以前的工作不同的是,NBDT的特点是更好的可解释性和更高的准确性,而不是牺牲一个来换取另一个。此外,层次结构中的差异表明,低精度、可解释的模型不能提供对高精度决策的洞察力;需要可解释的、最先进的模型来解释最先进的神经网络。
左边的精度高,解释性也好,用来证明精度和解释性是可以相辅相成的
值得注意的是,在具有 10 个类(如 CIFAR10)的小型数据集中,研究者可以找到所有节点的 WordNet 假设。但是,在具有 1000 个类别的大型数据集(即 ImageNet)中,则只能找到节点子集中的 WordNet 假设。
小型分类上可以用wordnet去解释节点,但大型分类网络则不能全部都找到wordnet上的解释
树遍历的可视化实现
为了不仅解释树层次结构,也解释树遍历,我们可视化了通过每个节点的样本的百分比(图9)。这既突出显示了正确的路径(最频繁遍历的路径),又允许我们解释常见的错误路径(图9A)。具体地说,我们可以解释遍历节点的叶子之间共享的属性。这些属性可以是背景或场景,但也可以是颜色或形状。图9B描绘了描述上下文的样本的路径。在这种情况下,很少有动物在海滨环境中被认出,而船只几乎总是在那个环境中被看到。图9C描述了属于不符合假设节点的属性但保持路径一致性的非分布类的样本的路径。在这种情况下,泰迪倾向于动物类,特别是狗,因为它有相似的形状和视觉特征。
图9:三个不同类的路径遍历频率的可视化。(A)分配类:马使用在训练中发现的类样本。中间节点的假设含义来自WordNet。(B)背景类:海边使用训练时看不到的样本,表明对背景的依赖。©混淆类:Teddy使用在节点含义中识别边缘情况的样本。
况下,泰迪倾向于动物类,特别是狗,因为它有相似的形状和视觉特征。
[外链图片转存中…(img-zbTn0SCl-1588410850893)]
图9:三个不同类的路径遍历频率的可视化。(A)分配类:马使用在训练中发现的类样本。中间节点的假设含义来自WordNet。(B)背景类:海边使用训练时看不到的样本,表明对背景的依赖。©混淆类:Teddy使用在节点含义中识别边缘情况的样本。
https://mp.weixin.qq.com/s/WrfLMXfgFbk_SaMvy2pweg
转载:https://blog.csdn.net/qq_33935895/article/details/105892655