├── ConTNet.py ├── README.md ├── arch5.png ├── block2.png ├── block3.png ├── criterion.py ├── data.py ├── lr_scheduler.py ├── main.py ├── optimizer.py └── utils.py /ConTNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | from einops.layers.torch import Rearrange 6 | from einops import rearrange 7 | 8 | import numpy as np 9 | 10 | from typing import Any, List 11 | import math 12 | import warnings 13 | from collections import OrderedDict 14 | 15 | __all__ = ['ConTBlock', 'ConTNet'] 16 | 17 | 18 | r""" The following trunc_normal method is pasted from timm https://github.com/rwightman/pytorch-image-models/tree/master/timm 19 | """ 20 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 21 | 22 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 23 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 24 | def norm_cdf(x): 25 | # Computes standard normal cumulative distribution function 26 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 27 | 28 | if (mean < a - 2 * std) or (mean > b + 2 * std): 29 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 30 | "The distribution of values may be incorrect.", 31 | stacklevel=2) 32 | 33 | with torch.no_grad(): 34 | # Values are generated by using a truncated uniform distribution and 35 | # then using the inverse CDF for the normal distribution. 36 | # Get upper and lower cdf values 37 | l = norm_cdf((a - mean) / std) 38 | u = norm_cdf((b - mean) / std) 39 | 40 | # Uniformly fill tensor with values from [l, u], then translate to 41 | # [2l-1, 2u-1]. 42 | tensor.uniform_(2 * l - 1, 2 * u - 1) 43 | 44 | # Use inverse cdf transform for normal distribution to get truncated 45 | # standard normal 46 | tensor.erfinv_() 47 | 48 | # Transform to proper mean, std 49 | tensor.mul_(std * math.sqrt(2.)) 50 | tensor.add_(mean) 51 | 52 | # Clamp to ensure it's in the proper range 53 | tensor.clamp_(min=a, max=b) 54 | return tensor 55 | 56 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 57 | # type: (Tensor, float, float, float, float) -> Tensor 58 | r"""Fills the input Tensor with values drawn from a truncated 59 | normal distribution. The values are effectively drawn from the 60 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 61 | with values outside :math:`[a, b]` redrawn until they are within 62 | the bounds. The method used for generating the random values works 63 | best when :math:`a \leq \text{mean} \leq b`. 64 | Args: 65 | tensor: an n-dimensional `torch.Tensor` 66 | mean: the mean of the normal distribution 67 | std: the standard deviation of the normal distribution 68 | a: the minimum cutoff value 69 | b: the maximum cutoff value 70 | Examples: 71 | >>> w = torch.empty(3, 5) 72 | >>> nn.init.trunc_normal_(w) 73 | """ 74 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 75 | 76 | def fixed_padding(inputs, kernel_size, dilation): 77 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 78 | pad_total = kernel_size_effective - 1 79 | pad_beg = pad_total // 2 80 | pad_end = pad_total - pad_beg 81 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 82 | return padded_inputs 83 | 84 | class ConvBN(nn.Sequential): 85 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, groups=1, bn=True): 86 | padding = (kernel_size - 1) // 2 87 | if bn: 88 | super(ConvBN, self).__init__(OrderedDict([ 89 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 90 | padding=padding, groups=groups, bias=False)), 91 | ('bn', nn.BatchNorm2d(out_planes)) 92 | ])) 93 | else: 94 | super(ConvBN, self).__init__(OrderedDict([ 95 | ('conv', nn.Conv2d(in_planes, out_planes, kernel_size, stride, 96 | padding=padding, groups=groups, bias=False)), 97 | ])) 98 | 99 | class MHSA(nn.Module): 100 | r""" 101 | Build a Multi-Head Self-Attention: 102 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 103 | """ 104 | def __init__(self, 105 | planes, 106 | head_num, 107 | dropout, 108 | patch_size, 109 | qkv_bias, 110 | relative): 111 | super(MHSA, self).__init__() 112 | self.head_num = head_num 113 | head_dim = planes // head_num 114 | self.qkv = nn.Linear(planes, 3*planes, bias=qkv_bias) 115 | self.relative = relative 116 | self.patch_size = patch_size 117 | self.scale = head_dim ** -0.5 118 | 119 | if self.relative: 120 | # print('### relative position embedding ###') 121 | self.relative_position_bias_table = nn.Parameter( 122 | torch.zeros((2 * patch_size - 1) * (2 * patch_size - 1), head_num)) 123 | coords_w = coords_h = torch.arange(patch_size) 124 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 125 | coords_flatten = torch.flatten(coords, 1) 126 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 127 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 128 | relative_coords[:, :, 0] += patch_size - 1 129 | relative_coords[:, :, 1] += patch_size - 1 130 | relative_coords[:, :, 0] *= 2 * patch_size - 1 131 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 132 | self.register_buffer("relative_position_index", relative_position_index) 133 | trunc_normal_(self.relative_position_bias_table, std=.02) 134 | 135 | self.attn_drop = nn.Dropout(p=dropout) 136 | self.proj = nn.Linear(planes, planes) 137 | self.proj_drop = nn.Dropout(p=dropout) 138 | 139 | def forward(self, x): 140 | B, N, C, H = *x.shape, self.head_num 141 | # print(x.shape) 142 | qkv = self.qkv(x).reshape(B, N, 3, H, C // H).permute(2, 0, 3, 1, 4) # x: (3, B, H, N, C//H) 143 | q, k, v = qkv[0], qkv[1], qkv[2] # x: (B, H, N, C//N) 144 | 145 | q = q * self.scale 146 | attn = (q @ k.transpose(-2, -1)) # attn: (B, H, N, N) 147 | 148 | if self.relative: 149 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 150 | self.patch_size ** 2, self.patch_size ** 2, -1) 151 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 152 | attn = attn + relative_position_bias.unsqueeze(0) 153 | 154 | attn = attn.softmax(dim=-1) 155 | attn = self.attn_drop(attn) 156 | 157 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 158 | x = self.proj(x) 159 | x = self.proj_drop(x) 160 | 161 | return x 162 | 163 | class MLP(nn.Module): 164 | r""" 165 | Build a Multi-Layer Perceptron 166 | - https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 167 | """ 168 | def __init__(self, 169 | planes, 170 | mlp_dim, 171 | dropout): 172 | super(MLP, self).__init__() 173 | 174 | self.fc1 = nn.Linear(planes, mlp_dim) 175 | self.act = nn.GELU() 176 | self.fc2 = nn.Linear(mlp_dim, planes) 177 | self.drop = nn.Dropout(dropout) 178 | 179 | def forward(self, x): 180 | x = self.fc1(x) 181 | x = self.act(x) 182 | x = self.drop(x) 183 | x = self.fc2(x) 184 | x = self.drop(x) 185 | 186 | return x 187 | 188 | 189 | class STE(nn.Module): 190 | r""" 191 | Build a Standard Transformer Encoder(STE) 192 | input: Tensor (b, c, h, w) 193 | output: Tensor (b, c, h, w) 194 | """ 195 | def __init__(self, 196 | planes: int, 197 | mlp_dim: int, 198 | head_num: int, 199 | dropout: float, 200 | patch_size: int, 201 | relative: bool, 202 | qkv_bias: bool, 203 | pre_norm: bool, 204 | **kwargs): 205 | super(STE, self).__init__() 206 | self.patch_size = patch_size 207 | self.pre_norm = pre_norm 208 | self.relative = relative 209 | 210 | self.flatten = nn.Sequential( 211 | Rearrange('b c pnh pnw psh psw -> (b pnh pnw) psh psw c'), 212 | ) 213 | if not relative: 214 | self.pe = nn.ParameterList( 215 | [nn.Parameter(torch.zeros(1, patch_size, 1, planes//2)), nn.Parameter(torch.zeros(1, 1, patch_size, planes//2))] 216 | ) 217 | self.attn = MHSA(planes, head_num, dropout, patch_size, qkv_bias=qkv_bias, relative=relative) 218 | self.mlp = MLP(planes, mlp_dim, dropout=dropout) 219 | self.norm1 = nn.LayerNorm(planes) 220 | self.norm2 = nn.LayerNorm(planes) 221 | 222 | def forward(self, x): 223 | bs, c, h, w = x.shape 224 | patch_size = self.patch_size 225 | patch_num_h, patch_num_w = h // patch_size, w // patch_size 226 | 227 | x = ( 228 | x.unfold(2, self.patch_size, self.patch_size) 229 | .unfold(3, self.patch_size, self.patch_size) 230 | ) # x: (b, c, patch_num, patch_num, patch_size, patch_size) 231 | x = self.flatten(x) # x: (b, patch_size, patch_size, c) 232 | ### add 2d position embedding ### 233 | if not self.relative: 234 | x_h, x_w = x.split(c // 2, dim=3) 235 | x = torch.cat((x_h + self.pe[0], x_w + self.pe[1]), dim=3) # x: (b, patch_size, patch_size, c) 236 | 237 | x = rearrange(x, 'b psh psw c -> b (psh psw) c') 238 | 239 | if self.pre_norm: 240 | x = x + self.attn(self.norm1(x)) 241 | x = x + self.mlp(self.norm2(x)) 242 | else: 243 | x = self.norm1(x + self.attn(x)) 244 | x = self.norm2(x + self.mlp(x)) 245 | 246 | x = rearrange(x, '(b pnh pnw) (psh psw) c -> b c (pnh psh) (pnw psw)', pnh=patch_num_h, pnw=patch_num_w, psh=patch_size, psw=patch_size) 247 | 248 | return x 249 | 250 | class ConTBlock(nn.Module): 251 | r""" 252 | Build a ConTBlock 253 | """ 254 | def __init__(self, 255 | planes: int, 256 | out_planes: int, 257 | mlp_dim: int, 258 | head_num: int, 259 | dropout: float, 260 | patch_size: List[int], 261 | downsample: nn.Module = None, 262 | stride: int=1, 263 | last_dropout: float=0.3, 264 | **kwargs): 265 | super(ConTBlock, self).__init__() 266 | self.downsample = downsample 267 | self.identity = nn.Identity() 268 | self.dropout = nn.Identity() 269 | 270 | self.bn = nn.BatchNorm2d(planes) 271 | self.relu = nn.ReLU(inplace=True) 272 | self.ste1 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[0], **kwargs) 273 | self.ste2 = STE(planes=planes, mlp_dim=mlp_dim, head_num=head_num, dropout=dropout, patch_size=patch_size[1], **kwargs) 274 | 275 | if stride == 1 and downsample is not None: 276 | self.dropout = nn.Dropout(p=last_dropout) 277 | kernel_size = 1 278 | else: 279 | kernel_size = 3 280 | 281 | self.out_conv = ConvBN(planes, out_planes, kernel_size, stride, bn=False) 282 | 283 | def forward(self, x): 284 | x_preact = self.relu(self.bn(x)) 285 | identity = self.identity(x) 286 | 287 | if self.downsample is not None: 288 | identity = self.downsample(x_preact) 289 | 290 | residual = self.ste1(x_preact) 291 | residual = self.ste2(residual) 292 | residual = self.out_conv(residual) 293 | out = self.dropout(residual+identity) 294 | 295 | return out 296 | 297 | class ConTNet(nn.Module): 298 | r""" 299 | Build a ConTNet backbone 300 | """ 301 | def __init__(self, 302 | block, 303 | layers: List[int], 304 | mlp_dim: List[int], 305 | head_num: List[int], 306 | dropout: List[float], 307 | in_channels: int=3, 308 | inplanes: int=64, 309 | num_classes: int=1000, 310 | init_weights: bool=True, 311 | first_embedding: bool=False, 312 | tweak_C: bool=False, 313 | **kwargs): 314 | r""" 315 | Args: 316 | block: ConT Block 317 | layers: number of blocks at each layer 318 | mlp_dim: dimension of mlp in each stage 319 | head_num: number of head in each stage 320 | dropout: dropout in the last two stage 321 | relative: if True, relative Position Embedding is used 322 | groups: nunmber of group at each conv layer in the Network 323 | depthwise: if True, depthwise convolution is adopted 324 | in_channels: number of channels of input image 325 | inplanes: channel of the first convolution layer 326 | num_classes: number of classes for classification task 327 | only useful when `with_classifier` is True 328 | with_avgpool: if True, an average pooling is added at the end of resnet stage5 329 | with_classifier: if True, FC layer is registered for classification task 330 | first_embedding: if True, a conv layer with both stride and kernel of 7 is placed at the top 331 | tweakC: if true, the first layer of ResNet-C replace the ori layer 332 | """ 333 | 334 | super(ConTNet, self).__init__() 335 | self.inplanes = inplanes 336 | self.block = block 337 | 338 | # build the top layer 339 | if tweak_C: 340 | self.layer0 = nn.Sequential(OrderedDict([ 341 | ('conv_bn1', ConvBN(in_channels, inplanes//2, kernel_size=3, stride=2)), 342 | ('relu1', nn.ReLU(inplace=True)), 343 | ('conv_bn2', ConvBN(inplanes//2, inplanes//2, kernel_size=3, stride=1)), 344 | ('relu2', nn.ReLU(inplace=True)), 345 | ('conv_bn3', ConvBN(inplanes//2, inplanes, kernel_size=3, stride=1)), 346 | ('relu3', nn.ReLU(inplace=True)), 347 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 348 | ])) 349 | elif first_embedding: 350 | self.layer0 = nn.Sequential(OrderedDict([ 351 | ('conv', nn.Conv2d(in_channels, inplanes, kernel_size=4, stride=4)), 352 | ('norm', nn.LayerNorm(inplanes)) 353 | ])) 354 | else: 355 | self.layer0 = nn.Sequential(OrderedDict([ 356 | ('conv', ConvBN(in_channels, inplanes, kernel_size=7, stride=2, bn=False)), 357 | # ('relu', nn.ReLU(inplace=True)), 358 | ('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) 359 | ])) 360 | 361 | # build cont layers 362 | self.cont_layers = [] 363 | self.out_channels = OrderedDict() 364 | 365 | for i in range(len(layers)): 366 | stride = 2, 367 | patch_size = [7,14] 368 | if i == len(layers)-1: 369 | stride, patch_size[1] = 1, 7 # the last stage does not conduct downsampling 370 | cont_layer = self._make_layer(inplanes * 2**i, layers[i], stride=stride, mlp_dim=mlp_dim[i], head_num=head_num[i], dropout=dropout[i], patch_size=patch_size, **kwargs) 371 | layer_name = 'layer{}'.format(i + 1) 372 | self.add_module(layer_name, cont_layer) 373 | self.cont_layers.append(layer_name) 374 | self.out_channels[layer_name] = 2 * inplanes * 2**i 375 | 376 | self.last_out_channels = next(reversed(self.out_channels.values())) 377 | self.fc = nn.Linear(self.last_out_channels, num_classes) 378 | 379 | if init_weights: 380 | self._initialize_weights() 381 | 382 | def _make_layer(self, 383 | planes: int, 384 | blocks: int, 385 | stride: int, 386 | mlp_dim: int, 387 | head_num: int, 388 | dropout: float, 389 | patch_size: List[int], 390 | use_avgdown: bool=False, 391 | **kwargs): 392 | 393 | layers = OrderedDict() 394 | for i in range(0, blocks-1): 395 | layers[f'{self.block.__name__}{i}'] = self.block( 396 | planes, planes, mlp_dim, head_num, dropout, patch_size, **kwargs) 397 | 398 | downsample = None 399 | if stride != 1: 400 | if use_avgdown: 401 | downsample = nn.Sequential(OrderedDict([ 402 | ('avgpool', nn.AvgPool2d(kernel_size=2, stride=2)), 403 | ('conv', ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False))])) 404 | else: 405 | downsample = ConvBN(planes, planes * 2, kernel_size=1, 406 | stride=2, bn=False) 407 | else: 408 | downsample = ConvBN(planes, planes * 2, kernel_size=1, stride=1, bn=False) 409 | 410 | layers[f'{self.block.__name__}{blocks-1}'] = self.block( 411 | planes, planes*2, mlp_dim, head_num, dropout, patch_size, downsample, stride, **kwargs) 412 | 413 | return nn.Sequential(layers) 414 | 415 | def _initialize_weights(self): 416 | for m in self.modules(): 417 | if isinstance(m, nn.Conv2d): 418 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 419 | elif isinstance(m, nn.Linear): 420 | trunc_normal_(m.weight, std=.02) 421 | if isinstance(m, nn.Linear) and m.bias is not None: 422 | nn.init.constant_(m.bias, 0) 423 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.LayerNorm): 424 | nn.init.constant_(m.weight, 1) 425 | nn.init.constant_(m.bias, 0) 426 | 427 | 428 | def forward(self, x): 429 | x = self.layer0(x) 430 | 431 | for _, layer_name in enumerate(self.cont_layers): 432 | cont_layer = getattr(self, layer_name) 433 | x = cont_layer(x) 434 | 435 | x = x.mean([2, 3]) 436 | x = self.fc(x) 437 | 438 | return x 439 | 440 | def create_ConTNet_Ti(kwargs): 441 | return ConTNet(block=ConTBlock, 442 | mlp_dim=[196, 392, 768, 768], 443 | head_num=[1, 2, 4, 8], 444 | dropout=[0,0,0,0], 445 | inplanes=48, 446 | layers=[1,1,1,1], 447 | last_dropout=0, 448 | **kwargs) 449 | 450 | def create_ConTNet_S(kwargs): 451 | return ConTNet(block=ConTBlock, 452 | mlp_dim=[256, 512, 1024, 1024], 453 | head_num=[1, 2, 4, 8], 454 | dropout=[0,0,0,0], 455 | inplanes=64, 456 | layers=[1,1,1,1], 457 | last_dropout=0, 458 | **kwargs) 459 | 460 | def create_ConTNet_M(kwargs): 461 | return ConTNet(block=ConTBlock, 462 | mlp_dim=[256, 512, 1024, 1024], 463 | head_num=[1, 2, 4, 8], 464 | dropout=[0,0,0,0], 465 | inplanes=64, 466 | layers=[2,2,2,2], 467 | last_dropout=0, 468 | **kwargs) 469 | 470 | def create_ConTNet_B(kwargs): 471 | return ConTNet(block=ConTBlock, 472 | mlp_dim=[256, 512, 1024, 1024], 473 | head_num=[1, 2, 4, 8], 474 | dropout=[0,0,0.1,0.1], 475 | inplanes=64, 476 | layers=[3,4,6,3], 477 | last_dropout=0.2, 478 | **kwargs) 479 | 480 | def build_model(arch, use_avgdown, relative, qkv_bias, pre_norm): 481 | type = arch.split('-')[-1] 482 | func = eval(f'create_ConTNet_{type}') 483 | kwargs = dict(use_avgdown=use_avgdown, relative=relative, qkv_bias=qkv_bias, pre_norm=pre_norm) 484 | return func(kwargs) 485 | 486 | if __name__ == "__main__": 487 | model = build_model(arch='ConT-Ti', use_avgdown=True, relative=True, qkv_bias=True, pre_norm=True) 488 | input = torch.Tensor(4, 3, 224, 224) 489 | print(model) 490 | out = model(input) 491 | print(out.shape) 492 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConTNet 2 | 3 | ## Introduction 4 | 5 | 7 | 8 | **ConTNet** (**Con**vlution-**T**ranformer Network) is a neural network built by stacking convolutional layers and transformers alternately. This architecture is proposed in response to the following two issues: **(1)** The receptive field of convolution is limited by a local window (3x3), which potentially impairs the performance of ConvNets on downstream tasks. **(2)** Transformer-based models suffers from insufficient robustness, as a result, the training course requires multiple training tricks and tons of regularization strategies. In our ConTNet, these drawbacks are alleviated through the combination of convolution and transformer. Two perspectives are offered to understand the motivation. **From the view of ConvNet**, the transformer sub-layer is inserted between any two conv layers to enhance the non-local interactions of ConvNet. **From the view of Transformer**, the presence of convolution layers reintroduces the inductive bias as a cause of under-fitting. Through numerical experiments, we find that ConTNet achieves competitive performance on image recognition and downstream tasks. More notably, ConTNet can be optimized easily even in the same way as ResNet. 9 | 10 | ![image](https://github.com/yan-hao-tian/ConTNet/blob/main/arch5.png) 11 | ![image](https://github.com/yan-hao-tian/ConTNet/blob/main/block2.png) 12 | ![image](https://github.com/yan-hao-tian/ConTNet/blob/main/block3.png) 13 | ## Training & Validation with this Repo 14 | We give an example of one machine multi-gpus training. 15 | ``` 16 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial_cont_m --save_best True 17 | ``` 18 | To validate a model, please add the arg ```--eval ```. 19 | ``` 20 | CUDA_VISIBLE_DEVICES=0 python3 -m torch.distributed.launch --nproc_per_node=1 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial --eval ./debug_trial_cont_m/checkpoint_bestTop1.pth 21 | ``` 22 | To implement resume training, please add the arg ```--resume```. 23 | ``` 24 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 29501 main.py --arch ConT-M --batch_size 256 --save_path debug_trial --save_best True --resume ./debug_trial_cont_m/checkpoint_bestTop1.pth 25 | ``` 26 | ## Pretrained Weights on ImageNet 27 | ImageNet-pretrained weights are available from [Google Drive][1] or [Baidu Cloud][2](the code is 3k3s). 28 | 29 | ## Main Results on ImageNet 30 | 31 | | name | resolution | acc@1 | #params(M) | FLOPs(G) | model | 32 | | ---- | ---- | ---- | ---- | ---- | ---- | 33 | | Res-18 | 224x224 | 71.5 | 11.7 | 1.8 | | 34 | | ConT-S | 224x224 | **74.9** | 10.1 | 1.5 | | 35 | | Res-50 | 224x224 | 77.1 | 25.6 | 4.0 | | 36 | | ConT-M | 224x224 | **77.6** | 19.2 | 3.1 | | 37 | | Res-101 | 224x224 | **78.2** | 44.5 | 7.6 | | 38 | | ConT-B | 224x224 | 77.9 | 39.6 | 6.4 | | 39 | | DeiT-Ti* | 224x224 | 72.2 | 5.7 | 1.3 | | 40 | | ConT-Ti* | 224x224 | **74.9**| 5.8 | 0.8 | | 41 | | Res-18* | 224x224 | 73.2 | 11.7 | 1.8 | | 42 | | ConT-S* | 224x224 | **76.5** | 10.1 | 1.5 | | 43 | | Res-50* | 224x224 | 78.6 | 25.6 | 4.0 | | 44 | | DeiT-S* | 224x224 | 79.8 | 22.1 | 4.6 | | 45 | | ConT-M* | 224x224 | **80.2** | 19.2 | 3.1 | | 46 | | Res-101* | 224x224 | 80.0 | 44.5 | 7.6 | | 47 | | DeiT-B* | 224x224 | **81.8** | 86.6 | 17.6| | 48 | | ConT-B* | 224x224 | **81.8** | 39.6 | 6.4 | | 49 | 50 | Note: * indicates training with strong augmentations(auto-augmentation and mixup). 51 | 52 | ## Main Results on Downstream Tasks 53 | 54 | Object detection results on COCO. 55 | 56 | | method | backbone | #params(M) | FLOPs(G) | AP | APs | APm | APl | 57 | | ---- | ---- | ---- | ---- | ---- | -------- | ----- | ----- | 58 | |RetinaNet| Res-50
ConTNet-M| 32.0
27.0 | 235.6
217.2 | 36.5
**37.9** | 20.4
**23.0** | 40.3
**40.6** | 48.1
**50.4** | 59 | | FCOS | Res-50
ConTNet-M| 32.2
27.2 | 242.9
228.4 | 38.7
**40.8** | 22.9
**25.1** | 42.5
**44.6** | 50.1
**53.0** | 60 | | faster rcnn | Res-50
ConTNet-M| 41.5
36.6 | 241.0
225.6 | 37.4
**40.0** | 21.2
**25.4** | 41.0
**43.0** | 48.1
**52.0** | 61 | 62 | Instance segmentation results on Cityscapes based on Mask-RCNN. 63 | | backbone | APbb | APsbb | APmbb | APlbb | APmk | APsmk | APmmk | APlmk | 64 | | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | ---- | 65 | | Res-50
ConT-M | 38.2
**40.5** | 21.9
**25.1** | 40.9
**44.4** | 49.5
**52.7** | 34.7
**38.1** | 18.3
**20.9** | 37.4
**41.0** | 47.2
**50.3** | 66 | 67 | Semantic segmentation results on cityscapes. 68 | | model | mIOU | 69 | | ----- | ---- | 70 | |PSP-Res50| 77.12 | 71 | |PSP-ConTM| **78.28** | 72 | 73 | ## Bib Citing 74 | ``` 75 | @article{yan2021contnet, 76 | title={ConTNet: Why not use convolution and transformer at the same time?}, 77 | author={Haotian Yan and Zhe Li and Weijian Li and Changhu Wang and Ming Wu and Chuang Zhang}, 78 | year={2021}, 79 | journal={arXiv preprint arXiv:2104.13497} 80 | } 81 | ``` 82 | 83 | [1]: https://drive.google.com/drive/folders/1ZXu--Bis3LTYLjf2pkmDtZH0TjuWWamO?usp=sharing 84 | [2]: https://pan.baidu.com/s/1thKK36jTFln1KcAuEkzleg 85 | -------------------------------------------------------------------------------- /arch5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/arch5.png -------------------------------------------------------------------------------- /block2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/block2.png -------------------------------------------------------------------------------- /block3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yan-hao-tian/ConTNet/a3699f49f5afbb9a9b264e9de270405ddef82f54/block3.png -------------------------------------------------------------------------------- /criterion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 3 | from timm.data import Mixup 4 | 5 | 6 | def build_criterion(mixup, label_smoothing): 7 | mixup_fn = None 8 | if mixup > 0.: 9 | criterion = SoftTargetCrossEntropy() 10 | 11 | mixup_fn = Mixup( 12 | mixup_alpha=mixup, cutmix_alpha=1, cutmix_minmax=None, 13 | prob=1, switch_prob=0.5, mode='batch', 14 | label_smoothing=label_smoothing, num_classes=1000) 15 | 16 | elif label_smoothing > 0.: 17 | criterion = LabelSmoothingCrossEntropy(smoothing=label_smoothing) 18 | else: 19 | criterion = nn.CrossEntropyLoss() 20 | 21 | return criterion, mixup_fn 22 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import torch 4 | from torchvision import datasets, transforms 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader, distributed 7 | 8 | from timm.data import create_transform 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | 11 | import torchvision 12 | 13 | import os 14 | import json 15 | from PIL import Image 16 | import pandas as pd 17 | from torch.utils.data import Dataset 18 | 19 | class MyDataSet(Dataset): 20 | 21 | def __init__(self, 22 | root_dir: str, 23 | csv_name: str, 24 | json_path: str, 25 | transform=None): 26 | images_dir = os.path.join(root_dir, "images") 27 | assert os.path.exists(images_dir), "dir:'{}' not found.".format(images_dir) 28 | 29 | assert os.path.exists(json_path), "file:'{}' not found.".format(json_path) 30 | self.label_dict = json.load(open(json_path, "r")) 31 | 32 | csv_path = os.path.join(root_dir, csv_name) 33 | assert os.path.exists(csv_path), "file:'{}' not found.".format(csv_path) 34 | csv_data = pd.read_csv(csv_path) 35 | self.total_num = csv_data.shape[0] 36 | self.img_paths = [os.path.join(images_dir, i)for i in csv_data["filename"].values] 37 | self.img_label = [self.label_dict[i][0] for i in csv_data["label"].values] 38 | self.labels = set(csv_data["label"].values) 39 | 40 | self.transform = transform 41 | 42 | def __len__(self): 43 | return self.total_num 44 | 45 | def __getitem__(self, item): 46 | img = Image.open(self.img_paths[item]) 47 | if img.mode != 'RGB': 48 | raise ValueError("image: {} isn't RGB mode.".format(self.img_paths[item])) 49 | label = self.img_label[item] 50 | 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | 54 | return img, label 55 | 56 | @staticmethod 57 | def collate_fn(batch): 58 | images, labels = tuple(zip(*batch)) 59 | 60 | images = torch.stack(images, dim=0) 61 | labels = torch.as_tensor(labels) 62 | return images, labels 63 | 64 | 65 | def build_loader(data_path, autoaug, batch_size, workers): 66 | 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | assert batch_size % world_size == 0, f'The batch size is indivisible by world size {batch_size} // {world_size}' 70 | 71 | train_transform = create_transform(input_size=224, 72 | is_training=True, 73 | auto_augment=autoaug) 74 | # train_dataset = MyDataSet(root_dir='./mini-imagenet', csv_name='new_train.csv', json_path='./mini-imagenet/classes_name.json', transform=train_transform) 75 | train_dataset = datasets.ImageFolder(osp.join(data_path, 'train'), transform=train_transform) 76 | 77 | train_sampler = distributed.DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) 78 | train_loader = DataLoader(train_dataset, 79 | batch_size=batch_size // world_size, 80 | shuffle=False, 81 | num_workers=workers, 82 | pin_memory=True, 83 | sampler=train_sampler) 84 | 85 | val_transform = transforms.Compose([ 86 | transforms.Resize(256), 87 | transforms.CenterCrop(224), 88 | transforms.ToTensor(), 89 | transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 90 | ]) 91 | val_dataset = datasets.ImageFolder(osp.join(data_path, 'val'), transform=val_transform) 92 | # val_dataset = MyDataSet(root_dir='./mini-imagenet', csv_name='new_val.csv', json_path='./mini-imagenet/classes_name.json', transform=val_transform) 93 | val_sampler = distributed.DistributedSampler(val_dataset, world_size, rank) 94 | val_loader = DataLoader(val_dataset, 95 | batch_size=batch_size // world_size, 96 | shuffle=False, 97 | num_workers=workers, 98 | pin_memory=True, 99 | sampler=val_sampler) 100 | 101 | return train_loader, val_loader 102 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.scheduler.cosine_lr import CosineLRScheduler 3 | 4 | def build_lr_scheduler(epoch, warmup_epoch, optimizer, n_iter_per_epoch): 5 | num_steps = int(epoch * n_iter_per_epoch) 6 | warmup_steps = int(warmup_epoch * n_iter_per_epoch) 7 | 8 | scheduler = CosineLRScheduler( 9 | optimizer, 10 | t_initial=num_steps, 11 | t_mul=1., 12 | lr_min=0, 13 | warmup_lr_init=0, 14 | warmup_t=warmup_steps, 15 | cycle_limit=1, 16 | t_in_epochs=False, 17 | ) 18 | 19 | return scheduler 20 | 21 | 22 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | 11 | from ConTNet import build_model 12 | from optimizer import build_optimizer 13 | from lr_scheduler import build_lr_scheduler 14 | from criterion import build_criterion 15 | from data import build_loader 16 | 17 | from utils import accuracy, reduce_tensor, resume_model, save_model 18 | from timm.utils import AverageMeter 19 | 20 | import warnings 21 | warnings.filterwarnings("ignore") 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='ConTNet') 26 | 27 | # data and model 28 | parser.add_argument('--data_path', type=str, help='path to dataset') 29 | parser.add_argument('--arch', type=str, default='ConT-M', 30 | choices=['ConT-M', 'ConT-B', 'ConT-S', 'ConT-Ti'], 31 | help='the architecture of ConTNet') 32 | 33 | # model hypeparameters 34 | parser.add_argument('--use_avgdown', type=bool, default=False, 35 | help='If True, using avgdown downsampling shortcut') 36 | parser.add_argument('--relative', type=bool, default=False, 37 | help='If True, using relative position embedding') 38 | parser.add_argument('--qkv_bias', type=bool, default=True) 39 | parser.add_argument('--pre_norm', type=bool, default=False) 40 | 41 | # base setting 42 | parser.add_argument('--eval', default=None, type=str, 43 | help='only validation') 44 | parser.add_argument('--batch_size', default=512, type=int, 45 | help='batch size') 46 | parser.add_argument('--workers', default=8, type=int, 47 | help='number of data loading workers') 48 | parser.add_argument('--epoch', default=200, type=int, 49 | help='number of total epochs to run') 50 | parser.add_argument('--warmup_epoch', default=10, type=int, 51 | help='the num of warmup epochs') 52 | parser.add_argument('--resume', default=None, type=str, 53 | help='resume file path') 54 | parser.add_argument('--init_lr', default=5e-4, type=float, 55 | help='a low initial learning rata for adamw optimizer') 56 | parser.add_argument('--wd', default=0.5, type=float, 57 | help='a high weight decay setting for adamw optimizer') 58 | parser.add_argument('--momentum', default=0.9, type=float, 59 | help='momentum for sgd') 60 | parser.add_argument('--optim', default='AdamW', type=str, choices=['AdamW', 'SGD'], 61 | help='optimizer supported by PyTorch') 62 | parser.add_argument('--print_freq', default=100, type=int, 63 | help='frequency of printing train info') 64 | parser.add_argument('--save_path', default='weights', type=str, 65 | help='the path to saving the checkpoints') 66 | parser.add_argument('--save_best', default=True, type=bool, 67 | help='saveing the checkpoint has the best acc') 68 | 69 | # aug® 70 | parser.add_argument('--mixup', default=0.8, type=float, 71 | help='using mixup and set alpha value') 72 | parser.add_argument('--autoaug', default='rand-m9-mstd0.5-inc1', type=str, 73 | help='using auto-augmentation') 74 | parser.add_argument('-ls','--label-smoothing', default=0.1, type=float, 75 | help='if > 0, using label-smothing') 76 | 77 | # distributed parallel triaining 78 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DDP') 79 | 80 | return parser.parse_args() 81 | 82 | 83 | def launch_worker(local_rank): 84 | # print(local_rank) 85 | if not torch.cuda.is_available(): 86 | raise ValueError(f'CPU-only training is not supported') 87 | torch.backends.cudnn.benchmark = True 88 | torch.cuda.set_device(local_rank) 89 | dist.init_process_group(backend='nccl', init_method='env://') 90 | dist.barrier() 91 | 92 | def train(loader, model, criterion, optimizer, mixup_fn, scheduler, print_freq, epoch): 93 | model.train() 94 | if dist.get_rank() == 0: 95 | print(f'\n=> Training epoch{epoch}') 96 | 97 | batch_time = AverageMeter() 98 | losses = AverageMeter() 99 | top1 = AverageMeter() 100 | top5 = AverageMeter() 101 | 102 | end = time.time() 103 | for i, (images, targets) in enumerate(loader): 104 | images = images.cuda(non_blocking=True) 105 | targets = targets.cuda(non_blocking=True) 106 | 107 | if mixup_fn: 108 | images, targets_ = mixup_fn(images, targets) 109 | 110 | # forward 111 | outputs = model(images) 112 | 113 | # update acc1, acc5 114 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 115 | acc1 = reduce_tensor(acc1) 116 | acc5 = reduce_tensor(acc5) 117 | top1.update(acc1.item(), targets.size(0)) 118 | top5.update(acc5.item(), targets.size(0)) 119 | 120 | # compute loss and backward 121 | loss = criterion(outputs, targets_) 122 | loss = reduce_tensor(loss) 123 | losses.update(loss.item(), targets_.size(0)) 124 | optimizer.zero_grad() 125 | loss.backward() 126 | optimizer.step() 127 | scheduler.step_update(epoch * len(loader) + i) 128 | 129 | # update using time 130 | interval = torch.tensor([time.time() - end]) 131 | interval = reduce_tensor(interval.cuda()) 132 | batch_time.update(interval.item()) 133 | end = time.time() 134 | 135 | if i % print_freq == 0 and dist.get_rank() == 0: 136 | lr = optimizer.param_groups[0]['lr'] 137 | sep = '| ' 138 | print(f'Epoch: [{epoch}] | [{i}/{len(loader)}] lr: {lr:.8f} '+ sep + 139 | f'loss {losses.val:.4f} ({losses.avg:.4f}) '+ sep + 140 | f'Top1.acc {top1.val:6.2f} ' + sep + 141 | f'Top5.acc {top5.val:6.2f} ' + sep + 142 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f}) ' + sep 143 | ) 144 | 145 | @torch.no_grad() 146 | def validate(val_loader, model, criterion, epoch=None): 147 | model.eval() 148 | 149 | batch_time = AverageMeter() 150 | losses = AverageMeter() 151 | top1 = AverageMeter() 152 | top5 = AverageMeter() 153 | 154 | end = time.time() 155 | for i, (images, targets) in enumerate(val_loader): 156 | images = images.cuda(non_blocking=True) 157 | targets = targets.cuda(non_blocking=True) 158 | 159 | # forward 160 | outputs = model(images) 161 | 162 | loss = criterion(outputs, targets) 163 | loss = reduce_tensor(loss) 164 | losses.update(loss.item(), images.size(0)) 165 | 166 | # update acc1, acc5 167 | acc1, acc5 = accuracy(outputs, targets, topk=(1, 5)) 168 | acc1 = reduce_tensor(acc1) 169 | acc5 = reduce_tensor(acc5) 170 | top1.update(acc1.item(), targets.size(0)) 171 | top5.update(acc5.item(), targets.size(0)) 172 | 173 | # update using time 174 | interval = torch.tensor([time.time() - end]) 175 | interval = reduce_tensor(interval.cuda()) 176 | batch_time.update(interval.item()) 177 | end = time.time() 178 | 179 | 180 | if dist.get_rank() == 0: 181 | stat = f"epoch {epoch}" if epoch is not None else "Only" 182 | print(f'=> Validation {stat}') 183 | sep = '| ' 184 | print(f'loss {losses.avg:.4f} '+ sep + 185 | f'Top1.acc {top1.avg:6.2f} ' + sep + 186 | f'Top5.acc {top5.avg:6.2f} ' + sep + 187 | f'time {batch_time.avg:.4f} ' + sep 188 | ) 189 | 190 | return top1.avg, top5.avg, losses.avg 191 | 192 | def main(config): 193 | # set up ddp 194 | launch_worker(config.local_rank) 195 | # build loader 196 | train_loader, val_loader = build_loader(config.data_path, config.autoaug, config.batch_size, config.workers) 197 | # build model 198 | model=build_model(config.arch, config.use_avgdown, config.relative, config.qkv_bias, config.pre_norm) 199 | model = DDP(model.cuda(), device_ids=[config.local_rank]) 200 | # build optimizer 201 | optimizer=build_optimizer(model, config.optim, config.init_lr, config.wd, config.momentum) 202 | # build learning scheduler 203 | scheduler=build_lr_scheduler(config.epoch, config.warmup_epoch, optimizer, len(train_loader)) 204 | # build criterion and mixup 205 | train_criterion, mixup_fn =build_criterion(config.mixup, config.label_smoothing) 206 | val_criterion = torch.nn.CrossEntropyLoss() 207 | # init acc1 and start epoch 208 | best_acc1 = 0.0 209 | start_epoch = 0 210 | 211 | # only validation 212 | if config.eval: 213 | if os.path.isfile(config.eval): 214 | model.load_state_dict(torch.load(config.eval)['model']) 215 | validate(val_loader, model, val_criterion) 216 | return 217 | else: 218 | print(f"=> !!!!!!! no checkpoint found at '{config.eval}'\n") 219 | print(f"=> !!!!!!! validation is stopped") 220 | return 221 | 222 | # resume training 223 | if not config.resume: 224 | print(f"=>Training is from scratch") 225 | else: 226 | if os.path.isfile(config.resume): 227 | model, optimizer, scheduler, start_epoch, best_acc1 = resume_model(config.resume, model, optimizer, scheduler) 228 | else: 229 | print(f"=> !!!!!!! no checkpoint found at '{config.resume}'\n") 230 | 231 | # training 232 | for epoch in range(start_epoch, args.epoch): 233 | train_loader.sampler.set_epoch(epoch) 234 | 235 | train(train_loader, model, train_criterion, optimizer, mixup_fn, scheduler, config.print_freq, epoch) 236 | 237 | acc1, acc5, loss = validate(val_loader, model, val_criterion, epoch) 238 | 239 | best_acc1 = max(best_acc1, acc1) 240 | is_best = (best_acc1 == acc1) 241 | 242 | if dist.get_rank() == 0: 243 | print('\n******************\t', 244 | f'\nBest Top1.acc {best_acc1:6.2f}\t', 245 | '\n******************\t') 246 | 247 | # save model 248 | if not config.save_best or is_best: 249 | save_model(config.save_path, model, optimizer, scheduler, best_acc1, epoch, is_best) 250 | 251 | 252 | 253 | if __name__ == '__main__': 254 | # build configs 255 | args = parse_args() 256 | # launch 257 | main(config=args) 258 | print('=> Finished!') 259 | 260 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def build_optimizer(model, optim, lr, wd, momentum): 4 | 5 | def _no_bias_decay(model): 6 | has_decay = [] 7 | no_decay = [] 8 | skip_list = ['relative_position_bias_table', 'pe'] 9 | 10 | for name, param in model.named_parameters(): 11 | if not param.requires_grad: 12 | continue 13 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list): 14 | no_decay.append(param) 15 | else: 16 | has_decay.append(param) 17 | 18 | assert len(list(model.parameters())) == len(has_decay) + len(no_decay), '{} vs. {}'.format( 19 | len(list(model.parameters())), len(has_decay) + len(no_decay)) 20 | 21 | return [{'params': has_decay}, 22 | {'params': no_decay, 'weight_decay': 0.}] 23 | 24 | parameters = _no_bias_decay(model) 25 | kwargs = dict(lr=lr, weight_decay=wd) 26 | if optim.lower() == 'SGD': 27 | kwargs['momentum'] = momentum 28 | 29 | optimizer = getattr(torch.optim, optim)(params=parameters, **kwargs) 30 | 31 | return optimizer 32 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import os 4 | 5 | def accuracy(output, target, topk=(1,)): 6 | """Computes the accuracy over the k top predictions for the specified values of k""" 7 | with torch.no_grad(): 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | 13 | pred = pred.t() 14 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 15 | 16 | res = [] 17 | for k in topk: 18 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 19 | res.append(correct_k.mul_(100.0 / batch_size)) 20 | return res 21 | 22 | def reduce_tensor(tensor): 23 | rt = tensor.clone() 24 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 25 | rt /= dist.get_world_size() 26 | return rt 27 | 28 | def resume_model(resume_path, model, optimizer, scheduler): 29 | print(f"=> loading checkpoint '{resume_path}'") 30 | checkpoint = torch.load(resume_path) 31 | start_epoch = checkpoint['epoch'] 32 | best_acc1 = checkpoint['best_acc1'] 33 | best_epoch = checkpoint['best_epoch'] 34 | model.load_state_dict(checkpoint['model']) 35 | optimizer.load_state_dict(checkpoint['optimizer']) 36 | scheduler.load_state_dict(checkpoint['scheduler']) 37 | print(f"=> loaded checkpoint successfully '{resume_path}' (epoch {start_epoch})") 38 | 39 | return model, optimizer, scheduler, start_epoch, best_acc1, best_epoch 40 | 41 | def save_model(save_path, model, optimizer, scheduler, best_acc1, epoch, is_best): 42 | save_state = {'model': model.state_dict(), 43 | 'optimizer': optimizer.state_dict(), 44 | 'scheduler': scheduler.state_dict(), 45 | 'best_acc1': best_acc1, 46 | 'epoch': epoch} 47 | 48 | os.makedirs(save_path, exist_ok=True) 49 | checkpoint_name = f'checkpoint_bestTop1.pth' if is_best else f'checkpoint_{epoch}.pth' 50 | save_path = os.path.join(save_path, checkpoint_name) 51 | torch.save(save_state, save_path) 52 | print(f'=> Saved checkpoint of epoch {epoch} to {save_path}') --------------------------------------------------------------------------------