├── ConTNet.py
├── README.md
├── arch5.png
├── block2.png
├── block3.png
├── criterion.py
├── data.py
├── lr_scheduler.py
├── main.py
├── optimizer.py
└── utils.py
/ConTNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch
4 |
5 | from einops.layers.torch import Rearrange
6 | from einops import rearrange
7 |
8 | import numpy as np
9 |
10 | from typing import Any, List
11 | import math
12 | import warnings
13 | from collections import OrderedDict
14 |
15 | __all__ = ['ConTBlock', 'ConTNet']
16 |
17 |
18 | r""" The following trunc_normal method is pasted from timm https://github.com/rwightman/pytorch-image-models/tree/master/timm
19 | """
20 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
21 |
22 | # Cut & paste from PyTorch official master until it's in a few official releases - RW
23 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
24 | def norm_cdf(x):
25 | # Computes standard normal cumulative distribution function
26 | return (1. + math.erf(x / math.sqrt(2.))) / 2.
27 |
28 | if (mean < a - 2 * std) or (mean > b + 2 * std):
29 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
30 | "The distribution of values may be incorrect.",
31 | stacklevel=2)
32 |
33 | with torch.no_grad():
34 | # Values are generated by using a truncated uniform distribution and
35 | # then using the inverse CDF for the normal distribution.
36 | # Get upper and lower cdf values
37 | l = norm_cdf((a - mean) / std)
38 | u = norm_cdf((b - mean) / std)
39 |
40 | # Uniformly fill tensor with values from [l, u], then translate to
41 | # [2l-1, 2u-1].
42 | tensor.uniform_(2 * l - 1, 2 * u - 1)
43 |
44 | # Use inverse cdf transform for normal distribution to get truncated
45 | # standard normal
46 | tensor.erfinv_()
47 |
48 | # Transform to proper mean, std
49 | tensor.mul_(std * math.sqrt(2.))
50 | tensor.add_(mean)
51 |
52 | # Clamp to ensure it's in the proper range
53 | tensor.clamp_(min=a, max=b)
54 | return tensor
55 |
56 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
57 | # type: (Tensor, float, float, float, float) -> Tensor
58 | r"""Fills the input Tensor with values drawn from a truncated
59 | normal distribution. The values are effectively drawn from the
60 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
61 | with values outside :math:`[a, b]` redrawn until they are within
62 | the bounds. The method used for generating the random values works
63 | best when :math:`a \leq \text{mean} \leq b`.
64 | Args:
65 | tensor: an n-dimensional `torch.Tensor`
66 | mean: the mean of the normal distribution
67 | std: the standard deviation of the normal distribution
68 | a: the minimum cutoff value
69 | b: the maximum cutoff value
70 | Examples:
71 | >>> w = torch.empty(3, 5)
72 | >>> nn.init.trunc_normal_(w)
73 | """
74 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
75 |
76 | def fixed_padding(inputs, kernel_size, dilation):
77 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
78 | pad_total = kernel_size_effective - 1
79 | pad_beg = pad_total // 2
80 | pad_end = pad_total - pad_beg
81 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end))
82 | return padded_inputs
83 |
84 | class ConvBN(nn.Sequential):
85 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1, bn=True):
86 | padding = (kernel_size - 1) // 2
87 | if bn:
88 | super(ConvBN, self).__init__(OrderedDict([
89 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
90 | padding=padding, groups=groups, bias=False)),
91 | ('bn', nn.BatchNorm2d(out_planes))
92 | ]))
93 | else:
94 | super(ConvBN, self).__init__(OrderedDict([
95 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride,
96 | padding=padding, groups=groups, bias=False)),
97 | ]))
98 |
99 | class MHSA(nn.Module):
100 | r"""
101 | Build a Multi-Head Self-Attention:
102 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
103 | """
104 | def __init__(self,
105 | planes,
106 | head_num,
107 | dropout,
108 | patch_size,
109 | qkv_bias,
110 | relative):
111 | super(MHSA, self).__init__()
112 | self.head_num = head_num
113 | head_dim = planes // head_num
114 | self.qkv = nn.Linear(planes, 3*planes, bias=qkv_bias)
115 | self.relative = relative
116 | self.patch_size = patch_size
117 | self.scale = head_dim ** -0.5
118 |
119 | if self.relative:
120 | # print('### relative position embedding ###')
121 | self.relative_position_bias_table = nn.Parameter(
122 | torch.zeros((2 * patch_size - 1) * (2 * patch_size - 1), head_num))
123 | coords_w = coords_h = torch.arange(patch_size)
124 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
125 | coords_flatten = torch.flatten(coords, 1)
126 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
127 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
128 | relative_coords[:, :, 0] += patch_size - 1
129 | relative_coords[:, :, 1] += patch_size - 1
130 | relative_coords[:, :, 0] *= 2 * patch_size - 1
131 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
132 | self.register_buffer("relative_position_index", relative_position_index)
133 | trunc_normal_(self.relative_position_bias_table, std=.02)
134 |
135 | self.attn_drop = nn.Dropout(p=dropout)
136 | self.proj = nn.Linear(planes, planes)
137 | self.proj_drop = nn.Dropout(p=dropout)
138 |
139 | def forward(self, x):
140 | B, N, C, H = *x.shape, self.head_num
141 | # print(x.shape)
142 | qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4) # x: (3, B, H, N, C//H)
143 | q, k, v = qkv[0], qkv[1], qkv[2] # x: (B, H, N, C//N)
144 |
145 | q = q * self.scale
146 | attn = (q @ k.transpose(-2, -1)) # attn: (B, H, N, N)
147 |
148 | if self.relative:
149 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
150 | self.patch_size ** 2, self.patch_size ** 2, -1)
151 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
152 | attn = attn + relative_position_bias.unsqueeze(0)
153 |
154 | attn = attn.softmax(dim=-1)
155 | attn = self.attn_drop(attn)
156 |
157 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
158 | x = self.proj(x)
159 | x = self.proj_drop(x)
160 |
161 | return x
162 |
163 | class MLP(nn.Module):
164 | r"""
165 | Build a Multi-Layer Perceptron
166 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
167 | """
168 | def __init__(self,
169 | planes,
170 | mlp_dim,
171 | dropout):
172 | super(MLP, self).__init__()
173 |
174 | self.fc1 = nn.Linear(planes, mlp_dim)
175 | self.act = nn.GELU()
176 | self.fc2 = nn.Linear(mlp_dim, planes)
177 | self.drop = nn.Dropout(dropout)
178 |
179 | def forward(self, x):
180 | x = self.fc1(x)
181 | x = self.act(x)
182 | x = self.drop(x)
183 | x = self.fc2(x)
184 | x = self.drop(x)
185 |
186 | return x
187 |
188 |
189 | class STE(nn.Module):
190 | r"""
191 | Build a Standard Transformer Encoder(STE)
192 | input: Tensor (b, c, h, w)
193 | output: Tensor (b, c, h, w)
194 | """
195 | def __init__(self,
196 | planes: int,
197 | mlp_dim: int,
198 | head_num: int,
199 | dropout: float,
200 | patch_size: int,
201 | relative: bool,
202 | qkv_bias: bool,
203 | pre_norm: bool,
204 | **kwargs):
205 | super(STE, self).__init__()
206 | self.patch_size = patch_size
207 | self.pre_norm = pre_norm
208 | self.relative = relative
209 |
210 | self.flatten = nn.Sequential(
211 | Rearrange('b c pnh pnw psh psw -> (b pnh pnw) psh psw c'),
212 | )
213 | if not relative:
214 | self.pe = nn.ParameterList(
215 | [nn.Parameter(torch.zeros(1, patch_size, 1, planes//2)), nn.Parameter(torch.zeros(1, 1, patch_size, planes//2))]
216 | )
217 | self.attn = MHSA(planes, head_num, dropout, patch_size, qkv_bias=qkv_bias, relative=relative)
218 | self.mlp = MLP(planes, mlp_dim, dropout=dropout)
219 | self.norm1 = nn.LayerNorm(planes)
220 | self.norm2 = nn.LayerNorm(planes)
221 |
222 | def forward(self, x):
223 | bs, c, h, w = x.shape
224 | patch_size = self.patch_size
225 | patch_num_h, patch_num_w = h // patch_size, w // patch_size
226 |
227 | x = (
228 | x.unfold(2, self.patch_size, self.patch_size)
229 | .unfold(3, self.patch_size, self.patch_size)
230 | ) # x: (b, c, patch_num, patch_num, patch_size, patch_size)
231 | x = self.flatten(x) # x: (b, patch_size, patch_size, c)
232 | ### add 2d position embedding ###
233 | if not self.relative:
234 | x_h, x_w = x.split(c // 2, dim=3)
235 | x = torch.cat((x_h + self.pe[0], x_w + self.pe[1]), dim=3) # x: (b, patch_size, patch_size, c)
236 |
237 | x = rearrange(x, 'b psh psw c -> b (psh psw) c')
238 |
239 | if self.pre_norm:
240 | x = x + self.attn(self.norm1(x))
241 | x = x + self.mlp(self.norm2(x))
242 | else:
243 | x = self.norm1(x + self.attn(x))
244 | x = self.norm2(x + self.mlp(x))
245 |
246 | x = rearrange(x, '(b pnh pnw) (psh psw) c -> b c (pnh psh) (pnw psw)', pnh=patch_num_h, pnw=patch_num_w, psh=patch_size, psw=patch_size)
247 |
248 | return x
249 |
250 | class ConTBlock(nn.Module):
251 | r"""
252 | Build a ConTBlock
253 | """
254 | def __init__(self,
255 | planes: int,
256 | out_planes: int,
257 | mlp_dim: int,
258 | head_num: int,
259 | dropout: float,
260 | patch_size: List[int],
261 | downsample: nn.Module = None,
262 | stride: int=1,
263 | last_dropout: float=0.3,
264 | **kwargs):
265 | super(ConTBlock, self).__init__()
266 | self.downsample = downsample
267 | self.identity = nn.Identity()
268 | self.dropout = nn.Identity()
269 |
270 | self.bn = nn.BatchNorm2d(planes)
271 | self.relu = nn.ReLU(inplace=True)
272 | self.ste1 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[0], **kwargs)
273 | self.ste2 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[1], **kwargs)
274 |
275 | if stride == 1 and downsample is not None:
276 | self.dropout = nn.Dropout(p=last_dropout)
277 | kernel_size = 1
278 | else:
279 | kernel_size = 3
280 |
281 | self.out_conv = ConvBN(planes, out_planes, kernel_size, stride, bn=False)
282 |
283 | def forward(self, x):
284 | x_preact = self.relu(self.bn(x))
285 | identity = self.identity(x)
286 |
287 | if self.downsample is not None:
288 | identity = self.downsample(x_preact)
289 |
290 | residual = self.ste1(x_preact)
291 | residual = self.ste2(residual)
292 | residual = self.out_conv(residual)
293 | out = self.dropout(residual+identity)
294 |
295 | return out
296 |
297 | class ConTNet(nn.Module):
298 | r"""
299 | Build a ConTNet backbone
300 | """
301 | def __init__(self,
302 | block,
303 | layers: List[int],
304 | mlp_dim: List[int],
305 | head_num: List[int],
306 | dropout: List[float],
307 | in_channels: int=3,
308 | inplanes: int=64,
309 | num_classes: int=1000,
310 | init_weights: bool=True,
311 | first_embedding: bool=False,
312 | tweak_C: bool=False,
313 | **kwargs):
314 | r"""
315 | Args:
316 | block: ConT Block
317 | layers: number of blocks at each layer
318 | mlp_dim: dimension of mlp in each stage
319 | head_num: number of head in each stage
320 | dropout: dropout in the last two stage
321 | relative: if True, relative Position Embedding is used
322 | groups: nunmber of group at each conv layer in the Network
323 | depthwise: if True, depthwise convolution is adopted
324 | in_channels: number of channels of input image
325 | inplanes: channel of the first convolution layer
326 | num_classes: number of classes for classification task
327 | only useful when `with_classifier` is True
328 | with_avgpool: if True, an average pooling is added at the end of resnet stage5
329 | with_classifier: if True, FC layer is registered for classification task
330 | first_embedding: if True, a conv layer with both stride and kernel of 7 is placed at the top
331 | tweakC: if true, the first layer of ResNet-C replace the ori layer
332 | """
333 |
334 | super(ConTNet, self).__init__()
335 | self.inplanes = inplanes
336 | self.block = block
337 |
338 | # build the top layer
339 | if tweak_C:
340 | self.layer0 = nn.Sequential(OrderedDict([
341 | ('conv_bn1', ConvBN(in_channels, inplanes//2, kernel_size=3, stride=2)),
342 | ('relu1', nn.ReLU(inplace=True)),
343 | ('conv_bn2', ConvBN(inplanes//2, inplanes//2, kernel_size=3, stride=1)),
344 | ('relu2', nn.ReLU(inplace=True)),
345 | ('conv_bn3', ConvBN(inplanes//2, inplanes, kernel_size=3, stride=1)),
346 | ('relu3', nn.ReLU(inplace=True)),
347 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
348 | ]))
349 | elif first_embedding:
350 | self.layer0 = nn.Sequential(OrderedDict([
351 | ('conv', nn.Conv2d(in_channels, inplanes, kernel_size=4, stride=4)),
352 | ('norm', nn.LayerNorm(inplanes))
353 | ]))
354 | else:
355 | self.layer0 = nn.Sequential(OrderedDict([
356 | ('conv', ConvBN(in_channels, inplanes, kernel_size=7, stride=2, bn=False)),
357 | # ('relu', nn.ReLU(inplace=True)),
358 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
359 | ]))
360 |
361 | # build cont layers
362 | self.cont_layers = []
363 | self.out_channels = OrderedDict()
364 |
365 | for i in range(len(layers)):
366 | stride = 2,
367 | patch_size = [7,14]
368 | if i == len(layers)-1:
369 | stride, patch_size[1] = 1, 7 # the last stage does not conduct downsampling
370 | cont_layer = self._make_layer(inplanes * 2**i, layers[i], stride=stride, mlp_dim=mlp_dim[i], head_num=head_num[i], dropout=dropout[i], patch_size=patch_size, **kwargs)
371 | layer_name = 'layer{}'.format(i + 1)
372 | self.add_module(layer_name, cont_layer)
373 | self.cont_layers.append(layer_name)
374 | self.out_channels[layer_name] = 2 * inplanes * 2**i
375 |
376 | self.last_out_channels = next(reversed(self.out_channels.values()))
377 | self.fc = nn.Linear(self.last_out_channels, num_classes)
378 |
379 | if init_weights:
380 | self._initialize_weights()
381 |
382 | def _make_layer(self,
383 | planes: int,
384 | blocks: int,
385 | stride: int,
386 | mlp_dim: int,
387 | head_num: int,
388 | dropout: float,
389 | patch_size: List[int],
390 | use_avgdown: bool=False,
391 | **kwargs):
392 |
393 | layers = OrderedDict()
394 | for i in range(0, blocks-1):
395 | layers[f'{self.block.__name__}{i}'] = self.block(
396 | planes, planes, mlp_dim, head_num, dropout, patch_size, **kwargs)
397 |
398 | downsample = None
399 | if stride != 1:
400 | if use_avgdown:
401 | downsample = nn.Sequential(OrderedDict([
402 | ('avgpool', nn.AvgPool2d(kernel_size=2, stride=2)),
403 | ('conv', ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False))]))
404 | else:
405 | downsample = ConvBN(planes, planes * 2, kernel_size=1,
406 | stride=2, bn=False)
407 | else:
408 | downsample = ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False)
409 |
410 | layers[f'{self.block.__name__}{blocks-1}'] = self.block(
411 | planes, planes*2, mlp_dim, head_num, dropout, patch_size, downsample, stride, **kwargs)
412 |
413 | return nn.Sequential(layers)
414 |
415 | def _initialize_weights(self):
416 | for m in self.modules():
417 | if isinstance(m, nn.Conv2d):
418 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
419 | elif isinstance(m, nn.Linear):
420 | trunc_normal_(m.weight, std=.02)
421 | if isinstance(m, nn.Linear) and m.bias is not None:
422 | nn.init.constant_(m.bias, 0)
423 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm):
424 | nn.init.constant_(m.weight, 1)
425 | nn.init.constant_(m.bias, 0)
426 |
427 |
428 | def forward(self, x):
429 | x = self.layer0(x)
430 |
431 | for _, layer_name in enumerate(self.cont_layers):
432 | cont_layer = getattr(self, layer_name)
433 | x = cont_layer(x)
434 |
435 | x = x.mean([2, 3])
436 | x = self.fc(x)
437 |
438 | return x
439 |
440 | def create_ConTNet_Ti(kwargs):
441 | return ConTNet(block=ConTBlock,
442 | mlp_dim=[196, 392, 768, 768],
443 | head_num=[1, 2, 4, 8],
444 | dropout=[0,0,0,0],
445 | inplanes=48,
446 | layers=[1,1,1,1],
447 | last_dropout=0,
448 | **kwargs)
449 |
450 | def create_ConTNet_S(kwargs):
451 | return ConTNet(block=ConTBlock,
452 | mlp_dim=[256, 512, 1024, 1024],
453 | head_num=[1, 2, 4, 8],
454 | dropout=[0,0,0,0],
455 | inplanes=64,
456 | layers=[1,1,1,1],
457 | last_dropout=0,
458 | **kwargs)
459 |
460 | def create_ConTNet_M(kwargs):
461 | return ConTNet(block=ConTBlock,
462 | mlp_dim=[256, 512, 1024, 1024],
463 | head_num=[1, 2, 4, 8],
464 | dropout=[0,0,0,0],
465 | inplanes=64,
466 | layers=[2,2,2,2],
467 | last_dropout=0,
468 | **kwargs)
469 |
470 | def create_ConTNet_B(kwargs):
471 | return ConTNet(block=ConTBlock,
472 | mlp_dim=[256, 512, 1024, 1024],
473 | head_num=[1, 2, 4, 8],
474 | dropout=[0,0,0.1,0.1],
475 | inplanes=64,
476 | layers=[3,4,6,3],
477 | last_dropout=0.2,
478 | **kwargs)
479 |
480 | def build_model(arch, use_avgdown, relative, qkv_bias, pre_norm):
481 | type = arch.split('-')[-1]
482 | func = eval(f'create_ConTNet_{type}')
483 | kwargs = dict(use_avgdown=use_avgdown, relative=relative, qkv_bias=qkv_bias, pre_norm=pre_norm)
484 | return func(kwargs)
485 |
486 | if __name__ == "__main__":
487 | model = build_model(arch='ConT-Ti', use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True)
488 | input = torch.Tensor(4, 3, 224, 224)
489 | print(model)
490 | out = model(input)
491 | print(out.shape)
492 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ConTNet
2 |
3 | ## Introduction
4 |
5 |
7 |
8 | **ConTNet** (**Con**vlution-**T**ranformer Network) is a neural network built by stacking convolutional layers and transformers alternately. This architecture is proposed in response to the following two issues: **(1)** The receptive field of convolution is limited by a local window (3x3), which potentially impairs the performance of ConvNets on downstream tasks. **(2)** Transformer-based models suffers from insufficient robustness, as a result, the training course requires multiple training tricks and tons of regularization strategies. In our ConTNet, these drawbacks are alleviated through the combination of convolution and transformer. Two perspectives are offered to understand the motivation. **From the view of ConvNet**, the transformer sub-layer is inserted between any two conv layers to enhance the non-local interactions of ConvNet. **From the view of Transformer**, the presence of convolution layers reintroduces the inductive bias as a cause of under-fitting. Through numerical experiments, we find that ConTNet achieves competitive performance on image recognition and downstream tasks. More notably, ConTNet can be optimized easily even in the same way as ResNet.
9 |
10 | 
11 | 
12 | 
13 | ## Training & Validation with this Repo
14 | We give an example of one machine multi-gpus training.
15 | ```
16 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial_cont_m --save_best True
17 | ```
18 | To validate a model, please add the arg ```--eval ```.
19 | ```
20 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial --eval ./debug_trial_cont_m/checkpoint_bestTop1.pth
21 | ```
22 | To implement resume training, please add the arg ```--resume```.
23 | ```
24 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial --save_best True --resume ./debug_trial_cont_m/checkpoint_bestTop1.pth
25 | ```
26 | ## Pretrained Weights on ImageNet
27 | ImageNet-pretrained weights are available from [Google Drive][1] or [Baidu Cloud][2](the code is 3k3s).
28 |
29 | ## Main Results on ImageNet
30 |
31 | | name | resolution | acc@1 | #params(M) | FLOPs(G) | model |
32 | | ---- | ---- | ---- | ---- | ---- | ---- |
33 | | Res-18 | 224x224 | 71.5 | 11.7 | 1.8 | |
34 | | ConT-S | 224x224 | **74.9** | 10.1 | 1.5 | |
35 | | Res-50 | 224x224 | 77.1 | 25.6 | 4.0 | |
36 | | ConT-M | 224x224 | **77.6** | 19.2 | 3.1 | |
37 | | Res-101 | 224x224 | **78.2** | 44.5 | 7.6 | |
38 | | ConT-B | 224x224 | 77.9 | 39.6 | 6.4 | |
39 | | DeiT-Ti* | 224x224 | 72.2 | 5.7 | 1.3 | |
40 | | ConT-Ti* | 224x224 | **74.9**| 5.8 | 0.8 | |
41 | | Res-18* | 224x224 | 73.2 | 11.7 | 1.8 | |
42 | | ConT-S* | 224x224 | **76.5** | 10.1 | 1.5 | |
43 | | Res-50* | 224x224 | 78.6 | 25.6 | 4.0 | |
44 | | DeiT-S* | 224x224 | 79.8 | 22.1 | 4.6 | |
45 | | ConT-M* | 224x224 | **80.2** | 19.2 | 3.1 | |
46 | | Res-101* | 224x224 | 80.0 | 44.5 | 7.6 | |
47 | | DeiT-B* | 224x224 | **81.8** | 86.6 | 17.6| |
48 | | ConT-B* | 224x224 | **81.8** | 39.6 | 6.4 | |
49 |
50 | Note: * indicates training with strong augmentations(auto-augmentation and mixup).
51 |
52 | ## Main Results on Downstream Tasks
53 |
54 | Object detection results on COCO.
55 |
56 | | method | backbone | #params(M) | FLOPs(G) | AP | APs | APm | APl |
57 | | ---- | ---- | ---- | ---- | ---- | -------- | ----- | ----- |
58 | |RetinaNet| Res-50
ConTNet-M| 32.0
27.0 | 235.6
217.2 | 36.5
**37.9** | 20.4
**23.0** | 40.3
**40.6** | 48.1
**50.4** |
59 | | FCOS | Res-50
ConTNet-M| 32.2
27.2 | 242.9
228.4 | 38.7
**40.8** | 22.9
**25.1** | 42.5
**44.6** | 50.1
**53.0** |
60 | | faster rcnn | Res-50
ConTNet-M| 41.5
36.6 | 241.0
225.6 | 37.4
**40.0** | 21.2
**25.4** | 41.0
**43.0** | 48.1
**52.0** |
61 |
62 | Instance segmentation results on Cityscapes based on Mask-RCNN.
63 | | backbone | APbb | APsbb | APmbb | APlbb | APmk | APsmk | APmmk | APlmk |
64 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- |
65 | | Res-50
ConT-M | 38.2
**40.5** | 21.9
**25.1** | 40.9
**44.4** | 49.5
**52.7** | 34.7
**38.1** | 18.3
**20.9** | 37.4
**41.0** | 47.2
**50.3** |
66 |
67 | Semantic segmentation results on cityscapes.
68 | | model | mIOU |
69 | | ----- | ---- |
70 | |PSP-Res50| 77.12 |
71 | |PSP-ConTM| **78.28** |
72 |
73 | ## Bib Citing
74 | ```
75 | @article{yan2021contnet,
76 | title={ConTNet: Why not use convolution and transformer at the same time?},
77 | author={Haotian Yan and Zhe Li and Weijian Li and Changhu Wang and Ming Wu and Chuang Zhang},
78 | year={2021},
79 | journal={arXiv preprint arXiv:2104.13497}
80 | }
81 | ```
82 |
83 | [1]: https://drive.google.com/drive/folders/1ZXu--Bis3LTYLjf2pkmDtZH0TjuWWamO?usp=sharing
84 | [2]: https://pan.baidu.com/s/1thKK36jTFln1KcAuEkzleg
85 |
--------------------------------------------------------------------------------
/arch5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/arch5.png
--------------------------------------------------------------------------------
/block2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/block2.png
--------------------------------------------------------------------------------
/block3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/block3.png
--------------------------------------------------------------------------------
/criterion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
3 | from timm.data import Mixup
4 |
5 |
6 | def build_criterion(mixup, label_smoothing):
7 | mixup_fn = None
8 | if mixup > 0.:
9 | criterion = SoftTargetCrossEntropy()
10 |
11 | mixup_fn = Mixup(
12 | mixup_alpha=mixup, cutmix_alpha=1, cutmix_minmax=None,
13 | prob=1, switch_prob=0.5, mode='batch',
14 | label_smoothing=label_smoothing, num_classes=1000)
15 |
16 | elif label_smoothing > 0.:
17 | criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing)
18 | else:
19 | criterion = nn.CrossEntropyLoss()
20 |
21 | return criterion, mixup_fn
22 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | from torchvision import datasets, transforms
5 | import torch.distributed as dist
6 | from torch.utils.data import DataLoader, distributed
7 |
8 | from timm.data import create_transform
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 |
11 | import torchvision
12 |
13 | import os
14 | import json
15 | from PIL import Image
16 | import pandas as pd
17 | from torch.utils.data import Dataset
18 |
19 | class MyDataSet(Dataset):
20 |
21 | def __init__(self,
22 | root_dir: str,
23 | csv_name: str,
24 | json_path: str,
25 | transform=None):
26 | images_dir = os.path.join(root_dir, "images")
27 | assert os.path.exists(images_dir), "dir:'{}' not found.".format(images_dir)
28 |
29 | assert os.path.exists(json_path), "file:'{}' not found.".format(json_path)
30 | self.label_dict = json.load(open(json_path, "r"))
31 |
32 | csv_path = os.path.join(root_dir, csv_name)
33 | assert os.path.exists(csv_path), "file:'{}' not found.".format(csv_path)
34 | csv_data = pd.read_csv(csv_path)
35 | self.total_num = csv_data.shape[0]
36 | self.img_paths = [os.path.join(images_dir, i)for i in csv_data["filename"].values]
37 | self.img_label = [self.label_dict[i][0] for i in csv_data["label"].values]
38 | self.labels = set(csv_data["label"].values)
39 |
40 | self.transform = transform
41 |
42 | def __len__(self):
43 | return self.total_num
44 |
45 | def __getitem__(self, item):
46 | img = Image.open(self.img_paths[item])
47 | if img.mode != 'RGB':
48 | raise ValueError("image: {} isn't RGB mode.".format(self.img_paths[item]))
49 | label = self.img_label[item]
50 |
51 | if self.transform is not None:
52 | img = self.transform(img)
53 |
54 | return img, label
55 |
56 | @staticmethod
57 | def collate_fn(batch):
58 | images, labels = tuple(zip(*batch))
59 |
60 | images = torch.stack(images, dim=0)
61 | labels = torch.as_tensor(labels)
62 | return images, labels
63 |
64 |
65 | def build_loader(data_path, autoaug, batch_size, workers):
66 |
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | assert batch_size % world_size == 0, f'The batch size is indivisible by world size {batch_size} // {world_size}'
70 |
71 | train_transform = create_transform(input_size=224,
72 | is_training=True,
73 | auto_augment=autoaug)
74 | # train_dataset = MyDataSet(root_dir='./mini-imagenet', csv_name='new_train.csv', json_path='./mini-imagenet/classes_name.json', transform=train_transform)
75 | train_dataset = datasets.ImageFolder(osp.join(data_path, 'train'), transform=train_transform)
76 |
77 | train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
78 | train_loader = DataLoader(train_dataset,
79 | batch_size=batch_size // world_size,
80 | shuffle=False,
81 | num_workers=workers,
82 | pin_memory=True,
83 | sampler=train_sampler)
84 |
85 | val_transform = transforms.Compose([
86 | transforms.Resize(256),
87 | transforms.CenterCrop(224),
88 | transforms.ToTensor(),
89 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
90 | ])
91 | val_dataset = datasets.ImageFolder(osp.join(data_path, 'val'), transform=val_transform)
92 | # val_dataset = MyDataSet(root_dir='./mini-imagenet', csv_name='new_val.csv', json_path='./mini-imagenet/classes_name.json', transform=val_transform)
93 | val_sampler = distributed.DistributedSampler(val_dataset, world_size, rank)
94 | val_loader = DataLoader(val_dataset,
95 | batch_size=batch_size // world_size,
96 | shuffle=False,
97 | num_workers=workers,
98 | pin_memory=True,
99 | sampler=val_sampler)
100 |
101 | return train_loader, val_loader
102 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from timm.scheduler.cosine_lr import CosineLRScheduler
3 |
4 | def build_lr_scheduler(epoch, warmup_epoch, optimizer, n_iter_per_epoch):
5 | num_steps = int(epoch * n_iter_per_epoch)
6 | warmup_steps = int(warmup_epoch * n_iter_per_epoch)
7 |
8 | scheduler = CosineLRScheduler(
9 | optimizer,
10 | t_initial=num_steps,
11 | t_mul=1.,
12 | lr_min=0,
13 | warmup_lr_init=0,
14 | warmup_t=warmup_steps,
15 | cycle_limit=1,
16 | t_in_epochs=False,
17 | )
18 |
19 | return scheduler
20 |
21 |
22 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import argparse
4 |
5 | import numpy as np
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from torch.nn.parallel import DistributedDataParallel as DDP
10 |
11 | from ConTNet import build_model
12 | from optimizer import build_optimizer
13 | from lr_scheduler import build_lr_scheduler
14 | from criterion import build_criterion
15 | from data import build_loader
16 |
17 | from utils import accuracy, reduce_tensor, resume_model, save_model
18 | from timm.utils import AverageMeter
19 |
20 | import warnings
21 | warnings.filterwarnings("ignore")
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description='ConTNet')
26 |
27 | # data and model
28 | parser.add_argument('--data_path', type=str, help='path to dataset')
29 | parser.add_argument('--arch', type=str, default='ConT-M',
30 | choices=['ConT-M', 'ConT-B', 'ConT-S', 'ConT-Ti'],
31 | help='the architecture of ConTNet')
32 |
33 | # model hypeparameters
34 | parser.add_argument('--use_avgdown', type=bool, default=False,
35 | help='If True, using avgdown downsampling shortcut')
36 | parser.add_argument('--relative', type=bool, default=False,
37 | help='If True, using relative position embedding')
38 | parser.add_argument('--qkv_bias', type=bool, default=True)
39 | parser.add_argument('--pre_norm', type=bool, default=False)
40 |
41 | # base setting
42 | parser.add_argument('--eval', default=None, type=str,
43 | help='only validation')
44 | parser.add_argument('--batch_size', default=512, type=int,
45 | help='batch size')
46 | parser.add_argument('--workers', default=8, type=int,
47 | help='number of data loading workers')
48 | parser.add_argument('--epoch', default=200, type=int,
49 | help='number of total epochs to run')
50 | parser.add_argument('--warmup_epoch', default=10, type=int,
51 | help='the num of warmup epochs')
52 | parser.add_argument('--resume', default=None, type=str,
53 | help='resume file path')
54 | parser.add_argument('--init_lr', default=5e-4, type=float,
55 | help='a low initial learning rata for adamw optimizer')
56 | parser.add_argument('--wd', default=0.5, type=float,
57 | help='a high weight decay setting for adamw optimizer')
58 | parser.add_argument('--momentum', default=0.9, type=float,
59 | help='momentum for sgd')
60 | parser.add_argument('--optim', default='AdamW', type=str, choices=['AdamW', 'SGD'],
61 | help='optimizer supported by PyTorch')
62 | parser.add_argument('--print_freq', default=100, type=int,
63 | help='frequency of printing train info')
64 | parser.add_argument('--save_path', default='weights', type=str,
65 | help='the path to saving the checkpoints')
66 | parser.add_argument('--save_best', default=True, type=bool,
67 | help='saveing the checkpoint has the best acc')
68 |
69 | # aug®
70 | parser.add_argument('--mixup', default=0.8, type=float,
71 | help='using mixup and set alpha value')
72 | parser.add_argument('--autoaug', default='rand-m9-mstd0.5-inc1', type=str,
73 | help='using auto-augmentation')
74 | parser.add_argument('-ls','--label-smoothing', default=0.1, type=float,
75 | help='if > 0, using label-smothing')
76 |
77 | # distributed parallel triaining
78 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DDP')
79 |
80 | return parser.parse_args()
81 |
82 |
83 | def launch_worker(local_rank):
84 | # print(local_rank)
85 | if not torch.cuda.is_available():
86 | raise ValueError(f'CPU-only training is not supported')
87 | torch.backends.cudnn.benchmark = True
88 | torch.cuda.set_device(local_rank)
89 | dist.init_process_group(backend='nccl', init_method='env://')
90 | dist.barrier()
91 |
92 | def train(loader, model, criterion, optimizer, mixup_fn, scheduler, print_freq, epoch):
93 | model.train()
94 | if dist.get_rank() == 0:
95 | print(f'\n=> Training epoch{epoch}')
96 |
97 | batch_time = AverageMeter()
98 | losses = AverageMeter()
99 | top1 = AverageMeter()
100 | top5 = AverageMeter()
101 |
102 | end = time.time()
103 | for i, (images, targets) in enumerate(loader):
104 | images = images.cuda(non_blocking=True)
105 | targets = targets.cuda(non_blocking=True)
106 |
107 | if mixup_fn:
108 | images, targets_ = mixup_fn(images, targets)
109 |
110 | # forward
111 | outputs = model(images)
112 |
113 | # update acc1, acc5
114 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
115 | acc1 = reduce_tensor(acc1)
116 | acc5 = reduce_tensor(acc5)
117 | top1.update(acc1.item(), targets.size(0))
118 | top5.update(acc5.item(), targets.size(0))
119 |
120 | # compute loss and backward
121 | loss = criterion(outputs, targets_)
122 | loss = reduce_tensor(loss)
123 | losses.update(loss.item(), targets_.size(0))
124 | optimizer.zero_grad()
125 | loss.backward()
126 | optimizer.step()
127 | scheduler.step_update(epoch * len(loader) + i)
128 |
129 | # update using time
130 | interval = torch.tensor([time.time() - end])
131 | interval = reduce_tensor(interval.cuda())
132 | batch_time.update(interval.item())
133 | end = time.time()
134 |
135 | if i % print_freq == 0 and dist.get_rank() == 0:
136 | lr = optimizer.param_groups[0]['lr']
137 | sep = '| '
138 | print(f'Epoch: [{epoch}] | [{i}/{len(loader)}] lr: {lr:.8f} '+ sep +
139 | f'loss {losses.val:.4f} ({losses.avg:.4f}) '+ sep +
140 | f'Top1.acc {top1.val:6.2f} ' + sep +
141 | f'Top5.acc {top5.val:6.2f} ' + sep +
142 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f}) ' + sep
143 | )
144 |
145 | @torch.no_grad()
146 | def validate(val_loader, model, criterion, epoch=None):
147 | model.eval()
148 |
149 | batch_time = AverageMeter()
150 | losses = AverageMeter()
151 | top1 = AverageMeter()
152 | top5 = AverageMeter()
153 |
154 | end = time.time()
155 | for i, (images, targets) in enumerate(val_loader):
156 | images = images.cuda(non_blocking=True)
157 | targets = targets.cuda(non_blocking=True)
158 |
159 | # forward
160 | outputs = model(images)
161 |
162 | loss = criterion(outputs, targets)
163 | loss = reduce_tensor(loss)
164 | losses.update(loss.item(), images.size(0))
165 |
166 | # update acc1, acc5
167 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
168 | acc1 = reduce_tensor(acc1)
169 | acc5 = reduce_tensor(acc5)
170 | top1.update(acc1.item(), targets.size(0))
171 | top5.update(acc5.item(), targets.size(0))
172 |
173 | # update using time
174 | interval = torch.tensor([time.time() - end])
175 | interval = reduce_tensor(interval.cuda())
176 | batch_time.update(interval.item())
177 | end = time.time()
178 |
179 |
180 | if dist.get_rank() == 0:
181 | stat = f"epoch {epoch}" if epoch is not None else "Only"
182 | print(f'=> Validation {stat}')
183 | sep = '| '
184 | print(f'loss {losses.avg:.4f} '+ sep +
185 | f'Top1.acc {top1.avg:6.2f} ' + sep +
186 | f'Top5.acc {top5.avg:6.2f} ' + sep +
187 | f'time {batch_time.avg:.4f} ' + sep
188 | )
189 |
190 | return top1.avg, top5.avg, losses.avg
191 |
192 | def main(config):
193 | # set up ddp
194 | launch_worker(config.local_rank)
195 | # build loader
196 | train_loader, val_loader = build_loader(config.data_path, config.autoaug, config.batch_size, config.workers)
197 | # build model
198 | model=build_model(config.arch, config.use_avgdown, config.relative, config.qkv_bias, config.pre_norm)
199 | model = DDP(model.cuda(), device_ids=[config.local_rank])
200 | # build optimizer
201 | optimizer=build_optimizer(model, config.optim, config.init_lr, config.wd, config.momentum)
202 | # build learning scheduler
203 | scheduler=build_lr_scheduler(config.epoch, config.warmup_epoch, optimizer, len(train_loader))
204 | # build criterion and mixup
205 | train_criterion, mixup_fn =build_criterion(config.mixup, config.label_smoothing)
206 | val_criterion = torch.nn.CrossEntropyLoss()
207 | # init acc1 and start epoch
208 | best_acc1 = 0.0
209 | start_epoch = 0
210 |
211 | # only validation
212 | if config.eval:
213 | if os.path.isfile(config.eval):
214 | model.load_state_dict(torch.load(config.eval)['model'])
215 | validate(val_loader, model, val_criterion)
216 | return
217 | else:
218 | print(f"=> !!!!!!! no checkpoint found at '{config.eval}'\n")
219 | print(f"=> !!!!!!! validation is stopped")
220 | return
221 |
222 | # resume training
223 | if not config.resume:
224 | print(f"=>Training is from scratch")
225 | else:
226 | if os.path.isfile(config.resume):
227 | model, optimizer, scheduler, start_epoch, best_acc1 = resume_model(config.resume, model, optimizer, scheduler)
228 | else:
229 | print(f"=> !!!!!!! no checkpoint found at '{config.resume}'\n")
230 |
231 | # training
232 | for epoch in range(start_epoch, args.epoch):
233 | train_loader.sampler.set_epoch(epoch)
234 |
235 | train(train_loader, model, train_criterion, optimizer, mixup_fn, scheduler, config.print_freq, epoch)
236 |
237 | acc1, acc5, loss = validate(val_loader, model, val_criterion, epoch)
238 |
239 | best_acc1 = max(best_acc1, acc1)
240 | is_best = (best_acc1 == acc1)
241 |
242 | if dist.get_rank() == 0:
243 | print('\n******************\t',
244 | f'\nBest Top1.acc {best_acc1:6.2f}\t',
245 | '\n******************\t')
246 |
247 | # save model
248 | if not config.save_best or is_best:
249 | save_model(config.save_path, model, optimizer, scheduler, best_acc1, epoch, is_best)
250 |
251 |
252 |
253 | if __name__ == '__main__':
254 | # build configs
255 | args = parse_args()
256 | # launch
257 | main(config=args)
258 | print('=> Finished!')
259 |
260 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def build_optimizer(model, optim, lr, wd, momentum):
4 |
5 | def _no_bias_decay(model):
6 | has_decay = []
7 | no_decay = []
8 | skip_list = ['relative_position_bias_table', 'pe']
9 |
10 | for name, param in model.named_parameters():
11 | if not param.requires_grad:
12 | continue
13 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list):
14 | no_decay.append(param)
15 | else:
16 | has_decay.append(param)
17 |
18 | assert len(list(model.parameters())) == len(has_decay) + len(no_decay), '{} vs. {}'.format(
19 | len(list(model.parameters())), len(has_decay) + len(no_decay))
20 |
21 | return [{'params': has_decay},
22 | {'params': no_decay, 'weight_decay': 0.}]
23 |
24 | parameters = _no_bias_decay(model)
25 | kwargs = dict(lr=lr, weight_decay=wd)
26 | if optim.lower() == 'SGD':
27 | kwargs['momentum'] = momentum
28 |
29 | optimizer = getattr(torch.optim, optim)(params=parameters, **kwargs)
30 |
31 | return optimizer
32 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import os
4 |
5 | def accuracy(output, target, topk=(1,)):
6 | """Computes the accuracy over the k top predictions for the specified values of k"""
7 | with torch.no_grad():
8 | maxk = max(topk)
9 | batch_size = target.size(0)
10 |
11 | _, pred = output.topk(maxk, 1, True, True)
12 |
13 | pred = pred.t()
14 | correct = pred.eq(target.view(1, -1).expand_as(pred))
15 |
16 | res = []
17 | for k in topk:
18 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
19 | res.append(correct_k.mul_(100.0 / batch_size))
20 | return res
21 |
22 | def reduce_tensor(tensor):
23 | rt = tensor.clone()
24 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
25 | rt /= dist.get_world_size()
26 | return rt
27 |
28 | def resume_model(resume_path, model, optimizer, scheduler):
29 | print(f"=> loading checkpoint '{resume_path}'")
30 | checkpoint = torch.load(resume_path)
31 | start_epoch = checkpoint['epoch']
32 | best_acc1 = checkpoint['best_acc1']
33 | best_epoch = checkpoint['best_epoch']
34 | model.load_state_dict(checkpoint['model'])
35 | optimizer.load_state_dict(checkpoint['optimizer'])
36 | scheduler.load_state_dict(checkpoint['scheduler'])
37 | print(f"=> loaded checkpoint successfully '{resume_path}' (epoch {start_epoch})")
38 |
39 | return model, optimizer, scheduler, start_epoch, best_acc1, best_epoch
40 |
41 | def save_model(save_path, model, optimizer, scheduler, best_acc1, epoch, is_best):
42 | save_state = {'model': model.state_dict(),
43 | 'optimizer': optimizer.state_dict(),
44 | 'scheduler': scheduler.state_dict(),
45 | 'best_acc1': best_acc1,
46 | 'epoch': epoch}
47 |
48 | os.makedirs(save_path, exist_ok=True)
49 | checkpoint_name = f'checkpoint_bestTop1.pth' if is_best else f'checkpoint_{epoch}.pth'
50 | save_path = os.path.join(save_path, checkpoint_name)
51 | torch.save(save_state, save_path)
52 | print(f'=> Saved checkpoint of epoch {epoch} to {save_path}')
--------------------------------------------------------------------------------