Patch-fool: Are vision transformers always robust against adversarial perturbations?

Conclusion

ViT 的对抗鲁棒性并不总高于 CNN

每个 patch 的干扰像素的数量高度影响 ViT 和 CNN 之间的鲁棒性排名

提出了 Patch Fool 方法及其两个变种,总结就是对部分patch、patch 内的部分像素做干扰,在交叉熵损失函数后添加了一项用于提高所选patch对其它patch的注意力值的项

一些工作指出,在足够数量的数据集上训练的 ViT,在大部分的情况下至少与和它相对应的 ResNet 一样鲁棒,无论是在自然损坏、分布偏移还是对抗性干扰的场景中 [1]。在 Lp 攻击下,ViT 比 CNN 更加鲁棒 [2]。ViT 学习更少的低级特征并且更具有泛化性,使得他们更加鲁棒。如果在 ViT 中加入卷积块,则会降低它的鲁棒性 [3]。

[1] Bhojanapalli S, Chakrabarti A, Glasner D, et al. Understanding robustness of transformers for image classification[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 10231-10241.

[2] Aldahdooh, Ahmed, Wassim Hamidouche, and Olivier Deforges. “Reveal of vision transformers robustness against adversarial attacks.” arXiv preprint arXiv:2106.03734 (2021).

[3] Shao R, Shi Z, Yi J, et al. On the adversarial robustness of vision transformers[J]. arXiv preprint arXiv:2103.15670, 2021.

Patch Fool

Attack setup

符号:

  • 由多个 patch 组成的图片 X=[x1,,xn]R[n×d]n 是 patch 的数量,d 是每个 patch 的维度
  • 对抗性干扰 E
  • one hot vector 1pRn,只有第 p 个元素是 1
  • 损失函数 J

攻击的目标是:

argmax1pn,ERn×dJ(X+1pE,y)

其中 是 penetrating face product,就是用 1p 乘以 E 的每个列向量(相同维度下的逐元素相乘)理解:对于每个图片,不止需要确定我们要干扰哪个 patch,而且要确定我们的具体干扰。

根据 Attention 确定 1p

aj(l,h,i) 表示第 j 个 patch 在第 l 层上第 h 个头对第 i 个 patch 的注意力值。用 sj(l) 表示第 j 个 patch 在第 l 层对其它 patches 的注意力值的总和。在某一层,挑选最大的 sj(l) 中的 j 作为进行干扰的 patch。

sj(l)=h,iaj(l,h,i)

论文中,将 l 设置为 5,因为在后面的层中第 j 个 patch 可能混杂了很多其它 patch 的信息导致其不能代表我们输入时的那个 patch。

Attention-aware Loss

为了最大化选定的 patch 的影响,论文选择用一个损失函数来最大化这个 patch 对其它 patch 的注意力,也就是在 l 层中,我们要最大化以下的损失函数:

JATTN (l)(X,p)=h,iap(l,h,i)

将上面的损失与最大化分类差距的交叉熵 JCE 相结合,形成了:

J(X~,y,p)=JCE(X~,y)+αlJATTN(l)(X~,p)

为了解决这两个损失的梯度冲突,论文使用了 PCGrad 来更新干扰,也就是对于 E 的梯度 δE,我们计算其为:

δE=EJ(X~,y,p)αlβlEJCE(X~,y)βl={0,EJCE(X~,y),EJATTN(l)(X~,p)>0EJCE(X~,y),EJATTN (l)(X~,p)EJCE(X~,y)2, otherwise 

作者还进一步使用了 Adam 优化器来根据梯度确定更新 E 的方向

Et+1=Et+ηAdam(δEt)

Variant

需要对 patch 中的多少像素做干扰才能可以?

  • 只对很少的像素做干扰 (sparse patch fool)
  • 对整个 patch 做干扰(前文所述的 patch fool)

Sparse patch fool 使用 mask M 来决定只对哪些像素做干扰,并把它作为一个可学习的参数,在反向传播的时候,使用直通估计器来更新 M 中的元素,也就是使用损失对 (ME) 的梯度作为对 M 的梯度。

argmax1pn,ERn×d,M{0,1}n×dJ(X+1p(ME), y) s.t. M0k

MILD Patch-Fool

为了符合 LP 攻击方法的范式,可以在每次更新 E 后对其进行缩放(L2) 或 Clip (L\infin

Experiments

PR 为 patch 内干扰的像素比例,#Patch 为干扰的 patch 数量,指标是鲁班准确率。可以看到 DeiT 和 ResNet 在不同的 PR 和 #Patch 下的鲁棒性会有差异,没有谁恒优于谁。加粗的是比较差的。

image-20241021151856500

下面是 L\infin 的 MILD Patch-Fool 的结果

image-20241021152110118