Improving Robustness of Vision Transformers by Reducing Sensitivity to Patch Corruptions

Improving Robustness of Vision Transformers by Reducing Sensitivity to Patch Corruptions

Transformer 对 Patch 的敏感性

英文表述为 Sensitivity of Vision Transformers to Pathes.

生成对抗扰动是非常耗费计算资源的,尤其对于大规模的数据集这几乎是不可行的。一种方法直接引入 Corruptions。然而,一些 Corruptions 的方法作用有线,如使用加性的随机噪声(右 2),只能降低模型的 2.4% 的置信度,即便其用的是 ImageNet-C 中 Serverity 最高的 Corruption 级别。这可以归因于受到 Corruption 后的 Patches 仍然留有足够的信息。

image-20240603210827312

这篇论文提出了一个新的 corruption 方法,将一些 patch 直接替换为完全随机的噪声(右 1),置信度直接降低 46.5%。数学形式可以表达为原始样本 xx 、蒙版 M(x)M(x) 和噪声 δ\delta 的乘法操作:

x^=M(x)x+(1M(x))δ\hat x = M(x) \cdot x + (1-M(x)) \cdot \delta

这些损坏的 patch 可以显著地分散模型的中间注意力层的注意力,使得模型对这些图片的预测难度增加,进而可以用这些图片作为提高整体鲁棒性的手段。并且这种方法非常高效,不用耗费非常多的计算资源。作者还提到了,与直接丢弃某些 patch 相比,这种方法生成的图片更难被模型准确地预测。

image-20240604100457038

寻找易受攻击的 Patch 进行损坏

用下面的目标训练一个 Patch Corruption Model,来预测哪些块是易受攻击的。Fl(x)\mathcal F_l(x) 表示的是样本 xx 在 ViT 获得的第 ll 层的特征。

maxCExDLalign (x,x^)where Lalign (x,x^)=1Ll=1LFl(x)Fl(x^)2\begin{aligned} &\max _{\mathcal{C}} \mathbb{E}_{x \sim \mathcal{D}} \mathcal{L}_{\text {align }}(x, \hat{x}) \\ \text{where } &\mathcal{L}_{\text {align }}(x, \hat{x})=\frac{1}{L} \sum_{l=1}^L\left\|\mathcal{F}_l(x)-\mathcal{F}_l(\hat{x})\right\|^2 \end{aligned}

image-20240604103435186

也就是对于 clean image 和 corrupted image,patch corruption model 的目标就是生成蒙版并根据蒙版为图像添加噪声,使得 corrupted 的图像与原始图像的中间特征对齐。其实也就是要让 ViT 在中间特征上尽可能区别开原始图像和损坏图像,进而达到预测错误的目的。

Patch Corruption Model 由一个卷积层以及后面的全连接层组成,在得出结果后进行阈值二值化,阈值的确定来自于预先设定的损坏块的比例 ρ\rho

Tip

感觉存在的缺点:

  • 需要额外训练一个神经网络,并且这个神经网络的损失还需要通过 ViT 的多次推理来确定;
  • 直接要求 Lalign\mathcal L_{align} 最大的方法缺乏可解释性,可能会出现一些关键特征被 occlude 的情况,此时如果还要求模型按照原本的预测进行输出,是否不妥?

训练策略

在训练上,对抗训练 patch corruption model C\mathcal C 和 classification model F\mathcal F

minFmaxCExD[Lce(x)+λLalign (x,x^)]\min _{\mathcal{F}} \max _{\mathcal{C}} \mathbb{E}_{x \sim \mathcal{D}}\left[\mathcal{L}_{\mathrm{ce}}(x)+\lambda \mathcal{L}_{\text {align }}(x, \hat{x})\right]

λ\lambda 决定 Lalign\mathcal L_{align} 的重要性。先构造损坏样本,用损坏样本去更新分类器,然后再根据分类器来更新损坏模型。

image-20240604113917654

实验

数据集:CIFAR-10 CIFAR-100 ImageNet CIFAR-10-C CIFAR-100-C ImageNet-A ImageNet-C ImageNet-P

对比:

  • 鲁棒性较高的 RVT FAN
  • 与 CNN 比较采用的是 ResNet50

从参数量上考虑,使用了 RVT-S 和 FAN-S-Hybrid 来作为 baselines,因为他们与流行的 CNN 和 Transformer 模型的参数量相近。

实验设置:使用 DeepAugment 增强数据并训练 200 轮模型。对于 RVT FAN 分别在加入和不加入该论文的方法的情况下进行训练。

image-20240604130225686