Invariant Risk Minimization (IRM)

Invariant Risk Minimization

入门

机器学习虽然可以学习到复杂的预测规则,但是容易受到数据中的偏差和虚假关联(spurious)的影响,使其难以达成人工智能的最终目标。这个问题已经被前人使用因果推断研究过,IRM 利用因果工具进一步发展了虚假关联和不变关联(invariant correlation)的数学理论,目标是降低机器学习对数据的过度依赖,提高机器学习模型在新的测试分布上的泛化性。

**虚假关联(spurious correlation):在未来不应该保持和过去一样方式的关联,不表现为稳定的性质(stable properties)。**而现有的数据集都不能够在发现稳定性质上起到amenable(经得起考验的、负责的)的作用。因为在我们利用这些数据时,我们常常将其随机地打乱并划分为训练集和测试集,这种划分基于机器学习通用的假设,即测试数据和训练数据独立地采样自同一分布。然而,打乱数据导致了我们无法观测到在分布变化的过程中,哪些性质是不变的,哪些是虚假的。

本文提出了帮助实现 OOD 泛化的范例:Invariant Risk Minimization (IRM),IRM 基于这样的原则:为了学习到在所有环境中的不变性,我们需要找到一种数据表示,使在这种数据表示上最优的分类器能够匹配所有的环境

环境与泛化

De:={(xie,yie)}i=1ne 表示在环境 eEtr 下采样得到的数据集,其服从分布 P(Xe,Ye)。我们的目标是使用多个这样的数据集,学习到一个预测器 Yf(X),其在未预见但相关联的更大规模的环境(all possible experimental conditions concerning our system of variables) Eall Etr  中表现良好。

Re(f):=EXe,Ye[(f(Xe),Ye)] 视作在环境 e 下的风险(比如训练损失),

ROOD(f)=maxeEallRe(f)

例 1: 考虑一个 X1YX2 的 structual equation model 如下:

X1Gaussian(0,σ2),YX1+Gaussian(0,σ2),X2Y+Gaussian(0,1).

对于这样的系统,环境可以是对 X1X2 变量方程的所有可能的修改,也可以是对 Y 方程中的 σ 的所有可能的取值 [0,σMAX2]。比如,我们可以在环境 e 中将其中的 X2 (记作 X2e)设置为常量,又或者是更改其高斯分布项的方差,以此来表示某个特定环境下 X2e 的取值。

为了预测 Y,我们可以建立回归模型 Y=a^1X1+a^2X2,有多种方法可以求解这个回归模型:

  • X1e 回归,得到 a^1=1a^2=0
  • X2e 回归,得到 a^1=0a^2=σ(e)2σ(e)2+12
  • (X1e,X2e) 回归,得到 a^1=1σ(e)2+1a^2=σ(e)2σ(e)2+1

可以看到,只有第一种情况得到的系数是不受环境影响的(即不变性)。而第二种和第三种回归方式得到的系数都会受到环境影响,进而使其不满足在新的环境下的预测规则。当然,我们也可以不依靠任何特征来直接猜测 Y,其也具有不变性,但这不是我们感兴趣的,因为其预测性能极低。

为什么现有的技术学习不到不变性?

  • 大部分的机器学习普遍使用的是经验风险最小化(Empirical Risk Minimization)原则,在这种情况下,如果环境中的方差 σ 较大,会导致模型对 X2 赋予更大的系数(因为 X2 的第二项的方差只有 1 而 X1 的第二项方差 σ 很大,导致使用 X1 来预测时误差很大),而这违背了不变性原则,因为对 X2 的系数是受到 e 影响的。

  • 即便我们使用鲁棒性学习目标,即 minimize Rrob(f)=maxeEtrRe(f)re,,其中 re 表示环境基线(enviroment baselines)。这样的目标用于最小化跨环境误差的最大值(minimizing the maximum error across enviroments)。然而,其等价于最小化环境的加权平均误差。和使用 ERM 的方法一样,它也没有办法发现我们期望的不变性。

    给定 Karush-Kuhn-Tucker (KKT)条件, 存在 λe0 使得 Rrob 的最小值是 eEtrλeRe(f) 的一阶驻点。

为了解决现有的机器学习无法建立 invariant predictors 的问题,作者提出了 Invariant Risk Minimization (IRM)

Invariant Risk Minimization

统计用语上,IRM 的目标是学习在不同训练环境中不变的关联。也就是找到一种数据表示(data representation),使在这种数据表示上最优的分类器能够匹配所有的环境

Invariant predictor 定义 如果存在一个分类器 w:HY 对所有环境同时是最优的,即 wargminw¯:HYRe(w¯Φ),对所有 eE 都成立,那我们说数据表征 Φ:XH 在环境集合 E 上引发了一个不变的预测器(Invariant predictor) wΦ

上诉定义等价于学习与标签变量拥有稳定关联(stable correlation)的特征

IRM 就是一种用来得到能够引发invariant predictor的数据表征学习方式,其目标可以在数学上表示为一个带约束的优化问题如下:

minΦ:XH,w:HYeEtrRe(wΦ)subject towargminw¯:HYRe(w¯Φ),for all eEtr

这个问题难以求解,进一步发展出了IRMv1:

minΦ:XYeEtrRe(Φ)+λw|w=1.0Re(wΦ)2

也就是变成了一个多目标优化的形式,其中 λ 用于权衡预测性能(也就是 ERM 损失)和 invariance。和其中 Φ 从数据表征变成了一个预测器,也就是此时 w=1,直接输出了数据表征 Φ 的标签变量。第二项的梯度惩罚项用于衡量分类器在不同环境下的最优性。如果在一个环境中分类器仍有较大的梯度,那说明它在这个环境中仍有很大的优化空间。第二项的目的是让不同环境中尽可能达到相同的优化程度,以此说明分类器所发现的特征是跨环境不变的,能够在不同环境中得到最优性相似的解。

IRMv1 将 IRM 变成了一个惩罚项,可以用损失函数的形式表示如下:

LIRM(Φ,w)=eEtrRe(wΦ)+λD(w,Φ,e)

D 是一个用于衡量给定 w,Φ 的情况下,w 与最小化 e 中的风险的距离,其作用是鼓励分类器使用具有跨环境不变性的特征进行预测:如果此时分类器试图通过极端 match 某些环境来降低第一项损失,那他必然导致第二项的增大