├── LICENSE ├── README.md ├── ceit.png ├── ceit_model.py ├── datasets.py ├── engine.py ├── main.py ├── requirements.txt ├── samplers.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Incorporating Convolution Designs into Visual Transformers 2 | 3 | 4 | This repository is the official implementation of CeiT (Convolution-enhanced image Transformer). It builds from [Data-Efficient Vision Transformer](https://github.com/facebookresearch/deit) and [timm](https://github.com/rwightman/pytorch-image-models) 5 | 6 | ![CeiT](ceit.png) 7 | 8 | For details see [Incorporating Convolution Designs into Visual Transformers](https://arxiv.org/pdf/2103.11816.pdf) by Kun Yuan, Shaopeng Guo, Ziwei Liu, Aojun Zhou, Fengwei Yu and Wei Wu 9 | 10 | If you use this code for a paper please cite: 11 | 12 | ``` 13 | @article{DBLP:journals/corr/abs-2103-11816, 14 | author = {Kun Yuan and 15 | Shaopeng Guo and 16 | Ziwei Liu and 17 | Aojun Zhou and 18 | Fengwei Yu and 19 | Wei Wu}, 20 | title = {Incorporating Convolution Designs into Visual Transformers}, 21 | journal = {CoRR}, 22 | volume = {abs/2103.11816}, 23 | year = {2021}, 24 | url = {https://arxiv.org/abs/2103.11816}, 25 | archivePrefix = {arXiv}, 26 | eprint = {2103.11816}, 27 | timestamp = {Wed, 24 Mar 2021 15:50:40 +0100}, 28 | biburl = {https://dblp.org/rec/journals/corr/abs-2103-11816.bib}, 29 | bibsource = {dblp computer science bibliography, https://dblp.org} 30 | } 31 | ``` 32 | 33 | # Model Zoo 34 | 35 | We provide baseline CeiT models pretrained on ImageNet 2012. The checkpoint can be downloaded from [here](https://drive.google.com/file/d/1S19SQUic9ILBGNkcJTOy74MGoj-JRM4x/view?usp=sharing) 36 | 37 | | model name | epoch | acc@1 | acc@5 | #params | 38 | | --- | --- | --- | --- | --- | 39 | | ceit_tiny_patch16_224 | 300 | 76.4 | 93.4 | 6.4M | 40 | | ceit_tiny_patch16_384 | 300 | 78.8| 94.7 | 6.4M | 41 | | ceit_small_patch16_224 | 300 | 82.0 | 95.9 | 24.2M | 42 | | ceit_small_patch16_384 | 300 | 83.3 | 96.5 | 24.2M | 43 | | ceit_base_patch16_224 | 100 | 81.9 | 95.8 | 94.4M | 44 | | ceit_base_patch16_224 | 150 | 82.5 | 95.6 | 94.4M | 45 | 46 | Before using it, make sure you have the pytorch-image-models package [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models) by [Ross Wightman](https://github.com/rwightman) installed. Note that our work relies of the augmentations proposed in this library. 47 | 48 | # Usage 49 | 50 | First, clone the repository locally: 51 | ``` 52 | git clone https://github.com/coeusguo/ceit.git 53 | ``` 54 | Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models): 55 | 56 | ``` 57 | conda install -c pytorch pytorch torchvision 58 | pip install timm==0.3.2 59 | ``` 60 | 61 | ## Data preparation 62 | 63 | Download and extract ImageNet train and val images from http://image-net.org/. 64 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 65 | 66 | ``` 67 | /path/to/imagenet/ 68 | train/ 69 | class1/ 70 | img1.jpeg 71 | class2/ 72 | img2.jpeg 73 | val/ 74 | class1/ 75 | img3.jpeg 76 | class/2 77 | img4.jpeg 78 | ``` 79 | 80 | ## Evaluation 81 | To evaluate a pre-trained CeiT model on ImageNet val with a single GPU run: 82 | ``` 83 | python main.py --eval --model --resume /path/to/checkpoint --data-path /path/to/imagenet 84 | ``` 85 | 86 | ## Training 87 | To train CeiT-Tiny and CeiT-small on ImageNet on a single node with 4 gpus for 300 epochs run: 88 | 89 | CeiT-tiny 90 | ``` 91 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model ceit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet 92 | ``` 93 | 94 | CeiT-small 95 | ``` 96 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model ceit_small_patch16_224 --batch-size 256 --data-path /path/to/imagenet 97 | ``` 98 | 99 | To train CeiT-Base on ImageNet on a single node with 4 gpus for 100 epochs run: 100 | CeiT-base 101 | ``` 102 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model ceit_base_patch16_224 --batch-size 256 --data-path /path/to/imagenet --epochs 100 103 | ``` 104 | -------------------------------------------------------------------------------- /ceit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/coeusguo/ceit/44c16d3ee9ed5306abffecdc254af11e89197253/ceit.png -------------------------------------------------------------------------------- /ceit_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from functools import partial 6 | 7 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 8 | from timm.models.registry import register_model 9 | from timm.models.vision_transformer import default_cfgs, _cfg 10 | 11 | 12 | __all__ = [ 13 | 'ceit_tiny_patch16_224', 'ceit_small_patch16_224', 'ceit_base_patch16_224', 14 | 'ceit_tiny_patch16_384', 'ceit_small_patch16_384', 15 | ] 16 | 17 | 18 | class Image2Tokens(nn.Module): 19 | def __init__(self, in_chans=3, out_chans=64, kernel_size=7, stride=2): 20 | super(Image2Tokens, self).__init__() 21 | self.conv = nn.Conv2d(in_chans, out_chans, kernel_size=kernel_size, stride=stride, 22 | padding=kernel_size // 2, bias=False) 23 | self.bn = nn.BatchNorm2d(out_chans) 24 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 25 | 26 | def forward(self, x): 27 | x = self.conv(x) 28 | x = self.bn(x) 29 | x = self.maxpool(x) 30 | return x 31 | 32 | 33 | class Mlp(nn.Module): 34 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 35 | super().__init__() 36 | out_features = out_features or in_features 37 | hidden_features = hidden_features or in_features 38 | self.fc1 = nn.Linear(in_features, hidden_features) 39 | self.act = act_layer() 40 | self.fc2 = nn.Linear(hidden_features, out_features) 41 | self.drop = nn.Dropout(drop) 42 | 43 | def forward(self, x): 44 | x = self.fc1(x) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class LocallyEnhancedFeedForward(nn.Module): 53 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., 54 | kernel_size=3, with_bn=True): 55 | super().__init__() 56 | out_features = out_features or in_features 57 | hidden_features = hidden_features or in_features 58 | # pointwise 59 | self.conv1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, stride=1, padding=0) 60 | # depthwise 61 | self.conv2 = nn.Conv2d( 62 | hidden_features, hidden_features, kernel_size=kernel_size, stride=1, 63 | padding=(kernel_size - 1) // 2, groups=hidden_features 64 | ) 65 | # pointwise 66 | self.conv3 = nn.Conv2d(hidden_features, out_features, kernel_size=1, stride=1, padding=0) 67 | self.act = act_layer() 68 | # self.drop = nn.Dropout(drop) 69 | 70 | self.with_bn = with_bn 71 | if self.with_bn: 72 | self.bn1 = nn.BatchNorm2d(hidden_features) 73 | self.bn2 = nn.BatchNorm2d(hidden_features) 74 | self.bn3 = nn.BatchNorm2d(out_features) 75 | 76 | def forward(self, x): 77 | b, n, k = x.size() 78 | cls_token, tokens = torch.split(x, [1, n - 1], dim=1) 79 | x = tokens.reshape(b, int(math.sqrt(n - 1)), int(math.sqrt(n - 1)), k).permute(0, 3, 1, 2) 80 | if self.with_bn: 81 | x = self.conv1(x) 82 | x = self.bn1(x) 83 | x = self.act(x) 84 | x = self.conv2(x) 85 | x = self.bn2(x) 86 | x = self.act(x) 87 | x = self.conv3(x) 88 | x = self.bn3(x) 89 | else: 90 | x = self.conv1(x) 91 | x = self.act(x) 92 | x = self.conv2(x) 93 | x = self.act(x) 94 | x = self.conv3(x) 95 | 96 | tokens = x.flatten(2).permute(0, 2, 1) 97 | out = torch.cat((cls_token, tokens), dim=1) 98 | return out 99 | 100 | 101 | class Attention(nn.Module): 102 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 103 | super().__init__() 104 | self.num_heads = num_heads 105 | head_dim = dim // num_heads 106 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 107 | self.scale = qk_scale or head_dim ** -0.5 108 | 109 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 110 | self.attn_drop = nn.Dropout(attn_drop) 111 | self.proj = nn.Linear(dim, dim) 112 | self.proj_drop = nn.Dropout(proj_drop) 113 | self.attention_map = None 114 | 115 | def forward(self, x): 116 | B, N, C = x.shape 117 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 118 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 119 | 120 | attn = (q @ k.transpose(-2, -1)) * self.scale 121 | attn = attn.softmax(dim=-1) 122 | # self.attention_map = attn 123 | attn = self.attn_drop(attn) 124 | 125 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 126 | x = self.proj(x) 127 | x = self.proj_drop(x) 128 | return x 129 | 130 | 131 | class AttentionLCA(Attention): 132 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 133 | super(AttentionLCA, self).__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop) 134 | self.dim = dim 135 | self.qkv_bias = qkv_bias 136 | 137 | def forward(self, x): 138 | 139 | q_weight = self.qkv.weight[:self.dim, :] 140 | q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim] 141 | kv_weight = self.qkv.weight[self.dim:, :] 142 | kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:] 143 | 144 | B, N, C = x.shape 145 | _, last_token = torch.split(x, [N-1, 1], dim=1) 146 | 147 | q = F.linear(last_token, q_weight, q_bias)\ 148 | .reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 149 | kv = F.linear(x, kv_weight, kv_bias)\ 150 | .reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 151 | k, v = kv[0], kv[1] 152 | 153 | attn = (q @ k.transpose(-2, -1)) * self.scale 154 | attn = attn.softmax(dim=-1) 155 | # self.attention_map = attn 156 | attn = self.attn_drop(attn) 157 | 158 | x = (attn @ v).transpose(1, 2).reshape(B, 1, C) 159 | x = self.proj(x) 160 | x = self.proj_drop(x) 161 | return x 162 | 163 | 164 | class Block(nn.Module): 165 | 166 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 167 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3, with_bn=True, 168 | feedforward_type='leff'): 169 | super().__init__() 170 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 171 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 172 | self.norm2 = norm_layer(dim) 173 | mlp_hidden_dim = int(dim * mlp_ratio) 174 | self.norm1 = norm_layer(dim) 175 | self.feedforward_type = feedforward_type 176 | 177 | if feedforward_type == 'leff': 178 | self.attn = Attention( 179 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 180 | self.leff = LocallyEnhancedFeedForward( 181 | in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, 182 | kernel_size=kernel_size, with_bn=with_bn, 183 | ) 184 | else: # LCA 185 | self.attn = AttentionLCA( 186 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 187 | self.feedforward = Mlp( 188 | in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop 189 | ) 190 | 191 | def forward(self, x): 192 | if self.feedforward_type == 'leff': 193 | x = x + self.drop_path(self.attn(self.norm1(x))) 194 | x = x + self.drop_path(self.leff(self.norm2(x))) 195 | return x, x[:, 0] 196 | else: # LCA 197 | _, last_token = torch.split(x, [x.size(1)-1, 1], dim=1) 198 | x = last_token + self.drop_path(self.attn(self.norm1(x))) 199 | x = x + self.drop_path(self.feedforward(self.norm2(x))) 200 | return x 201 | 202 | 203 | class HybridEmbed(nn.Module): 204 | """ CNN Feature Map Embedding 205 | Extract feature map from CNN, flatten, project to embedding dim. 206 | """ 207 | def __init__(self, backbone, img_size=224, patch_size=16, feature_size=None, in_chans=3, embed_dim=768): 208 | super().__init__() 209 | assert isinstance(backbone, nn.Module) 210 | img_size = to_2tuple(img_size) 211 | self.img_size = img_size 212 | self.backbone = backbone 213 | if feature_size is None: 214 | with torch.no_grad(): 215 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 216 | # map for all networks, the feature metadata has reliable channel and stride info, but using 217 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 218 | training = backbone.training 219 | if training: 220 | backbone.eval() 221 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 222 | if isinstance(o, (list, tuple)): 223 | o = o[-1] # last feature if backbone outputs list/tuple of features 224 | feature_size = o.shape[-2:] 225 | feature_dim = o.shape[1] 226 | backbone.train(training) 227 | else: 228 | feature_size = to_2tuple(feature_size) 229 | feature_dim = self.backbone.feature_info.channels()[-1] 230 | print('feature_size is {}, feature_dim is {}, patch_size is {}'.format( 231 | feature_size, feature_dim, patch_size 232 | )) 233 | self.num_patches = (feature_size[0] // patch_size) * (feature_size[1] // patch_size) 234 | self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) 235 | 236 | def forward(self, x): 237 | x = self.backbone(x) 238 | if isinstance(x, (list, tuple)): 239 | x = x[-1] # last feature if backbone outputs list/tuple of features 240 | x = self.proj(x).flatten(2).transpose(1, 2) 241 | return x 242 | 243 | 244 | class CeIT(nn.Module): 245 | def __init__(self, 246 | img_size=224, 247 | patch_size=16, 248 | in_chans=3, 249 | num_classes=1000, 250 | embed_dim=768, 251 | depth=12, 252 | num_heads=12, 253 | mlp_ratio=4., 254 | qkv_bias=False, 255 | qk_scale=None, 256 | drop_rate=0., 257 | attn_drop_rate=0., 258 | drop_path_rate=0., 259 | hybrid_backbone=None, 260 | norm_layer=nn.LayerNorm, 261 | leff_local_size=3, 262 | leff_with_bn=True): 263 | """ 264 | args: 265 | - img_size (:obj:`int`): input image size 266 | - patch_size (:obj:`int`): patch size 267 | - in_chans (:obj:`int`): input channels 268 | - num_classes (:obj:`int`): number of classes 269 | - embed_dim (:obj:`int`): embedding dimensions for tokens 270 | - depth (:obj:`int`): depth of encoder 271 | - num_heads (:obj:`int`): number of heads in multi-head self-attention 272 | - mlp_ratio (:obj:`float`): expand ratio in feedforward 273 | - qkv_bias (:obj:`bool`): whether to add bias for mlp of qkv 274 | - qk_scale (:obj:`float`): scale ratio for qk, default is head_dim ** -0.5 275 | - drop_rate (:obj:`float`): dropout rate in feedforward module after linear operation 276 | and projection drop rate in attention 277 | - attn_drop_rate (:obj:`float`): dropout rate for attention 278 | - drop_path_rate (:obj:`float`): drop_path rate after attention 279 | - hybrid_backbone (:obj:`nn.Module`): backbone e.g. resnet 280 | - norm_layer (:obj:`nn.Module`): normalization type 281 | - leff_local_size (:obj:`int`): kernel size in LocallyEnhancedFeedForward 282 | - leff_with_bn (:obj:`bool`): whether add bn in LocallyEnhancedFeedForward 283 | """ 284 | super().__init__() 285 | self.num_classes = num_classes 286 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 287 | 288 | self.i2t = HybridEmbed( 289 | hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 290 | num_patches = self.i2t.num_patches 291 | 292 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 293 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 294 | self.pos_drop = nn.Dropout(p=drop_rate) 295 | 296 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 297 | self.blocks = nn.ModuleList([ 298 | Block( 299 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 300 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 301 | kernel_size=leff_local_size, with_bn=leff_with_bn) 302 | for i in range(depth)]) 303 | 304 | # without droppath 305 | self.lca = Block( 306 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 307 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=0., norm_layer=norm_layer, 308 | feedforward_type = 'lca' 309 | ) 310 | self.pos_layer_embed = nn.Parameter(torch.zeros(1, depth, embed_dim)) 311 | 312 | self.norm = norm_layer(embed_dim) 313 | 314 | # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here 315 | # self.repr = nn.Linear(embed_dim, representation_size) 316 | # self.repr_act = nn.Tanh() 317 | 318 | # Classifier head 319 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 320 | 321 | trunc_normal_(self.pos_embed, std=.02) 322 | trunc_normal_(self.cls_token, std=.02) 323 | self.apply(self._init_weights) 324 | 325 | def _init_weights(self, m): 326 | if isinstance(m, nn.Linear): 327 | trunc_normal_(m.weight, std=.02) 328 | if isinstance(m, nn.Linear) and m.bias is not None: 329 | nn.init.constant_(m.bias, 0) 330 | elif isinstance(m, nn.LayerNorm): 331 | nn.init.constant_(m.bias, 0) 332 | nn.init.constant_(m.weight, 1.0) 333 | 334 | @torch.jit.ignore 335 | def no_weight_decay(self): 336 | return {'pos_embed', 'cls_token'} 337 | 338 | def get_classifier(self): 339 | return self.head 340 | 341 | def reset_classifier(self, num_classes, global_pool=''): 342 | self.num_classes = num_classes 343 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 344 | 345 | def forward_features(self, x): 346 | B = x.shape[0] 347 | x = self.i2t(x) 348 | 349 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 350 | x = torch.cat((cls_tokens, x), dim=1) 351 | x = x + self.pos_embed 352 | x = self.pos_drop(x) 353 | 354 | cls_token_list = [] 355 | for blk in self.blocks: 356 | x, curr_cls_token = blk(x) 357 | cls_token_list.append(curr_cls_token) 358 | 359 | all_cls_token = torch.stack(cls_token_list, dim=1) # B*D*K 360 | all_cls_token = all_cls_token + self.pos_layer_embed 361 | # attention over cls tokens 362 | last_cls_token = self.lca(all_cls_token) 363 | last_cls_token = self.norm(last_cls_token) 364 | 365 | return last_cls_token.view(B, -1) 366 | 367 | def forward(self, x): 368 | x = self.forward_features(x) 369 | x = self.head(x) 370 | return x 371 | 372 | 373 | @register_model 374 | def ceit_tiny_patch16_224(pretrained=False, **kwargs): 375 | """ 376 | convolutional + pooling stem 377 | local enhanced feedforward 378 | attention over cls_tokens 379 | """ 380 | i2t = Image2Tokens() 381 | model = CeIT( 382 | hybrid_backbone=i2t, 383 | patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 384 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 385 | model.default_cfg = _cfg() 386 | return model 387 | 388 | 389 | @register_model 390 | def ceit_small_patch16_224(pretrained=False, **kwargs): 391 | """ 392 | convolutional + pooling stem 393 | local enhanced feedforward 394 | attention over cls_tokens 395 | """ 396 | i2t = Image2Tokens() 397 | model = CeIT( 398 | hybrid_backbone=i2t, 399 | patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 400 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 401 | model.default_cfg = _cfg() 402 | return model 403 | 404 | 405 | @register_model 406 | def ceit_base_patch16_224(pretrained=False, **kwargs): 407 | """ 408 | convolutional + pooling stem 409 | local enhanced feedforward 410 | attention over cls_tokens 411 | """ 412 | i2t = Image2Tokens() 413 | model = CeIT( 414 | hybrid_backbone=i2t, 415 | patch_size=4, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 416 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 417 | model.default_cfg = _cfg() 418 | return model 419 | 420 | 421 | @register_model 422 | def ceit_tiny_patch16_384(pretrained=False, **kwargs): 423 | """ 424 | convolutional + pooling stem 425 | local enhanced feedforward 426 | attention over cls_tokens 427 | """ 428 | i2t = Image2Tokens() 429 | model = CeIT( 430 | hybrid_backbone=i2t, img_size=384, 431 | patch_size=4, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 432 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 433 | model.default_cfg = _cfg() 434 | return model 435 | 436 | 437 | @register_model 438 | def ceit_small_patch16_384(pretrained=False, **kwargs): 439 | """ 440 | convolutional + pooling stem 441 | local enhanced feedforward 442 | attention over cls_tokens 443 | """ 444 | i2t = Image2Tokens() 445 | model = CeIT( 446 | hybrid_backbone=i2t, img_size=384, 447 | patch_size=4, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 448 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 449 | model.default_cfg = _cfg() 450 | return model 451 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | from PIL import ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | 15 | class INatDataset(ImageFolder): 16 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 17 | category='name', loader=default_loader): 18 | self.transform = transform 19 | self.loader = loader 20 | self.target_transform = target_transform 21 | self.year = year 22 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 23 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 24 | with open(path_json) as json_file: 25 | data = json.load(json_file) 26 | 27 | with open(os.path.join(root, 'categories.json')) as json_file: 28 | data_catg = json.load(json_file) 29 | 30 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 31 | 32 | with open(path_json_for_targeter) as json_file: 33 | data_for_targeter = json.load(json_file) 34 | 35 | targeter = {} 36 | indexer = 0 37 | for elem in data_for_targeter['annotations']: 38 | king = [] 39 | king.append(data_catg[int(elem['category_id'])][category]) 40 | if king[0] not in targeter.keys(): 41 | targeter[king[0]] = indexer 42 | indexer += 1 43 | self.nb_classes = len(targeter) 44 | 45 | self.samples = [] 46 | for elem in data['images']: 47 | cut = elem['file_name'].split('/') 48 | target_current = int(cut[2]) 49 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 50 | 51 | categors = data_catg[target_current] 52 | target_current_true = targeter[categors[category]] 53 | self.samples.append((path_current, target_current_true)) 54 | 55 | # __getitem__ and __len__ inherited from ImageFolder 56 | 57 | 58 | def build_dataset(is_train, args): 59 | transform = build_transform(is_train, args) 60 | 61 | if args.data_set == 'CIFAR': 62 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 63 | nb_classes = 100 64 | elif args.data_set == 'IMNET': 65 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 66 | dataset = datasets.ImageFolder(root, transform=transform) 67 | nb_classes = 1000 68 | elif args.data_set == 'INAT': 69 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 70 | category=args.inat_category, transform=transform) 71 | nb_classes = dataset.nb_classes 72 | elif args.data_set == 'INAT19': 73 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 74 | category=args.inat_category, transform=transform) 75 | nb_classes = dataset.nb_classes 76 | 77 | return dataset, nb_classes 78 | 79 | 80 | def build_transform(is_train, args): 81 | resize_im = args.input_size > 32 82 | if is_train: 83 | # this should always dispatch to transforms_imagenet_train 84 | transform = create_transform( 85 | input_size=args.input_size, 86 | is_training=True, 87 | color_jitter=args.color_jitter, 88 | auto_augment=args.aa, 89 | interpolation=args.train_interpolation, 90 | re_prob=args.reprob, 91 | re_mode=args.remode, 92 | re_count=args.recount, 93 | ) 94 | if not resize_im: 95 | # replace RandomResizedCropAndInterpolation with 96 | # RandomCrop 97 | transform.transforms[0] = transforms.RandomCrop( 98 | args.input_size, padding=4) 99 | return transform 100 | 101 | t = [] 102 | if resize_im: 103 | size = int((256 / 224) * args.input_size) 104 | t.append( 105 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 106 | ) 107 | t.append(transforms.CenterCrop(args.input_size)) 108 | 109 | t.append(transforms.ToTensor()) 110 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 111 | return transforms.Compose(t) 112 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | import utils 16 | 17 | 18 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 19 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 20 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 21 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 22 | set_training_mode=True): 23 | model.train(set_training_mode) 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | header = 'Epoch: [{}]'.format(epoch) 27 | print_freq = 10 28 | 29 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 30 | samples = samples.to(device, non_blocking=True) 31 | targets = targets.to(device, non_blocking=True) 32 | 33 | if mixup_fn is not None: 34 | samples, targets = mixup_fn(samples, targets) 35 | 36 | with torch.cuda.amp.autocast(): 37 | outputs = model(samples) 38 | loss = criterion(outputs, targets) 39 | 40 | loss_value = loss.item() 41 | 42 | if not math.isfinite(loss_value): 43 | print("Loss is {}, stopping training".format(loss_value)) 44 | sys.exit(1) 45 | 46 | optimizer.zero_grad() 47 | 48 | # this attribute is added by timm on one optimizer (adahessian) 49 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 50 | loss_scaler(loss, optimizer, clip_grad=max_norm, 51 | parameters=model.parameters(), create_graph=is_second_order) 52 | 53 | torch.cuda.synchronize() 54 | if model_ema is not None: 55 | model_ema.update(model) 56 | 57 | metric_logger.update(loss=loss_value) 58 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 59 | # gather the stats from all processes 60 | metric_logger.synchronize_between_processes() 61 | print("Averaged stats:", metric_logger) 62 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 63 | 64 | 65 | @torch.no_grad() 66 | def evaluate(data_loader, model, device): 67 | criterion = torch.nn.CrossEntropyLoss() 68 | 69 | metric_logger = utils.MetricLogger(delimiter=" ") 70 | header = 'Test:' 71 | 72 | # switch to evaluation mode 73 | model.eval() 74 | 75 | for images, target in metric_logger.log_every(data_loader, 10, header): 76 | images = images.to(device, non_blocking=True) 77 | target = target.to(device, non_blocking=True) 78 | 79 | # compute output 80 | with torch.cuda.amp.autocast(): 81 | output = model(images) 82 | loss = criterion(output, target) 83 | 84 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 85 | 86 | batch_size = images.shape[0] 87 | metric_logger.update(loss=loss.item()) 88 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 89 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 90 | # gather the stats from all processes 91 | metric_logger.synchronize_between_processes() 92 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 93 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 94 | 95 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 96 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | 11 | from pathlib import Path 12 | 13 | from timm.data import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.scheduler import create_scheduler 17 | from timm.optim import create_optimizer 18 | from timm.utils import NativeScaler, get_state_dict, ModelEma 19 | 20 | from datasets import build_dataset 21 | from engine import train_one_epoch, evaluate 22 | from samplers import RASampler 23 | import ceit_model 24 | import utils 25 | 26 | 27 | def get_args_parser(): 28 | parser = argparse.ArgumentParser('CeiT training and evaluation script', add_help=False) 29 | parser.add_argument('--batch-size', default=64, type=int) 30 | parser.add_argument('--epochs', default=300, type=int) 31 | 32 | # Model parameters 33 | parser.add_argument('--model', default='ceit_tiny_patch16_224', type=str, metavar='MODEL', 34 | help='Name of model to train') 35 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 36 | 37 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 38 | help='Dropout rate (default: 0.)') 39 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 40 | help='Drop path rate (default: 0.1)') 41 | 42 | parser.add_argument('--model-ema', action='store_true') 43 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 44 | parser.set_defaults(model_ema=True) 45 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 46 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 47 | 48 | # Optimizer parameters 49 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 50 | help='Optimizer (default: "adamw"') 51 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 52 | help='Optimizer Epsilon (default: 1e-8)') 53 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 54 | help='Optimizer Betas (default: None, use opt default)') 55 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 56 | help='Clip gradient norm (default: None, no clipping)') 57 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 58 | help='SGD momentum (default: 0.9)') 59 | parser.add_argument('--weight-decay', type=float, default=0.05, 60 | help='weight decay (default: 0.05)') 61 | 62 | # Learning rate schedule parameters 63 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 64 | help='LR scheduler (default: "cosine"') 65 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 66 | help='learning rate (default: 5e-4)') 67 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 68 | help='learning rate noise on/off epoch percentages') 69 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 70 | help='learning rate noise limit percent (default: 0.67)') 71 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 72 | help='learning rate noise std-dev (default: 1.0)') 73 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 74 | help='warmup learning rate (default: 1e-6)') 75 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 76 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 77 | 78 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 79 | help='epoch interval to decay LR') 80 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 81 | help='epochs to warmup LR, if scheduler supports') 82 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 83 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 84 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 85 | help='patience epochs for Plateau LR scheduler (default: 10') 86 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 87 | help='LR decay rate (default: 0.1)') 88 | 89 | # Augmentation parameters 90 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 91 | help='Color jitter factor (default: 0.4)') 92 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 93 | help='Use AutoAugment policy. "v0" or "original". " + \ 94 | "(default: rand-m9-mstd0.5-inc1)'), 95 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 96 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 97 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 98 | 99 | parser.add_argument('--repeated-aug', action='store_true') 100 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 101 | parser.set_defaults(repeated_aug=True) 102 | 103 | # * Random Erase params 104 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 105 | help='Random erase prob (default: 0.25)') 106 | parser.add_argument('--remode', type=str, default='pixel', 107 | help='Random erase mode (default: "pixel")') 108 | parser.add_argument('--recount', type=int, default=1, 109 | help='Random erase count (default: 1)') 110 | parser.add_argument('--resplit', action='store_true', default=False, 111 | help='Do not random erase first (clean) augmentation split') 112 | 113 | # * Mixup params 114 | parser.add_argument('--mixup', type=float, default=0.8, 115 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 116 | parser.add_argument('--cutmix', type=float, default=1.0, 117 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 118 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 119 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 120 | parser.add_argument('--mixup-prob', type=float, default=1.0, 121 | help='Probability of performing mixup or cutmix when either/both is enabled') 122 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 123 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 124 | parser.add_argument('--mixup-mode', type=str, default='batch', 125 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 126 | 127 | # * Finetuning params 128 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 129 | 130 | # Dataset parameters 131 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 132 | help='dataset path') 133 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 134 | type=str, help='Image Net dataset path') 135 | parser.add_argument('--inat-category', default='name', 136 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 137 | type=str, help='semantic granularity') 138 | 139 | parser.add_argument('--output_dir', default='', 140 | help='path where to save, empty for no saving') 141 | parser.add_argument('--device', default='cuda', 142 | help='device to use for training / testing') 143 | parser.add_argument('--seed', default=0, type=int) 144 | parser.add_argument('--resume', default='', help='resume from checkpoint') 145 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 146 | help='start epoch') 147 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 148 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 149 | parser.add_argument('--num_workers', default=10, type=int) 150 | parser.add_argument('--pin-mem', action='store_true', 151 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 152 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 153 | help='') 154 | parser.set_defaults(pin_mem=True) 155 | 156 | # distributed training parameters 157 | parser.add_argument('--world_size', default=1, type=int, 158 | help='kernel size ') 159 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 160 | 161 | # ceit leff module 162 | parser.add_argument('--leff-local-size', default=3, type=int, 163 | help='Kernel size of depth-wise conv in leff module') 164 | parser.add_argument('--leff-with-bn', default=True, help='Using batchnorm in leff module') 165 | 166 | return parser 167 | 168 | 169 | def main(args): 170 | utils.init_distributed_mode(args) 171 | 172 | print(args) 173 | 174 | device = torch.device(args.device) 175 | 176 | # fix the seed for reproducibility 177 | seed = args.seed + utils.get_rank() 178 | torch.manual_seed(seed) 179 | np.random.seed(seed) 180 | # random.seed(seed) 181 | 182 | cudnn.benchmark = True 183 | 184 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 185 | dataset_val, _ = build_dataset(is_train=False, args=args) 186 | 187 | num_tasks = utils.get_world_size() 188 | global_rank = utils.get_rank() 189 | if args.repeated_aug: 190 | sampler_train = RASampler( 191 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 192 | ) 193 | else: 194 | sampler_train = torch.utils.data.DistributedSampler( 195 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 196 | ) 197 | if args.dist_eval: 198 | if len(dataset_val) % num_tasks != 0: 199 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 200 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 201 | 'equal num of samples per-process.') 202 | sampler_val = torch.utils.data.DistributedSampler( 203 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 204 | else: 205 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 206 | 207 | 208 | data_loader_train = torch.utils.data.DataLoader( 209 | dataset_train, sampler=sampler_train, 210 | batch_size=args.batch_size, 211 | num_workers=args.num_workers, 212 | pin_memory=args.pin_mem, 213 | drop_last=True, 214 | ) 215 | 216 | data_loader_val = torch.utils.data.DataLoader( 217 | dataset_val, sampler=sampler_val, 218 | batch_size=int(1.5 * args.batch_size), 219 | num_workers=args.num_workers, 220 | pin_memory=args.pin_mem, 221 | drop_last=False 222 | ) 223 | 224 | mixup_fn = None 225 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 226 | if mixup_active: 227 | mixup_fn = Mixup( 228 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 229 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 230 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 231 | 232 | print(f"Creating model: {args.model}") 233 | model = create_model( 234 | args.model, 235 | pretrained=False, 236 | num_classes=args.nb_classes, 237 | drop_rate=args.drop, 238 | drop_path_rate=args.drop_path, 239 | drop_block_rate=None, 240 | leff_local_size=args.leff_local_size, 241 | leff_with_bn=args.leff_with_bn 242 | ) 243 | 244 | if args.finetune: 245 | if args.finetune.startswith('https'): 246 | checkpoint = torch.hub.load_state_dict_from_url( 247 | args.finetune, map_location='cpu', check_hash=True) 248 | else: 249 | checkpoint = torch.load(args.finetune, map_location='cpu') 250 | 251 | checkpoint_model = checkpoint['model'] 252 | state_dict = model.state_dict() 253 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 254 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 255 | print(f"Removing key {k} from pretrained checkpoint") 256 | del checkpoint_model[k] 257 | 258 | # interpolate position embedding 259 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 260 | embedding_size = pos_embed_checkpoint.shape[-1] 261 | num_patches = model.patch_embed.num_patches 262 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 263 | # height (== width) for the checkpoint position embedding 264 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 265 | # height (== width) for the new position embedding 266 | new_size = int(num_patches ** 0.5) 267 | # class_token and dist_token are kept unchanged 268 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 269 | # only the position tokens are interpolated 270 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 271 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 272 | pos_tokens = torch.nn.functional.interpolate( 273 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 274 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 275 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 276 | checkpoint_model['pos_embed'] = new_pos_embed 277 | 278 | model.load_state_dict(checkpoint_model, strict=False) 279 | 280 | model.to(device) 281 | 282 | model_ema = None 283 | if args.model_ema: 284 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 285 | model_ema = ModelEma( 286 | model, 287 | decay=args.model_ema_decay, 288 | device='cpu' if args.model_ema_force_cpu else '', 289 | resume='') 290 | 291 | model_without_ddp = model 292 | if args.distributed: 293 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 294 | model_without_ddp = model.module 295 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 296 | print('number of params:', n_parameters) 297 | 298 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 299 | args.lr = linear_scaled_lr 300 | optimizer = create_optimizer(args, model_without_ddp) 301 | loss_scaler = NativeScaler() 302 | 303 | lr_scheduler, _ = create_scheduler(args, optimizer) 304 | 305 | criterion = LabelSmoothingCrossEntropy() 306 | 307 | if args.mixup > 0.: 308 | # smoothing is handled with mixup label transform 309 | criterion = SoftTargetCrossEntropy() 310 | elif args.smoothing: 311 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 312 | else: 313 | criterion = torch.nn.CrossEntropyLoss() 314 | 315 | output_dir = Path(args.output_dir) 316 | if args.resume: 317 | if args.resume.startswith('https'): 318 | checkpoint = torch.hub.load_state_dict_from_url( 319 | args.resume, map_location='cpu', check_hash=True) 320 | else: 321 | checkpoint = torch.load(args.resume, map_location='cpu') 322 | model_without_ddp.load_state_dict(checkpoint['model']) 323 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 324 | optimizer.load_state_dict(checkpoint['optimizer']) 325 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 326 | args.start_epoch = checkpoint['epoch'] + 1 327 | if args.model_ema: 328 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 329 | if 'scaler' in checkpoint: 330 | loss_scaler.load_state_dict(checkpoint['scaler']) 331 | 332 | if args.eval: 333 | test_stats = evaluate(data_loader_val, model, device) 334 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 335 | return 336 | 337 | print(f"Start training for {args.epochs} epochs") 338 | start_time = time.time() 339 | max_accuracy = 0.0 340 | for epoch in range(args.start_epoch, args.epochs): 341 | if args.distributed: 342 | data_loader_train.sampler.set_epoch(epoch) 343 | 344 | train_stats = train_one_epoch( 345 | model, criterion, data_loader_train, 346 | optimizer, device, epoch, loss_scaler, 347 | args.clip_grad, model_ema, mixup_fn, 348 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 349 | ) 350 | 351 | lr_scheduler.step(epoch) 352 | if args.output_dir: 353 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 354 | for checkpoint_path in checkpoint_paths: 355 | utils.save_on_master({ 356 | 'model': model_without_ddp.state_dict(), 357 | 'optimizer': optimizer.state_dict(), 358 | 'lr_scheduler': lr_scheduler.state_dict(), 359 | 'epoch': epoch, 360 | 'model_ema': get_state_dict(model_ema), 361 | 'scaler': loss_scaler.state_dict(), 362 | 'args': args, 363 | }, checkpoint_path) 364 | 365 | test_stats = evaluate(data_loader_val, model, device) 366 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 367 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 368 | print(f'Max accuracy: {max_accuracy:.2f}%') 369 | 370 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 371 | **{f'test_{k}': v for k, v in test_stats.items()}, 372 | 'epoch': epoch, 373 | 'n_parameters': n_parameters} 374 | 375 | if args.output_dir and utils.is_main_process(): 376 | with (output_dir / "log.txt").open("a") as f: 377 | f.write(json.dumps(log_stats) + "\n") 378 | 379 | total_time = time.time() - start_time 380 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 381 | print('Training time {}'.format(total_time_str)) 382 | 383 | 384 | if __name__ == '__main__': 385 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 386 | args = parser.parse_args() 387 | if args.output_dir: 388 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 389 | main(args) 390 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | self.epoch = 0 29 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 30 | self.total_size = self.num_samples * self.num_replicas 31 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 32 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 33 | self.shuffle = shuffle 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | if self.shuffle: 40 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 41 | else: 42 | indices = list(range(len(self.dataset))) 43 | 44 | # add extra samples to make it evenly divisible 45 | indices = [ele for ele in indices for i in range(3)] 46 | indices += indices[:(self.total_size - len(indices))] 47 | assert len(indices) == self.total_size 48 | 49 | # subsample 50 | indices = indices[self.rank:self.total_size:self.num_replicas] 51 | assert len(indices) == self.num_samples 52 | 53 | return iter(indices[:self.num_selected_samples]) 54 | 55 | def __len__(self): 56 | return self.num_selected_samples 57 | 58 | def set_epoch(self, epoch): 59 | self.epoch = epoch 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | --------------------------------------------------------------------------------