├── README.md ├── SSA.py ├── configs └── Shunted │ ├── shunted_B.py │ ├── shunted_S.py │ └── shunted_T.py ├── datasets.py ├── dist_train.sh ├── engine.py ├── hubconf.py ├── losses.py ├── main.py ├── mcloader ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── classification.cpython-38.pyc │ ├── data_prefetcher.cpython-38.pyc │ ├── image_list.cpython-38.pyc │ ├── imagenet.cpython-38.pyc │ └── mcloader.cpython-38.pyc ├── classification.py ├── data_prefetcher.py ├── image_list.py ├── imagenet.py └── mcloader.py ├── requirements.txt ├── samplers.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Shunted Transformer 2 | 3 | This is the offical implementation of [Shunted Self-Attention via Multi-Scale Token Aggregation](https://arxiv.org/abs/2111.15193) 4 | by Sucheng Ren, Daquan Zhou, Shengfeng He, Jiashi Feng, Xinchao Wang 5 | ### Training from scratch 6 | 7 | ## Training 8 | ```shell 9 | bash dist_train.sh 10 | ``` 11 | 12 | ## Model Zoo 13 | The checkpoints can be found at [Goolge Drive](https://drive.google.com/drive/folders/15iZKXFT7apjUSoN2WUMAbb0tvJgyh3YP?usp=sharing), [Baidu Pan](https://pan.baidu.com/s/1a9nVWpw2SzP0csCBCF8DNw) (code:hazr) (Checkpoints of the large models are coming soon.) 14 | 15 | | Method | Size | Acc@1 | #Params (M) | 16 | |------------------|:----:|:-----:|:-----------:| 17 | | Shunted-T | 224 | 79.8 | 11.5 | 18 | | Shunted-S | 224 | 82.9 | 22.4 | 19 | | Shunted-B | 224 | 84.0 | 39.6 | 20 | 21 | 22 | ## Citation 23 | ```shell 24 | @misc{ren2021shunted, 25 | title={Shunted Self-Attention via Multi-Scale Token Aggregation}, 26 | author={Sucheng Ren and Daquan Zhou and Shengfeng He and Jiashi Feng and Xinchao Wang}, 27 | year={2021}, 28 | eprint={2111.15193}, 29 | archivePrefix={arXiv}, 30 | primaryClass={cs.CV} 31 | } 32 | ``` -------------------------------------------------------------------------------- /SSA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from functools import partial 5 | 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | from timm.models.registry import register_model 8 | from timm.models.vision_transformer import _cfg 9 | import math 10 | 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.dwconv = DWConv(hidden_features) 19 | self.act = act_layer() 20 | self.fc2 = nn.Linear(hidden_features, out_features) 21 | self.drop = nn.Dropout(drop) 22 | self.apply(self._init_weights) 23 | 24 | def _init_weights(self, m): 25 | if isinstance(m, nn.Linear): 26 | trunc_normal_(m.weight, std=.02) 27 | if isinstance(m, nn.Linear) and m.bias is not None: 28 | nn.init.constant_(m.bias, 0) 29 | elif isinstance(m, nn.LayerNorm): 30 | nn.init.constant_(m.bias, 0) 31 | nn.init.constant_(m.weight, 1.0) 32 | elif isinstance(m, nn.Conv2d): 33 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 34 | fan_out //= m.groups 35 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | 39 | def forward(self, x, H, W): 40 | x = self.fc1(x) 41 | x = self.act(x + self.dwconv(x, H, W)) 42 | x = self.drop(x) 43 | x = self.fc2(x) 44 | x = self.drop(x) 45 | return x 46 | 47 | 48 | class Attention(nn.Module): 49 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 50 | super().__init__() 51 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 52 | 53 | self.dim = dim 54 | self.num_heads = num_heads 55 | head_dim = dim // num_heads 56 | self.scale = qk_scale or head_dim ** -0.5 57 | 58 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 59 | 60 | self.attn_drop = nn.Dropout(attn_drop) 61 | self.proj = nn.Linear(dim, dim) 62 | self.proj_drop = nn.Dropout(proj_drop) 63 | 64 | 65 | self.sr_ratio = sr_ratio 66 | if sr_ratio > 1: 67 | self.act = nn.GELU() 68 | if sr_ratio==8: 69 | self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8) 70 | self.norm1 = nn.LayerNorm(dim) 71 | self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4) 72 | self.norm2 = nn.LayerNorm(dim) 73 | if sr_ratio==4: 74 | self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4) 75 | self.norm1 = nn.LayerNorm(dim) 76 | self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2) 77 | self.norm2 = nn.LayerNorm(dim) 78 | if sr_ratio==2: 79 | self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2) 80 | self.norm1 = nn.LayerNorm(dim) 81 | self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 82 | self.norm2 = nn.LayerNorm(dim) 83 | self.kv1 = nn.Linear(dim, dim, bias=qkv_bias) 84 | self.kv2 = nn.Linear(dim, dim, bias=qkv_bias) 85 | self.local_conv1 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2) 86 | self.local_conv2 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2) 87 | else: 88 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 89 | self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim) 90 | self.apply(self._init_weights) 91 | 92 | def _init_weights(self, m): 93 | if isinstance(m, nn.Linear): 94 | trunc_normal_(m.weight, std=.02) 95 | if isinstance(m, nn.Linear) and m.bias is not None: 96 | nn.init.constant_(m.bias, 0) 97 | elif isinstance(m, nn.LayerNorm): 98 | nn.init.constant_(m.bias, 0) 99 | nn.init.constant_(m.weight, 1.0) 100 | elif isinstance(m, nn.Conv2d): 101 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | fan_out //= m.groups 103 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 104 | if m.bias is not None: 105 | m.bias.data.zero_() 106 | 107 | def forward(self, x, H, W): 108 | B, N, C = x.shape 109 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 110 | if self.sr_ratio > 1: 111 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 112 | x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1))) 113 | x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1))) 114 | kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4) 116 | k1, v1 = kv1[0], kv1[1] #B head N C 117 | k2, v2 = kv2[0], kv2[1] 118 | attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scale 119 | attn1 = attn1.softmax(dim=-1) 120 | attn1 = self.attn_drop(attn1) 121 | v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2). 122 | transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\ 123 | view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2) 124 | x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2) 125 | attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scale 126 | attn2 = attn2.softmax(dim=-1) 127 | attn2 = self.attn_drop(attn2) 128 | v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2). 129 | transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\ 130 | view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2) 131 | x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2) 132 | 133 | x = torch.cat([x1,x2], dim=-1) 134 | else: 135 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 136 | k, v = kv[0], kv[1] 137 | 138 | attn = (q @ k.transpose(-2, -1)) * self.scale 139 | attn = attn.softmax(dim=-1) 140 | attn = self.attn_drop(attn) 141 | 142 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) + self.local_conv(v.transpose(1, 2).reshape(B, N, C). 143 | transpose(1, 2).view(B,C, H, W)).view(B, C, N).transpose(1, 2) 144 | x = self.proj(x) 145 | x = self.proj_drop(x) 146 | 147 | return x 148 | 149 | 150 | class Block(nn.Module): 151 | 152 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 153 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 154 | super().__init__() 155 | self.norm1 = norm_layer(dim) 156 | self.attn = Attention( 157 | dim, 158 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 159 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 160 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 161 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 162 | self.norm2 = norm_layer(dim) 163 | mlp_hidden_dim = int(dim * mlp_ratio) 164 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 165 | 166 | self.apply(self._init_weights) 167 | 168 | def _init_weights(self, m): 169 | if isinstance(m, nn.Linear): 170 | trunc_normal_(m.weight, std=.02) 171 | if isinstance(m, nn.Linear) and m.bias is not None: 172 | nn.init.constant_(m.bias, 0) 173 | elif isinstance(m, nn.LayerNorm): 174 | nn.init.constant_(m.bias, 0) 175 | nn.init.constant_(m.weight, 1.0) 176 | elif isinstance(m, nn.Conv2d): 177 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 178 | fan_out //= m.groups 179 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 180 | if m.bias is not None: 181 | m.bias.data.zero_() 182 | 183 | def forward(self, x, H, W): 184 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 185 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 186 | 187 | return x 188 | 189 | 190 | class OverlapPatchEmbed(nn.Module): 191 | """ Image to Patch Embedding 192 | """ 193 | 194 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 195 | super().__init__() 196 | img_size = to_2tuple(img_size) 197 | patch_size = to_2tuple(patch_size) 198 | 199 | self.img_size = img_size 200 | self.patch_size = patch_size 201 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 202 | self.num_patches = self.H * self.W 203 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 204 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 205 | self.norm = nn.LayerNorm(embed_dim) 206 | 207 | self.apply(self._init_weights) 208 | 209 | def _init_weights(self, m): 210 | if isinstance(m, nn.Linear): 211 | trunc_normal_(m.weight, std=.02) 212 | if isinstance(m, nn.Linear) and m.bias is not None: 213 | nn.init.constant_(m.bias, 0) 214 | elif isinstance(m, nn.LayerNorm): 215 | nn.init.constant_(m.bias, 0) 216 | nn.init.constant_(m.weight, 1.0) 217 | elif isinstance(m, nn.Conv2d): 218 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 219 | fan_out //= m.groups 220 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 221 | if m.bias is not None: 222 | m.bias.data.zero_() 223 | 224 | def forward(self, x): 225 | x = self.proj(x) 226 | _, _, H, W = x.shape 227 | x = x.flatten(2).transpose(1, 2) 228 | x = self.norm(x) 229 | 230 | return x, H, W 231 | 232 | class Head(nn.Module): 233 | def __init__(self, num): 234 | super(Head, self).__init__() 235 | stem = [nn.Conv2d(3, 64, 7, 2, padding=3, bias=False), nn.BatchNorm2d(64), nn.ReLU(True)] 236 | for i in range(num): 237 | stem.append(nn.Conv2d(64, 64, 3, 1, padding=1, bias=False)) 238 | stem.append(nn.BatchNorm2d(64)) 239 | stem.append(nn.ReLU(True)) 240 | stem.append(nn.Conv2d(64, 64, kernel_size=2, stride=2)) 241 | self.conv = nn.Sequential(*stem) 242 | self.norm = nn.LayerNorm(64) 243 | self.apply(self._init_weights) 244 | 245 | def _init_weights(self, m): 246 | if isinstance(m, nn.Linear): 247 | trunc_normal_(m.weight, std=.02) 248 | if isinstance(m, nn.Linear) and m.bias is not None: 249 | nn.init.constant_(m.bias, 0) 250 | elif isinstance(m, nn.LayerNorm): 251 | nn.init.constant_(m.bias, 0) 252 | nn.init.constant_(m.weight, 1.0) 253 | elif isinstance(m, nn.Conv2d): 254 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 255 | fan_out //= m.groups 256 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 257 | if m.bias is not None: 258 | m.bias.data.zero_() 259 | def forward(self, x): 260 | x = self.conv(x) 261 | _, _, H, W = x.shape 262 | x = x.flatten(2).transpose(1, 2) 263 | x = self.norm(x) 264 | return x, H,W 265 | 266 | class ShuntedTransformer(nn.Module): 267 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512], 268 | num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0., 269 | attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, 270 | depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, num_conv=0): 271 | super().__init__() 272 | self.num_classes = num_classes 273 | self.depths = depths 274 | self.num_stages = num_stages 275 | 276 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 277 | cur = 0 278 | 279 | for i in range(num_stages): 280 | if i ==0: 281 | patch_embed = Head(num_conv)# 282 | else: 283 | patch_embed = OverlapPatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)), 284 | patch_size=7 if i == 0 else 3, 285 | stride=4 if i == 0 else 2, 286 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 287 | embed_dim=embed_dims[i]) 288 | 289 | block = nn.ModuleList([Block( 290 | dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias, qk_scale=qk_scale, 291 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer, 292 | sr_ratio=sr_ratios[i]) 293 | for j in range(depths[i])]) 294 | norm = norm_layer(embed_dims[i]) 295 | cur += depths[i] 296 | 297 | setattr(self, f"patch_embed{i + 1}", patch_embed) 298 | setattr(self, f"block{i + 1}", block) 299 | setattr(self, f"norm{i + 1}", norm) 300 | 301 | # classification head 302 | self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity() 303 | 304 | self.apply(self._init_weights) 305 | 306 | def _init_weights(self, m): 307 | if isinstance(m, nn.Linear): 308 | trunc_normal_(m.weight, std=.02) 309 | if isinstance(m, nn.Linear) and m.bias is not None: 310 | nn.init.constant_(m.bias, 0) 311 | elif isinstance(m, nn.LayerNorm): 312 | nn.init.constant_(m.bias, 0) 313 | nn.init.constant_(m.weight, 1.0) 314 | elif isinstance(m, nn.Conv2d): 315 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 316 | fan_out //= m.groups 317 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 318 | if m.bias is not None: 319 | m.bias.data.zero_() 320 | 321 | def freeze_patch_emb(self): 322 | self.patch_embed1.requires_grad = False 323 | 324 | @torch.jit.ignore 325 | def no_weight_decay(self): 326 | return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better 327 | 328 | def get_classifier(self): 329 | return self.head 330 | 331 | def reset_classifier(self, num_classes, global_pool=''): 332 | self.num_classes = num_classes 333 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 334 | 335 | def forward_features(self, x): 336 | B = x.shape[0] 337 | 338 | for i in range(self.num_stages): 339 | patch_embed = getattr(self, f"patch_embed{i + 1}") 340 | block = getattr(self, f"block{i + 1}") 341 | norm = getattr(self, f"norm{i + 1}") 342 | x, H, W = patch_embed(x) 343 | for blk in block: 344 | x = blk(x, H, W) 345 | x = norm(x) 346 | if i != self.num_stages - 1: 347 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 348 | 349 | return x.mean(dim=1) 350 | 351 | def forward(self, x): 352 | x = self.forward_features(x) 353 | x = self.head(x) 354 | 355 | return x 356 | 357 | 358 | class DWConv(nn.Module): 359 | def __init__(self, dim=768): 360 | super(DWConv, self).__init__() 361 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 362 | 363 | def forward(self, x, H, W): 364 | B, N, C = x.shape 365 | x = x.transpose(1, 2).view(B, C, H, W) 366 | x = self.dwconv(x) 367 | x = x.flatten(2).transpose(1, 2) 368 | 369 | return x 370 | 371 | 372 | def _conv_filter(state_dict, patch_size=16): 373 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 374 | out_dict = {} 375 | for k, v in state_dict.items(): 376 | if 'patch_embed.proj.weight' in k: 377 | v = v.reshape((v.shape[0], 3, patch_size, patch_size)) 378 | out_dict[k] = v 379 | 380 | return out_dict 381 | 382 | 383 | 384 | @register_model 385 | def shunted_t(pretrained=False, **kwargs): 386 | model = ShuntedTransformer( 387 | patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 388 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[1, 2, 4, 1], sr_ratios=[8, 4, 2, 1], num_conv=0, 389 | **kwargs) 390 | model.default_cfg = _cfg() 391 | 392 | return model 393 | 394 | 395 | @register_model 396 | def shunted_s(pretrained=False, **kwargs): 397 | model = ShuntedTransformer( 398 | patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 399 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 4, 12, 1], sr_ratios=[8, 4, 2, 1], num_conv=1, **kwargs) 400 | model.default_cfg = _cfg() 401 | 402 | return model 403 | 404 | 405 | @register_model 406 | def shunted_b(pretrained=False, **kwargs): 407 | model = ShuntedTransformer( 408 | patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[8, 8, 4, 4], qkv_bias=True, 409 | norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 24, 2], sr_ratios=[8, 4, 2, 1], num_conv=2, 410 | **kwargs) 411 | model.default_cfg = _cfg() 412 | 413 | return model 414 | -------------------------------------------------------------------------------- /configs/Shunted/shunted_B.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='shunted_b', 3 | drop_path=0.3, 4 | clip_grad=1.0, 5 | output_dir='checkpoints/shunted_b', 6 | ) 7 | -------------------------------------------------------------------------------- /configs/Shunted/shunted_S.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='shunted_s', 3 | drop_path=0.2, 4 | clip_grad=None, 5 | output_dir='checkpoints/shunted_s', 6 | ) -------------------------------------------------------------------------------- /configs/Shunted/shunted_T.py: -------------------------------------------------------------------------------- 1 | cfg = dict( 2 | model='shunted_t', 3 | drop_path=0.1, 4 | clip_grad=None, 5 | output_dir='checkpoints/shunted_t', 6 | ) -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | from mcloader import ClassificationDataset 12 | 13 | 14 | class INatDataset(ImageFolder): 15 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 16 | category='name', loader=default_loader): 17 | self.transform = transform 18 | self.loader = loader 19 | self.target_transform = target_transform 20 | self.year = year 21 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 22 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 23 | with open(path_json) as json_file: 24 | data = json.load(json_file) 25 | 26 | with open(os.path.join(root, 'categories.json')) as json_file: 27 | data_catg = json.load(json_file) 28 | 29 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 30 | 31 | with open(path_json_for_targeter) as json_file: 32 | data_for_targeter = json.load(json_file) 33 | 34 | targeter = {} 35 | indexer = 0 36 | for elem in data_for_targeter['annotations']: 37 | king = [] 38 | king.append(data_catg[int(elem['category_id'])][category]) 39 | if king[0] not in targeter.keys(): 40 | targeter[king[0]] = indexer 41 | indexer += 1 42 | self.nb_classes = len(targeter) 43 | 44 | self.samples = [] 45 | for elem in data['images']: 46 | cut = elem['file_name'].split('/') 47 | target_current = int(cut[2]) 48 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 49 | 50 | categors = data_catg[target_current] 51 | target_current_true = targeter[categors[category]] 52 | self.samples.append((path_current, target_current_true)) 53 | 54 | # __getitem__ and __len__ inherited from ImageFolder 55 | 56 | 57 | def build_dataset(is_train, args): 58 | transform = build_transform(is_train, args) 59 | 60 | if args.data_set == 'CIFAR': 61 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 62 | nb_classes = 100 63 | elif args.data_set == 'IMNET': 64 | if not args.use_mcloader: 65 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 66 | dataset = datasets.ImageFolder(root, transform=transform) 67 | else: 68 | dataset = ClassificationDataset( 69 | 'train' if is_train else 'val', 70 | pipeline=transform 71 | ) 72 | nb_classes = 1000 73 | elif args.data_set == 'INAT': 74 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 75 | category=args.inat_category, transform=transform) 76 | nb_classes = dataset.nb_classes 77 | elif args.data_set == 'INAT19': 78 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 79 | category=args.inat_category, transform=transform) 80 | nb_classes = dataset.nb_classes 81 | 82 | return dataset, nb_classes 83 | 84 | 85 | def build_transform(is_train, args): 86 | resize_im = args.input_size > 32 87 | if is_train: 88 | # this should always dispatch to transforms_imagenet_train 89 | transform = create_transform( 90 | input_size=args.input_size, 91 | is_training=True, 92 | color_jitter=args.color_jitter, 93 | auto_augment=args.aa, 94 | interpolation=args.train_interpolation, 95 | re_prob=args.reprob, 96 | re_mode=args.remode, 97 | re_count=args.recount, 98 | ) 99 | if not resize_im: 100 | # replace RandomResizedCropAndInterpolation with 101 | # RandomCrop 102 | transform.transforms[0] = transforms.RandomCrop( 103 | args.input_size, padding=4) 104 | return transform 105 | 106 | t = [] 107 | if resize_im: 108 | size = int((256 / 224) * args.input_size) 109 | t.append( 110 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 111 | ) 112 | t.append(transforms.CenterCrop(args.input_size)) 113 | 114 | t.append(transforms.ToTensor()) 115 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 116 | return transforms.Compose(t) 117 | -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_LL_THRESHOLD=0 3 | 4 | PORT=${PORT:-6666} 5 | 6 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=$PORT \ 7 | --use_env main.py --config configs/Shunted/shunted_S.py --data-path /path/to/ImageNet/ --batch-size 128 8 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | from losses import DistillationLoss 16 | import utils 17 | 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 23 | set_training_mode=True, 24 | fp32=False): 25 | model.train(set_training_mode) 26 | metric_logger = utils.MetricLogger(delimiter=" ") 27 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 10 30 | 31 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 32 | samples = samples.to(device, non_blocking=True) 33 | targets = targets.to(device, non_blocking=True) 34 | 35 | if mixup_fn is not None: 36 | samples, targets = mixup_fn(samples, targets) 37 | 38 | # with torch.cuda.amp.autocast(): 39 | # outputs = model(samples) 40 | # loss = criterion(samples, outputs, targets) 41 | with torch.cuda.amp.autocast(enabled=not fp32): 42 | outputs = model(samples) 43 | loss = criterion(samples, outputs, targets) 44 | 45 | loss_value = loss.item() 46 | 47 | if not math.isfinite(loss_value): 48 | print("Loss is {}, stopping training".format(loss_value)) 49 | sys.exit(1) 50 | 51 | optimizer.zero_grad() 52 | 53 | # this attribute is added by timm on one optimizer (adahessian) 54 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 55 | loss_scaler(loss, optimizer, clip_grad=max_norm, 56 | parameters=model.parameters(), create_graph=is_second_order) 57 | 58 | torch.cuda.synchronize() 59 | if model_ema is not None: 60 | model_ema.update(model) 61 | 62 | metric_logger.update(loss=loss_value) 63 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 64 | # gather the stats from all processes 65 | metric_logger.synchronize_between_processes() 66 | print("Averaged stats:", metric_logger) 67 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 68 | 69 | 70 | @torch.no_grad() 71 | def evaluate(data_loader, model, device): 72 | criterion = torch.nn.CrossEntropyLoss() 73 | 74 | metric_logger = utils.MetricLogger(delimiter=" ") 75 | header = 'Test:' 76 | 77 | # switch to evaluation mode 78 | model.eval() 79 | 80 | for images, target in metric_logger.log_every(data_loader, 10, header): 81 | images = images.to(device, non_blocking=True) 82 | target = target.to(device, non_blocking=True) 83 | 84 | # compute output 85 | #with torch.cuda.amp.autocast(): 86 | output = model(images) 87 | loss = criterion(output, target) 88 | 89 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 90 | 91 | batch_size = images.shape[0] 92 | metric_logger.update(loss=loss.item()) 93 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 94 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 95 | # gather the stats from all processes 96 | metric_logger.synchronize_between_processes() 97 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 98 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 99 | 100 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 101 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | 5 | dependencies = ["torch", "torchvision", "timm"] 6 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | F.log_softmax(teacher_outputs / T, dim=1), 57 | reduction='sum', 58 | log_target=True 59 | ) * (T * T) / outputs_kd.numel() 60 | elif self.distillation_type == 'hard': 61 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 62 | 63 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 64 | return loss 65 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | 11 | from pathlib import Path 12 | 13 | from timm.data import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.scheduler import create_scheduler 17 | from timm.optim import create_optimizer 18 | from timm.utils import NativeScaler, get_state_dict, ModelEma 19 | 20 | from datasets import build_dataset 21 | from engine import train_one_epoch, evaluate 22 | from losses import DistillationLoss 23 | from samplers import RASampler 24 | # import models 25 | import SSA 26 | import utils 27 | import collections 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('PVT training and evaluation script', add_help=False) 32 | parser.add_argument('--fp32-resume', action='store_true', default=False) 33 | parser.add_argument('--batch-size', default=128, type=int) 34 | parser.add_argument('--epochs', default=300, type=int) 35 | parser.add_argument('--config', required=True, type=str, help='config') 36 | 37 | # Model parameters 38 | parser.add_argument('--model', default='pvt_small', type=str, metavar='MODEL', 39 | help='Name of model to train') 40 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 41 | 42 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 43 | help='Dropout rate (default: 0.)') 44 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 45 | help='Drop path rate (default: 0.1)') 46 | 47 | # parser.add_argument('--model-ema', action='store_true') 48 | # parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 49 | # parser.set_defaults(model_ema=True) 50 | # parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 51 | # parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 52 | 53 | # Optimizer parameters 54 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 55 | help='Optimizer (default: "adamw"') 56 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 57 | help='Optimizer Epsilon (default: 1e-8)') 58 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 59 | help='Optimizer Betas (default: None, use opt default)') 60 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 61 | help='Clip gradient norm (default: None, no clipping)') 62 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 63 | help='SGD momentum (default: 0.9)') 64 | parser.add_argument('--weight-decay', type=float, default=0.05, 65 | help='weight decay (default: 0.05)') 66 | # Learning rate schedule parameters 67 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 68 | help='LR scheduler (default: "cosine"') 69 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 70 | help='learning rate (default: 5e-4)') 71 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 72 | help='learning rate noise on/off epoch percentages') 73 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 74 | help='learning rate noise limit percent (default: 0.67)') 75 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 76 | help='learning rate noise std-dev (default: 1.0)') 77 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 78 | help='warmup learning rate (default: 1e-6)') 79 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 80 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 81 | 82 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 83 | help='epoch interval to decay LR') 84 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 85 | help='epochs to warmup LR, if scheduler supports') 86 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 87 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 88 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 89 | help='patience epochs for Plateau LR scheduler (default: 10') 90 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 91 | help='LR decay rate (default: 0.1)') 92 | 93 | # Augmentation parameters 94 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 95 | help='Color jitter factor (default: 0.4)') 96 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 97 | help='Use AutoAugment policy. "v0" or "original". " + \ 98 | "(default: rand-m9-mstd0.5-inc1)'), 99 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 100 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 101 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 102 | 103 | parser.add_argument('--repeated-aug', action='store_true') 104 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 105 | parser.set_defaults(repeated_aug=True) 106 | 107 | # * Random Erase params 108 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 109 | help='Random erase prob (default: 0.25)') 110 | parser.add_argument('--remode', type=str, default='pixel', 111 | help='Random erase mode (default: "pixel")') 112 | parser.add_argument('--recount', type=int, default=1, 113 | help='Random erase count (default: 1)') 114 | parser.add_argument('--resplit', action='store_true', default=False, 115 | help='Do not random erase first (clean) augmentation split') 116 | 117 | # * Mixup params 118 | parser.add_argument('--mixup', type=float, default=0.8, 119 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 120 | parser.add_argument('--cutmix', type=float, default=1.0, 121 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 122 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 123 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 124 | parser.add_argument('--mixup-prob', type=float, default=1.0, 125 | help='Probability of performing mixup or cutmix when either/both is enabled') 126 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 127 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 128 | parser.add_argument('--mixup-mode', type=str, default='batch', 129 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 130 | 131 | # Distillation parameters 132 | # parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 133 | # help='Name of teacher model to train (default: "regnety_160"') 134 | # parser.add_argument('--teacher-path', type=str, default='') 135 | # parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 136 | # parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 137 | # parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 138 | 139 | # * Finetuning params 140 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 141 | 142 | # Dataset parameters 143 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 144 | help='dataset path') 145 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 146 | type=str, help='Image Net dataset path') 147 | parser.add_argument('--use-mcloader', action='store_true', default=False, help='Use mcloader') 148 | parser.add_argument('--inat-category', default='name', 149 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 150 | type=str, help='semantic granularity') 151 | 152 | parser.add_argument('--output_dir', default='', 153 | help='path where to save, empty for no saving') 154 | parser.add_argument('--device', default='cuda', 155 | help='device to use for training / testing') 156 | parser.add_argument('--seed', default=0, type=int) 157 | parser.add_argument('--resume', default='', help='resume from checkpoint') 158 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 159 | help='start epoch') 160 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 161 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 162 | parser.add_argument('--num_workers', default=10, type=int) 163 | parser.add_argument('--pin-mem', action='store_true', 164 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 165 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 166 | help='') 167 | parser.set_defaults(pin_mem=True) 168 | 169 | # distributed training parameters 170 | parser.add_argument('--world_size', default=1, type=int, 171 | help='number of distributed processes') 172 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 173 | parser.add_argument('--local_rank', default=-1, type=int, 174 | help='number of distributed processes') 175 | return parser 176 | 177 | 178 | def main(args): 179 | utils.init_distributed_mode(args) 180 | print(args) 181 | # if args.distillation_type != 'none' and args.finetune and not args.eval: 182 | # raise NotImplementedError("Finetuning with distillation not yet supported") 183 | 184 | device = torch.device(args.device) 185 | 186 | # fix the seed for reproducibility 187 | # seed = args.seed + utils.get_rank() 188 | # torch.manual_seed(seed) 189 | # np.random.seed(seed) 190 | # random.seed(seed) 191 | 192 | cudnn.benchmark = True 193 | 194 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 195 | dataset_val, _ = build_dataset(is_train=False, args=args) 196 | 197 | if True: # args.distributed: 198 | num_tasks = utils.get_world_size() 199 | global_rank = utils.get_rank() 200 | if args.repeated_aug: 201 | sampler_train = RASampler( 202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 203 | ) 204 | else: 205 | sampler_train = torch.utils.data.DistributedSampler( 206 | dataset_train, 207 | # num_replicas=num_tasks, 208 | num_replicas=0, 209 | rank=global_rank, shuffle=True 210 | ) 211 | if args.dist_eval: 212 | if len(dataset_val) % num_tasks != 0: 213 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 214 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 215 | 'equal num of samples per-process.') 216 | sampler_val = torch.utils.data.DistributedSampler( 217 | dataset_val, 218 | # num_replicas=num_tasks, 219 | num_replicas=0, 220 | rank=global_rank, shuffle=False) 221 | else: 222 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 223 | else: 224 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 225 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 226 | 227 | data_loader_train = torch.utils.data.DataLoader( 228 | dataset_train, sampler=sampler_train, 229 | batch_size=args.batch_size, 230 | num_workers=args.num_workers, 231 | pin_memory=args.pin_mem, 232 | drop_last=True, 233 | ) 234 | 235 | data_loader_val = torch.utils.data.DataLoader( 236 | dataset_val, sampler=sampler_val, 237 | batch_size=int(args.batch_size*1.5), 238 | num_workers=args.num_workers, 239 | pin_memory=args.pin_mem, 240 | drop_last=False 241 | ) 242 | 243 | mixup_fn = None 244 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 245 | if mixup_active: 246 | mixup_fn = Mixup( 247 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 248 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 249 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 250 | 251 | print(f"Creating model: {args.model}") 252 | model = create_model( 253 | args.model, 254 | pretrained=False, 255 | num_classes=args.nb_classes, 256 | drop_rate=args.drop, 257 | drop_path_rate=args.drop_path, 258 | drop_block_rate=None, 259 | ) 260 | 261 | 262 | if args.finetune: 263 | if args.finetune.startswith('https'): 264 | checkpoint = torch.hub.load_state_dict_from_url( 265 | args.finetune, map_location='cpu', check_hash=True) 266 | else: 267 | checkpoint = torch.load(args.finetune, map_location='cpu') 268 | 269 | if 'model' in checkpoint: 270 | checkpoint_model = checkpoint['model'] 271 | else: 272 | checkpoint_model = checkpoint 273 | state_dict = model.state_dict() 274 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 275 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 276 | print(f"Removing key {k} from pretrained checkpoint") 277 | del checkpoint_model[k] 278 | 279 | model.load_state_dict(checkpoint_model, strict=False) 280 | 281 | model.to(device) 282 | 283 | model_ema = None 284 | 285 | model_without_ddp = model 286 | if args.distributed: 287 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 288 | model_without_ddp = model.module 289 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 290 | print('number of params:', n_parameters) 291 | 292 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 293 | args.lr = linear_scaled_lr 294 | optimizer = create_optimizer(args, model_without_ddp) 295 | loss_scaler = NativeScaler() 296 | lr_scheduler, _ = create_scheduler(args, optimizer) 297 | 298 | criterion = LabelSmoothingCrossEntropy() 299 | 300 | if args.mixup > 0.: 301 | # smoothing is handled with mixup label transform 302 | criterion = SoftTargetCrossEntropy() 303 | elif args.smoothing: 304 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 305 | else: 306 | criterion = torch.nn.CrossEntropyLoss() 307 | 308 | criterion = DistillationLoss( 309 | criterion, None, 'none', 0, 0 310 | ) 311 | 312 | output_dir = Path(args.output_dir) 313 | if args.resume: 314 | if args.resume.startswith('https'): 315 | checkpoint = torch.hub.load_state_dict_from_url( 316 | args.resume, map_location='cpu', check_hash=True) 317 | else: 318 | checkpoint = torch.load(args.resume, map_location='cpu') 319 | if 'model' in checkpoint: 320 | msg = model_without_ddp.load_state_dict(checkpoint['model']) 321 | else: 322 | msg = model_without_ddp.load_state_dict(checkpoint) 323 | print(msg) 324 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 325 | optimizer.load_state_dict(checkpoint['optimizer']) 326 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 327 | args.start_epoch = checkpoint['epoch'] + 1 328 | # if args.model_ema: 329 | # utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 330 | if 'scaler' in checkpoint: 331 | loss_scaler.load_state_dict(checkpoint['scaler']) 332 | 333 | if args.eval: 334 | test_stats = evaluate(data_loader_val, model, device) 335 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") 336 | return 337 | 338 | print(f"Start training for {args.epochs} epochs") 339 | start_time = time.time() 340 | max_accuracy = 0.0 341 | 342 | for epoch in range(args.start_epoch, args.epochs): 343 | if args.fp32_resume and epoch > args.start_epoch + 1: 344 | args.fp32_resume = False 345 | loss_scaler._scaler = torch.cuda.amp.GradScaler(enabled=not args.fp32_resume) 346 | 347 | if args.distributed: 348 | data_loader_train.sampler.set_epoch(epoch) 349 | 350 | train_stats = train_one_epoch( 351 | model, criterion, data_loader_train, 352 | optimizer, device, epoch, loss_scaler, 353 | args.clip_grad, model_ema, mixup_fn, 354 | set_training_mode=args.finetune == '', # keep in eval mode during finetuning 355 | fp32=args.fp32_resume 356 | ) 357 | 358 | lr_scheduler.step(epoch) 359 | if args.output_dir: 360 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 361 | for checkpoint_path in checkpoint_paths: 362 | utils.save_on_master({ 363 | 'model': model_without_ddp.state_dict(), 364 | 'optimizer': optimizer.state_dict(), 365 | 'lr_scheduler': lr_scheduler.state_dict(), 366 | 'epoch': epoch, 367 | # 'model_ema': get_state_dict(model_ema), 368 | 'scaler': loss_scaler.state_dict(), 369 | 'args': args, 370 | }, "./checkpoints/epoch_"+str(epoch)+".pth") 371 | 372 | test_stats = evaluate(data_loader_val, model, device) 373 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.2f}%") 374 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 375 | print(f'Max accuracy: {max_accuracy:.2f}%') 376 | 377 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 378 | **{f'test_{k}': v for k, v in test_stats.items()}, 379 | 'epoch': epoch, 380 | 'n_parameters': n_parameters} 381 | 382 | if args.output_dir and utils.is_main_process(): 383 | with (output_dir / "log.txt").open("a") as f: 384 | f.write(json.dumps(log_stats) + "\n") 385 | 386 | total_time = time.time() - start_time 387 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 388 | print('Training time {}'.format(total_time_str)) 389 | 390 | 391 | if __name__ == '__main__': 392 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 393 | args = parser.parse_args() 394 | args = utils.update_from_config(args) 395 | if args.output_dir: 396 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 397 | main(args) 398 | -------------------------------------------------------------------------------- /mcloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import ClassificationDataset 2 | from .data_prefetcher import DataPrefetcher -------------------------------------------------------------------------------- /mcloader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/__pycache__/classification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/classification.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/__pycache__/data_prefetcher.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/data_prefetcher.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/__pycache__/image_list.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/image_list.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/__pycache__/mcloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OliverRensu/Shunted-Transformer/6861101d8f4f592da79f84294f015e1e6b150351/mcloader/__pycache__/mcloader.cpython-38.pyc -------------------------------------------------------------------------------- /mcloader/classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from .imagenet import ImageNet 4 | 5 | 6 | class ClassificationDataset(Dataset): 7 | """Dataset for classification. 8 | """ 9 | 10 | def __init__(self, split='train', pipeline=None): 11 | if split == 'train': 12 | self.data_source = ImageNet(root='data/imagenet/train', 13 | list_file='data/imagenet/meta/train.txt', 14 | memcached=True, 15 | mclient_path='/mnt/lustre/share/memcached_client') 16 | else: 17 | self.data_source = ImageNet(root='data/imagenet/val', 18 | list_file='data/imagenet/meta/val.txt', 19 | memcached=True, 20 | mclient_path='/mnt/lustre/share/memcached_client') 21 | self.pipeline = pipeline 22 | 23 | def __len__(self): 24 | return self.data_source.get_length() 25 | 26 | def __getitem__(self, idx): 27 | img, target = self.data_source.get_sample(idx) 28 | if self.pipeline is not None: 29 | img = self.pipeline(img) 30 | 31 | return img, target 32 | -------------------------------------------------------------------------------- /mcloader/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class DataPrefetcher: 5 | def __init__(self, loader): 6 | self.loader = iter(loader) 7 | self.stream = torch.cuda.Stream() 8 | self.preload() 9 | 10 | def preload(self): 11 | try: 12 | self.next_input, self.next_target = next(self.loader) 13 | except StopIteration: 14 | self.next_input = None 15 | self.next_target = None 16 | return 17 | 18 | with torch.cuda.stream(self.stream): 19 | self.next_input = self.next_input.cuda(non_blocking=True) 20 | self.next_target = self.next_target.cuda(non_blocking=True) 21 | 22 | def next(self): 23 | torch.cuda.current_stream().wait_stream(self.stream) 24 | input = self.next_input 25 | target = self.next_target 26 | if input is not None: 27 | self.preload() 28 | return input, target 29 | -------------------------------------------------------------------------------- /mcloader/image_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | from .mcloader import McLoader 5 | 6 | 7 | class ImageList(object): 8 | 9 | def __init__(self, root, list_file, memcached=False, mclient_path=None): 10 | with open(list_file, 'r') as f: 11 | lines = f.readlines() 12 | self.has_labels = len(lines[0].split()) == 2 13 | if self.has_labels: 14 | self.fns, self.labels = zip(*[l.strip().split() for l in lines]) 15 | self.labels = [int(l) for l in self.labels] 16 | else: 17 | self.fns = [l.strip() for l in lines] 18 | self.fns = [os.path.join(root, fn) for fn in self.fns] 19 | self.memcached = memcached 20 | self.mclient_path = mclient_path 21 | self.initialized = False 22 | 23 | def _init_memcached(self): 24 | if not self.initialized: 25 | assert self.mclient_path is not None 26 | self.mc_loader = McLoader(self.mclient_path) 27 | self.initialized = True 28 | 29 | def get_length(self): 30 | return len(self.fns) 31 | 32 | def get_sample(self, idx): 33 | if self.memcached: 34 | self._init_memcached() 35 | if self.memcached: 36 | img = self.mc_loader(self.fns[idx]) 37 | else: 38 | img = Image.open(self.fns[idx]) 39 | img = img.convert('RGB') 40 | if self.has_labels: 41 | target = self.labels[idx] 42 | return img, target 43 | else: 44 | return img 45 | -------------------------------------------------------------------------------- /mcloader/imagenet.py: -------------------------------------------------------------------------------- 1 | from .image_list import ImageList 2 | 3 | 4 | class ImageNet(ImageList): 5 | 6 | def __init__(self, root, list_file, memcached, mclient_path): 7 | super(ImageNet, self).__init__( 8 | root, list_file, memcached, mclient_path) 9 | -------------------------------------------------------------------------------- /mcloader/mcloader.py: -------------------------------------------------------------------------------- 1 | import io 2 | from PIL import Image 3 | try: 4 | import mc 5 | except ImportError as E: 6 | pass 7 | 8 | 9 | def pil_loader(img_str): 10 | buff = io.BytesIO(img_str) 11 | return Image.open(buff) 12 | 13 | 14 | class McLoader(object): 15 | 16 | def __init__(self, mclient_path): 17 | assert mclient_path is not None, \ 18 | "Please specify 'data_mclient_path' in the config." 19 | self.mclient_path = mclient_path 20 | server_list_config_file = "{}/server_list.conf".format( 21 | self.mclient_path) 22 | client_config_file = "{}/client.conf".format(self.mclient_path) 23 | self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, 24 | client_config_file) 25 | 26 | def __call__(self, fn): 27 | try: 28 | img_value = mc.pyvector() 29 | self.mclient.Get(fn, img_value) 30 | img_value_str = mc.ConvertBuffer(img_value) 31 | img = pil_loader(img_value_str) 32 | except: 33 | print('Read image failed ({})'.format(fn)) 34 | return None 35 | else: 36 | return img -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | mmcv==1.3.8 5 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | import mmcv 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | 240 | 241 | def update_from_config(args): 242 | cfg = mmcv.Config.fromfile(args.config) 243 | for _, cfg_item in cfg._cfg_dict.items(): 244 | for k, v in cfg_item.items(): 245 | setattr(args, k, v) 246 | return args 247 | 248 | --------------------------------------------------------------------------------