Pytorch Image Models (timm) 使用备忘

怎么查看某个模型的代码

1
site_packages/timm/models

添加代理

1
2
os.environ["http_proxy"] = "http://127.0.0.1:7890"
os.environ["https_proxy"] = "http://127.0.0.1:7890"

ViT 代码解读

File: timm.vision_transformer

Forward Pass

每个 Block 的 forward pass:

1
2
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))

self.attn 是注意力模块,ls 是 LayerScale,drop_path 是 DropPath

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def forward_features(self, x):
x = self.patch_embed(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
# 蒸馏时的 token
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
# 经过每个 Transformer Block 后进行规范化
x = self.blocks(x)
x = self.norm(x)
if self.dist_token is None:
# self.pre_logits 学习输入最终MLP的表示
# 一般由一个线性连接层和一个激活函数组成
# 也可以直接使用 nn.Identity() 来跳过这一步
return self.pre_logits(x[:, 0])
else:
# x[:, 0] Classification token
# x[:, 1] Distillation token
return x[:, 0], x[:, 1]

def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x

Attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x

dim: 输入到多头自注意力机制的每个 token 的特征数量。

num_heads: 多头自注意力中的头数。

qkv_bias: 控制是否为 Query、Key 和 Value(注意力机制中的三个主要向量)引入偏差项。偏差项是一个可学习的常数,用来解决对称性问题。

偏差项(bias)通常在神经网络中的线性变换中使用。它是一个与输入无关的可学习参数,加到线性变换的结果上。在注意力机制中,Query、Key 和 Value 向量是通过线性变换生成的(即:输入向量乘以权重矩阵),而偏差项就是这个过程中加到结果上的一部分。其目的是帮助模型在初始状态下避免对称性问题,使得每个神经元在初始时拥有不同的输出,即便输入是相同的,这有助于更好地训练网络。

什么是对称性问题?

神经元的初始化相同,且接收相同的输入,会导致输出以及反向传播的梯度相同,使得这些神经元在训练中始终保持相同,这就是所谓的对称性问题:神经元无法学习到不同的特征或表达方式,因为它们从头到尾都在执行相同的计算

qk_norm: 是否对 Query 和 Key 向量进行归一化。归一化可以帮助提高训练的稳定性,并减少梯度爆炸或消失的风险。

attn_drop: 注意力权重的 dropout 概率。在计算注意力分数之后,有时会对其应用 dropout 来防止模型过拟合,并增强模型的泛化能力。这个参数控制了应用 dropout 的概率。

proj_drop: 投影层的 dropout 概率。在计算完注意力权重并进行投影操作后,会对投影结果应用 dropout,这个参数控制该 dropout 的概率。

norm_layer: 控制规范化层的类型。 规范化层通常用来提高训练稳定性,常见的规范化方式有 LayerNorm 或 BatchNorm。这个参数定义了 Attention 模块中使用哪种类型的规范化操作。

LayerScale

1
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()

self.ls1 是一个 LayerScale 机制,用来对自注意力的输出进行缩放,起到了层间归一化的作用。具体而言,LayerScale 引入了一个可学习的缩放系数,对自注意力的输出进行缩放后再输入 MLP,可以理解为一种可学习的归一化。如果不使用 LayerScale 则使用 nn.Identity(),即不进行缩放。

LayerScale 可以被定义为以下形式:

Output=F(x)A\text{Output} = F(x) A

其中:

  • $ F(x) $ 表示当前层中的变换(如自注意力层或 MLP 层)的输出;
  • $ A $ 是一个可学习的缩放因子对角矩阵,其维度与 $ F(x) $ 的输出维度相同;

DropPath

1
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

在 Transformer 或 ResNet 等架构中,计算流中的路径可以是残差连接、子模块之间的连接等。通过随机丢弃这些路径,可以让模型在训练过程中使用不同的计算路径,从而提高模型的泛化能力。

正常的残差连接:

 Output =x+F(x)\text { Output }=x+F(x)

Drop Path 后的:

 Output =x+pF(x)\text { Output }=x+p \cdot F(x)

pp 是一个在训练过程中随机生成的二值变量,它的取值为 0 或 1,表示当前的残差路径是否被丢弃。

Attention Dropout

在计算出 QK 的点积之后进行 Dropout 再乘以 V

1
2
3
4
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v

常见的数字及其含义

197: Patches 总数,196 + 1 (cls_token)

196: 14x14 14是 长宽/patchsize

64: 单个头的 q 的长度,k,v 同

12: heads 数量

768: 12 x 64,每个头的注意力图的 concat

训练

Vision Transformer (google/vit-base-patch16-224)在 ImageNet-21K (大约1400万张,2.1万个类)上预训练,然后再在 ISLVRC2012 (ImageNet-1k,训练集大约是1281167张+标签,验证集是50000张图片加标签,最终打分的测试集是10w 张图片,一共1000个类别) 上微调。

img

PiT

Pooling-based Transformer

image-20240921201705539

池化层和 transformer 组成一个块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Transformer(nn.Module):
def __init__(
self,
base_dim,
depth,
heads,
mlp_ratio,
pool=None,
proj_drop=.0,
attn_drop=.0,
drop_path_prob=None,
norm_layer=None,
):
super(Transformer, self).__init__()
embed_dim = base_dim * heads
# 池化层
self.pool = pool
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
# 普通的 Transformer 块,包含 attn 和 mlp
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim,
num_heads=heads,
mlp_ratio=mlp_ratio,
qkv_bias=True,
proj_drop=proj_drop,
attn_drop=attn_drop,
drop_path=drop_path_prob[i],
norm_layer=partial(nn.LayerNorm, eps=1e-6)
)
for i in range(depth)])

def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# 先池化一遍,然后再经过 Transformer Blocks
x, cls_tokens = x
token_length = cls_tokens.shape[1]
if self.pool is not None:
# Pool 内部对 x 进行池化,对 cls_token 进行线性映射
# x = self.conv(x)
# cls_token = self.fc(cls_token)
x, cls_tokens = self.pool(x, cls_tokens)

B, C, H, W = x.shape
# x.flatten(2) transforms x into shape (B, C, H * W), where the height and width are combined into a single dimension.
# .transpose(1, 2) swaps the channel and spatial dimensions to get shape (B, H * W, C) because transformers typically expect input where the second dimension represents the sequence length (here, the number of spatial tokens).
x = x.flatten(2).transpose(1, 2)
x = torch.cat((cls_tokens, x), dim=1)

x = self.norm(x)
x = self.blocks(x)

cls_tokens = x[:, :token_length]
x = x[:, token_length:]
x = x.transpose(1, 2).reshape(B, C, H, W)

return x, cls_tokens

多个上面的块结合成 PiT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
...
for i in range(len(depth)):
pool = None
embed_dim = base_dims[i] * heads[i]
if i > 0: # 第 0 层的前面没有池化
pool = Pooling(
prev_dim,
embed_dim,
stride=2,
)
transformers += [Transformer(
base_dims[i],
depth[i], # 两个池化层之间的 Transformer Block 数量
heads[i],
mlp_ratio,
pool=pool,
proj_drop=proj_drop_rate,
attn_drop=attn_drop_rate,
drop_path_prob=dpr[i],
)]
prev_dim = embed_dim
self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')]
...

Cait

self.blocks 是 正常的 self attention

self.blocks_token_only 是 class attention

先经过 self attention,再经过 class attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to do CA
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
# x[:, 0] 是 cls_token,在这部分中,q 是只由 cls_token 决定的
# q 的形状是 Batch x 1 x heads x head_length (C // heads)
# C 是 embedding 的长度
# 虽然 3 维的 q 也能完整地存储 cls_token 的 q,但将 q 作为 4 维可以方便后面 q@k 的运算,因为 k 一定是 4 维,k 多一个维度是为了存储各种 patch 的 k
# permute(0, 2, 1, 3) 后形状变成了 B x heads x 1 x head_length
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# k 是正常的 self attention 的 k,是 cls_token 和 patch 共同决定的
# k,v: B x heads x N x head_length
k = self.k(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
q = q * self.scale
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# attn: B x heads x 1 x N
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C)
x_cls = self.proj(x_cls)
x_cls = self.proj_drop(x_cls)

return x_cls