├── ReadME.md ├── data ├── ImagenetDataset.py ├── __init__.py └── samplers.py ├── imgs └── model.png ├── model └── Transformers │ ├── CMT │ ├── __init__.py │ └── cmt.py │ └── __init__.py ├── test.py ├── test.sh ├── train.py ├── train.sh └── utils ├── __init__.py ├── augments.py ├── calculate_acc.py ├── optimizer_step.py └── precise_bn.py /ReadME.md: -------------------------------------------------------------------------------- 1 | # CMT: Convolutional Neural Networks Meet Vision Transformers 2 | 3 | [[arxiv](https://arxiv.org/abs/2107.06263)] 4 | 5 | ### 1. Introduction 6 | 7 | ![model](imgs/model.png) 8 | This repo is the CMT model which impelement with pytorch, no reference source code so this is a **non-official** version. 9 | 10 | ### 2. Enveriments 11 | - python 3.7+ 12 | - pytorch 1.7.1 13 | - pillow 14 | - apex 15 | - opencv-python 16 | 17 | You can see this [repo](https://github.com/NVIDIA/apex) to find how to install the apex 18 | 19 | ### 3. DataSet 20 | - **Trainig** 21 | ``` 22 | /data/home/imagenet/train/xxx.jpeg, 0 23 | /data/home/imagenet/train/xxx.jpeg, 1 24 | ... 25 | /data/home/imagenet/train/xxx.jpeg, 999 26 | ``` 27 | - **Testing** 28 | ``` 29 | /data/home/imagenet/test/xxx.jpeg, 0 30 | /data/home/imagenet/test/xxx.jpeg, 1 31 | ... 32 | /data/home/imagenet/test/xxx.jpeg, 999 33 | ``` 34 | 35 | ### 4. Training & Inference 36 | 37 | 1. Training 38 | 39 | **CMT-Tiny** 40 | ```bash 41 | #!/bin/bash 42 | OMP_NUM_THREADS=1 43 | MKL_NUM_THREADS=1 44 | export OMP_NUM_THREADS 45 | export MKL_NUM_THREADS 46 | cd CMT-pytorch; 47 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train.py --batch_size 512 --num_workers 48 --lr 6e-3 --optimizer_name "adamw" --tf_optimizer 1 --cosine 1 --model_name cmtti --max_epochs 300 \ 48 | --warmup_epochs 5 --num-classes 1000 --input_size 184 \ --crop_size 160 --weight_decay 1e-1 --grad_clip 0 --repeated-aug 0 --max_grad_norm 5.0 49 | --drop_path_rate 0.1 --FP16 0 --qkv_bias 1 50 | --ape 0 --rpe 1 --pe_nd 0 --mode O2 --amp 1 --apex 0 \ 51 | --train_file $file_folder$/train.txt \ 52 | --val_file $file_folder$/val.txt \ 53 | --log-dir $save_folder$/log_dir \ 54 | --checkpoints-path $save_folder$/checkpoints 55 | 56 | ``` 57 | 58 | **Note**: If you use the bs 128 * 8 may be get more accuracy, balance the acc & speed. 59 | 60 | 61 | 2. Inference 62 | ```bash 63 | #!/bin/bash 64 | cd CMT-pytorch; 65 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore test.py \ 66 | --dist-url 'tcp://127.0.0.1:9966' --dist-backend 'nccl' --multiprocessing-distributed=1 --world-size=1 --rank=0 67 | --batch-size 128 --num-workers 48 --num-classes 1000 --input_size 184 --crop_size 160 \ 68 | --ape 0 --rpe 1 --pe_nd 0 --qkv_bias 1 --swin 0 --model_name cmtti --dropout 0.1 --emb_dropout 0.1 \ 69 | --test_file $file_folder$/val.txt \ 70 | --checkpoints-path $save_folder$/checkpoints/xxx.pth.tar \ 71 | --save_folder $save_folder$/acc_logits/ 72 | ``` 73 | 74 | 3. calculate acc 75 | ```python 76 | python utils/calculate_acc.py --logits_file $save_folder$/acc_logits/ 77 | ``` 78 | 79 | ### 5. Imagenet Result 80 | 81 | |model-name|input_size|FLOPs|Params|acc@one_crop(ours)|acc(papers)|weights| 82 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:| 83 | |CMT-T|160x160|516M|11.3M|75.124%|79.2%|[weights](https://drive.google.com/file/d/1YngcCchrJ43bVWxuy4OiTfwy76gQyIBk/view?usp=sharing)| 84 | |CMT-T|224x224|1.01G|11.3M|78.4%|-|[weights](https://drive.google.com/file/d/11fK2rYxPPvFZOZPd1VpJ0mOK0sLB99OS/view?usp=sharing)| 85 | |CMT-XS|192x192|-|-|-|81.8%|-| 86 | |CMT-S|224x224|-|-|-|83.5%|-| 87 | |CMT-L|256x256|-|-|-|84.5%|-| 88 | 89 | 90 | ### 6. TODO 91 | - [ ] Other result may comming sonn if someone need. 92 | - [ ] Release the CMT-XS result on the imagenet. 93 | - [x] Check the diff with papers, **author give the hyparameters on the issue** 94 | - [x] Adjusting the best hyperparameters for CMT or transformers 95 | 96 | ### Supplementary 97 | If you want to know more, I give the CMT explanation, as well as the tuning and training process on [here](https://zhuanlan.zhihu.com/p/398019698). 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /data/ImagenetDataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | @datetime: 2021-06-28 4 | @author : jiangmingchao@joyy.sg 5 | @describe: Imagenet dataset 6 | """ 7 | from math import e 8 | import torch 9 | import random 10 | import numpy as np 11 | import urllib.request as urt 12 | 13 | from PIL import Image 14 | from io import BytesIO 15 | from torch.utils.data.dataset import Dataset 16 | from torchvision.transforms import transforms as imagenet_transforms 17 | from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform 18 | from timm.data.transforms import _pil_interp 19 | 20 | 21 | class ImageDataset(Dataset): 22 | def __init__(self, 23 | image_file, 24 | train_phase, 25 | input_size, 26 | crop_size, 27 | shuffle=True, 28 | interpolation='random', 29 | auto_augment="rand", 30 | color_prob=0.4, 31 | hflip_prob=0.5, 32 | vflip_prob=0.0, 33 | 34 | ) -> None: 35 | """image dataset 36 | Args: 37 | image_file (str) : the image 38 | train_phase (bool) : train or eval mode 39 | input_size (int) : resize size 40 | crop_size (int) : crop size 41 | shuffle (bool) : shuffle or use the origin order 42 | interpolation (str): 'bilinear' , 'nearst' , 'bicbic' or 'random' default is 'bilinear' 43 | auto_augment (str): 'original', 'rand', 'automix' or None 44 | color_prob (float): None or prob for colorjitter 45 | hflip_prob (float) : prob for horizional flip 46 | vflip_prob (float) : prob for vertical flip 47 | Returns: 48 | dataset 49 | """ 50 | 51 | super(ImageDataset, self).__init__() 52 | self.image_file = image_file 53 | self.image_list = [x.strip() 54 | for x in open(self.image_file).readlines()] 55 | self.length = [x for x in range(len(self.image_list))] 56 | self.train_phase = train_phase 57 | self.input_size = input_size 58 | self.crop_size = crop_size 59 | self.shuffle = shuffle 60 | self.mean = [0.485, 0.456, 0.406] 61 | self.std = [0.229, 0.224, 0.225] 62 | self.hflip_prob = hflip_prob 63 | self.vflip_prob = vflip_prob 64 | 65 | if self.shuffle and self.train_phase: 66 | for _ in range(10): 67 | random.shuffle(self.image_list) 68 | 69 | self.colorjitter_prob = None if color_prob is None else ( 70 | color_prob, )*3 71 | self.auto_augment = auto_augment 72 | self.interpolation = interpolation 73 | 74 | # train 75 | if self.train_phase: 76 | basic_tf = [ 77 | imagenet_transforms.RandomResizedCrop( 78 | (self.crop_size, self.crop_size)), 79 | imagenet_transforms.RandomHorizontalFlip(self.hflip_prob), 80 | imagenet_transforms.RandomVerticalFlip(self.vflip_prob), 81 | ] 82 | 83 | auto_tf = [] 84 | if self.auto_augment: 85 | assert isinstance(auto_augment, str) 86 | if isinstance(self.crop_size, (tuple, list)): 87 | img_size_min = min(self.crop_size) 88 | else: 89 | img_size_min = self.crop_size 90 | 91 | aa_params = dict( 92 | translate_dict=int(img_size_min * 0.45), 93 | img_mean=tuple([min(255, round(255 * x)) 94 | for x in self.mean]) 95 | ) 96 | if self.interpolation and self.interpolation != "random": 97 | aa_params['interpolation'] = _pil_interp( 98 | self.interpolation) 99 | # rand aug 100 | if auto_augment.startswith('rand'): 101 | auto_tf += [rand_augment_transform( 102 | auto_augment, aa_params)] 103 | # augmix 104 | elif auto_augment.startswith('augmix'): 105 | aa_params['translate_pct'] = 0.3 106 | auto_tf += [augment_and_mix_transform( 107 | auto_augment, aa_params)] 108 | # auto aug 109 | else: 110 | auto_tf += [auto_augment_transform( 111 | auto_augment, aa_params)] 112 | 113 | if self.colorjitter_prob is not None: 114 | auto_tf += [ 115 | imagenet_transforms.ColorJitter(*self.colorjitter_prob) 116 | ] 117 | 118 | final_tf = [ 119 | imagenet_transforms.ToTensor(), 120 | imagenet_transforms.Normalize( 121 | mean=self.mean, 122 | std=self.std 123 | ) 124 | ] 125 | self.data_aug = imagenet_transforms.Compose( 126 | basic_tf + auto_tf + final_tf 127 | ) 128 | 129 | print(self.data_aug) 130 | 131 | # test 132 | else: 133 | self.data_aug = imagenet_transforms.Compose([ 134 | imagenet_transforms.Resize(int(256 / 224 * self.crop_size)), 135 | imagenet_transforms.CenterCrop( 136 | (self.crop_size, self.crop_size)), 137 | imagenet_transforms.ToTensor(), 138 | imagenet_transforms.Normalize( 139 | mean=self.mean, 140 | std=self.std 141 | ) 142 | ]) 143 | 144 | def _decode_image(self, image_path): 145 | if "http" in image_path: 146 | image = Image.open(BytesIO(urt.urlopen(image_path).read())) 147 | else: 148 | image = Image.open(image_path) 149 | 150 | if image.mode != "RGB": 151 | image = image.convert("RGB") 152 | 153 | return image 154 | 155 | def __getitem__(self, index): 156 | for _ in range(10): 157 | try: 158 | line = self.image_list[index] 159 | image_path, image_label = line.split( 160 | ',')[0], line.split(',')[1] 161 | image = self._decode_image(image_path) 162 | image = self.data_aug(image) 163 | label = torch.from_numpy(np.array(int(image_label))).long() 164 | return image, label, image_path 165 | except Exception as e: 166 | index = random.choice(self.length) 167 | print(f"The exception is {e}, image path is {image_path}!!!") 168 | 169 | def __len__(self): 170 | return len(self.image_list) 171 | 172 | # val 173 | 174 | 175 | class ImageDatasetTest(Dataset): 176 | def __init__(self, 177 | image_file, 178 | train_phase, 179 | input_size, 180 | crop_size, 181 | shuffle=True, 182 | mode="cnn" 183 | ) -> None: 184 | super(ImageDatasetTest, self).__init__() 185 | self.image_file = image_file 186 | self.image_list = [x.strip() 187 | for x in open(self.image_file).readlines()] 188 | self.length = [x for x in range(len(self.image_list))] 189 | self.train_phase = train_phase 190 | self.input_size = input_size 191 | self.crop_size = crop_size 192 | self.shuffle = shuffle 193 | self.mean = [0.485, 0.456, 0.406] 194 | self.std = [0.229, 0.224, 0.225] 195 | self.mode = mode 196 | if self.shuffle and self.train_phase: 197 | for _ in range(10): 198 | random.shuffle(self.image_list) 199 | 200 | if self.mode == "cnn": 201 | self.data_aug = imagenet_transforms.Compose( 202 | [ 203 | imagenet_transforms.Resize(int(256 / 224 * self.crop_size)), 204 | imagenet_transforms.CenterCrop(self.crop_size), 205 | imagenet_transforms.ToTensor(), 206 | imagenet_transforms.Normalize( 207 | mean=self.mean, 208 | std=self.std 209 | ) 210 | ] 211 | ) 212 | elif self.mode == "transformers": 213 | self.data_aug = imagenet_transforms.Compose( 214 | [ 215 | imagenet_transforms.Resize( 216 | (self.crop_size, self.crop_size)), 217 | imagenet_transforms.ToTensor(), 218 | imagenet_transforms.Normalize( 219 | mean=self.mean, 220 | std=self.std 221 | ) 222 | ] 223 | ) 224 | 225 | def _decode_image(self, image_path): 226 | if "http" in image_path: 227 | image = Image.open(BytesIO(urt.urlopen(image_path).read())) 228 | else: 229 | image = Image.open(image_path) 230 | 231 | if image.mode != "RGB": 232 | image = image.convert("RGB") 233 | 234 | return image 235 | 236 | def __getitem__(self, index): 237 | for _ in range(10): 238 | try: 239 | line = self.image_list[index] 240 | if len(line.split(',')) >= 2: 241 | image_path, image_label = line.split( 242 | ',')[0], line.split(',')[1] 243 | label = torch.from_numpy(np.array(int(image_label))).long() 244 | else: 245 | image_path = line 246 | label = torch.from_numpy(np.array(0)).long() 247 | 248 | image = self._decode_image(image_path) 249 | image = self.data_aug(image) 250 | 251 | return image, label, image_path 252 | 253 | except Exception as e: 254 | index = random.choice(self.length) 255 | print(f"The exception is {e}, image path is {image_path}!!!") 256 | 257 | def __len__(self): 258 | return len(self.image_list) 259 | 260 | 261 | if __name__ == "__main__": 262 | train_file = "/data/jiangmingchao/data/dataset/imagenet/val_oss_imagenet_128w.txt" 263 | train_dataset = ImageDataset( 264 | image_file=train_file, 265 | train_phase=True, 266 | input_size=224, 267 | crop_size=224, 268 | shuffle=True, 269 | interpolation='bilinear', 270 | auto_augment="rand" 271 | ) 272 | print(train_dataset) 273 | print(len(train_dataset)) 274 | for idx, data in enumerate(train_dataset): 275 | print(f"{idx}", data[0].shape, data[1]) 276 | break 277 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/CMT-pytorch/1ce3c8c9732b09634fffe973b43193ada09c0844/data/__init__.py -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /imgs/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/CMT-pytorch/1ce3c8c9732b09634fffe973b43193ada09c0844/imgs/model.png -------------------------------------------------------------------------------- /model/Transformers/CMT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/CMT-pytorch/1ce3c8c9732b09634fffe973b43193ada09c0844/model/Transformers/CMT/__init__.py -------------------------------------------------------------------------------- /model/Transformers/CMT/cmt.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author : jiangmingchao 3 | @datetime : 20210716 4 | @paper : CMT: Convolutional Neural Networks Meet Vision Transformers 5 | @email : jiangmingchao@joyy.sg 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | from einops import rearrange, repeat, reduce 13 | from timm.models.layers import DropPath, trunc_normal_ 14 | 15 | 16 | def make_pairs(x): 17 | """make the int -> tuple 18 | """ 19 | return x if isinstance(x, tuple) else (x, x) 20 | 21 | 22 | def generate_relative_distance(number_size): 23 | """return relative distance, (number_size**2, number_size**2, 2) 24 | """ 25 | indices = torch.tensor(np.array([[x, y] for x in range(number_size) for y in range(number_size)])) 26 | distances = indices[None, :, :] - indices[:, None, :] 27 | distances = distances + number_size - 1 # shift the zeros postion 28 | return distances 29 | 30 | 31 | class CMTLayers(nn.Module): 32 | def __init__(self, dim, num_heads=8, ffn_ratio = 4., 33 | relative_pos_embeeding=True, no_distance_pos_embeeding=False, 34 | features_size=56, qkv_bias=False, qk_scale=None, 35 | attn_drop=0., proj_drop=0., sr_ratio=1. , drop_path_rate=0.): 36 | super(CMTLayers, self).__init__() 37 | 38 | self.dim = dim 39 | self.ffn_ratio = ffn_ratio 40 | 41 | self.norm1 = nn.LayerNorm(self.dim) 42 | self.norm2 = nn.LayerNorm(self.dim) 43 | self.LPU = LocalPerceptionUint(self.dim) 44 | self.LMHSA = LightMutilHeadSelfAttention( 45 | dim = self.dim, 46 | num_heads = num_heads, 47 | relative_pos_embeeding = relative_pos_embeeding, 48 | no_distance_pos_embeeding = no_distance_pos_embeeding, 49 | features_size = features_size, 50 | qkv_bias = qkv_bias, 51 | qk_scale = qk_scale, 52 | attn_drop= attn_drop, 53 | proj_drop=proj_drop, 54 | sr_ratio=sr_ratio 55 | ) 56 | self.IRFFN = InvertedResidualFeedForward(self.dim, self.ffn_ratio) 57 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 58 | 59 | def forward(self, x): 60 | lpu = self.LPU(x) 61 | x = x + lpu 62 | 63 | b, c, h, w = x.shape 64 | x_1 = rearrange(x, 'b c h w -> b ( h w ) c ') 65 | norm1 = self.norm1(x_1) 66 | norm1 = rearrange(norm1, 'b ( h w ) c -> b c h w', h=h, w=w) 67 | attn = self.LMHSA(norm1) 68 | x = x + attn 69 | 70 | b, c, h, w = x.shape 71 | x_2 = rearrange(x, 'b c h w -> b ( h w ) c ') 72 | norm2 = self.norm2(x_2) 73 | norm2 = rearrange(norm2, 'b ( h w ) c -> b c h w', h=h, w=w) 74 | ffn = self.IRFFN(norm2) 75 | x = x + self.drop_path(ffn) 76 | 77 | return x 78 | 79 | 80 | class CMTBlock(nn.Module): 81 | def __init__(self, dim, num_heads=8, ffn_ratio=4., 82 | relative_pos_embeeding=True, no_distance_pos_embeeding=False, 83 | features_size=56, qkv_bias=False, qk_scale=None, 84 | attn_drop=0., proj_drop=0., sr_ratio=1., num_layers=1, drop_path_rate=[0.1]): 85 | super(CMTBlock, self).__init__() 86 | self.dim = dim 87 | self.num_layers = num_layers 88 | self.ffn_ratio = ffn_ratio 89 | 90 | self.block_list = nn.ModuleList([CMTLayers( 91 | dim = self.dim, 92 | ffn_ratio = self.ffn_ratio, 93 | relative_pos_embeeding = relative_pos_embeeding, 94 | no_distance_pos_embeeding = no_distance_pos_embeeding, 95 | features_size = features_size, 96 | num_heads = num_heads, 97 | qkv_bias = qkv_bias, 98 | qk_scale = qk_scale, 99 | attn_drop = attn_drop, 100 | proj_drop = proj_drop, 101 | sr_ratio = sr_ratio, 102 | drop_path_rate = drop_path_rate[i] 103 | ) for i in range(num_layers)] 104 | ) 105 | 106 | def forward(self, x): 107 | for block in self.block_list: 108 | x = block(x) 109 | return x 110 | 111 | 112 | class LocalPerceptionUint(nn.Module): 113 | def __init__(self, dim, act=False): 114 | super(LocalPerceptionUint, self).__init__() 115 | self.act = act 116 | self.conv_3x3_dw = ConvDW3x3(dim) 117 | if self.act: 118 | self.actation = nn.Sequential( 119 | nn.GELU(), 120 | nn.BatchNorm2d(dim) 121 | ) 122 | 123 | def forward(self, x): 124 | if self.act: 125 | out = self.actation(self.conv_3x3_dw(x)) 126 | return out 127 | else: 128 | out = self.conv_3x3_dw(x) 129 | return out 130 | 131 | 132 | class LightMutilHeadSelfAttention(nn.Module): 133 | """calculate the self attention with down sample the resolution for k, v, add the relative position bias before softmax 134 | Args: 135 | dim (int) : features map channels or dims 136 | num_heads (int) : attention heads numbers 137 | relative_pos_embeeding (bool) : relative position embeeding 138 | no_distance_pos_embeeding (bool): no_distance_pos_embeeding 139 | features_size (int) : features shape 140 | qkv_bias (bool) : if use the embeeding bias 141 | qk_scale (float) : qk scale if None use the default 142 | attn_drop (float) : attention dropout rate 143 | proj_drop (float) : project linear dropout rate 144 | sr_ratio (float) : k, v resolution downsample ratio 145 | Returns: 146 | x : LMSA attention result, the shape is (B, H, W, C) that is the same as inputs. 147 | """ 148 | def __init__(self, dim, num_heads=8, features_size=56, 149 | relative_pos_embeeding=False, no_distance_pos_embeeding=False, qkv_bias=False, qk_scale=None, 150 | attn_drop=0., proj_drop=0., sr_ratio=1.): 151 | super(LightMutilHeadSelfAttention, self).__init__() 152 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}" 153 | self.dim = dim 154 | self.num_heads = num_heads 155 | head_dim = dim // num_heads # used for each attention heads 156 | self.scale = qk_scale or head_dim ** -0.5 157 | 158 | self.relative_pos_embeeding = relative_pos_embeeding 159 | self.no_distance_pos_embeeding = no_distance_pos_embeeding 160 | 161 | self.features_size = features_size 162 | 163 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 164 | self.kv = nn.Linear(dim, dim*2, bias=qkv_bias) 165 | self.attn_drop = nn.Dropout(attn_drop) 166 | self.proj = nn.Linear(dim, dim) 167 | self.proj_drop = nn.Dropout(proj_drop) 168 | self.softmax = nn.Softmax(dim=-1) 169 | 170 | self.sr_ratio = sr_ratio 171 | if sr_ratio > 1: 172 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 173 | self.norm = nn.LayerNorm(dim) 174 | 175 | if self.relative_pos_embeeding: 176 | self.relative_indices = generate_relative_distance(self.features_size) 177 | self.position_embeeding = nn.Parameter(torch.randn(2 * self.features_size - 1, 2 * self.features_size - 1)) 178 | elif self.no_distance_pos_embeeding: 179 | self.position_embeeding = nn.Parameter(torch.randn(self.features_size ** 2, self.features_size ** 2)) 180 | else: 181 | self.position_embeeding = None 182 | 183 | if self.position_embeeding is not None: 184 | trunc_normal_(self.position_embeeding, std=0.2) 185 | 186 | def forward(self, x): 187 | B, C, H, W = x.shape 188 | N = H*W 189 | x_q = rearrange(x, 'B C H W -> B (H W) C') # translate the B,C,H,W to B (H X W) C 190 | q = self.q(x_q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B,N,H,DIM -> B,H,N,DIM 191 | 192 | # conv for down sample the x resoution for the k, v 193 | if self.sr_ratio > 1: 194 | x_reduce_resolution = self.sr(x) 195 | x_kv = rearrange(x_reduce_resolution, 'B C H W -> B (H W) C ') 196 | x_kv = self.norm(x_kv) 197 | else: 198 | x_kv = rearrange(x, 'B C H W -> B (H W) C ') 199 | 200 | kv_emb = rearrange(self.kv(x_kv), 'B N (dim h l ) -> l B h N dim', h=self.num_heads, l=2) # 2 B H N DIM 201 | k, v = kv_emb[0], kv_emb[1] 202 | 203 | attn = (q @ k.transpose(-2, -1)) * self.scale # (B H Nq DIM) @ (B H DIM Nk) -> (B H NQ NK) 204 | 205 | # TODO: add the relation position bias, because the k_n != q_n, we need to split the position embeeding matrix 206 | q_n, k_n = q.shape[1], k.shape[2] 207 | 208 | if self.relative_pos_embeeding: 209 | attn = attn + self.position_embeeding[self.relative_indices[:, :, 0], self.relative_indices[:, :, 1]][:, :k_n] 210 | elif self.no_distance_pos_embeeding: 211 | attn = attn + self.position_embeeding[:, :k_n] 212 | 213 | attn = self.softmax(attn) 214 | attn = self.attn_drop(attn) 215 | 216 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) # (B H NQ NK) @ (B H NK dim) -> (B NQ H*DIM) 217 | x = self.proj(x) 218 | x = self.proj_drop(x) 219 | 220 | x = rearrange(x, 'B (H W) C -> B C H W ', H=H, W=W) 221 | return x 222 | 223 | 224 | class InvertedResidualFeedForward(nn.Module): 225 | def __init__(self, dim, dim_ratio=4.): 226 | super(InvertedResidualFeedForward, self).__init__() 227 | output_dim = int(dim_ratio * dim) 228 | self.conv1x1_gelu_bn = ConvGeluBN( 229 | in_channel=dim, 230 | out_channel=output_dim, 231 | kernel_size=1, 232 | stride_size=1, 233 | padding=0 234 | ) 235 | self.conv3x3_dw = ConvDW3x3(dim=output_dim) 236 | self.act = nn.Sequential( 237 | nn.GELU(), 238 | nn.BatchNorm2d(output_dim) 239 | ) 240 | self.conv1x1_pw = nn.Sequential( 241 | nn.Conv2d(output_dim, dim, 1, 1, 0), 242 | nn.BatchNorm2d(dim) 243 | ) 244 | 245 | def forward(self, x): 246 | x = self.conv1x1_gelu_bn(x) 247 | out = x + self.act(self.conv3x3_dw(x)) 248 | out = self.conv1x1_pw(out) 249 | return out 250 | 251 | 252 | class ConvDW3x3(nn.Module): 253 | def __init__(self, dim, kernel_size=3): 254 | super(ConvDW3x3, self).__init__() 255 | self.conv = nn.Conv2d( 256 | in_channels=dim, 257 | out_channels=dim, 258 | kernel_size=make_pairs(kernel_size), 259 | padding=make_pairs(1), 260 | groups=dim) 261 | 262 | def forward(self, x): 263 | x = self.conv(x) 264 | return x 265 | 266 | 267 | class ConvGeluBN(nn.Module): 268 | def __init__(self, in_channel, out_channel, kernel_size, stride_size, padding=1): 269 | """build the conv3x3 + gelu + bn module 270 | """ 271 | super(ConvGeluBN, self).__init__() 272 | self.kernel_size = make_pairs(kernel_size) 273 | self.stride_size = make_pairs(stride_size) 274 | self.padding_size = make_pairs(padding) 275 | self.in_channel = in_channel 276 | self.out_channel = out_channel 277 | self.conv3x3_gelu_bn = nn.Sequential( 278 | nn.Conv2d(in_channels=self.in_channel, 279 | out_channels=self.out_channel, 280 | kernel_size=self.kernel_size, 281 | stride=self.stride_size, 282 | padding=self.padding_size), 283 | nn.GELU(), 284 | nn.BatchNorm2d(self.out_channel) 285 | ) 286 | 287 | def forward(self, x): 288 | x = self.conv3x3_gelu_bn(x) 289 | return x 290 | 291 | 292 | class CMTStem(nn.Module): 293 | """make the model conv stem module 294 | """ 295 | def __init__(self, kernel_size, in_channel, out_channel, layers_num): 296 | super(CMTStem, self).__init__() 297 | self.layers_num = layers_num 298 | self.conv3x3_gelu_bn_downsample = ConvGeluBN( 299 | in_channel=in_channel, 300 | out_channel=out_channel, 301 | kernel_size=kernel_size, 302 | stride_size=make_pairs(2) 303 | ) 304 | self.conv3x3_gelu_bn_list = nn.ModuleList( 305 | [ConvGeluBN(kernel_size=kernel_size, in_channel=out_channel, out_channel=out_channel, stride_size=1) for _ in range(self.layers_num)] 306 | ) 307 | 308 | def forward(self, x): 309 | x = self.conv3x3_gelu_bn_downsample(x) 310 | for i in range(self.layers_num): 311 | x = self.conv3x3_gelu_bn_list[i](x) 312 | return x 313 | 314 | 315 | class PatchAggregation(nn.Module): 316 | """down sample the feature resolution, build with conv 2x2 stride 2 317 | """ 318 | def __init__(self, in_channel, out_channel, kernel_size=2, stride_size=2): 319 | super(PatchAggregation, self).__init__() 320 | self.patch_aggregation = nn.Conv2d( 321 | in_channels=in_channel, 322 | out_channels=out_channel, 323 | kernel_size=make_pairs(kernel_size), 324 | stride=make_pairs(stride_size) 325 | ) 326 | 327 | def forward(self, x): 328 | x = self.patch_aggregation(x) 329 | return x 330 | 331 | # TODO: add the RPE 332 | class ConvolutionMeetVisionTransformers(nn.Module): 333 | def __init__(self, 334 | input_resolution: tuple, ape: bool , 335 | input_channels: int , dims_list: list, 336 | heads_list: list, block_list: list, 337 | sr_ratio_list: list, qkv_bias : bool, 338 | proj_drop: float, attn_drop: float, 339 | rpe: bool, pe_nd: bool, ffn_ratio: float, 340 | num_classes: int, drop_path_rate: float = 0.1 341 | ): 342 | """CMT implementation 343 | Args: 344 | input_resolution : (h, w) for image resolution 345 | ape: absoluate position embeeding (learnable) 346 | input_channels: images input channel, default 3 347 | dims_list : a list of each stage dimension 348 | heads_list : mutil head self-attention heads numbers 349 | block_list : cmt block numbers for each stage 350 | sr_ratio_list: k,v reduce ratio for each stage 351 | qkv_bias : use bias for qkv embeeding 352 | proj_drop : proj layer dropout 353 | attn_drop : attention dropout 354 | rpe : relative position embeeding (learnable ) 355 | pe_nd : no distance pos embeeding (learnable) 356 | ffn_ratio : ffn up & down dims 357 | num_classes : output numclasses 358 | drop_path_rate: Stochastic depth rate. Default: 0.1 359 | Return: 360 | cmt model 361 | """ 362 | super(ConvolutionMeetVisionTransformers, self).__init__() 363 | assert input_resolution[0]==input_resolution[1], "input must be square " 364 | 365 | self.input_resolution = input_resolution 366 | 367 | self.input_channels = input_channels 368 | self.dims_list = dims_list 369 | self.heads_list = heads_list 370 | self.block_list = block_list 371 | self.sr_ratio_list = sr_ratio_list 372 | 373 | # position embeeding 374 | self.ape = ape 375 | self.rpe = rpe 376 | self.pe_nd = pe_nd 377 | 378 | # ffn ratio 379 | self.ffn_ratio = ffn_ratio 380 | self.drop_path_rate = drop_path_rate 381 | 382 | self.qkv_bias = qkv_bias 383 | self.proj_drop = proj_drop 384 | self.attn_drop = attn_drop 385 | 386 | self.img_height = self.input_resolution[0] 387 | self.img_width = self.input_resolution[1] 388 | self.num_patches = (self.img_width // 4) * (self.img_height // 4) 389 | 390 | # absolate position embeeding, add after the first patch aggregation layers 391 | if self.ape: 392 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, dims_list[1])) 393 | trunc_normal_(self.absolute_pos_embed, std=0.2) 394 | 395 | # down sample the image with 2x 396 | features_downsample_raito = [ 397 | 2**1, 2**2, 2**3, 2**4, 2**5 398 | ] 399 | resolution_list = [self.input_resolution[0] // x for x in features_downsample_raito] 400 | print("resolution :", resolution_list) 401 | self.stem = CMTStem( 402 | kernel_size=3, 403 | in_channel=self.input_channels, 404 | out_channel=dims_list[0], 405 | layers_num=2 406 | ) 407 | # stochastic depth 408 | if self.drop_path_rate > 0.0: 409 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(block_list))] # stochastic depth decay rule 410 | else: 411 | dpr = [-1 for _ in sum(block_list)] 412 | 413 | self.pool1 = PatchAggregation(in_channel=dims_list[0], out_channel=dims_list[1]) 414 | self.pool2 = PatchAggregation(in_channel=dims_list[1], out_channel=dims_list[2]) 415 | self.pool3 = PatchAggregation(in_channel=dims_list[2], out_channel=dims_list[3]) 416 | self.pool4 = PatchAggregation(in_channel=dims_list[3], out_channel=dims_list[4]) 417 | 418 | self.stage1 = CMTBlock( 419 | dim=dims_list[1], num_heads=heads_list[0], relative_pos_embeeding=self.rpe, no_distance_pos_embeeding=self.pe_nd, 420 | features_size=resolution_list[1], qkv_bias=self.qkv_bias, attn_drop=self.attn_drop, ffn_ratio=self.ffn_ratio, 421 | proj_drop=self.proj_drop, sr_ratio=sr_ratio_list[0], num_layers=block_list[0], drop_path_rate=dpr[:block_list[0]]) 422 | self.stage2 = CMTBlock( 423 | dim=dims_list[2], num_heads=heads_list[1], relative_pos_embeeding=self.rpe, no_distance_pos_embeeding=self.pe_nd, 424 | features_size=resolution_list[2], qkv_bias=self.qkv_bias, attn_drop=self.attn_drop, ffn_ratio=self.ffn_ratio, 425 | proj_drop=self.proj_drop, sr_ratio=sr_ratio_list[1], num_layers=block_list[1], drop_path_rate=dpr[sum(block_list[:1]): sum(block_list[:2])]) 426 | self.stage3 = CMTBlock( 427 | dim=dims_list[3], num_heads=heads_list[2], relative_pos_embeeding=self.rpe, no_distance_pos_embeeding=self.pe_nd, 428 | features_size=resolution_list[3], qkv_bias=self.qkv_bias, attn_drop=self.attn_drop, ffn_ratio=self.ffn_ratio, 429 | proj_drop=self.proj_drop, sr_ratio=sr_ratio_list[2], num_layers=block_list[2], drop_path_rate=dpr[sum(block_list[:2]): sum(block_list[:3])]) 430 | self.stage4 = CMTBlock( 431 | dim=dims_list[4], num_heads=heads_list[3], relative_pos_embeeding=self.rpe, no_distance_pos_embeeding=self.pe_nd, 432 | features_size=resolution_list[4], qkv_bias=self.qkv_bias, attn_drop=self.attn_drop, ffn_ratio=self.ffn_ratio, 433 | proj_drop=self.proj_drop, sr_ratio=sr_ratio_list[3], num_layers=block_list[3], drop_path_rate=dpr[sum(block_list[:3]): sum(block_list[:4])]) 434 | 435 | self.gap = nn.AdaptiveAvgPool2d(output_size=(1, 1)) 436 | self.fc = nn.Linear(dims_list[4], 1280) 437 | self.classifier = nn.Linear(1280, num_classes) 438 | 439 | self.dropout = nn.Dropout(p=0.1) 440 | 441 | self.apply(self._init_weights) 442 | 443 | @torch.jit.ignore 444 | def no_weight_decay(self): 445 | return {'absolute_pos_embed'} 446 | 447 | def forward_features(self, x): 448 | x = self.stem(x) 449 | x = self.pool1(x) 450 | 451 | if self.ape: 452 | B, C, H, W = x.shape 453 | x = rearrange(x, ' b c h w -> b (h w) c ') 454 | x = x + self.absolute_pos_embed 455 | x = rearrange(x, ' b (h w) c -> b c h w ', h=H) 456 | 457 | x = self.stage1(x) 458 | x = self.pool2(x) 459 | x = self.stage2(x) 460 | x = self.pool3(x) 461 | x = self.stage3(x) 462 | x = self.pool4(x) 463 | x = self.stage4(x) 464 | return x 465 | 466 | def forward(self, x): 467 | x = self.forward_features(x) 468 | x = self.gap(x) 469 | B, C, H, W = x.shape 470 | x = x.view(B, -1) 471 | x = self.fc(x) 472 | x = self.dropout(x) 473 | out = self.classifier(x) 474 | return out 475 | 476 | def _init_weights(self, m): 477 | if isinstance(m, nn.Linear): 478 | trunc_normal_(m.weight, std=.02) 479 | if isinstance(m, nn.Linear) and m.bias is not None: 480 | nn.init.constant_(m.bias, 0) 481 | elif isinstance(m, nn.Conv2d): 482 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 483 | elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)): 484 | nn.init.constant_(m.weight, 1) 485 | nn.init.constant_(m.bias, 0) 486 | 487 | 488 | # cmt tiny with resolution 160x160 489 | def CmtTi(input_resolution=(160, 160), 490 | ape = False, 491 | rpe = True, 492 | pe_nd = False, 493 | ffn_ratio = 3.6, 494 | qkv_bias = True, 495 | proj_drop = 0.1, 496 | attn_drop = 0.1, 497 | num_classes = 1000, 498 | drop_path_rate = 0.1): 499 | model = ConvolutionMeetVisionTransformers( 500 | input_resolution=input_resolution, 501 | ape=ape, 502 | rpe=rpe, 503 | pe_nd=pe_nd, 504 | qkv_bias=qkv_bias, 505 | proj_drop=proj_drop, 506 | attn_drop=attn_drop, 507 | ffn_ratio=ffn_ratio, 508 | input_channels=3, 509 | dims_list=[16, 46, 92, 184, 368], 510 | heads_list=[1,2,4,8], 511 | block_list=[2, 2, 10, 2], 512 | sr_ratio_list=[8, 4, 2, 1], 513 | num_classes=num_classes, 514 | drop_path_rate = drop_path_rate 515 | ) 516 | return model 517 | 518 | # cmt tiny with resolution 192x192 519 | def CmtXS(input_resolution=(192, 192), 520 | ape = False, 521 | rpe = True, 522 | pe_nd = False, 523 | ffn_ratio = 3.8, 524 | qkv_bias = True, 525 | proj_drop = 0.1, 526 | attn_drop = 0.1, 527 | num_classes = 1000, 528 | drop_path_rate = 0.1): 529 | model = ConvolutionMeetVisionTransformers( 530 | input_resolution=input_resolution, 531 | ape=ape, 532 | rpe=rpe, 533 | pe_nd=pe_nd, 534 | qkv_bias=qkv_bias, 535 | proj_drop=proj_drop, 536 | attn_drop=attn_drop, 537 | ffn_ratio=ffn_ratio, 538 | input_channels=3, 539 | dims_list=[16, 52, 104, 208, 416], 540 | heads_list=[1,2,4,8], 541 | block_list=[3, 3, 12, 3], 542 | sr_ratio_list=[8, 4, 2, 1], 543 | num_classes=num_classes, 544 | drop_path_rate = drop_path_rate 545 | ) 546 | return model 547 | 548 | # cmt small with resolution 224x224 549 | def CmtS(input_resolution=(224, 224), 550 | ape = False, 551 | rpe = True, 552 | pe_nd = False, 553 | ffn_ratio = 4.0, 554 | qkv_bias = True, 555 | proj_drop = 0.1, 556 | attn_drop = 0.1, 557 | num_classes = 1000, 558 | drop_path_rate = 0.1): 559 | model = ConvolutionMeetVisionTransformers( 560 | input_resolution=input_resolution, 561 | ape=ape, 562 | rpe=rpe, 563 | pe_nd=pe_nd, 564 | qkv_bias=qkv_bias, 565 | proj_drop=proj_drop, 566 | attn_drop=attn_drop, 567 | ffn_ratio=ffn_ratio, 568 | input_channels=3, 569 | dims_list=[32, 64, 128, 256, 512], 570 | heads_list=[1, 2, 4, 8], 571 | block_list=[3, 3, 16, 3], 572 | sr_ratio_list=[8, 4, 2, 1], 573 | num_classes=num_classes, 574 | drop_path_rate = drop_path_rate 575 | ) 576 | return model 577 | 578 | 579 | # cmt big with resolution 256x256 580 | def CmtB(input_resolution=(256, 256), 581 | ape = False, 582 | rpe = True, 583 | pe_nd = False, 584 | ffn_ratio = 4.0, 585 | qkv_bias = True, 586 | proj_drop = 0.1, 587 | attn_drop = 0.1, 588 | num_classes = 1000, 589 | drop_path_rate = 0.1): 590 | model = ConvolutionMeetVisionTransformers( 591 | input_resolution=input_resolution, 592 | ape=ape, 593 | rpe=rpe, 594 | pe_nd=pe_nd, 595 | qkv_bias=qkv_bias, 596 | proj_drop=proj_drop, 597 | attn_drop=attn_drop, 598 | ffn_ratio=ffn_ratio, 599 | input_channels=3, 600 | dims_list=[38, 76, 152, 304, 608], 601 | heads_list=[1, 2, 4, 8], 602 | block_list=[4, 4, 20, 4], 603 | sr_ratio_list=[8, 4, 2, 1], 604 | num_classes=num_classes, 605 | drop_path_rate = drop_path_rate 606 | ) 607 | return model 608 | 609 | 610 | if __name__ == "__main__": 611 | x = torch.randn(1, 3, 160, 160) 612 | model = CmtTi() 613 | 614 | print(model) 615 | out = model(x) 616 | print(out.shape) 617 | -------------------------------------------------------------------------------- /model/Transformers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/CMT-pytorch/1ce3c8c9732b09634fffe973b43193ada09c0844/model/Transformers/__init__.py -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding:utf-8 -*- 3 | @author : jiangmingchao@joyy.sg 4 | @datetime: 2021-06-30 5 | @describe: inference code 6 | """ 7 | from random import choice 8 | import warnings 9 | warnings.filterwarnings('ignore') 10 | 11 | import os 12 | import time 13 | import json 14 | import argparse 15 | import numpy as np 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.distributed as dist 21 | import torch.multiprocessing as mp 22 | 23 | from torch.utils.data.dataloader import DataLoader 24 | from torch.nn.parallel import DistributedDataParallel as DDP 25 | from torch.utils.data.distributed import DistributedSampler 26 | 27 | 28 | from model.Transformers.CMT.cmt import CmtTi, CmtXS, CmtS, CmtB 29 | from data.ImagenetDataset import ImageDatasetTest 30 | from utils.precise_bn import * 31 | 32 | from torchsummaryX import summary 33 | # from thop import profile 34 | 35 | parser = argparse.ArgumentParser() 36 | # ----- data ------ 37 | parser.add_argument('--test_file', type=str, default="") 38 | parser.add_argument('--num-classes', type=int, default=1000) 39 | 40 | # ----- model ----- 41 | parser.add_argument('--checkpoints-path', default='', type=str) 42 | parser.add_argument('--input_size', default=256, type=int) 43 | parser.add_argument('--crop_size', default=224, type=int) 44 | parser.add_argument('--batch-size', default=128, type=int) 45 | parser.add_argument('--num-workers', default=32, type=int) 46 | parser.add_argument('--save_folder', default="", type=str) 47 | parser.add_argument('--FP16', default=0, type=int) 48 | parser.add_argument('--get_features', default=0, type=int) 49 | parser.add_argument('--model_name', default="R50", type=str) 50 | 51 | # ----- vit ------ 52 | parser.add_argument('--patch_size', default=32, type=int) 53 | parser.add_argument('--dim', default=512, type=int, 54 | help="token embeeding dims") 55 | parser.add_argument('--depth', default=12, type=int, 56 | help="transformers encoder layer numbers") 57 | parser.add_argument('--heads', default=8, type=int, 58 | help="Mutil self attention heads numbers") 59 | parser.add_argument('--dim_head', default=64, type=int, 60 | help="embeeding dims") 61 | parser.add_argument('--mlp_dim', default=2048, type=int, 62 | help="fead forward network fc dimension, simple x4 for the head dims") 63 | parser.add_argument('--dropout', default=0.1, type=float, 64 | help="used for attention and mlp dropout") 65 | parser.add_argument('--emb_dropout', default=0.1, type=float, 66 | help="embeeding dropout used for token embeeding!!!") 67 | 68 | # ------cmt ------- 69 | parser.add_argument('--ape', default=1, type=int) 70 | parser.add_argument('--rpe', default=1, type=int) 71 | parser.add_argument('--pe_nd', default=1, type=int) 72 | parser.add_argument('--qkv_bias', default=1, type=int) 73 | 74 | 75 | # ------swin-transformers------ 76 | parser.add_argument('--cfg', default="/data/jiangmingchao/data/code/ImageClassification/model/Transformers/swin_transformers/configs/swin_large_patch4_window12_384.yaml", type=str) 77 | parser.add_argument('--swin', default=0, type=int) 78 | 79 | # ------ddp ------- 80 | parser.add_argument('--ngpu', type=int, default=1) 81 | parser.add_argument('--world-size', type=int, default=-1, 82 | help="number of nodes for distributed training") 83 | parser.add_argument('--rank', default=-1, type=int, 84 | help='node rank for distributed training') 85 | parser.add_argument('--gpu', default=None, type=int, 86 | help='GPU id to use.') 87 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:23456', type=str, 88 | help='url used to set up distributed training') 89 | parser.add_argument('--dist-backend', default='nccl', type=str, 90 | help='distributed backend') 91 | parser.add_argument('--multiprocessing-distributed', default=1, type=int, 92 | help='Use multi-processing distributed training to launch ' 93 | 'N processes per node, which has N GPUs. This is the ' 94 | 'fastest way to use PyTorch for either single node or ' 95 | 'multi node data parallel training') 96 | parser.add_argument('--local_rank', default=1) 97 | 98 | args = parser.parse_args() 99 | 100 | def main_worker(gpu, ngpus_per_node, args): 101 | args.gpu = gpu 102 | 103 | if args.dist_url == "env://" and args.rank == -1: 104 | args.rank = int(os.environ["RANK"]) 105 | 106 | if args.multiprocessing_distributed: 107 | args.rank = args.rank * ngpus_per_node + gpu 108 | 109 | if args.gpu is not None: 110 | print("Use GPU: {} for Testing".format(args.gpu)) 111 | print('rank: {} / {}'.format(args.rank, args.world_size)) 112 | 113 | if args.distributed: 114 | dist.init_process_group( 115 | backend=args.dist_backend, 116 | init_method=args.dist_url, 117 | world_size=args.world_size, 118 | rank=args.rank) 119 | torch.cuda.set_device(args.gpu) 120 | 121 | if args.rank == 0: 122 | if not os.path.isfile(args.checkpoints_path): 123 | os.makedirs(args.checkpoints_path) 124 | 125 | 126 | if args.model_name.lower() == "cmtti": 127 | model = CmtTi( 128 | num_classes=args.num_classes, 129 | ape=True if args.ape else False, 130 | rpe=True if args.rpe else False, 131 | pe_nd=True if args.pe_nd else False, 132 | qkv_bias=True if args.qkv_bias else False, 133 | input_resolution=(args.crop_size, args.crop_size) 134 | ) 135 | mode = "cnn" 136 | elif args.model_name.lower() == "cmtxs": 137 | model = CmtXS( 138 | num_classes=args.num_classes, 139 | ape=True if args.ape else False, 140 | rpe=True if args.rpe else False, 141 | pe_nd=True if args.pe_nd else False, 142 | qkv_bias=True if args.qkv_bias else False, 143 | input_resolution=(args.crop_size, args.crop_size) 144 | ) 145 | mode = "cnn" 146 | elif args.model_name.lower() == "cmts": 147 | model = CmtS( 148 | num_classes=args.num_classes, 149 | ape=True if args.ape else False, 150 | rpe=True if args.rpe else False, 151 | pe_nd=True if args.pe_nd else False, 152 | qkv_bias=True if args.qkv_bias else False, 153 | input_resolution=(args.crop_size, args.crop_size) 154 | ) 155 | mode = "cnn" 156 | elif args.model_name.lower() == "cmtb": 157 | model = CmtB( 158 | num_classes=args.num_classes, 159 | ape=True if args.ape else False, 160 | rpe=True if args.rpe else False, 161 | pe_nd=True if args.pe_nd else False, 162 | qkv_bias=True if args.qkv_bias else False, 163 | input_resolution=(args.crop_size, args.crop_size) 164 | ) 165 | mode = "cnn" 166 | else: 167 | raise NotImplementedError(f"{args.model_name} have not been use!!") 168 | 169 | # # load the model checkpoints 170 | state_dict = torch.load(args.checkpoints_path, map_location="cpu")['state_dict'] 171 | model.load_state_dict(state_dict) 172 | 173 | if args.rank == 0: 174 | print(model) 175 | 176 | # profile(model, inputs=(torch.randn(1, 3, 224, 224), )) 177 | summary(model, torch.randn(1, 3, 160, 160)) 178 | 179 | if args.FP16: 180 | model = model.half() 181 | for bn in get_bn_modules(model): 182 | bn.float() 183 | 184 | if torch.cuda.is_available(): 185 | model.cuda(args.gpu) 186 | 187 | if args.distributed: 188 | model = DDP(model, device_ids=[args.gpu], find_unused_parameters=True) 189 | 190 | dataset = ImageDatasetTest( 191 | image_file = args.test_file, 192 | train_phase= False, 193 | input_size = args.input_size, 194 | crop_size = args.crop_size, 195 | shuffle = False, 196 | mode = mode 197 | ) 198 | 199 | if args.rank == 0: 200 | print("Validation dataset length: ", len(dataset)) 201 | 202 | if args.distributed: 203 | sampler = DistributedSampler(dataset) 204 | else: 205 | sampler = None 206 | 207 | criterion = nn.CrossEntropyLoss() 208 | length = len(dataset) 209 | 210 | dataloader = DataLoader( 211 | dataset = dataset, 212 | batch_size = args.batch_size, 213 | shuffle = False, 214 | num_workers= args.num_workers, 215 | sampler = sampler, 216 | drop_last = False 217 | ) 218 | validation(args, dataloader, model, criterion, length) 219 | 220 | 221 | def validation(args, dataloader, model, criterion, length): 222 | model.eval() 223 | device = model.device 224 | total_batch = int(length / (args.batch_size*8)) 225 | 226 | if not os.path.exists(args.save_folder): 227 | os.makedirs(args.save_folder) 228 | 229 | file = open(os.path.join(args.save_folder +'r50_features_'+ str(args.rank) + '.log') , "w") 230 | for batch_idx, data in enumerate(dataloader): 231 | batch_data, batch_label, batch_path = data[0], data[1], data[2] 232 | 233 | batch_data = batch_data.to(device) 234 | batch_label = batch_label.to(device) 235 | 236 | with torch.no_grad(): 237 | start_time = time.time() 238 | 239 | if args.FP16: 240 | batch_data = batch_data.half() 241 | 242 | batch_output = model(batch_data) 243 | batch_time = time.time() - start_time 244 | 245 | batch_losses = criterion(batch_output, batch_label) 246 | batch_logits = batch_output.cpu().numpy() 247 | 248 | for i in range(batch_logits.shape[0]): 249 | image_path = batch_path[i] 250 | output = batch_logits[i].tolist() 251 | gt = batch_label[i].data.item() 252 | result = { 253 | "path" : image_path, 254 | "pred_logits" : output, 255 | "real_label" : gt 256 | } 257 | file.write(json.dumps(result, ensure_ascii=False) + '\n') 258 | 259 | 260 | if args.rank == 0: 261 | print(f"Validation Iter: [{batch_idx+1}/{total_batch}] losses: {batch_losses} , batchtime: {batch_time}") 262 | 263 | file.close() 264 | 265 | 266 | if __name__ == '__main__': 267 | args = parser.parse_args() 268 | 269 | if args.dist_url == "env://" and args.world_size == -1: 270 | args.world_size = int(os.environ["WORLD_SIZE"]) 271 | 272 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 273 | ngpus_per_node = torch.cuda.device_count() 274 | 275 | if args.multiprocessing_distributed: 276 | args.world_size = ngpus_per_node * args.world_size 277 | print("ngpus_per_node", ngpus_per_node) 278 | mp.spawn(main_worker, nprocs=ngpus_per_node, 279 | args=(ngpus_per_node, args)) 280 | else: 281 | # Simply call main_worker function 282 | print("ngpus_per_node", ngpus_per_node) 283 | main_worker(args.gpu, ngpus_per_node, args) -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd /data/jiangmingchao/data/code/ImageClassification; 3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore /data/jiangmingchao/data/code/ImageClassification/test.py \ 4 | --dist-url 'tcp://127.0.0.1:9966' \ 5 | --dist-backend 'nccl' \ 6 | --multiprocessing-distributed=1 \ 7 | --world-size=1 \ 8 | --rank=0 \ 9 | --test_file /data/jiangmingchao/data/dataset/imagenet/val_oss_imagenet_128w.txt \ 10 | --batch-size 128 \ 11 | --num-workers 48 \ 12 | --num-classes 1000 \ 13 | --input_size 184 \ 14 | --crop_size 160 \ 15 | --ape 0 \ 16 | --rpe 1 \ 17 | --pe_nd 0 \ 18 | --qkv_bias 1 \ 19 | --swin 0 \ 20 | --model_name cmtti \ 21 | --depth 12 \ 22 | --patch_size 32 \ 23 | --heads 12 \ 24 | --dim_head 64 \ 25 | --dim 768 \ 26 | --mlp_dim 3072 \ 27 | --dropout 0.1 \ 28 | --emb_dropout 0.1 \ 29 | --checkpoints-path /data/jiangmingchao/data/AICutDataset/transformers/CMT/cmt_tiny_160x160_300epoch_mixup_cutmix_adamw_all_wd_0.1_6e-3_dp/checkpoints/r50_accuracy_0.76806640625.pth \ 30 | --save_folder /data/jiangmingchao/data/AICutDataset/imagenet/r50_acc_result/ 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding:utf-8 -*- 3 | @author : jiangmingchao@joyy.sg 4 | @datetime: 2021-0628 5 | @describe: Training loop 6 | """ 7 | import torch.nn as nn 8 | import torch 9 | import numpy as np 10 | import random 11 | import math 12 | import time 13 | import os 14 | import torch.distributed as dist 15 | 16 | 17 | from model.Transformers.CMT.cmt import CmtTi, CmtXS, CmtS, CmtB 18 | 19 | from utils.augments import * 20 | from utils.precise_bn import * 21 | from datetime import datetime 22 | from torch.utils.tensorboard import SummaryWriter 23 | from utils.optimizer_step import Optimizer, build_optimizer 24 | from data.ImagenetDataset import ImageDataset 25 | from data.samplers import RASampler 26 | from torch.utils.data.distributed import DistributedSampler 27 | from torch.nn.parallel import DistributedDataParallel as DataParallel 28 | from torch.utils.data import DataLoader 29 | from torch.cuda.amp import autocast as autocast 30 | 31 | from timm.data import Mixup 32 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 33 | 34 | import torch.multiprocessing as mp 35 | import torch.distributed as dist 36 | import torch.nn.functional as F 37 | import argparse 38 | import warnings 39 | 40 | from apex.amp import scaler 41 | warnings.filterwarnings('ignore') 42 | 43 | # apex 44 | try: 45 | from apex import amp 46 | from apex.parallel import convert_syncbn_model 47 | from apex.parallel import DistributedDataParallel as DDP 48 | except Exception as e: 49 | print("amp have not been import !!!") 50 | 51 | # actnn 52 | try: 53 | import actnn 54 | actnn.set_optimization_level("L3") 55 | except Exception as e: 56 | print("actnn have no import !!!") 57 | 58 | 59 | parser = argparse.ArgumentParser() 60 | # ------ddp 61 | parser.add_argument('--ngpu', type=int, default=1) 62 | parser.add_argument('--rank', default=-1, type=int, 63 | help='node rank for distributed training') 64 | parser.add_argument('--dist-backend', default='nccl', 65 | type=str, help='distributed backend') 66 | parser.add_argument('--local_rank', default=-1, type=int) 67 | parser.add_argument('--distributed', default=1, type=int, 68 | help="use distributed method to training!!") 69 | # ----- data 70 | parser.add_argument('--train_file', type=str, 71 | default="/data/jiangmingchao/data/dataset/imagenet/train_oss_imagenet_128w.txt") 72 | parser.add_argument('--val_file', type=str, 73 | default="/data/jiangmingchao/data/dataset/imagenet/val_oss_imagenet_128w.txt") 74 | parser.add_argument('--num-classes', type=int) 75 | parser.add_argument('--input_size', type=int, default=224) 76 | parser.add_argument('--crop_size', type=int, default=224) 77 | parser.add_argument('--num_classes', type=int, default=1000) 78 | 79 | # ----- checkpoints log dir 80 | parser.add_argument('--checkpoints-path', default='checkpoints', type=str) 81 | parser.add_argument('--log-dir', default='logs', type=str) 82 | 83 | # ---- model 84 | parser.add_argument('--model_name', default="R50", type=str) 85 | parser.add_argument('--qkv_bias', default=0, type=int, 86 | help="qkv embedding bias") 87 | parser.add_argument('--ape', default=0, type=int, 88 | help="absoluate position embeeding") 89 | parser.add_argument('--rpe', default=1, type=int, 90 | help="relative position embeeding") 91 | parser.add_argument('--pe_nd', default=1, type=int, 92 | help="no distance relative position embeeding") 93 | 94 | # ----transformers 95 | parser.add_argument('--patch_size', default=32, type=int) 96 | parser.add_argument('--dim', default=512, type=int, 97 | help="token embeeding dims") 98 | parser.add_argument('--depth', default=12, type=int, 99 | help="transformers encoder layer numbers") 100 | parser.add_argument('--heads', default=8, type=int, 101 | help="Mutil self attention heads numbers") 102 | parser.add_argument('--dim_head', default=64, type=int, 103 | help="embeeding dims") 104 | parser.add_argument('--mlp_dim', default=2048, type=int, 105 | help="fead forward network fc dimension, simple x4 for the head dims") 106 | parser.add_argument('--dropout', default=0.1, type=float, 107 | help="used for attention and mlp dropout") 108 | parser.add_argument('--emb_dropout', default=0.1, type=float, 109 | help="embeeding dropout used for token embeeding!!!") 110 | 111 | # ---- optimizer 112 | parser.add_argument('--optimizer_name', default="sgd", type=str) 113 | parser.add_argument('--tf_optimizer', default=1, type=int) 114 | parser.add_argument('--lr', default=1e-1, type=float) 115 | parser.add_argument('--weight_decay', default=1e-4, type=float) 116 | parser.add_argument('--momentum', default=0.9, type=float) 117 | parser.add_argument('--batch_size', default=64, type=int) 118 | parser.add_argument('--num_workers', default=8, type=int) 119 | parser.add_argument('--cosine', default=0, type=int) 120 | 121 | # clip grad 122 | parser.add_argument('--grad_clip', default=0, type=int) 123 | parser.add_argument('--max_grad_norm', default=5.0, type=float) 124 | 125 | # drop path rate 126 | parser.add_argument('--drop_path_rate', default=0.1, type=float) 127 | 128 | # * Mixup params 129 | parser.add_argument('--mixup', type=float, default=0.8, 130 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 131 | parser.add_argument('--cutmix', type=float, default=1.0, 132 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 133 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 134 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 135 | parser.add_argument('--mixup-prob', type=float, default=1.0, 136 | help='Probability of performing mixup or cutmix when either/both is enabled') 137 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 138 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 139 | parser.add_argument('--mixup-mode', type=str, default='batch', 140 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 141 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 142 | parser.add_argument('--repeated-aug', type=float, default=0, help="repeated aug") 143 | # ---- actnn 2-bit 144 | parser.add_argument('--actnn', default=0, type=int) 145 | 146 | # ---- train 147 | parser.add_argument('--warmup_epochs', default=5, type=int) 148 | parser.add_argument('--max_epochs', default=90, type=int) 149 | parser.add_argument('--FP16', default=0, type=int) 150 | parser.add_argument('--apex', default=0, type=int) 151 | parser.add_argument('--mode', default='O1', type=str) 152 | parser.add_argument('--amp', default=1, type=int) 153 | 154 | # random seed 155 | 156 | 157 | def setup_seed(seed=100): 158 | torch.manual_seed(seed) 159 | torch.cuda.manual_seed_all(seed) 160 | np.random.seed(seed) 161 | random.seed(seed) 162 | torch.backends.cudnn.benchmark = True 163 | torch.backends.cudnn.deterministic = False 164 | 165 | 166 | def translate_state_dict(state_dict): 167 | new_state_dict = {} 168 | for key, value in state_dict.items(): 169 | if 'module' in key: 170 | new_state_dict[key[7:]] = value 171 | else: 172 | new_state_dict[key] = value 173 | return new_state_dict 174 | 175 | 176 | def accuracy(output, target, topk=(1,)): 177 | with torch.no_grad(): 178 | maxk = max(topk) 179 | batch_size = target.size(0) 180 | _, pred = output.topk(k=maxk, dim=1, largest=True, sorted=True) 181 | pred = pred.t() 182 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 183 | res = [] 184 | crr = [] 185 | for k in topk: 186 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 187 | acc = correct_k.mul_(1/batch_size).item() 188 | res.append(acc) # unit: percentage (%) 189 | crr.append(correct_k) 190 | return res, crr 191 | 192 | 193 | class Metric_rank: 194 | def __init__(self, name): 195 | self.name = name 196 | self.sum = 0.0 197 | self.n = 0 198 | 199 | def update(self, val): 200 | self.sum += val 201 | self.n += 1 202 | 203 | @property 204 | def average(self): 205 | return self.sum / self.n 206 | 207 | data_list = [] 208 | 209 | # main func 210 | def main_worker(args): 211 | total_rank = torch.cuda.device_count() 212 | print('rank: {} / {}'.format(args.local_rank, total_rank)) 213 | dist.init_process_group(backend=args.dist_backend) 214 | torch.cuda.set_device(args.local_rank) 215 | 216 | ngpus_per_node = total_rank 217 | 218 | if args.local_rank == 0: 219 | if not os.path.exists(args.checkpoints_path): 220 | os.makedirs(args.checkpoints_path) 221 | 222 | # metric 223 | train_losses_metric = Metric_rank("train_losses") 224 | train_accuracy_metric = Metric_rank("train_accuracy") 225 | train_metric = {"losses": train_losses_metric, 226 | "accuracy": train_accuracy_metric} 227 | 228 | 229 | if args.model_name.lower() == "cmtti": 230 | model = CmtTi(num_classes=args.num_classes, 231 | input_resolution=(args.crop_size, args.crop_size), 232 | qkv_bias=True if args.qkv_bias else False, 233 | ape=True if args.ape else False, 234 | rpe=True if args.rpe else False, 235 | pe_nd=True if args.pe_nd else False, 236 | drop_path_rate = args.drop_path_rate 237 | ) 238 | elif args.model_name.lower() == "cmtxs": 239 | model = CmtXS(num_classes=args.num_classes, 240 | input_resolution=(args.crop_size, args.crop_size), 241 | qkv_bias=True if args.qkv_bias else False, 242 | ape=True if args.ape else False, 243 | rpe=True if args.rpe else False, 244 | pe_nd=True if args.pe_nd else False, 245 | drop_path_rate = args.drop_path_rate 246 | ) 247 | 248 | elif args.model_name.lower() == "cmts": 249 | model = CmtS(num_classes=args.num_classes, 250 | input_resolution=(args.crop_size, args.crop_size), 251 | qkv_bias=True if args.qkv_bias else False, 252 | ape=True if args.ape else False, 253 | rpe=True if args.rpe else False, 254 | pe_nd=True if args.pe_nd else False, 255 | drop_path_rate = args.drop_path_rate 256 | ) 257 | 258 | elif args.model_name.lower() == "cmtb": 259 | model = CmtB(num_classes=args.num_classes, 260 | input_resolution=(args.crop_size, args.crop_size), 261 | qkv_bias=True if args.qkv_bias else False, 262 | ape=True if args.ape else False, 263 | rpe=True if args.rpe else False, 264 | pe_nd=True if args.pe_nd else False, 265 | drop_path_rate = args.drop_path_rate 266 | ) 267 | 268 | else: 269 | raise NotImplementedError(f"{args.model_name} have not been use!!") 270 | 271 | if args.local_rank == 0: 272 | print(f"===============model arch ===============") 273 | print(model) 274 | 275 | # model mode 276 | model.train() 277 | 278 | if args.actnn: 279 | model = actnn.QModule(model) 280 | if args.local_rank == 0: 281 | print(model) 282 | 283 | if args.apex: 284 | model = convert_syncbn_model(model) 285 | 286 | # FP16 287 | if args.FP16: 288 | model = model.half() 289 | for bn in get_bn_modules(model): 290 | bn.float() 291 | 292 | if torch.cuda.is_available(): 293 | model.cuda(args.local_rank) 294 | 295 | # loss 296 | if args.mixup > 0.: 297 | # smoothing is handled with mixup label transform 298 | criterion = SoftTargetCrossEntropy() 299 | elif args.smoothing: 300 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 301 | else: 302 | criterion = nn.CrossEntropyLoss() 303 | 304 | # optimizer 305 | print("optimizer name: ", args.optimizer_name) 306 | if args.tf_optimizer: 307 | optimizer = build_optimizer( 308 | model, 309 | args.optimizer_name, 310 | lr=args.lr, 311 | weights_decay=args.weight_decay 312 | ) 313 | else: 314 | optimizer = Optimizer(args.optimizer_name)( 315 | param=model.parameters(), 316 | lr=args.lr, 317 | weight_decay=args.weight_decay 318 | ) 319 | # print(optimizer) 320 | 321 | if args.apex: 322 | model, optimizer = amp.initialize( 323 | model, optimizer, opt_level=args.mode) 324 | model = DDP(model, delay_allreduce=True) 325 | 326 | else: 327 | if args.distributed: 328 | model = DataParallel(model, 329 | device_ids=[args.local_rank], 330 | find_unused_parameters=True) 331 | 332 | # dataset & dataloader 333 | train_dataset = ImageDataset( 334 | image_file=args.train_file, 335 | train_phase=True, 336 | input_size=args.input_size, 337 | crop_size=args.crop_size, 338 | shuffle=True, 339 | interpolation="bilinear", 340 | auto_augment="rand", 341 | color_prob=0.4, 342 | hflip_prob=0.5 343 | ) 344 | 345 | validation_dataset = ImageDataset( 346 | image_file=args.val_file, 347 | train_phase=False, 348 | input_size=args.input_size, 349 | crop_size=args.crop_size, 350 | shuffle=False 351 | ) 352 | 353 | if args.local_rank == 0: 354 | print("Trainig dataset length: ", len(train_dataset)) 355 | print("Validation dataset length: ", len(validation_dataset)) 356 | 357 | # sampler 358 | if args.distributed: 359 | if args.repeated_aug: 360 | train_sampler = RASampler( 361 | train_dataset, 362 | num_replicas=dist.get_world_size(), 363 | rank=dist.get_rank(), 364 | shuffle=True 365 | ) 366 | print("use the repeated augment!!!") 367 | else: 368 | train_sampler = DistributedSampler(train_dataset) 369 | validation_sampler = DistributedSampler(validation_dataset) 370 | else: 371 | train_sampler = None 372 | validation_sampler = None 373 | 374 | # logs 375 | log_writer = SummaryWriter(args.log_dir) 376 | 377 | # dataloader 378 | train_loader = DataLoader( 379 | dataset=train_dataset, 380 | batch_size=args.batch_size, 381 | shuffle=(train_sampler is None), 382 | num_workers=args.num_workers, 383 | pin_memory=True, 384 | sampler=train_sampler, 385 | drop_last=True 386 | ) 387 | 388 | validation_loader = DataLoader( 389 | dataset=validation_dataset, 390 | batch_size=args.batch_size, 391 | shuffle=(validation_sampler is None), 392 | num_workers=args.num_workers, 393 | pin_memory=True, 394 | sampler=validation_sampler, 395 | drop_last=True 396 | ) 397 | 398 | # mixup & cutmix 399 | mixup_fn = None 400 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 401 | if mixup_active: 402 | mixup_fn = Mixup( 403 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 404 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 405 | label_smoothing=args.smoothing, num_classes=args.num_classes) 406 | print("use the mixup function ") 407 | 408 | 409 | start_epoch = 1 410 | batch_iter = 0 411 | train_batch = math.ceil(len(train_dataset) / 412 | (args.batch_size * ngpus_per_node)) 413 | total_batch = train_batch * args.max_epochs 414 | no_warmup_total_batch = int( 415 | args.max_epochs - args.warmup_epochs) * train_batch 416 | 417 | if args.amp: 418 | scaler = torch.cuda.amp.GradScaler() 419 | else: 420 | scaler = None 421 | 422 | best_loss, best_acc = np.inf, 0.0 423 | # training loop 424 | for epoch in range(start_epoch, args.max_epochs + 1): 425 | if args.distributed: 426 | train_sampler.set_epoch(epoch) 427 | # train for epoch 428 | batch_iter, scaler = train(args, scaler, train_loader, mixup_fn, model, criterion, optimizer, 429 | epoch, batch_iter, total_batch, train_batch, log_writer, train_metric) 430 | 431 | # calculate the validation with the batch iter 432 | if epoch % 2 == 0: 433 | val_loss, val_acc = val( 434 | args, validation_loader, model, criterion, epoch, log_writer) 435 | # recored & write 436 | if args.local_rank == 0: 437 | best_loss = val_loss 438 | state_dict = translate_state_dict(model.state_dict()) 439 | state_dict = { 440 | 'epoch': epoch, 441 | 'state_dict': state_dict, 442 | 'optimizer': optimizer.state_dict(), 443 | } 444 | torch.save( 445 | state_dict, 446 | args.checkpoints_path + '/' 'r50' + 447 | f'_losses_{best_loss}' + '.pth' 448 | ) 449 | 450 | best_acc = val_acc 451 | state_dict = translate_state_dict(model.state_dict()) 452 | state_dict = { 453 | 'epoch': epoch, 454 | 'state_dict': state_dict, 455 | 'optimizer': optimizer.state_dict(), 456 | } 457 | torch.save(state_dict, 458 | args.checkpoints_path + '/' + 'r50' + f'_accuracy_{best_acc}' + '.pth') 459 | # model mode 460 | model.train() 461 | 462 | 463 | # train function 464 | def train(args, 465 | scaler, 466 | train_loader, 467 | mixup_fn, 468 | model, 469 | criterion, 470 | optimizer, 471 | epoch, 472 | batch_iter, 473 | total_batch, 474 | train_batch, 475 | log_writer, 476 | train_metric, 477 | ): 478 | """Traing with the batch iter for get the metric 479 | """ 480 | model.train() 481 | # device = model.device 482 | loader_length = len(train_loader) 483 | 484 | for batch_idx, data in enumerate(train_loader): 485 | batch_start = time.time() 486 | if args.cosine: 487 | # cosine learning rate 488 | lr = cosine_learning_rate( 489 | args, epoch, batch_iter, optimizer, train_batch 490 | ) 491 | else: 492 | # step learning rate 493 | lr = step_learning_rate( 494 | args, epoch, batch_iter, optimizer, train_batch 495 | ) 496 | 497 | # forward 498 | batch_data, batch_label, data_path = data[0], data[1], data[2] 499 | 500 | if args.FP16: 501 | batch_data = batch_data.half() 502 | 503 | batch_data = batch_data.cuda() 504 | batch_label = batch_label.cuda() 505 | 506 | # print(batch_data.shape) 507 | # if torch.isnan(batch_data).float().sum() >= 1: 508 | # print(batch_data) 509 | # with open(f"/data/jiangmingchao/data/AICutDataset/transformers/CMT/data/nan_{args.local_rank}.txt", "w") as file: 510 | # for i in range(len(data_path)): 511 | # file.write(data_path[i] + '\t' + str(batch_label[i]) + '\n') 512 | # print("There are some error on the batch_data & nan!!!!") 513 | # break 514 | 515 | # if args.local_rank == 0: 516 | # print(batch_iter) 517 | 518 | # mixup or cutmix 519 | if mixup_fn is not None: 520 | batch_data, batch_label = mixup_fn(batch_data, batch_label) 521 | 522 | if args.amp: 523 | with autocast(): 524 | batch_output = model(batch_data) 525 | losses = criterion(batch_output, batch_label) 526 | 527 | else: 528 | batch_output = model(batch_data) 529 | losses = criterion(batch_output, batch_label) 530 | 531 | # translate the miuxp one hot to float 532 | if mixup_fn is not None: 533 | batch_label = batch_label.argmax(dim=1) 534 | 535 | optimizer.zero_grad() 536 | 537 | if args.apex: 538 | with amp.scale_loss(losses, optimizer) as scaled_loss: 539 | scaled_loss.backward() 540 | optimizer.step() 541 | 542 | elif args.amp: 543 | scaler.scale(losses).backward() 544 | if args.grad_clip: 545 | scaler.unscale_(optimizer) 546 | torch.nn.utils.clip_grad_norm_( 547 | model.parameters(), args.max_grad_norm, norm_type=2.0) 548 | scaler.step(optimizer) 549 | scaler.update() 550 | 551 | else: 552 | losses.backward() 553 | if args.grad_clip: 554 | torch.nn.utils.clip_grad_norm_( 555 | model.parameters(), args.max_grad_norm, norm_type=2.0) 556 | optimizer.step() 557 | 558 | # calculate the accuracy 559 | batch_acc, _ = accuracy(batch_output, batch_label) 560 | 561 | # record the average momentum result 562 | train_metric["losses"].update(losses.data.item()) 563 | train_metric["accuracy"].update(batch_acc[0]) 564 | 565 | batch_time = time.time() - batch_start 566 | 567 | batch_iter += 1 568 | 569 | if args.local_rank == 0: 570 | print("[Training] Time: {} Epoch: [{}/{}] batch_idx: [{}/{}] batch_iter: [{}/{}] batch_losses: {:.4f} batch_accuracy: {:.4f} LearningRate: {:.6f} BatchTime: {:.4f}".format( 571 | datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 572 | epoch, 573 | args.max_epochs, 574 | batch_idx, 575 | train_batch, 576 | batch_iter, 577 | total_batch, 578 | losses.data.item(), 579 | batch_acc[0], 580 | lr, 581 | batch_time 582 | )) 583 | 584 | if args.local_rank == 0: 585 | # batch record 586 | record_log(log_writer, losses, 587 | batch_acc[0], lr, batch_iter, batch_time) 588 | 589 | if args.local_rank == 0: 590 | # epoch record 591 | record_scalars(log_writer, train_metric["losses"].average, 592 | train_metric["accuracy"].average, epoch, flag="train") 593 | 594 | return batch_iter, scaler 595 | 596 | 597 | def val( 598 | args, 599 | val_loader, 600 | model, 601 | criterion, 602 | epoch, 603 | log_writer, 604 | ): 605 | """Validation and get the metric 606 | """ 607 | model.eval() 608 | # device = model.device 609 | criterion = nn.CrossEntropyLoss() 610 | epoch_losses, epoch_accuracy = 0.0, 0.0 611 | 612 | batch_acc_list = [] 613 | batch_loss_list = [] 614 | 615 | with torch.no_grad(): 616 | for batch_idx, data in enumerate(val_loader): 617 | batch_data, batch_label, _ = data[0], data[1], data[2] 618 | 619 | if args.FP16: 620 | batch_data = batch_data.half() 621 | 622 | batch_data = batch_data.cuda() 623 | batch_label = batch_label.cuda() 624 | 625 | if args.amp: 626 | with autocast(): 627 | batch_output = model(batch_data) 628 | batch_losses = criterion(batch_output, batch_label) 629 | else: 630 | batch_output = model(batch_data) 631 | batch_losses = criterion(batch_output, batch_label) 632 | 633 | batch_accuracy, _ = accuracy(batch_output, batch_label) 634 | 635 | batch_acc_list.append(batch_accuracy[0]) 636 | batch_loss_list.append(batch_losses.data.item()) 637 | 638 | epoch_acc = np.mean(batch_acc_list) 639 | epoch_loss = np.mean(batch_loss_list) 640 | 641 | # all reduce the correct number 642 | # dist.all_reduce(epoch_accuracy, op=dist.ReduceOp.SUM) 643 | 644 | if args.local_rank == 0: 645 | print( 646 | f"Validation Epoch: [{epoch}/{args.max_epochs}] Epoch_mean_losses: {epoch_loss} Epoch_mean_accuracy: {epoch_acc}") 647 | 648 | record_scalars(log_writer, epoch_loss, epoch_acc, epoch, flag="val") 649 | 650 | return epoch_loss, epoch_acc 651 | 652 | 653 | def record_scalars(log_writer, mean_loss, mean_acc, epoch, flag="train"): 654 | log_writer.add_scalar(f"{flag}/epoch_average_loss", mean_loss, epoch) 655 | log_writer.add_scalar(f"{flag}/epoch_average_acc", mean_acc, epoch) 656 | 657 | 658 | # batch scalar record 659 | def record_log(log_writer, losses, acc, lr, batch_iter, batch_time, flag="Train"): 660 | log_writer.add_scalar(f"{flag}/batch_loss", losses.data.item(), batch_iter) 661 | log_writer.add_scalar(f"{flag}/batch_acc", acc, batch_iter) 662 | log_writer.add_scalar(f"{flag}/learning_rate", lr, batch_iter) 663 | log_writer.add_scalar(f"{flag}/batch_time", batch_time, batch_iter) 664 | 665 | 666 | def step_learning_rate(args, epoch, batch_iter, optimizer, train_batch): 667 | """Sets the learning rate 668 | # Adapted from PyTorch Imagenet example: 669 | # https://github.com/pytorch/examples/blob/master/imagenet/main.py 670 | """ 671 | total_epochs = args.max_epochs 672 | warm_epochs = args.warmup_epochs 673 | if epoch <= warm_epochs: 674 | lr_adj = (batch_iter + 1) / (warm_epochs * train_batch) 675 | elif epoch < int(0.3 * total_epochs): 676 | lr_adj = 1. 677 | elif epoch < int(0.6 * total_epochs): 678 | lr_adj = 1e-1 679 | elif epoch < int(0.8 * total_epochs): 680 | lr_adj = 1e-2 681 | else: 682 | lr_adj = 1e-3 683 | 684 | for param_group in optimizer.param_groups: 685 | param_group['lr'] = args.lr * lr_adj 686 | return args.lr * lr_adj 687 | 688 | 689 | def cosine_learning_rate(args, epoch, batch_iter, optimizer, train_batch): 690 | """Cosine Learning rate 691 | """ 692 | total_epochs = args.max_epochs 693 | warm_epochs = args.warmup_epochs 694 | if epoch <= warm_epochs: 695 | lr_adj = (batch_iter + 1) / (warm_epochs * train_batch) + 1e-6 696 | else: 697 | lr_adj = 1/2 * (1 + math.cos(batch_iter * math.pi / 698 | ((total_epochs - warm_epochs) * train_batch))) 699 | 700 | for param_group in optimizer.param_groups: 701 | param_group['lr'] = args.lr * lr_adj 702 | return args.lr * lr_adj 703 | 704 | 705 | if __name__ == "__main__": 706 | args = parser.parse_args() 707 | setup_seed() 708 | 709 | main_worker(args) 710 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | OMP_NUM_THREADS=1 3 | MKL_NUM_THREADS=1 4 | export OMP_NUM_THREADS 5 | export MKL_NUM_THREADS 6 | cd CMT-PYTORCH; 7 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore -m torch.distributed.launch --nproc_per_node 8 train_lanuch.py \ 8 | --batch_size 512 \ 9 | --num_workers 48 \ 10 | --lr 6e-3 \ 11 | --optimizer_name "adamw" \ 12 | --tf_optimizer 1 \ 13 | --cosine 1 \ 14 | --model_name cmtti \ 15 | --max_epochs 300 \ 16 | --warmup_epochs 5 \ 17 | --num-classes 1000 \ 18 | --input_size 184 \ 19 | --crop_size 160 \ 20 | --weight_decay 1e-1 \ 21 | --grad_clip 0 \ 22 | --repeated-aug 0 \ 23 | --max_grad_norm 5.0 \ 24 | --drop_path_rate 0.1 \ 25 | --FP16 0 \ 26 | --qkv_bias 1 \ 27 | --ape 0 \ 28 | --rpe 1 \ 29 | --pe_nd 0 \ 30 | --mode O2 \ 31 | --amp 1 \ 32 | --apex 0 \ 33 | --train_file /data/jiangmingchao/data/dataset/imagenet/train_oss_imagenet_128w.txt \ 34 | --val_file /data/jiangmingchao/data/dataset/imagenet/val_oss_imagenet_128w.txt \ 35 | --log-dir /data/jiangmingchao/data/AICutDataset/transformers/CMT/cmt_tiny_160x160_300epoch_mixup_cutmix_adamw_all/log_dir \ 36 | --checkpoints-path /data/jiangmingchao/data/AICutDataset/transformers/CMT/cmt_tiny_160x160_300epoch_mixup_cutmix_adamw_all/checkpoints 37 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FlyEgle/CMT-pytorch/1ce3c8c9732b09634fffe973b43193ada09c0844/utils/__init__.py -------------------------------------------------------------------------------- /utils/augments.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding:utf-8 -*- 3 | cutmix original code: https://github.com/clovaai/CutMix-PyTorch 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import random 10 | import numpy as np 11 | 12 | 13 | def rand_bbox(size, lam): 14 | W = size[2] 15 | H = size[3] 16 | cut_rat = np.sqrt(1. - lam) 17 | cut_w = np.int(W * cut_rat) 18 | cut_h = np.int(H * cut_rat) 19 | 20 | # uniform 21 | cx = np.random.randint(W) 22 | cy = np.random.randint(H) 23 | 24 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 25 | bby1 = np.clip(cy - cut_h // 2, 0, H) 26 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 27 | bby2 = np.clip(cy + cut_h // 2, 0, H) 28 | 29 | return bbx1, bby1, bbx2, bby2 30 | 31 | 32 | def cutmix_data(x, y, alpha=1., use_cuda=True): 33 | if alpha > 0.: 34 | lam = np.random.beta(alpha, alpha) 35 | else: 36 | lam = 1. 37 | 38 | batch_size = x.size()[0] 39 | if use_cuda: 40 | index = torch.randperm(batch_size).cuda() 41 | else: 42 | index = torch.randperm(batch_size) 43 | 44 | size = x.size() 45 | bbx1, bby1, bbx2, bby2 = rand_bbox(size, lam) 46 | x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2] 47 | # adjust lambda to exactly match pixel ratio 48 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 49 | y_a, y_b = y, y[index] 50 | return x, y_a, y_b, lam 51 | 52 | 53 | def mixup_data(x, y, alpha=1.0, use_cuda=True): 54 | '''Returns mixed inputs, pairs of targets, and lambda''' 55 | if alpha > 0: 56 | lam = np.random.beta(alpha, alpha) 57 | else: 58 | lam = 1 59 | 60 | batch_size = x.size()[0] 61 | if use_cuda: 62 | index = torch.randperm(batch_size).cuda() 63 | else: 64 | index = torch.randperm(batch_size) 65 | 66 | mixed_x = lam * x + (1 - lam) * x[index, :] 67 | y_a, y_b = y, y[index] 68 | return mixed_x, y_a, y_b, lam 69 | 70 | 71 | class LabelSmoothingCrossEntropy(nn.Module): 72 | """ 73 | NLL loss with label smoothing. 74 | """ 75 | def __init__(self, smoothing=0.1): 76 | """ 77 | Constructor for the LabelSmoothing module. 78 | :param smoothing: label smoothing factor 79 | """ 80 | super(LabelSmoothingCrossEntropy, self).__init__() 81 | assert smoothing < 1.0 82 | self.smoothing = smoothing 83 | self.confidence = 1. - smoothing 84 | 85 | def forward(self, x, target): 86 | logprobs = F.log_softmax(x, dim=-1) 87 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 88 | nll_loss = nll_loss.squeeze(1) 89 | smooth_loss = -logprobs.mean(dim=-1) 90 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 91 | return loss.mean() 92 | -------------------------------------------------------------------------------- /utils/calculate_acc.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | """ 3 | Calculate the video accuracy 4 | """ 5 | import os 6 | import json 7 | import numpy as np 8 | from scipy.special import softmax 9 | from tqdm import tqdm 10 | import argparse 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--logits_file', type=str, default="/data/jiangmingchao/data/AICutDataset/imagenet/r50_acc_result/") 15 | 16 | def parse_file(data_file): 17 | lines = open(data_file).readlines() 18 | lines_json = [json.loads(x.strip()) for x in lines] 19 | total_length = len(lines_json) 20 | correct = 0 21 | 22 | for line in lines_json: 23 | pred = np.argmax(softmax(np.array(line["pred_logits"]))) 24 | label = line["real_label"] 25 | if pred == label: 26 | correct += 1 27 | return correct, total_length 28 | 29 | # get the max value index 30 | def argmax(data_list: list, num: int): 31 | data_dict = {x:data_list[x] for x in range(len(data_list))} 32 | sorted_data_dict = sorted(data_dict.items(), key=lambda k: (k[1], k[0]), reverse=True) 33 | argmax_data_dict = sorted_data_dict[:num] 34 | argmax_index = [x[0] for x in argmax_data_dict] 35 | return argmax_index 36 | 37 | def acc_top_n(data_file, n=5): 38 | lines = open(data_file).readlines() 39 | lines_json = [json.loads(x.strip()) for x in lines] 40 | total_length = len(lines_json) 41 | correct = 0 42 | 43 | for line in tqdm(lines_json): 44 | pred_list = softmax(np.array(line["pred_logits"])).tolist() 45 | # print(pred_list) 46 | arg_index = argmax(pred_list, n) 47 | # print(arg_index) 48 | label = line["real_label"] 49 | if label in arg_index: 50 | correct += 1 51 | return correct, total_length 52 | 53 | # logits_file = "/data/jiangmingchao/data/AICutDataset/logits2" 54 | 55 | args = parser.parse_args() 56 | 57 | logits_file = args.logits_file 58 | 59 | total_correct, total_num = 0, 0 60 | if os.path.isdir(logits_file): 61 | for file in os.listdir(logits_file): 62 | file_path = os.path.join(logits_file, file) 63 | # correct, num = parse_file(file_path) 64 | correct, num = acc_top_n(file_path, n=1) 65 | total_correct += correct 66 | total_num += num 67 | 68 | print(f"Accuracy is {total_correct/total_num}") 69 | 70 | elif os.path.isfile(logits_file): 71 | correct, num = parse_file(logits_file) 72 | total_correct += correct 73 | total_num += num 74 | print(f"Accuracy is {total_correct/total_num}") 75 | -------------------------------------------------------------------------------- /utils/optimizer_step.py: -------------------------------------------------------------------------------- 1 | from torch import optim as optim 2 | from torch.optim import SGD, Adam, AdamW 3 | 4 | 5 | class Optimizer(object): 6 | def __init__(self, name) -> None: 7 | super().__init__() 8 | self._name = name 9 | 10 | def __call__(self, param, lr, weight_decay): 11 | if self._name.lower() == "sgd": 12 | optimizer = SGD( 13 | param, 14 | lr, 15 | weight_decay=weight_decay, 16 | momentum=0.9 17 | ) 18 | elif self._name.lower() == "adam": 19 | optimizer = Adam( 20 | param, 21 | lr, 22 | weight_decay=weight_decay 23 | ) 24 | elif self._name.lower() == "adamw": 25 | optimizer = AdamW( 26 | param, 27 | lr, 28 | weight_decay=weight_decay 29 | ) 30 | else: 31 | raise NotImplementedError(f"{self._name} optimizer have not been implement!") 32 | 33 | return optimizer 34 | 35 | 36 | def build_optimizer(model, opt_name, lr, weights_decay): 37 | """Build optimizer, set weight decay of normalization to 0 by default 38 | """ 39 | skip = {} 40 | skip_keywords = {} 41 | if hasattr(model, 'no_weight_decay'): 42 | skip = model.no_weight_decay() 43 | if hasattr(model, 'no_weight_decay_keywords'): 44 | skip_keywords = model.no_weight_decay_keywords() 45 | parameters = set_weight_decay(model, skip, skip_keywords) 46 | 47 | optimizer = None 48 | if opt_name.lower() == 'sgd': 49 | optimizer = optim.SGD(parameters, 50 | momentum=0.9, 51 | nesterov=True, 52 | lr=lr, 53 | weight_decay=weights_decay) 54 | elif opt_name.lower() == 'adamw': 55 | optimizer = optim.AdamW(parameters, 56 | lr=lr, 57 | weight_decay=weights_decay) 58 | 59 | return optimizer 60 | 61 | 62 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 63 | has_decay = [] 64 | no_decay = [] 65 | 66 | for name, param in model.named_parameters(): 67 | if not param.requires_grad: 68 | continue # frozen weights 69 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 70 | check_keywords_in_name(name, skip_keywords): 71 | no_decay.append(param) 72 | # print(f"{name} has no weight decay") 73 | else: 74 | has_decay.append(param) 75 | return [{'params': has_decay}, 76 | {'params': no_decay, 'weight_decay': 0.}] 77 | 78 | 79 | def check_keywords_in_name(name, keywords=()): 80 | isin = False 81 | for keyword in keywords: 82 | if keyword in name: 83 | isin = True 84 | return isin -------------------------------------------------------------------------------- /utils/precise_bn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | ## https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/precise_bn.py 5 | 6 | import itertools 7 | import torch 8 | import torch.nn as nn 9 | import logging 10 | from typing import Iterable, Any 11 | from torch.distributed import ReduceOp, all_reduce 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | BN_MODULE_TYPES = ( 16 | torch.nn.BatchNorm1d, 17 | torch.nn.BatchNorm2d, 18 | torch.nn.BatchNorm3d, 19 | torch.nn.SyncBatchNorm, 20 | ) 21 | 22 | 23 | # pyre-fixme[56]: Decorator `torch.no_grad(...)` could not be called, because its 24 | # type `no_grad` is not callable. 25 | @torch.no_grad() 26 | def update_bn_stats( 27 | args: Any, model: nn.Module, data_loader: Iterable[Any], num_iters: int = 200 # pyre-ignore 28 | ) -> None: 29 | """ 30 | Recompute and update the batch norm stats to make them more precise. During 31 | training both BN stats and the weight are changing after every iteration, so 32 | the running average can not precisely reflect the actual stats of the 33 | current model. 34 | In this function, the BN stats are recomputed with fixed weights, to make 35 | the running average more precise. Specifically, it computes the true average 36 | of per-batch mean/variance instead of the running average. 37 | Args: 38 | model (nn.Module): the model whose bn stats will be recomputed. 39 | Note that: 40 | 1. This function will not alter the training mode of the given model. 41 | Users are responsible for setting the layers that needs 42 | precise-BN to training mode, prior to calling this function. 43 | 2. Be careful if your models contain other stateful layers in 44 | addition to BN, i.e. layers whose state can change in forward 45 | iterations. This function will alter their state. If you wish 46 | them unchanged, you need to either pass in a submodule without 47 | those layers, or backup the states. 48 | data_loader (iterator): an iterator. Produce data as inputs to the model. 49 | num_iters (int): number of iterations to compute the stats. 50 | """ 51 | bn_layers = get_bn_modules(model) 52 | 53 | if len(bn_layers) == 0: 54 | return 55 | 56 | # In order to make the running stats only reflect the current batch, the 57 | # momentum is disabled. 58 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 59 | # Setting the momentum to 1.0 to compute the stats without momentum. 60 | momentum_actual = [bn.momentum for bn in bn_layers] 61 | if args.rank == 0: 62 | a = [round(i.running_mean.cpu().numpy().max(), 4) for i in bn_layers] 63 | logger.info('bn mean max, %s', max(a)) 64 | logger.info(a) 65 | a = [round(i.running_var.cpu().numpy().max(), 4) for i in bn_layers] 66 | logger.info('bn var max, %s', max(a)) 67 | logger.info(a) 68 | for bn in bn_layers: 69 | # pyre-fixme[16]: `Module` has no attribute `momentum`. 70 | # bn.running_mean = torch.ones_like(bn.running_mean) 71 | # bn.running_var = torch.zeros_like(bn.running_var) 72 | bn.momentum = 1.0 73 | 74 | # Note that PyTorch's running_var means "running average of 75 | # bessel-corrected batch variance". (PyTorch's BN normalizes by biased 76 | # variance, but updates EMA by unbiased (bessel-corrected) variance). 77 | # So we estimate population variance by "simple average of bessel-corrected 78 | # batch variance". This is the same as in the BatchNorm paper, Sec 3.1. 79 | # This estimator converges to population variance as long as batch size 80 | # is not too small, and total #samples for PreciseBN is large enough. 81 | # Its convergence may be affected by small batch size. 82 | 83 | # Alternatively, one can estimate population variance by the sample variance 84 | # of all batches combined. However, this needs a way to know the batch size 85 | # of each batch in this function (otherwise we only have access to the 86 | # bessel-corrected batch variance given by pytorch), which is an extra 87 | # requirement. 88 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 89 | running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] 90 | 91 | ind = -1 92 | for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): 93 | with torch.no_grad(): 94 | model(inputs) 95 | 96 | for i, bn in enumerate(bn_layers): 97 | # Accumulates the bn stats. 98 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 99 | running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) 100 | if torch.sum(torch.isnan(bn.running_mean)) > 0 or torch.sum(torch.isnan(bn.running_var)) > 0: 101 | raise RuntimeError( 102 | "update_bn_stats ERROR(args.rank {}): Got NaN val".format(args.rank)) 103 | if torch.sum(torch.isinf(bn.running_mean)) > 0 or torch.sum(torch.isinf(bn.running_var)) > 0: 104 | raise RuntimeError( 105 | "update_bn_stats ERROR(args.rank {}): Got INf val".format(args.rank)) 106 | if torch.sum(~torch.isfinite(bn.running_mean)) > 0 or torch.sum(~torch.isfinite(bn.running_var)) > 0: 107 | raise RuntimeError( 108 | "update_bn_stats ERROR(args.rank {}): Got INf val".format(args.rank)) 109 | 110 | assert ind == num_iters - 1, ( 111 | "update_bn_stats is meant to run for {} iterations, " 112 | "but the dataloader stops at {} iterations.".format(num_iters, ind) 113 | ) 114 | 115 | for i, bn in enumerate(bn_layers): 116 | if args.distributed: 117 | all_reduce(running_mean[i], op=ReduceOp.SUM) 118 | all_reduce(running_var[i], op=ReduceOp.SUM) 119 | running_mean[i] = running_mean[i] / args.gpu_nums 120 | running_var[i] = running_var[i] / args.gpu_nums 121 | 122 | # Sets the precise bn stats. 123 | # pyre-fixme[16]: `Module` has no attribute `running_mean`. 124 | bn.running_mean = running_mean[i] 125 | # pyre-fixme[16]: `Module` has no attribute `running_var`. 126 | bn.running_var = running_var[i] 127 | bn.momentum = momentum_actual[i] 128 | 129 | if args.rank == 0: 130 | a = [round(i.cpu().numpy().max(), 4) for i in running_mean] 131 | logger.info('bn mean max, %s (%s)', max(a), a) 132 | a = [round(i.cpu().numpy().max(), 4) for i in running_var] 133 | logger.info('bn var max, %s (%s)', max(a), a) 134 | 135 | 136 | def get_bn_modules(model): 137 | """ 138 | Find all BatchNorm (BN) modules that are in training mode. See 139 | fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are 140 | included in this search. 141 | Args: 142 | model (nn.Module): a model possibly containing BN modules. 143 | Returns: 144 | list[nn.Module]: all BN modules in the model. 145 | """ 146 | # Finds all the bn layers. 147 | bn_layers = [ 148 | m 149 | for m in model.modules() 150 | if m.training and isinstance(m, BN_MODULE_TYPES) 151 | ] 152 | return bn_layers 153 | --------------------------------------------------------------------------------