小言_互联网的博客

最优传输论文(九):Multi-source Domain Adaptation via WJDOT论文原理

472人阅读  评论(0)

前言

  • 本文属于我最优传输系列里的九篇,该专栏用于记录本人研究生阶段相关最优传输论文的原理阐述以及复现工作。
  • 本专栏的文章主要内容为解释原理,论文具体的翻译及复现代码在文章的github中。

原理阐述

介绍

  • 本文发表于2020年的arXiv。
  • 本文的作者从一个崭新的角度来试图完成多源域的领域自适应。作者不是在源域和目标域之间寻找潜在的表示不变量(representation invariant),而是利用了源域分布的多样性,根据手头的任务来调整其权重。
  • 作者的模型:Weighted Joint Distribution Optimal Transport (WJDOT),旨在同时找到一个基于最优传输的源和目标分配之间的对齐和源域的新的权重。
  • 在多源域领域自适应(MSDA)方面,涌现了很多解决办法。例如,[13,14]提供了如何使用代理度量(如假设的准确性)组合多个来源预测器的理论保证。该方法在假设目标分布可以写成源分布的凸组合的前提下,可以实现目标域上的低误差预测。其他最近的方法[15,16,17]寻找一个唯一的假设,使其误差在所有源域上的凸组合最小化,并提供该假设在目标域上的误差的理论边界。这些保证通常涉及到一些取决于每个源分布和目标分布之间的距离的术语,建议使用对抗学习[16,18,19]或时刻匹配[15]找到源和目标之间特征分布尽可能接近的嵌入。然而,当源/目标边界之间的距离如图1所示很小时,就不可能找到保持嵌入判别的方法,其中源之间的旋转阻止了[20]中理论上的不变嵌入的存在。
  • 而在本文,作者的角度与上述常见方法不同。作者没有寻找所有源分布都与目标分布相似的潜在表示,而是拥抱源分布的多样性,寻找与目标距离最小的源的联合分布的凸组合(a convex combination of the joint distribution of sources with minimal distance to the target one,),而不是参考一个代理度量(proxy measure),比如源预测器的准确性。在推导了包含该距离的目标的新泛化界之后,作者提出优化Wasserstein距离,该距离定义在特征/标签产品空间上,类似于[10]中提出的,但在目标域和标记源的加权和之间。
  • 作者的方法的一个独特的特点是,权值是与分类函数同时学习的,这使得我们可以根据源与目标的相似度,在特征和输出空间中分配质量(mass,其实就是根据相似度设置概率向量)。
  • 所谓凸组合:设向量 { x i } , i = 1 , … , n \{x_i\},i=1,…,n { xi},i=1,,n,如果有实数 λ i > = 0 λ_i>=0 λi>=0,且 ∑ i = 1 n λ i = 1 \displaystyle\sum^n_{i=1}λ_i=1 i=1nλi=1,则称 ∑ i = 1 n λ i x i \displaystyle\sum^n_{i=1}λ_ix_i i=1nλixi为向量 { x i } \{x_i\} { xi}的一个凸组合。
  • 符号定义: S S S设置为源域的数量,每个源域都有对应的特征和标签。假设存在一个可微的嵌入函数(differentiable embedding function): g : X → G g: X → G gXG G G G是嵌入空间(embedding space),假设所有的输入分布都在这个嵌入空间。假设 p s p_s ps是源域 s s s的真实分布, p T p^{T} pT是目标域的真实分布,被包含在乘积空间 G × Y G×Y G×Y,其中 Y Y Y是标签空间。在实践中我们只获得源域有限数量 { N s } S = 1 S \{N_s\}^S_{S = 1} { Ns}S=1S的样本来获得经验源域分布 p s ′ = ( 1 / N s ) ∑ i = 1 N s δ g ( x s i ) , y s i p'_s = (1/Ns) \displaystyle\sum_{i = 1}^{N_s}δ_{g (x^i_s),y^i_s} ps=(1/Ns)i=1Nsδg(xsi),ysi。在目标域我们只能使用特征空间中有限数量的未标记样本得到 µ ′ = ( 1 / N ) ∑ i = 1 N δ g ( x i ) µ' = (1/N) \displaystyle\sum_{i = 1}^{N}δ_{g(x^i)} µ=(1/N)i=1Nδg(xi),即经验目标域边缘分布。给定一个损失函数 L L L和一个联合分布 p p p,函数 f f f的期望损失定义为 ε p ( f ) = E ( x , y ) ∞ p [ L ( y , f ( x ) ] ε_p(f) = E_{(x,y)∞p}[L(y,f(x)] εp(f)=E(x,y)p[L(y,f(x)]

Optimal Transport and Domain Adaptation

  • 在本节中,作者回顾了最优运输问题和Wasserstein距离的概念,在其方法中起着核心作用。然后,讨论了如何在联合配送最优运输(JDOT)公式中利用它们进行域调整,这将是本方法的核心。
  • OT问题的回顾我就不提了,每一篇论文都先回顾一下。
  • 至于JDOT的话,我还是再重述一遍吧。
  • Joint Distribution Optimal Transport (JDOT): OT问题的下式:

    可以通过考虑联合概率分布,而不是边缘概率分布来求得。即变成下述问题:

    然而目标域的标签是不可获取的,所以使用预测值来表示,从而计算经验概率分布,即
    p ′ f = ( 1 / N ) ∑ i = 1 N δ g ( x i ) , f ( g ( x i ) ) p'^f=(1/N)\displaystyle\sum^N_{i=1}δ_{g(x^i),f(g(x^i))} pf=(1/N)i=1Nδg(xi),f(g(xi))
    所以最终的优化目标如下:

    其中,
    D ( ) = β ∣ ∣ g ( x 1 ) − g ( x 2 ) ∣ ∣ 2 + L ( y 1 , f ( g ( x 2 ) ) ) D()=β||g(x_1) − g(x_2)||^2+L(y_1,f(g(x_2))) D()=βg(x1)g(x2)2+L(y1,f(g(x2))),其中 L L L是损失函数,β是调整feature loss和label loss的参数。随后将其扩展到深度学习框架,在该框架中,通过[11](该论文也十分经典,讲解:https://blog.csdn.net/qq_41076797/article/details/116698770)中的一个有效的随机优化过程,同时估计嵌入g和分类器f。
  • 作者提到,优化问题其实是涉及到联合嵌入和标签分布(joint embedding/label distribution)的,而现在的大多数DA方法往往只考虑边缘分布。

Multi-source DA with Weighted JDOT (WJDOT)

  • 在本节中,作者提出了一个基于源分布权重的MSDA问题的新泛化界。然后,介绍了WJDOT优化问题,并提出了一种求解该问题的算法。最后,讨论了WJDOT与现有方法之间的关系。

Generalization bound for multi-source DA

  • 领域适应的理论极限得到了很好的研究和理解,因为[27]的工作提供了一个“不可能定理”,表明如果目标分布与源分布太不同,适应是不可能的。然而,在MSDA的情况下,可以利用源域的多样性,只使用接近目标分布的源,从而获得更好的泛化界限。为此目的,在ML[13]中已经考虑了一个相关的假设,即假设目标分布是源分布的凸组合。下面的引理说明了这种方法的合理性。

    其中 h h h是从输入到结果的映射(在此处作者用符号 f f f表示), H H H h h h(即 f f f)的集合,有
    其中 f D f_D fD是真实的映射,缩写:

    其中 B B B是泛化界generalization bound, p α = ∑ s = 1 S α s p s p^α=\displaystyle\sum^S_{s=1}α_sp_s pα=s=1Sαsps是对 S S S个概率分布的耦合,其中 α s α_s αs是凸组合, D T V ( p α , p T ) D_{TV}(p^α,p^T) DTV(pα,pT)是两个概率分布的距离。
    再重新看引理下的公式:

    作者提到,该公式告诉我们,对于目标域泛化界问题的关键,就是找到一个低误差的映射 f f f,另外对于目标域的耦合也应该与目标域“相近”,此处作者提到的相近的意思就是耦合的效果更好,更近于目标域的分布。该原则同样适用于单源域问题。
  • 从引理1的结果中得到启发,作者提出了一个从多个领域来源学习的理论基础框架。作者的方法基于这样一种思想,即可以通过使用假设标签函数 f f f来弥补目标标签的不足,该函数提供了一个联合分布 p f ( 2 ) p^f(2) pf(2),其中 f f f被搜索,以便将 p f p^f pf与源分布的加权组合对齐。在此基础上,我们引入了以下泛化界。
  • 定理1:

    H H H是映射函数集合,然后假设输入空间 ∀ f ∈ H , ∣ f ( x ) − f ( x ′ ) ∣ ≤ M ∀f ∈ H,|f(x) − f(x')| ≤ M fH,f(x)f(x)M,考虑以下相似性测度:
    Λ ( p α , p T ) = m i n f ∈ H ε p α ( f ) + ε p T ( f ) Λ(p^α, p^T) = min_{f∈H}ε_pα(f) + ε_{p^T}(f) Λ(pα,pT)=minfHεpα(f)+εpT(f)
    另外还要假设最小化得到的的函数 f ∗ f^{*} f要满足Probabilistic Transfer Lipschitzness(PTL)property,作者指出关于这个PTL是引用的joint distribution optimal transportation for domain adaptation中的内容,其原文如下:

    当时读joint distribution optimal transportation for domain adaptation时这一块就没弄明白,只要知道它是提供了泛化界就好了。PS.这篇文章感觉和上文在理论上几乎没有区别,emm,感觉本文只是考虑了多源域。

Weighted Joint distribution OT problem

  • WJDOT优化问题: 作者的方法旨在找到一个函数 f f f,该函数将分布 p f p^f pf和具有凸组合的权重 ∑ s = 1 S α s p s ′ \displaystyle\sum ^S_{s=1} α_sp'_s s=1Sαsps的分布对齐,作者将多域适应问题表示为:

    这里的优化目标是 α α α f f f,见下图:

    最左边是四个源域的分布图,每个领域有两个类别。中间偏左的图像展示了2d下的源域和目标域样本分布图,由于目标域样本没有标签,所以都是黑色的。中间偏右使用了最优的 α α α权重=[0,0.5,0.5,0],也就是说,只有源2和源3的权重> 0,这是因为它们在Wasserstein意义上最接近目标分布。右图是分类器预测标签。
  • Optimization algorithm最优化算法 :首先这个最优化问题是可以借鉴之前那篇论文里的做法的,即使用块坐标下降法,那篇文章是这么描述该算法的:

    但作者发现,加入了权重 α α α之后,很容易陷入表现很差的局部最小值。于是作者使用了投影梯度下降法:

    关于这个方法,网上的资料很少,我确实是不太清楚,这里待补充。。。
    然后需要注意的是,我们不要忘记了最小化 W W W距离,还要考虑一个参数,就是转移矩阵 r r r,joint distribution optimal transportation for domain adaptation文章就是考虑的 r r r f f f使用坐标下降法, 本文还加入了一个参数 α α α,但是我们并没有使用三个参数进行坐标下降,而是首先固定另外两个参数,对 r r r求得最优,这是一个OT问题,前面重复过很多次了,解法很多。然后我们使用固定的 r r r,对 α α α f f f进行投影梯度下降法(因为坐标下降法表现不好)。
  • 作者厚着脸皮提到,WJDOT和JDOT很相似,但是WJDOT将JDOT应用到了多源域领域自适应问题上(MSDA)。它提到,有两种方法可以在MSDA中实现WJDOT,首先可以将所有的源域样本连接成一个源分布,然后在这个分布上使用JDOT,本文重点讨论的应该就是这个。第二种方法就是,对每一个源域分布的JDOT计算总和。
  • 然后作者开始进行横向比较:很明显,当一些源分布与目标分布非常不同时(在WJDOT中,这些源分布的权重很小),这两种方法都不稳健。在[9]的基础上,有一种叫做JCPOT [30]的MSDA方法,该方法被提议只处理目标转移(类之间的比例变化),并满足一个一般化界限,该界限表明估计目标分布中的类比例是恢复良好性能的关键。作者说:虽然我们没有遵循这个观点,但我们称WJDOT也可以作为一种特殊情况来处理目标转移,因为重新加权 α α α与类的比例直接相关。主要区别在于JCPOT仅使用特征余量来估计类别的比例,而WJDOT通过优化联合嵌入/标签空间中的瓦瑟斯坦距离来同时估计比例和分类器。还要注意,WJDOT依赖于样本的权重,其中权重在源域内共享。这是一种类似于领域适应方法的方法,如重要性加权经验风险最小化(IWERM) [31],设计用于协变量转换,使用所有样本的重新加权。一个主要的区别是,我们只估计 α α α中相对少量的权重,这导致了更好的统计估计。众所周知,对样本进行适当的个体重新加权所必需的连续密度的估计在高维空间中是一个非常困难的问题。

Numerical experiments

  • 在本节中,作者首先提供一些WJDOT的实现细节。然后,评估了所提出的方法,并将其与最先进的MSDA方法进行了比较,包括模拟数据和真实数据。
  • Practical implementation of WJDOT(WJDOT的实际实现): 作者在所有数值实验中使用算法1的WJDOT来求解。还记得之前作者提到的嵌入 g g g嘛,这里作者要对其进行估计,采用的是多任务学习框架(Multi-Task Learning framework)。

    f s f_s fs是最后的分类器,看上去其实就是个标签分类损失嘛。也就是说,作者将这个对于 g g g的生成是单独拿出来进行预训练的,而不是在模型中边进行域适应边进行学习。
  • 然后作者提到另一个重要的任务是如何进行参数验证和早期停止(early stop),在无监督的数据采集中,由于缺乏用于验证的目标样本,这一直是一个难题。为了克服这个问题,作者使用估计的输出 f ( X ) f(X) f(X)和它们在目标数据上的估计聚类质心之间的平方误差之和(sum of squared errors,SSE)。
  • Compared methods: 作者将它的方法与下面的MSDA方法进行比较,其中两种方法是JDOT公式的不明显的扩展。其中CJDOT法将所有的源样本连接成一个源分布。MJDOT正好对应第二种方法,对于所有的源域都优化该式 ∑ s W ( p s , p f ) \displaystyle\sum _s W(p_s,p^f) sW(ps,pf)。对于这两种方法,作者使用上面提到的SSE方法进行提前止损。重要性加权经验风险最小化(Importance Weighted Empirical Risk Minimization,IWERM) [31],这是ERM的一种变体,其中样本通过目标和源密度的比率加权,最小化每个源的IWERM目标之和(minimizing the sum of the IWERM objective for each sources)。DCTN是[18]的深层混合网络,其中对抗学习被用来学习特征提取器、领域鉴别器和源分类器。领域鉴别器提供多个源-目标特定的困惑分数(perplexity scores),用于加权源特定的分类器预测并产生目标估计。最后, M 3 S D A M^3SDA M3SDA是在[15]中为MSDA提出的矩匹配方法,其中嵌入是通过对准源和目标分布的矩来学习的。请注意,在DCTN和 M 3 S D A M^3SDA M3SDA中,嵌入学习是方法的核心,因此它们对于固定的嵌入g是不可行的。因此,我们仅在必须估计g时才与这些方法进行比较。然后作者训练了一个baseline,就是专门使用源域的数据进行训练一个分类器,使得其在多个源域数据上的表现做好,这个方法是为了检测这个源域分类器是否具有对域的鲁棒性,是不是能在目标域上表现良好(当然基本是不可能的,不过这也是一种对比实验,就相当于不使用任何域适应手段,看看仅仅对源域训练分类器能否在目标域表现良好,基本可以作为最差模型)。既然提供了“最差的”情况,当然还要准备最好的模型,就是将目标域样本及标签也加入到训练过程中,即baseline+target,其中target部分训练一个分类器仅仅用于目标域样本及标签,当然这容易过拟合,因为目标域样本及标签不多。因为最后这两种方法是有标签的(即baseline和baseline+target),所以我们使用验证集,比例为7:2:1。
  • Simulated data: 我们考虑一个类似于图1所示的分类问题,但是有3个类,即Y = {0,1,2}。对于源域和目标域,作者是这样生成数据的:

    作者实施了多组实验,以观察多种参数的影响,比如,源域数量,源样本数量,目标样本数量等。
    然后作者在图片2中显示了精度图,左边源域数量为3,另外 N s = N = 300 N_s=N=300 Ns=N=300,右边源域数量为30,即下图:

    WJDOT在性能和方差方面明显优于所有竞争方法,即使是对有限数量的源。有趣的是,WJDOT甚至可以超越Target,因为它可以访问更多的样本。WJDOT的另一个重要方面是获得了可用于解释的权重 α α α。在图2中显示,估计的权重趋向于稀疏,并且将更多的质量放在具有相似角度的源上。
  • Object recognition: The Caltech-Office 数据集包含四个不同的域:Amazon , Caltech [37], Webcam and DSLR。不同领域的不同来自几个因素:背景的存在/不存在、光照条件、噪声等。作者使用预训练DeCAF模型[38]第7层的输出作为嵌入函数 G G G,类似于[9]中所做的,得到嵌入空间 G ∈ R 4096 G ∈ R^{4096} GR4096
    后面都是关于实验的一些配置,不多说了。

Conclusion

  • 作者提到,未来的工作将研究 α α α的正则化,并用WJDOT同时估计嵌入 g g g,而不是用多任务学习对其进行预训练。

转载:https://blog.csdn.net/qq_41076797/article/details/117214665
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场