手推 XGBoost 损失函数

前言

集成学习的博客中,我推导了一次 XGBoost,但是那篇的重点是集成学习,所以写得比较宽泛,以至于在面试的时候忽然间不清楚面试官问的问题,即 XGBoost 损失函数的推导。一直纠结于二阶泰勒展开与正则项,却不明白面试官所问的是进一步的推导,这包括:

  1. 叶子节点最优分数推导;
  2. 叶子节点代入损失函数后的损失函数推导;
  3. 某个节点分裂后的损失

故新开一个博客,用于记录 XGBoost 推导的全过程。

二阶泰勒展开

Note

二阶泰勒展开:

f(x+Δx)f(x)+f(x)Δx+12f(x)(Δx)2

对于第 t 棵树 ft(x),其损失函数 L(t) 为:

(1)L(t)=i=1nl(yi,y^i(t1)+ft(xi))+Ω(ft)

其中:

  • 当前预测:之前 t1 棵树的预测结果 y^(t1) 加上当前这棵树的预测结果 ft(xi)
  • 标签: yi
  • 正则项:Ω(ft)=γT+12λj=1Twj2,其中 T 是当前第 t 棵树的叶子树,可以看到正则项只和当前这棵树有关,且与样本无关;
  • 每个样本的损失累加作为最终损失。

因为 y^(t1) 是已经确定的了,我们可以计算得到 l(yi,y^(t1))。所以,对 (1) 中的 ly^i(t1) 处做泰勒展开,得到了:

L(t)i=1n[l(yi,y^(t1))+gift(xi)+12hift2(xi)]+Ω(ft)

其中

gi=l(yi,y^(t1))y^(t1)hi=2l(yi,y^(t1))y^(t1)2

后面我们将直接忽略上面的近似等于符号,直接将损失函数视为展开后的形式。我们可以看到,现在损失被拆分为了三种项:

  1. t1 棵树的累加的预测结果的损失 l(yi,y^(t1))

  2. 一阶项 gift(xi),二阶项 12hift2(xi),其中一阶导和二阶导是已经确定的常数

    Caution

    1. 这俩都是损失函数对 y^(t1)的偏导乘以 ft(xi)。在泰勒展开原始形式中,Δx 就是此时的 ft(xi)

    2. 偏导是对先前的预测结果做偏导,而不是对样本做偏导;

  3. 正则化项;

可以看出,上面的项中,1 是已经确定的了,在训练第 t 棵树的时候,我们不会再去修改前面的 t1 棵树。我们在优化的时候无需考虑这个常数项。再将正则化项展开,得到了:

(2)L(t)=i=1n[gift(xi)+12hift2(xi)]+γT+12λj=1Twj2

叶子最优分数推导

(2) 式子中,我们可以看出,式子中存在样本的累加项和叶子节点的累加项。我们可以根据样本被分配到的叶子节点,建立起叶子节点和样本之间的联系:

  • 对于第 j 个叶子,其所包含的样本集合是 Ij,即 Ij={iq(xi)=j}q(x) 表示样本被划分到的叶子节点,即决策树的结构;
  • 对于第 i 个样本,如果它被划分到 Ij,则其预测值是该节点的分数 wj,即对于 iIj,我们有 ft(xi)=wj
  • 在之后,我们都用 i 作为样本的编号,用 j 作为叶子的编号;

我们可以将 (2) 进一步合并同类项,其思路是以叶子节点作为外部累加,对于每个叶子节点,将样本的预测值 ft(xi) 变成叶子的分数 wj,这样就和正则项有重合的项。于是,变成:

(3)L(t)=i=1n[gift(xi)+12hift2(xi)]+γT+12λj=1Twj2=j=1T[wjiIjgi+12wj2iIjhi]+γT+12λj=1Twj2=j=1T[wjiIjgi+12wj2(iIjhi+λ)]+γT

对于一个固定的结构 q(x),我们可以计算出 L(t) 对每个叶子节点分数 wj 的极值点,也就是另 L(t)wj 的偏导为 0。对 wj 的偏导形式可以非常简化,因为:

  1. 其他叶子节点的项与 wj 无关,求导后可以直接删去;
  2. 叶子结点数量与 wj 无关,可以直接删除;
  3. 一阶项求导只剩下系数;
  4. 二阶项求导后平方与 1/2 相乘变成 1

可得:

L(t)wj=iIjgi+wj(iIjhi+λ)

另其为 0,易得:

(4)wj=iIjgiiIjhi+λ

此即为叶子节点的最优分数推导结果,可以看到其与一些因素有关:

  1. 正则项中的 λ
  2. 损失函数对叶子中的所有节点先前预测的一阶导和二阶导。

叶子节点代入损失函数

将每个叶子结点的最优分数代入损失函数,也就是 (4) 代入 (3)

简化一些累加项方便计算:

  • Gj=iIjgi
  • Hj=iIjhi

可以得到:

(5)L(t)=j=1T[wjGj+12wj2(Hj+λ)]+γT=j=1T[(GjHj+λ)Gj+12(GjHj+λ)2(Hj+λ)]+γT=j=1T[(GjHj+λ)Gj+12(GjHj+λ)2(Hj+λ)]+γT=12j=1T[Gj2Hj+λ]+γT

故 XGBoost 第 t 棵树的损失函数是:

L(t)=12j=1TGj2Hj+λ+γT

可以看到,其损失函数只受到以下因素的影响:

  1. 损失函数对 y^(t1) 的偏导,一阶偏导越大,损失越大,二阶偏导越小,损失越小;
  2. 叶子数量越多,损失越大;

因此,其计算非常快,因为损失函数对 y^(t1) 的偏导是可以提前计算好的,只需要根据树当前的叶子节点内的样本集合,快速累加即可。

节点分裂后的损失

一个节点 I,分裂后变成 ILIR,即 I=ILIR,于是乎,对于某个节点,其分裂后,整棵树的损失变化为

Lsplit=12[(iILgi)2iILhi+λ+(iIRgi)2iIRhi+λ(iIgi)2iIhi+λ]γ

分裂后,叶子节点从 1 个变成了 2 个,所以要多减去一个 γ。哪个属性带来的分裂损失最小,就选择这个属性分裂。

参考资料

XGBoost: A Scalable Tree Boosting System