目前最新的轻量级网络,整体结构清晰,代码可复现性强,其中的设计思路值得学习。在通道层面做自注意力将复杂度降低到线性的,同时利用深度可分离卷积构建空间上的上下文关系,分开建模两部分的信息。实验效果提升显著,保持高精度的同时大大降低了延迟和运算。

#Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications

Overall Architecture

image-20220623154946031

Conv Encoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class ConvEncoder(nn.Module):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim)
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, expan_ratio * dim)
self.act = nn.GELU()
self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)
return x

SDTA Encoder

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
class SDTAEncoder(nn.Module):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4,
use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., drop=0., scales=1):
super().__init__()
# width:每个分支的对应通道数 scales:几个分支
width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales)))
self.width = width
if scales == 1:
self.nums = 1
else:
self.nums = scales - 1 #需要的卷积数为分支数-1
convs = []
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, groups=width))
self.convs = nn.ModuleList(convs)

self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.norm_xca = LayerNorm(dim, eps=1e-6)
self.gamma_xca = nn.Parameter(layer_scale_init_value * torch.ones(dim),
requires_grad=True) if layer_scale_init_value > 0 else None
self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)

self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, expan_ratio * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU() # TODO: MobileViT is using 'swish'
self.pwconv2 = nn.Linear(expan_ratio * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

def forward(self, x):
input = x

spx = torch.split(x, self.width, 1)
# 以Stange2为例,self.nums=1, i只取0
for i in range(self.nums):
if i == 0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
if i == 0:
out = sp
else:
out = torch.cat((out, sp), 1)
x = torch.cat((out, spx[self.nums]), 1)
# XCA
B, C, H, W = x.shape
x = x.reshape(B, C, H * W).permute(0, 2, 1)
if self.pos_embd:
pos_encoding = self.pos_embd(B, H, W).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
x = x.reshape(B, H, W, C)

# Inverted Bottleneck
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

x = input + self.drop_path(x)

return x

#通道方向的自注意力
class XCA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)

# (H*W, C) -> (C, H*W)
q = q.transpose(-2, -1)
k = k.transpose(-2, -1)
v = v.transpose(-2, -1)

q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)

attn = (q @ k.transpose(-2, -1)) * self.temperature
# -------------------
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
# ------------------
x = self.proj(x)
x = self.proj_drop(x)

return x

@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}

EdgeNeXt

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class EdgeNeXt(nn.Module):
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[24, 48, 88, 168],
global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'],
drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4,
kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False],
use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], **kwargs):
super().__init__()
for g in global_block_type:
assert g in ['None', 'SDTA']
if use_pos_embd_global:
self.pos_embd = PositionalEncodingFourier(dim=dims[0])
else:
self.pos_embd = None
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)

self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage_blocks = []
for j in range(depths[i]):
if j > depths[i] - global_block[i] - 1:
if global_block_type[i] == 'SDTA':
stage_blocks.append(SDTAEncoder(dim=dims[i], drop_path=dp_rates[cur + j],
expan_ratio=expan_ratio, scales=d2_scales[i],
use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i]))
else:
raise NotImplementedError
else:
stage_blocks.append(ConvEncoder(dim=dims[i], drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
expan_ratio=expan_ratio, kernel_size=kernel_sizes[i]))

self.stages.append(nn.Sequential(*stage_blocks))
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # Final norm layer
self.head = nn.Linear(dims[-1], num_classes)

self.apply(self._init_weights)
self.head_dropout = nn.Dropout(kwargs["classifier_dropout"])
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)

def _init_weights(self, m): # TODO: MobileViT is using 'kaiming_normal' for initializing conv layers
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (LayerNorm, nn.LayerNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)

def forward_features(self, x):
x = self.downsample_layers[0](x)
x = self.stages[0](x)
if self.pos_embd:
B, C, H, W = x.shape
x = x + self.pos_embd(B, H, W)
for i in range(1, 4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)

return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C)

def forward(self, x):
x = self.forward_features(x)
x = self.head(self.head_dropout(x))
return x

Experiments

image-20220623164213601

image-20220623164312032

论文地址

source code