文章连接:因果推理说泛化
文章代码:
文章目录
- 摘要
- 引言
- 相关工作
- 3 类-条件不变性的不足
-
- 3.1 一个简单的反例
- 3.2 类条件方法的经验研究
- 4 域泛化的因果视角
-
- 4.1 数据增强过程
- 4.2 确定不变性条件
- 4.3 “完美匹配”不变性
- 4.4 过去的工作:学习共同表示法
- 5 MatchDG:没有对象匹配
-
- 5.1 迭代匹配的两阶段方法
- 5.2 MDG 混合
- 6 验证
-
- 6.1 Rotated MNIST and Fashion MNIST
- 6.2 PACS 数据集
- 6.3 胸部X射线数据集
- 7 结论
摘要
在领域泛化文献中,一个共同的目标是在对类标签进行条件反射后学习独立于领域的表示。我们证明这个目标是不够的:存在反例,其中模型即使在满足类条件域不变性之后也不能推广到未见过的域。我们通过结构因果模型形式化了这一观察结果,并展示了类内变化建模对泛化的重要性。具体来说,类包含表征特定因果特征的对象,而域可以解释为对这些对象的干预,从而改变非因果特征。我们强调了另一种条件:如果来自相同对象,则跨域的输入应该具有相同的表示。基于这一目标,我们提出了在观察到基本对象时(例如,通过数据增强)基于匹配的算法,并在未观察到对象时近似目标(MatchDG)。我们简单的基于匹配的算法在旋转MNIST, Fashion-MNIST, PACS和胸部x射线数据集的域外精度方面与之前的工作具有竞争力。我们的方法MatchDG也恢复了真实的对象匹配:在MNIST和Fashion-MNIST上,MatchDG的前10个匹配与真实的匹配有超过50%的重叠。
引言
领域泛化是学习机器学习模型的任务,该模型可以在对多个数据分布进行训练后泛化到不可见的数据分布。例如,在一个地区的医院上训练的模型可以部署到另一个地区,或者可以在稍微旋转的图像上部署图像分类器。通常,假设不同的域共享一些“稳定”的特征,这些特征与输出的关系跨域是不变的,目标是学习这些特征。一类流行的方法旨在学习与类条件无关的域表示,基于其优越性的证据,以学习与领域略微独立的表示的方法。
在这种情况下,表示的类条件支配不变目标是不够的。我们提供了反例,其中特征表示满足目标,但仍然无法在理论上和经验上推广到新的领域。具体来说,当需要学习的稳定特征在不同领域的分布不同时,类条件目标不足以学习稳定特征(只有当稳定特征在不同领域的分布相同时,它们才是最优的)。在真实世界的数据集中,同一类标签内稳定特征的不同分布是很常见的,例如,在数字识别中,稳定特征的形状可能会根据人们的笔迹而不同,或者医学图像可能会根据人们身体特征的变化而不同。我们的调查揭示了在稳定特征中考虑类内变化的重要性。
为了获得更好的领域泛化目标,我们基于先前的单域泛化工作,使用结构因果模型表示稳定特征的类内变化。具体来说,我们为数据生成过程构建了一个模型,该模型假设每个输入都是由稳定(因果)和领域相关(非因果)特征的混合构造而成,并且只有稳定特征导致输出。我们认为域是一种特殊的干预,它改变了输入的非因果特征,并假设理想的分类器应该只基于因果特征。使用d-separation,我们展示了正确的目标是建立一个对每个对象都具有不变条件的表示,其中对象被定义为共享相同因果特征的一组输入(例如,从不同角度拍摄的同一个人的照片或在不同旋转、颜色或背景下对图像的增强)。当观察到对象变量时(例如,在自收集的数据中或通过数据集增强),我们提出了一个完美匹配正则化器域泛化,将相同对象跨域表示之间的距离最小化。
然而,在实践中,底层对象并不总是已知的。因此,我们提出了一个近似,旨在学习哪些输入共享相同的对象,假设来自同一类的输入比来自不同类的输入具有更多相似的因果特征。
我们在旋转的MNIST和Fashion-MNIST、PACS和胸部x射线数据集上评估了我们基于匹配的方法。在所有数据集上,简单的方法MatchDG和mdhybrid在域外精度方面与最先进的方法具有竞争力。在旋转的MNIST和FashionMNIST数据集上,已知的基础真值对象,Mat chDG学习使表示更类似于它们的基础真值匹配(前10个匹配大约有50%的重叠),即使该方法无法访问它们。我们用简单匹配方法的结果显示了强制正确的不变性条件的重要性。
关注接触环境的重要性。的贡献。总的来说,我们的贡献包括:1).领域泛化的对象不变条件,突出了先前方法的关键局限性;2).当对象信息不可用时,采用一种两阶段迭代算法来逼近基于对象的匹配。
相关工作
学习共同表征。
其中,几部作品表明类条件方法比那些强制特征的边缘域不变性,每当跨域的类标签分布不同时。我们证明了类条件不变量也不足以推广到不可见的域。
因果关系和领域泛化。
Y
→
X
Y→X
Y→X 方向和其他工作连接假设
X
→
Y
X→Y
X→Y 的因果关系。我们的SCM模型通过引入
Y
t
r
u
e
Y_{true}
Ytrue? 和标记为Y来统一这些流,并开发了在两种解释下都有效的领域泛化的不变性条件。也许最接近我们工作的是,他们在单域数据集中使用对象概念来更好地泛化。我们将他们的SCM扩展到多域设置,并用它来显示先前方法的不一致性。此外,(Heinze-Deml & Meinshausen, 2019) 假设对象总是被观察到,我们还提供了一个算法,当对象是未被观察到的情况下。
匹配和对比损失。基于匹配的正则化器被提出用于领域泛化。(Motiian et al., 2017)提出了来自同一类输入的匹配表示。(Dou et al., 2019)使用对比(三重)损失来正则化ERM目标。与基于对比损失的正则化相比,我们的算法MatchDG分两个阶段进行,并学习独立于ERM目标的表示。这种迭代的两阶段算法具有经验上的好处,我们将在附录D.4中展示。此外,我们还提出了一种理想的基于对象的目标匹配算法。
其他工作。领域泛化的其他方法包括元学习、数据集增强,参数分解,并强制执行最优
P
(
Y
∣
Φ
(
x
)
)
P(Y|Phi(x))
P(Y∣Φ(x)) 的域不变性 。我们根据经验将我们的算法与其中一些算法进行了比较。
3 类-条件不变性的不足
考虑一个分类任务,其中学习算法可以访问来自
m
m
m 个域
{
(
d
i
,
x
i
,
y
i
)
}
i
=
1
n
~
(
D
m
,
X
,
Y
)
{(d_i, x_i, y_i)}_{i=1}^nsim (D_m,mathcal{X}, mathcal{Y})
{(di?,xi?,yi?)}i=1n?~(Dm?,X,Y) 的 i.id 数据,其中
d
i
∈
D
m
d_iin D_m
di?∈Dm? 和
D
m
?
D
D_msubset D
Dm??D 是
m
m
m 个定义域的集合。每个训练输入
(
d
,
x
,
y
)
(d,x,y)
(d,x,y) 都是从一个未知分布
P
m
(
D
,
X
,
Y
)
mathcal{P}_m(D, X, Y)
Pm?(D,X,Y) 中采样的。域泛化任务是学习一个分类器,该分类器可以很好地泛化到不可见的域
d
′
?
D
m
d^{'}
otin D_m
d′∈/Dm? 和来自相同域的新数据。最优分类器可以写成:
f
?
=
a
r
g
m
i
n
f
∈
F
E
(
d
,
x
,
y
)
~
P
[
l
(
y
(
d
)
,
f
(
x
(
d
)
)
)
]
f* = arg min_{fin mathcal{F}} mathbb{E}_{(d,x,y)sim mathcal{P}}[l(y^{(d)}, f(x^{(d)}))]
f?=argminf∈F?E(d,x,y)~P?[l(y(d),f(x(d)))],其中
(
d
,
x
,
y
)
~
P
,
o
v
e
r
(
D
,
X
,
Y
)
(d,x,y)sim mathcal{P} , over (D,mathcal{X,Y})
(d,x,y)~P,over(D,X,Y)。
如上所述,一种流行的工作方式强制要求学习到的表示
Φ
(
x
)
Φ(x)
Φ(x) 独立于类的域条件,
Φ
(
x
)
⊥
????????
⊥
D
∣
Y
Φ(x) perp!!!!perp D|Y
Φ(x)⊥⊥D∣Y。下面我们给出两个反例,表明类条件目标是不充分的。
3.1 一个简单的反例
我们构建了一个示例,其中
Φ
(
x
)
⊥
????????
⊥
D
∣
Y
Φ(x) perp!!!!perp D|Y
Φ(x)⊥⊥D∣Y ,但分类器仍然没有泛化到新的领域。考虑一个二维问题,其中
x
1
=
x
e
+
α
d
;
x
2
=
α
d
x_1 = x_e+ alpha_d;x2 = alpha_d
x1?=xe?+αd?;x2=αd?,其中
x
e
x_e
xe? 和
α
d
alpha_d
αd? 是未观察到的变量,
α
d
alpha_d
αd? 随域变化(图1(a))。真实函数只依赖于稳定特征
x
c
,
y
=
f
(
x
c
)
=
I
(
x
e
≥
0
)
x_c, y = f(x_c) = I(x_ege 0)
xc?,y=f(xc?)=I(xe?≥0),假设有两个训练域,域1
α
1
=
1
α1 = 1
α1=1,域2
α
2
=
2
α2 = 2
α2=2,测试域
a
3
=
0
a3 = 0
a3=0 (见图1(a))。进一步设给定
Y
Y
Y 的
X
c
的
X_c的
Xc?的 条件分布是跨域变化的均匀分布:对于域1,
X
c
∣
Y
=
1
~
U
(
1
,
3
)
;
X
c
∣
Y
=
0
~
U
(
?
2
,
0
)
;
X_c|Y = 1 sim U(1,3);X_c|Y = 0sim U(-2,0);
Xc?∣Y=1~U(1,3);Xc?∣Y=0~U(?2,0); 对于域2,
X
c
∣
Y
=
1
U
(
0
,
2
)
;
X
c
∣
Y
=
0
U
(
?
3
,
?
1
)
X_c|Y = 1 ~ U(0,2);X_c|Y = 0 ~ U(-3,-1)
Xc?∣Y=1 U(0,2);Xc?∣Y=0 U(?3,?1)。注意,选择的分布使得
φ
(
x
1
,
x
2
)
=
x
1
φ(x_1,x_2) = x_1
φ(x1?,x2?)=x1? 满足条件分布不变量,
φ
(
x
)
⊥
????????
⊥
D
∣
Y
φ(x)perp!!!!perp D|Y
φ(x)⊥⊥D∣Y。基于此表示的最优ERM分类器(
I
(
x
1
≥
1.5
)
I(x_1 ge 1.5)
I(x1?≥1.5))在两个域中都具有100%的训练准确率。但是对于测试域
α
d
=
0
;
X
c
∣
Y
=
1
~
U
(
0
,
2
)
;
X
c
∣
Y
=
0
~
U
(
?
2
,
0
)
α_d = 0;X_c|Y =1 sim U(0,2);X_c|Y =0sim U(-2,0)
αd?=0;Xc?∣Y=1~U(0,2);Xc?∣Y=0~U(?2,0),分类器泛化失败。即使它的表示满足类条件域不变性,它也获得62.5%的测试准确率(在正类上获得25%的准确率)。相比之下,理想的表示为
x
1
?
x
2
x_1 - x_2
x1??x2?,达到100%的训练精度和100%的测试域精度,并且不满足类条件不变量。
上面的反例是由于
x
c
x_c
xc? 跨域分布的变化。
P
(
X
c
∣
Y
)
P(X_c|Y)
P(Xc?∣Y)跨域保持相同,那么类条件方法就不会错误地选择
x
i
x_i
xi? 作为表示。跟随(Akuzawa等人,2019),我们声明以下内容(在附件中的证明)。B.3)。
命题1。在如上所述的域泛化设置下,如果
P
(
X
c
∣
Y
)
P(X_c|Y)
P(Xc?∣Y) 在多个域中保持相同,其中xe是稳定特征,则用于学习表示的类条件域不变目标产生一个可泛化的分类器,使得学习到的表示
Φ
(
x
)
Φ(x)
Φ(x) 独立于给定
x
c
x_c
xc?的域。具体来说,熵
H
(
d
∣
x
c
)
=
H
(
d
∣
Φ
,
x
c
)
H(d|x_c) = H(d|Φ, x_c)
H(d∣xc?)=H(d∣Φ,xc?)。
但是,如果
P
(
X
c
∣
Y
)
P(X_c|Y)
P(Xc?∣Y) 跨域变化,那么我们不能保证相同:
H
(
d
∣
x
c
)
H(d|x_c)
H(d∣xc?) 和
H
(
d
∣
Φ
,
x
c
)
H(d|Φ,x_c)
H(d∣Φ,xc?)可能不相等。为了在这种情况下构建可泛化的分类器,这个例子表明我们需要对
Φ
,
H
(
d
∣
x
c
)
=
H
(
d
∣
Φ
,
x
c
)
Φ, H(d|x_c) = H(d|Φ, x_c)
Φ,H(d∣xc?)=H(d∣Φ,xc?)附加约束;即域和表示应该在
x
c
x_c
xc? 条件下独立。
3.2 类条件方法的经验研究
作为一个更现实的例子,考虑为检测神经网络中的简单性偏差而引入的 slab 数据集,该数据集包含一个具有虚假相关性的特征。它包括两个特征和一个二元标签;(1)与标签呈线性关系,其他特征(2)与标签呈分段线性关系,为稳定关系。线性特征与标签的关系随着域的变化而变化(A.1);我们通过在域1中添加概率为
?
=
0
epsilon=0
?=0 的噪声,在域2中添加概率为
?
=
0.1
epsilon=0.1
?=0.1 的噪声来做到这一点。在第三个(测试)域,我们以1的概率添加噪声(参见图1(b))。我们预计依赖于伪特征
x
1
x_1
x1? 的方法将无法在域外数据上表现良好。
结果见表1(实现细节见附录)A.1)表明,ERM无法学习slab特征,对目标域的泛化效果较差尽管在源域上有很好的性能。我们还表明,基于无条件(DANN, MMD, CORAL)和条件分布匹配(CDANN, C-MMD, C-CORAL)和匹配同类输入(RandomMatch) 学习不变表示的方法无法学习稳定板特征。请注意,命题1表明,当稳定特征(slab特征)在源域的分布不同时,条件分布匹配(CDM)算法会失败。然而,slab数据集在源域具有相似的稳定特征(slab)分布,但CDM算法无法推广到目标域。这可以通过考虑伪线性特征来解释,伪线性特征也可以通过沿线性特征“移动”y条件分布来满足CDM约束。我们推测,由于其简单性,该模型可能首先学习线性特征,然后在进一步优化时保留伪线性特征,因为它满足CDM约束。这表明CDM方法在经验上可能会失败,即使在稳定特征跨域均匀分布的情况下。
我们如何确保模型学习稳定的、可推广的特征
x
2
x_2
x2??
Ф
(
x
)
Ф(x)
Ф(x) 应该独立于给定稳定特征的域。我们运用这种直觉并构建了一个模型,该模型强制学习表征独立于给定的域
x
2
x_2
x2?. 我们通过最小化来自共享相同slab值的不同域的数据点的表示的
l
2
l_2
l2? 范数来做到这一点
在下一节中,我们将使用因果图形式化对稳定特征x的条件作用的直觉,并引入作为稳定特征代理的对象的概念。
4 域泛化的因果视角
4.1 数据增强过程
图2(a)显示了一个结构化因果模型(SCM),它描述了域泛化任务的数据生成过程。对于直觉来说,考虑一个分类项目类型或筛选医疗状况图像的任务。由于人为的可变性或设计(使用数据增强),数据生成过程为每个类生成各种图像,有时为同一对象生成多个视图。在这里,每个视图可以看作是不同的域
D
D
D,项目类型或医疗条件的标签作为类别
Y
Y
Y,图像像素作为特征
x
x
x。同一项目或同一个人的照片对应于一个共同的对象变量,表示为
O
O
O。为了创建图像,数据生成过程首先对对象和视图(域)进行采样,可能彼此相关(用虚线箭头表示)。照片中的像素是由物体和视图引起的,如两个指向x的箭头所示。物体也对应于高级因果特征
X
c
X_c
Xc?,这是同一物体的任何图像所共有的,而人类又用它来标记
Y
Y
Y 类。我们称
X
c
X_c
Xc? 为因果特征,因为它们直接导致
Y
Y
Y 类。
上面的例子是一个典型的领域泛化问题;一般SCM如图2(b)所示,与(Heinze-Deml & Meinshausen, 2019)中的图表相似。一般来说,可能无法观察到每个输入
x
i
(
d
)
x_i^{(d)}
xi(d)?的底层对象。与对象相关(因果)特性
X
C
X_C
XC?类似,我们为对象
X
A
X_A
XA? 的领域相关高级特性引入了一个节点。改变域可以看作是一种干预:对于每个观察到的
x
i
(
d
)
x_i^{(d)}
xi(d)?,都有一组(可能未观察到的)反事实输入
x
i
(
d
)
x_i^{(d)}
xi(d)?,其中
d
≠
d
′
d
eq d'
d=d′,这样所有输入都对应于相同的对象(因此共享相同的
X
C
X_C
XC?)。为了完整起见,我们还显示了导致其生成为
Y
t
r
u
e
Y_{true}
Ytrue?的对象的真实未观察到的标签(因果图的额外动机在 supp B.1)。与对象
O
O
O 一样,
Y
Y
Y 可能与域
d
d
d 相关。扩展(Heinze-Deml & Meinshausen, 2019)中的模型,我们允许对象可以与以
Y
t
r
u
e
Y_{true}
Ytrue? 为条件的域相关。我们将看到,考虑对象节点的关系成为开发不变条件的关键部分。SCM对应于以下非参数方程。
o
:
=
g
o
(
y
t
r
u
e
,
?
o
,
?
o
d
)
,
x
c
=
g
x
c
(
o
)
x
a
:
=
g
x
a
(
d
,
o
,
?
x
a
)
,
x
:
=
g
x
(
x
c
,
x
a
,
?
x
)
:
=
h
(
x
c
,
?
y
)
o:=g_o(y_{true},epsilon _o,epsilon _{od}), x_c=g_{xc}(o)\ x_a:=g_{xa}(d,o,epsilon_{xa}), x:=g_x(x_c,x_a,epsilon_x):=h(x_c,epsilon_y)
o:=go?(ytrue?,?o?,?od?),xc?=gxc?(o)xa?:=gxa?(d,o,?xa?),x:=gx?(xc?,xa?,?x?):=h(xc?,?y?)其中
g
o
,
g
x
c
,
g
x
a
,
g
x
和
h
g_o, g_{xc}, g_{xa}, g_x和 h
go?,gxc?,gxa?,gx?和h是一般的非参数函数。误差
?
o
d
epsilon_{od}
?od?与定义域
d
d
d 相关,而
?
o
,
?
x
a
,
?
x
,
?
y
epsilon_{o},epsilon_{xa},epsilon_{x},epsilon_{y}
?o?,?xa?,?x?,?y?是相互独立的误差项,与所有其他变量无关。因此,类标签中的噪声与域无关。由于
x
c
x_c
xc? 对于同一对象的所有输入是共同的,因此
g
x
c
g_{xc}
gxc? 是
o
o
o 的确定性函数。此外,通过d-separation的概念,SCM提供了所有数据分布
P
mathcal{P}
P 必须满足的条件独立条件。B.2)和完美映射假设(Pearl, 2009)。
4.2 确定不变性条件
从图2(b)中,
X
c
X_c
Xc? 是引起
Y
Y
Y 的节点。进一步,通过d-separation,类标签独立于以
X
c
,
Y
⊥
????????
⊥
D
∣
X
c
X_c, Y perp!!!!perp D|X_c
Xc?,Y⊥⊥D∣Xc?为条件的域。因此,我们的目标是学习
y
y
y 为
h
(
x
c
)
h(x_c)
h(xc?),其中
h
:
C
→
Y
h: mathcal{C→Y}
h:C→Y。理想的损失最小化函数
f
?
f^{*}
f? 可以重写为(假设
x
c
x_c
xc?已知):
a
r
g
m
i
n
f
E
(
d
,
x
,
y
)
l
(
y
,
f
(
x
)
)
=
a
r
g
m
i
n
h
E
[
l
(
y
,
h
(
x
c
)
)
]
(
1
)
argmin_fmathbb{E}_{(d, x, y)}l (y, f (x)) =argmin_hmathbb{E}[l(y, h (x_c))] (1)
argminf?E(d,x,y)?l(y,f(x))=argminh?E[l(y,h(xc?))](1)
f
f
f 由于
X
c
X_c
Xc? 是未观察到的,这意味着我们需要通过表示函数
Φ
:
X
→
C
Φ:mathcal{X→C}
Φ:X→C 来学习它。总之,
h
(
Φ
(
x
)
)
h(Φ(x))
h(Φ(x)) 导致期望的分类器
f
:
X
→
Y
f: mathcal{X→Y}
f:X→Y。
识别的消极结果。因果特征的识别是一个不容忽视的问题。我们首先证明,给定多个域上的观测数据
P
(
X
,
Y
,
D
,
O
)
,
x
c
P(X,Y,D,O), x_c
P(X,Y,D,O),xc?是不可识别的。给定相同的概率分布
P
(
X
,
Y
,
D
,
O
)
,
X
c
P(X, Y, D, O), X_c
P(X,Y,D,O),Xc? 可能有多个值。代入SCM方程中的
o
o
o,得到,
y
=
h
(
g
x
c
(
o
)
,
?
y
)
;
x
=
g
x
(
g
x
c
(
o
)
,
g
x
a
(
d
,
o
,
?
x
a
)
,
?
x
)
y = h(g_{xc}(o),epsilon_y);x = g_x(g_{xc}(o), g_{xa}(d, o,epsilon_{xa}),epsilon_{x})
y=h(gxc?(o),?y?);x=gx?(gxc?(o),gxa?(d,o,?xa?),?x?)。通过适当地选择
g
x
g_x
gx? 和
h
h
h,不同的
g
x
c
g_{xc}
gxc? 值(即从
o
o
o 确定
x
c
x_c
xc?值)可以导致
(
y
,
d
,
o
,
x
)
(y, d, o, x)
(y,d,o,x) 的观测值相同。
命题2。给定观测数据分布
P
(
Y
,
X
,
D
,
O
)
P(Y,X, D,O)
P(Y,X,D,O),其中可能还包括从
D
D
D 域干预
X
C
X_C
XC? 的多个值产生完全相同的观测和干预分布,因此
X
c
X_c
Xc?无法识别。(证明在附录B.4)
4.3 “完美匹配”不变性
在没有可辨识性的情况下,我们继续寻找一个可以表征
X
c
X_c
Xc? 的不变量。通过 d-separation, 我们看到
X
c
X_c
Xc?满足两个条件:1)
X
C
⊥
????????
⊥
D
∣
O
,
2
)
X
C
⊥
????????
⊥
O
;
其中
O
X_Cperp!!!!perp D|O,2) X_C perp!!!!perp O;其中O
XC?⊥⊥D∣O,2)XC?⊥⊥O;其中O
D
D
D 表示域。第一个是不变性条件:同一个对象的不同域不会改变
X
C
X_C
XC?。为了加强这一点,我们规定相同对象的跨域输入之间的平均成对距离
Φ
(
x
)
为
0
,
∑
Ω
(
j
,
k
)
=
1
;
d
≠
d
′
D
i
s
t
(
Φ
(
x
j
(
d
)
)
,
Φ
(
x
k
(
d
′
)
)
)
=
0
Φ(x)为0,sum_{Ω(j,k)=1;d
eq d}'Dist (Φ(x_j^{(d)}), Φ(x_k^{(d^{'})})) = 0
Φ(x)为0,∑Ω(j,k)=1;d=d′?Dist(Φ(xj(d)?),Φ(xk(d′)?))=0。这里
Ω
:
X
×
X
→
{
0
,
1
}
Ω: mathcal{X imes X}→{0,1}
Ω:X×X→{0,1} 是一个匹配函数,对于对应同一对象的跨域输入对,该函数为1,否则为0。
然而,仅仅是上面的不变性是行不通的:我们需要表示对象
O
O
O 的信息(否则即使是一个常数
Φ
Φ
Φ 也会最小化上面的损失)。因此,第二个条件规定
X
C
X_C
XC? 应该是对象的信息,因此是关于
Y
Y
Y 的信息。我们加入标准分类损失,导致约束优化,
f
p
e
r
f
e
c
t
m
a
t
c
h
=
a
r
g
m
i
n
h
,
Φ
Σ
d
=
1
m
L
d
(
h
(
Φ
(
X
)
)
,
Y
)
S
.
t
Σ
Ω
(
j
.
k
)
=
1
;
d
≠
d
′
d
i
s
t
(
Φ
(
x
j
(
d
)
)
,
Φ
(
x
k
(
d
′
)
)
)
=
0
(
2
)
f_{perfectmatch} = arg min_{h,Φ} Σ_{d=1}^m L_d(h(Φ(X)),Y)\ S.t quad Σ_{Ω(j.k)=1;d
eq d^{'}} dist(Φ(x_j^{(d)}), Φ(x_k^{(d^{'})})) = 0 quad(2)
fperfectmatch?=argminh,Φ?Σd=1m?Ld?(h(Φ(X)),Y)S.tΣΩ(j.k)=1;d=d′?dist(Φ(xj(d)?),Φ(xk(d′)?))=0(2)式中
L
d
(
h
(
Φ
(
X
)
,
Y
)
)
=
Σ
i
=
1
n
d
l
(
h
(
Φ
(
x
j
(
d
)
)
)
,
y
i
(
d
)
)
L_d(h(Φ(X), Y)) = Σ_{i=1}^{n_d} l(h(Φ(x_j^{(d)})),y_i^{(d)})
Ld?(h(Φ(X),Y))=Σi=1nd??l(h(Φ(xj(d)?)),yi(d)?),其中
f
f
f 表示组成
h
°
Φ
hcirc Φ
h°Φ。例如,以
Φ
(
x
)
Φ(x)
Φ(x) 为第
r
r
r 层的神经网络,
h
h
h 为其余层。
请注意,可以有多个
Φ
(
x
)
Φ(x)
Φ(x) (例如,线性变换)同样适合于预测任务。由于
x
c
x_c
xc? 是不可识别的,我们关注的是一组稳定的表示,它们与给定
O
O
O 的
D
D
D 分离,独立于给定对象的域,它们不能与直接依赖于域的高级特征
X
a
X_a
Xa? 有任何关联(图2b)。下一个定理的证明在附录B.5。
定理1。对于有限个数的域
m
m
m,作为每个域的样例个数
n
d
→
∞
n_d→infty
nd?→∞,
- 满足条件
Σ
Ω
(
j
,
k
)
=
1
;
d
≠
d
′
d
i
s
t
(
Φ
(
x
j
(
d
)
)
,
Φ
(
x
k
(
d
′
)
)
)
=
0
Σ_{Ω(j,k)=1;d
eq d^{'}} dist(Φ(x_j^{(d)}), Φ(x_k^{(d^{'})})) = 0ΣΩ(j,k)=1;d=d′?dist(Φ(xj(d)?),Φ(xk(d′)?))=0 的表示集合包含使 (1) 中的域泛化损失最小化的最优
Φ
(
x
)
=
X
C
Φ(x) = X_C
Φ(x)=XC?。
- 假设对于直接由域引起的每个高级特征
X
a
,
P
(
X
a
∣
O
,
D
)
<
1
X_a, P(X_a|O,D) < 1
Xa?,P(Xa?∣O,D)<1,并且对于最小化是条件期望(例如
l
2
l_2
l2? 或交叉熵)的 P-可容许损失函数,对于某个值为
λ
lambda
λ 的值,以下损失的最小化分类器为真函数
f
?
f^*
f?。
f
p
e
r
f
e
c
t
m
a
t
c
h
=
a
r
g
m
i
n
h
,
Φ
Σ
d
=
1
m
L
d
(
h
(
Φ
(
X
)
)
,
Y
)
+
λ
Σ
Ω
(
j
.
k
)
=
1
;
d
≠
d
′
d
i
s
t
(
Φ
(
x
j
(
d
)
)
,
Φ
(
x
k
(
d
′
)
)
)
=
0
(
3
)
f_{perfectmatch} = arg min_{h,Φ} Σ_{d=1}^m L_d(h(Φ(X)),Y)+\ lambda Σ_{Ω(j.k)=1;d
eq d^{'}} dist(Φ(x_j^{(d)}), Φ(x_k^{(d^{'})})) = 0quad (3)fperfectmatch?=argminh,Φ?Σd=1m?Ld?(h(Φ(X)),Y)+λΣΩ(j.k)=1;d=d′?dist(Φ(xj(d)?),Φ(xk(d′)?))=0(3)
4.4 过去的工作:学习共同表示法
使用SCM,我们现在将提出的不变性条件与域不变性和类条件域不变性目标进行比较。d-separation 结果表明这两个目标都是不正确的:特别是,类条件目标
Φ
(
x
)
⊥
????????
⊥
D
∣
Y
Φ(x)perp!!!!perp D|Y
Φ(x)⊥⊥D∣Y 不满足
X
c
,
(
X
c
?
??
⊥
??????
⊥
D
∣
Y
t
r
u
e
)
X_c, (X_c
ot!perp!!!perp D|Y_{true})
Xc?,(Xc?⊥⊥D∣Ytrue?),因为通过
O
O
O 的路径。即使有无限的跨域数据,他们也不会学到真正的
X
c
X_c
Xc?。证据在附录B.6。
命题3。因果表示Xc不满足域不变(
Φ
(
x
)
⊥
????????
⊥
D
Φ(x) perp!!!!perp D
Φ(x)⊥⊥D )或类条件域不变(
Φ
(
x
)
⊥
????????
⊥
D
∣
Y
Φ(x) perp!!!!perp D|Y
Φ(x)⊥⊥D∣Y)方法所强制的条件。因此,如果没有额外的假设,满足这些条件的表示集不包含
X
C
X_C
XC?,即使
n
→
∞
n→infty
n→∞。
5 MatchDG:没有对象匹配
当对象信息可用时,Eq.(3)提供了一个损失目标来使用因果特征构建分类器。然而,对象信息并不总是可用的,并且在许多数据集中,可能不存在基于跨域的相同对象的完美“反事实”匹配。因此,我们提出了一种两阶段迭代对比学习方法来近似对象匹配。
第4.2节中的对象不变条件可以解释为来自共享相同
X
c
X_c
Xc? 的不同域的输入对的匹配。为了接近它,我们的目标是学习一个匹配
Ω
:
X
×
X
→
{
0
,
1
}
Ω:mathcal{X imes X}→{0,1}
Ω:X×X→{0,1},使得
Ω
(
x
,
x
′
)
=
1
Ω(x, x') = 1
Ω(x,x′)=1 的配对在
x
c
x_c
xc? 和
x
′
x'
x′ 上的差异很小。我们做如下假设。
假设1。
(
x
i
(
d
)
,
y
)
(
x
j
(
d
′
)
,
y
)
(x_i^{(d)}, y) (x_j^{(d^{'})},y)
(xi(d)?,y)(xj(d′)?,y)是属于同一类的任意两个点,设
(
x
k
(
d
)
,
y
′
)
(x_k^{(d)}, y')
(xk(d)?,y′) 是具有不同类标号的任意其他点。然后是因果特征
x
i
x_i
xi? 和
x
j
x_j
xj?之间的距离;小于
x
i
x_i
xi? 与
x
k
x_k
xk? 或
x
j
x_j
xj? 与
x
k
x_k
xk?之间的关系;
d
i
s
t
(
x
c
,
i
(
d
)
,
x
c
,
j
(
d
′
)
)
<
d
i
s
t
(
d
′
)
d
i
s
t
(
x
c
,
i
(
d
)
,
x
c
,
k
(
d
′
)
)
dist(x_{c,i}^{(d)},x_{c,j}^{(d^{'})}) <dist(d') dist(x_{c,i}^{(d)},x_{c,k}^{(d^{'})})
dist(xc,i(d)?,xc,j(d′)?)<dist(d′)dist(xc,i(d)?,xc,k(d′)?)
d
i
s
t
(
x
c
,
j
(
d
)
,
x
c
,
i
(
d
′
)
)
<
d
i
s
t
(
d
′
)
d
i
s
t
(
x
c
,
j
(
d
)
,
x
c
,
k
(
d
′
)
)
dist(x_{c,j}^{(d)},x_{c,i}^{(d^{'})}) <dist(d') dist(x_{c,j}^{(d)},x_{c,k}^{(d^{'})})
dist(xc,j(d)?,xc,i(d′)?)<dist(d′)dist(xc,j(d)?,xc,k(d′)?)
5.1 迭代匹配的两阶段方法
为了学习匹配函数
Ω
Ω
Ω ,我们使用了(Chen et al., 2020;他等人,2019),并将其用于构建迭代MatchDG算法,该算法在每个 epoch 之后更新表示和匹配。该算法依赖于来自同一类的两个输入比来自不同类的输入具有更多相似因果特征的属性。
对比的损失。为了找到匹配,我们优化了一个对比表示学习损失,将来自不同领域的同类输入与来自不同领域的不同类别的输入之间的距离最小化。为了适应单个域的对比损失(Chen et al., 2020),我们将正匹配视为具有相同类别但不同域的两个输入,将负匹配视为具有不同类别的对。对于每一对正匹配对
(
x
j
,
x
k
)
(x_j, x_k)
(xj?,xk?) ,我们提出一个损失,其中的
τ
au
τ 是一个超参数,
B
B
B 是批大小,
s
i
m
(
a
,
b
)
Φ
(
x
a
)
T
Φ
(
x
b
)
/
∣
∣
Φ
(
x
a
)
∣
∣
∣
∣
Φ
(
x
b
)
∣
∣
sim(a, b)Φ(x_a)^TΦ(x_b)/ ||Φ(x_a)||||Φ(x_b)||
sim(a,b)Φ(xa?)TΦ(xb?)/∣∣Φ(xa?)∣∣∣∣Φ(xb?)∣∣ 是余弦相似度。
l
(
x
j
,
x
k
)
=
?
l
o
g
e
s
i
m
(
j
,
k
)
/
τ
e
s
i
m
(
j
,
k
)
/
τ
+
Σ
i
=
0
,
y
i
≠
y
j
B
e
s
i
m
(
j
,
i
)
/
τ
(
4
)
l(x_j,x_k) = -logfrac{e^{sim(j,k)/τ}}{e^{sim(j,k)/τ} +Σ_{i=0,y_i
eq y_j}^Be^{sim(j,i)/τ}}quad (4)
l(xj?,xk?)=?logesim(j,k)/τ+Σi=0,yi?=yj?B?esim(j,i)/τesim(j,k)/τ?(4)
迭代匹配。我们的关键观点是在训练中更新积极的比赛。我们从基于类的随机正匹配集开始训练,但在每t个epoch之后,我们基于表示空间中最接近的同类对更新正匹配并迭代直到收敛。因此,对于每个锚点,从一组初始的正匹配开始,在每个历元中使用对比学习来学习表征;之后,根据表示中跨域的最接近的同类数据点更新正匹配本身。因此,该方法可以区分同一类的数据点,而不是将所有数据点视为单个单元。通过对正匹配的迭代更新,目标是考虑跨域的类内方差,并跨更有可能共享相同基本对象的域匹配数据点。在附录D.6,我们比较了所提出的迭代匹配与标准对比训练的增益。
获得最终表示完成了算法的第一阶段。在第二阶段,我们使用这种表示来计算基于最接近的同类对的新匹配函数,并应用 Eq.(3)来获得对这些匹配进行正则化的分类器。
使用两个阶段的重要性。与之前的方法不同,我们将MatchDG实现为两阶段方法(Motiian et al., 2017;Dou等人,2019)使用基于类的对比损失作为ERM的正则化器。这是为了避免分类损失干扰跨域学习不变表示的目标(例如,在数据集中,其中一个域比其他域有更多的样本)。因此,我们首先学习match函数只使用对比损耗。我们的结果在附录 D.4表明,两阶段方法比同时优化分类和匹配提供了更好的地真完美匹配重叠。
为了实现 MatchDG,我们为每个输入构建一个包含
q
?
1
q-1
q?1 个正匹配的
p
×
q
p×q
p×q 数据矩阵,然后从这个矩阵中采样小批量。对比损失网络的最后一层被认为是学习到的表示(见算法1;详情见附录C.1)。
5.2 MDG 混合
虽然MatchDG不假设任何关于对象的信息,但它可以很容易地扩展以包含关于已知对象的信息。例如,在计算机视觉中,标准做法是通过执行旋转、水平翻转、颜色抖动等来增强数据。这些自增强为我们提供了对已知对象的访问,通过向Eq 3的损失添加另一个正则化器,可以将其作为MatchDG Phase-II中的完美匹配。我们将此方法命名为MDGhybrid,并将其与MatchDG一起评估我们可以执行自增强的数据集。
6 验证
6.1 Rotated MNIST and Fashion MNIST
表2显示了使用Resnet-18模型在测试域0°和90°的rotMNIST和rotFashionMNIST上的分类精度。在这两个数据集上,MatchDG都优于所有基线。最后一列显示了oracle方法ERM-PerfMatch的准确性,该方法可以访问跨域的真实完美匹配。Mat chDG的准确率介于ERM-RandMatch和ERM-PerfMatch之间,说明学习匹配函数的好处。随着训练域数量的减少,MatchDG与基线之间的差距突出:对于rotFashionMNIST,只有3个源域,MatchDG的准确率为43.8%,而第二好的方法ERM-RandMatch的准确率为38.4%。
我们还评估了更简单的 2 层 LeNet(Motiian 等人,2017 年)和 (Gulrajani & Lopez-Paz,2020 年)的模型,以将 MatchDG 与以前的工作进行比较,结果见附录D.1、D.2。
为什么MatchDG有效?我们比较返回的匹配垫chDG第一阶段(Resnet-18网络)groundtruth完美匹配,发现它有重叠明显高于匹配基于ERM损失(表3)。我们报告三个指标表示学习:垫chDG匹配的百分比是完美的匹配,%表示数量的输入中的完美匹配十大排名MatchDG比赛,完美的匹配和平均等级来衡量距离MatchDG表示。
在所有这三个指标上,Mat chDG找到了一个表示匹配更符合真实的完美匹配。对于rotMNIST和rotFashionMNIST数据集,根据MatchDG阶段1学习到的表示,大约50%的输入在排名前10位的匹配中有完美匹配。MatchDG学习到的所有匹配中约有25%是完美匹配。为了比较,我们还展示了用完全匹配初始化的(oracle) MatchDG方法的度量:它实现了更好的总体和前10个值。Mat chDG第二阶段的类似结果也在供应中。D.4。rotFashionMNIST的平均排名可能更高,因为样本量更大,每个域10000;2000个样本的训练指标在供应中。D.5。为了了解与完美匹配的重叠如何影响准确性,我们模拟了与完美匹配重叠25%、50%和75%的随机匹配。资源描述。D.3)。准确性随着完美匹配的比例的增加而增加,这表明捕获良好匹配的重要性。
MatchDG与IRM在零训练误差上的对比。由于神经网络通常达到零训练误差,我们也评估了在这种情况下Mat chDG正则化的有效性。图3显示了rotMNIST和rotFashionMNIST在训练过程中的匹配损失项。甚至当模型达到零训练误差后,我们看到普通的ERM目标无法最小化匹配损失(因此需要MatchDG惩罚)。这是因为MatchDG正则化依赖于比较(最后一层)表示,零训练误差并不意味着每个类中的表示是相同的。相比之下,当训练误差趋于零时,基于IRM惩罚等训练域之间损失比较的正则化可以通过普通ERM来满足(图3(b));类似于(Krueger et al., 2020)的图(5),其中ERM可以最小化有色MNIST的IRM处罚。
6.2 PACS 数据集
ResNet-18。在具有ResNet-18体系结构的PACS数据集上(表4),我们的方法与state-of-the art具有竞争力在所有领域的平均值。除了与DDEC和RSC相比,mdhybrid具有最高的跨域平均精度。这些工作没有透露他们的模型选择策略(无论结果是使用源验证还是测试领域验证)。因此,我们还报告了使用测试域验证的MatchDG和MDGHybrid的结果,其中MDGHybrid获得了与最佳性能方法相当的结果。此外,对于DDEC (Asadi等人,2019),这不是一个公平的比较,因为它们使用了来自Behance BAM的额外风格传输数据!数据集。
ResNet-50。我们在Resnet50模型(表5)上实现了Mat chDG,该模型由DomainBed中的ERM使用。加入MatchDG损失正则化后,mdhybrid的DomainBed精度从85.7提高到87.5。此外,MDGHybrid的性能也优于我们之前的使用 resnet-50 结构的方法。除了RSC (Huang et al., 2020),其结果(87.83)与我们的(87.52)接近。注意,我们为表4选择了性能最好的基线的一个子集。5;与其他作品的广泛比较在附录E.1。E.2给出了使用AlexNet网络的结果,t-SNE图(图5)显示了MatchDG学习到的表示质量。
6.3 胸部X射线数据集
表6提供了胸部x射线数据集的结果,其中垂直平移与源域中的类标签的虚假相关性可能导致模型学习不稳定的关系。以RSNA为目标域时,ERM在源域上的准确率为79.8%,在目标域上的准确率为81.8%,在目标域上的准确率为55.1%。相比之下,mdhybrid的分类准确率最高(高于ERM的8%),其次是CSD和Mat chDG;而像ERM和IRM这样的方法更容易受到伪相关的影响。然而,在ChexPert作为目标域时,CSD和IRM比ERM效果更好,而基于匹配的方法效果不佳。我们推测这些变化的趋势可能是由于源域中图像的固有可变性,这表明了为现实世界数据集构建域泛化方法的挑战。
7 结论
我们提出了一个领域泛化的因果观点,提供了一个客观条件的目标。简单的基于匹配的方法在PACS上的性能优于最先进的方法,这表明选择正确的不变性的重要性。提出的MatchDG在对象未知时使用某些假设。胸部x射线数据集的混合结果表明,需要做更多的工作来开发更好的匹配方法。