├── LICENSE ├── README.md ├── requirements.txt └── src ├── check_model.py ├── datasets.py ├── edgevit.py ├── engine.py ├── hubconf.py ├── imagenet_h5.py ├── losses.py ├── main.py ├── samplers.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers](https://arxiv.org/abs/2205.03436) 2 | 3 | ## Abstract 4 | Self-attention based models such as vision transformers (ViTs) have emerged as a very competitive architecture alternative to convolutional neural networks (CNNs) in computer vision. Despite increasingly stronger variants with ever-higher recognition accuracies, due to the quadratic complexity of self-attention, existing ViTs are typically demanding in computation and model size. Although several successful design choices (e.g., the convolutions and hierarchical multi-stage structure) of prior CNNs have been reintroduced into recent ViTs, they are still not sufficient to meet the limited resource requirements of mobile devices. This motivates a very recent attempt to develop light ViTs based on the state-of-the-art MobileNet-v2, but still leaves a performance gap behind. In this work, pushing further along this under-studied direction we introduce EdgeViTs, a new family of light-weight ViTs that, for the first time, enable attention-based vision models to compete with the best light-weight CNNs in the tradeoff between accuracy and on-device efficiency. This is realized by introducing a highly cost-effective local-global-local (LGL) information exchange bottleneck based on optimal integration of self-attention and convolutions. For device-dedicated evaluation, rather than relying on inaccurate proxies like the number of FLOPs or parameters, we adopt a practical approach of focusing directly on on-device latency and, for the first time, energy efficiency. Specifically, we show that our models are Pareto-optimal when both accuracy-latency and accuracy-energy trade-offs are considered, achieving strict dominance over other ViTs in almost all cases and competing with the most efficient CNNs. 5 | 6 | ## Software required 7 | The code is only tested on Linux 64: 8 | Please install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models): 9 | ``` 10 | conda install -c pytorch pytorch torchvision 11 | pip install timm==0.3.2 12 | ``` 13 | 14 | ## Data preparation 15 | 16 | Download and extract ImageNet train and val images from http://image-net.org/. 17 | The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: 18 | 19 | ``` 20 | /path/to/imagenet/ 21 | train/ 22 | class1/ 23 | img1.jpeg 24 | class2/ 25 | img2.jpeg 26 | val/ 27 | class1/ 28 | img3.jpeg 29 | class/2 30 | img4.jpeg 31 | ``` 32 | 33 | 34 | ## Training 35 | Training the model on ImageNet with an 8-gpu server for 300 epochs: 36 | 37 | EdgeViT-small 38 | ``` 39 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model edgevit_s --batch-size 256 --data-path /path/to/imagenet --output_dir /path/to/save 40 | ``` 41 | 42 | If you find our paper/code useful, please consider citing: 43 | 44 | ``` 45 | @inproceedings{pan2022edgevits, 46 | title={EdgeViTs: Competing Light-weight CNNs on Mobile Devices with Vision Transformers}, 47 | author={Pan, Junting and Bulat, Adrian and Tan, Fuwen and Zhu, Xiatian and Dudziak, Lukasz and Li, Hongsheng and Tzimiropoulos, Georgios and Martinez, Brais}, 48 | booktitle={European Conference on Computer Vision}, 49 | year={2022} 50 | } 51 | ``` 52 | 53 | ## Acknowledgement 54 | 55 | This repository is built using the [timm](https://github.com/rwightman/pytorch-image-models) library and the [DeiT](https://github.com/facebookresearch/deit) and [Uniformer](https://github.com/Sense-X/UniFormer/tree/main/image_classification) repository. 56 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision==0.8.1 3 | timm==0.3.2 4 | -------------------------------------------------------------------------------- /src/check_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from fvcore.nn import FlopCountAnalysis 3 | from fvcore.nn import flop_count_table 4 | # from models import uniformer_small 5 | from timm import create_model 6 | 7 | import uniformer_mvit 8 | # import uniformer_mvit_5x5 9 | import uniformer_twins 10 | # import uniformer_twins_convnext 11 | import uniformer 12 | import uniformer_deconv 13 | # import uniformer_twins_channel_split 14 | import uniformer_deconv_channel_split 15 | import uniformer_upsample 16 | import uniformer_deconv_ffn_channel_split 17 | import vit 18 | import uniformer_select 19 | import uniformer_sparse 20 | import uniformer_select_1 21 | import uniformer_deconv_sparse 22 | import uniformer_deconv_shuffle 23 | import uniformer_twins_speed 24 | import uniformer_deconv_lsa 25 | import uniformer_twins_lsa 26 | import uniformer_twins_ablate_deconv 27 | import uniformer_twins_ablate 28 | 29 | import sys 30 | sys.path.append('../SlowFast_dev') 31 | from slowfast.config.defaults import get_cfg 32 | from slowfast.models.video_model_builder import MViT 33 | from timm.models.registry import register_model 34 | 35 | sys.path.append('../Twins') 36 | import gvt 37 | 38 | import time 39 | import argparse 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--model', type=str, 43 | help='A required integer positional argument') 44 | 45 | args = parser.parse_args() 46 | 47 | torch.backends.cudnn.benchmark = True 48 | 49 | @register_model 50 | def mvit_tiny(pretrained=False, **kwargs): 51 | cfg = get_cfg() 52 | cfg_file = '../SlowFast_dev/configs/ImageNet/MVIT_T_10_CONV.yaml' 53 | cfg.merge_from_file(cfg_file) 54 | model = MViT(cfg) 55 | return model 56 | 57 | 58 | # model_name = 'uniformer_small_convnext_plus_ls' 59 | # model_name = 'uniformer_small_twins_plus_sr' 60 | # model_name = 'uniformer_small_convnext' 61 | 62 | model_name = args.model 63 | model = create_model( 64 | model_name, 65 | pretrained=False, 66 | num_classes=1000, 67 | drop_rate=0.0, 68 | drop_path_rate=0.0, 69 | drop_block_rate=None, 70 | ) 71 | 72 | # print (model) 73 | input = torch.rand(1, 3, 224, 224) 74 | output = model(input) 75 | print(output.shape) 76 | 77 | model.eval() 78 | flops = FlopCountAnalysis(model, torch.rand(1, 3, 224, 224)) 79 | print(flop_count_table(flops)) 80 | 81 | iterations = 30 82 | 83 | # def throughput(model): 84 | # model.eval() 85 | # model.cuda() 86 | # # for idx, (images, _) in enumerate(data_loader): 87 | # images = torch.rand(64, 3, 224, 224).cuda(non_blocking=True) 88 | # batch_size = images.shape[0] 89 | # for i in range(50): 90 | # model(images) 91 | # torch.cuda.synchronize() 92 | # print(f"throughput averaged with 30 times") 93 | # # warm up 94 | # for i in range(10): 95 | # model(images) 96 | # tic1 = time.time() 97 | # for i in range(iterations): 98 | # model(images) 99 | # torch.cuda.synchronize() 100 | # tic2 = time.time() 101 | # print(f"batch_size {batch_size} throughput {iterations * batch_size / (tic2 - tic1)}") 102 | # 103 | # throughput(model) 104 | -------------------------------------------------------------------------------- /src/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | from imagenet_h5 import Imagenet 13 | 14 | class INatDataset(ImageFolder): 15 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 16 | category='name', loader=default_loader): 17 | self.transform = transform 18 | self.loader = loader 19 | self.target_transform = target_transform 20 | self.year = year 21 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 22 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 23 | with open(path_json) as json_file: 24 | data = json.load(json_file) 25 | 26 | with open(os.path.join(root, 'categories.json')) as json_file: 27 | data_catg = json.load(json_file) 28 | 29 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 30 | 31 | with open(path_json_for_targeter) as json_file: 32 | data_for_targeter = json.load(json_file) 33 | 34 | targeter = {} 35 | indexer = 0 36 | for elem in data_for_targeter['annotations']: 37 | king = [] 38 | king.append(data_catg[int(elem['category_id'])][category]) 39 | if king[0] not in targeter.keys(): 40 | targeter[king[0]] = indexer 41 | indexer += 1 42 | self.nb_classes = len(targeter) 43 | 44 | self.samples = [] 45 | for elem in data['images']: 46 | cut = elem['file_name'].split('/') 47 | target_current = int(cut[2]) 48 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 49 | 50 | categors = data_catg[target_current] 51 | target_current_true = targeter[categors[category]] 52 | self.samples.append((path_current, target_current_true)) 53 | 54 | # __getitem__ and __len__ inherited from ImageFolder 55 | 56 | 57 | def build_dataset(is_train, args): 58 | transform = build_transform(is_train, args) 59 | 60 | if args.data_set == 'CIFAR': 61 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 62 | nb_classes = 100 63 | elif args.data_set == 'IMNET': 64 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 65 | dataset = datasets.ImageFolder(root, transform=transform) 66 | # dataset = Imagenet(args.data_path, train=is_train, transform=transform) 67 | nb_classes = 1000 68 | elif args.data_set == 'INAT': 69 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 70 | category=args.inat_category, transform=transform) 71 | nb_classes = dataset.nb_classes 72 | elif args.data_set == 'INAT19': 73 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 74 | category=args.inat_category, transform=transform) 75 | nb_classes = dataset.nb_classes 76 | elif args.data_set == 'IMNETH5': 77 | dataset = Imagenet(args.data_path, train=is_train, transform=transform) 78 | nb_classes = 1000 79 | 80 | return dataset, nb_classes 81 | 82 | 83 | def build_transform(is_train, args): 84 | resize_im = args.input_size > 32 85 | if is_train: 86 | # this should always dispatch to transforms_imagenet_train 87 | transform = create_transform( 88 | input_size=args.input_size, 89 | is_training=True, 90 | color_jitter=args.color_jitter, 91 | auto_augment=args.aa, 92 | interpolation=args.train_interpolation, 93 | re_prob=args.reprob, 94 | re_mode=args.remode, 95 | re_count=args.recount, 96 | ) 97 | if not resize_im: 98 | # replace RandomResizedCropAndInterpolation with 99 | # RandomCrop 100 | transform.transforms[0] = transforms.RandomCrop( 101 | args.input_size, padding=4) 102 | return transform 103 | 104 | t = [] 105 | if resize_im: 106 | size = int((256 / 224) * args.input_size) 107 | t.append( 108 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 109 | ) 110 | t.append(transforms.CenterCrop(args.input_size)) 111 | 112 | t.append(transforms.ToTensor()) 113 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 114 | return transforms.Compose(t) 115 | -------------------------------------------------------------------------------- /src/edgevit.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | from functools import partial 5 | import torch.nn.functional as F 6 | import math 7 | from timm.models.vision_transformer import _cfg 8 | from timm.models.registry import register_model 9 | from timm.models.layers import trunc_normal_, DropPath, to_2tuple 10 | 11 | 12 | class Mlp(nn.Module): 13 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 14 | super().__init__() 15 | out_features = out_features or in_features 16 | hidden_features = hidden_features or in_features 17 | self.fc1 = nn.Linear(in_features, hidden_features) 18 | self.act = act_layer() 19 | self.fc2 = nn.Linear(hidden_features, out_features) 20 | self.drop = nn.Dropout(drop) 21 | 22 | def forward(self, x): 23 | x = self.fc1(x) 24 | x = self.act(x) 25 | x = self.drop(x) 26 | x = self.fc2(x) 27 | x = self.drop(x) 28 | return x 29 | 30 | 31 | class CMlp(nn.Module): 32 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 33 | super().__init__() 34 | out_features = out_features or in_features 35 | hidden_features = hidden_features or in_features 36 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 37 | self.act = act_layer() 38 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 39 | self.drop = nn.Dropout(drop) 40 | 41 | def forward(self, x): 42 | x = self.fc1(x) 43 | x = self.act(x) 44 | x = self.drop(x) 45 | x = self.fc2(x) 46 | x = self.drop(x) 47 | return x 48 | 49 | 50 | class GlobalSparseAttn(nn.Module): 51 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 52 | super().__init__() 53 | self.num_heads = num_heads 54 | head_dim = dim // num_heads 55 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 56 | self.scale = qk_scale or head_dim ** -0.5 57 | 58 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 59 | self.attn_drop = nn.Dropout(attn_drop) 60 | self.proj = nn.Linear(dim, dim) 61 | self.proj_drop = nn.Dropout(proj_drop) 62 | 63 | # self.upsample = nn.Upsample(scale_factor=sr_ratio, mode='nearest') 64 | self.sr= sr_ratio 65 | if self.sr > 1: 66 | self.sampler = nn.AvgPool2d(1, sr_ratio) 67 | kernel_size = sr_ratio 68 | self.LocalProp= nn.ConvTranspose2d(dim, dim, kernel_size, stride=sr_ratio, groups=dim) 69 | self.norm = nn.LayerNorm(dim) 70 | else: 71 | self.sampler = nn.Identity() 72 | self.upsample = nn.Identity() 73 | self.norm = nn.Identity() 74 | 75 | 76 | def forward(self, x, H:int, W:int): 77 | B, N, C = x.shape 78 | if self.sr > 1.: 79 | x = x.transpose(1, 2).reshape(B, C, H, W) 80 | x = self.sampler(x) 81 | x = x.flatten(2).transpose(1, 2) 82 | 83 | qkv = self.qkv(x).reshape(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 84 | q, k, v = qkv[0], qkv[1], qkv[2] 85 | 86 | attn = (q @ k.transpose(-2, -1)) * self.scale 87 | attn = attn.softmax(dim=-1) 88 | attn = self.attn_drop(attn) 89 | 90 | x = (attn @ v).transpose(1, 2).reshape(B, -1, C) 91 | 92 | if self.sr > 1: 93 | x = x.permute(0, 2, 1).reshape(B, C, int(H/self.sr), int(W/self.sr)) 94 | x = self.LocalProp(x) 95 | x = x.reshape(B, C, -1).permute(0, 2, 1) 96 | x = self.norm(x) 97 | 98 | x = self.proj(x) 99 | x = self.proj_drop(x) 100 | return x 101 | 102 | 103 | class LocalAgg(nn.Module): 104 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 105 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 106 | super().__init__() 107 | self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) 108 | self.norm1 = nn.BatchNorm2d(dim) 109 | self.conv1 = nn.Conv2d(dim, dim, 1) 110 | self.conv2 = nn.Conv2d(dim, dim, 1) 111 | self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 112 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 113 | self.norm2 = nn.BatchNorm2d(dim) 114 | mlp_hidden_dim = int(dim * mlp_ratio) 115 | self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 116 | 117 | def forward(self, x): 118 | x = x + self.pos_embed(x) 119 | x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) 120 | x = x + self.drop_path(self.mlp(self.norm2(x))) 121 | return x 122 | 123 | 124 | class SelfAttn(nn.Module): 125 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 126 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1.): 127 | super().__init__() 128 | self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) 129 | self.norm1 = norm_layer(dim) 130 | self.attn = GlobalSparseAttn( 131 | dim, 132 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 133 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 134 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 135 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 136 | self.norm2 = norm_layer(dim) 137 | mlp_hidden_dim = int(dim * mlp_ratio) 138 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 139 | # global layer_scale 140 | # self.ls = layer_scale 141 | 142 | def forward(self, x): 143 | x = x + self.pos_embed(x) 144 | B, N, H, W = x.shape 145 | x = x.flatten(2).transpose(1, 2) 146 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 147 | x = x + self.drop_path(self.mlp(self.norm2(x))) 148 | x = x.transpose(1, 2).reshape(B, N, H, W) 149 | return x 150 | 151 | 152 | class LGLBlock(nn.Module): 153 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 154 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1.): 155 | super().__init__() 156 | 157 | if sr_ratio > 1: 158 | self.LocalAgg = LocalAgg(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer) 159 | else: 160 | self.LocalAgg = nn.Identity() 161 | 162 | self.SelfAttn = SelfAttn(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer, sr_ratio) 163 | 164 | def forward(self, x): 165 | x = self.LocalAgg(x) 166 | x = self.SelfAttn(x) 167 | return x 168 | 169 | 170 | class PatchEmbed(nn.Module): 171 | """ Image to Patch Embedding 172 | """ 173 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 174 | super().__init__() 175 | img_size = to_2tuple(img_size) 176 | patch_size = to_2tuple(patch_size) 177 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 178 | self.img_size = img_size 179 | self.patch_size = patch_size 180 | self.num_patches = num_patches 181 | self.norm = nn.LayerNorm(embed_dim) 182 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 183 | 184 | def forward(self, x): 185 | B, C, H, W = x.shape 186 | assert H == self.img_size[0] and W == self.img_size[1], \ 187 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 188 | x = self.proj(x) 189 | B, C, H, W = x.shape 190 | x = x.flatten(2).transpose(1, 2) 191 | x = self.norm(x) 192 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 193 | return x 194 | 195 | 196 | class EdgeVit(nn.Module): 197 | """ Vision Transformer 198 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 199 | https://arxiv.org/abs/2010.11929 200 | """ 201 | def __init__(self, depth=[1, 2, 5, 3], img_size=224, in_chans=3, num_classes=1000, embed_dim=[48, 96, 240, 384], 202 | head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 203 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, sr_ratios=[4,2,2,1], **kwargs): 204 | """ 205 | Args: 206 | depth (list): depth of each stage 207 | img_size (int, tuple): input image size 208 | in_chans (int): number of input channels 209 | num_classes (int): number of classes for classification head 210 | embed_dim (list): embedding dimension of each stage 211 | head_dim (int): head dimension 212 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 213 | qkv_bias (bool): enable bias for qkv if True 214 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 215 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 216 | drop_rate (float): dropout rate 217 | attn_drop_rate (float): attention dropout rate 218 | drop_path_rate (float): stochastic depth rate 219 | norm_layer (nn.Module): normalization layer 220 | """ 221 | super().__init__() 222 | self.num_classes = num_classes 223 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 224 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 225 | 226 | self.patch_embed1 = PatchEmbed( 227 | img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0]) 228 | self.patch_embed2 = PatchEmbed( 229 | img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1]) 230 | self.patch_embed3 = PatchEmbed( 231 | img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2]) 232 | self.patch_embed4 = PatchEmbed( 233 | img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3]) 234 | 235 | self.pos_drop = nn.Dropout(p=drop_rate) 236 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))] # stochastic depth decay rule 237 | num_heads = [dim // head_dim for dim in embed_dim] 238 | self.blocks1 = nn.ModuleList([ 239 | LGLBlock( 240 | dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio[0], qkv_bias=qkv_bias, qk_scale=qk_scale, 241 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, sr_ratio=sr_ratios[0]) 242 | for i in range(depth[0])]) 243 | self.blocks2 = nn.ModuleList([ 244 | LGLBlock( 245 | dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio[1], qkv_bias=qkv_bias, qk_scale=qk_scale, 246 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer, sr_ratio=sr_ratios[1]) 247 | for i in range(depth[1])]) 248 | self.blocks3 = nn.ModuleList([ 249 | LGLBlock( 250 | dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio[2], qkv_bias=qkv_bias, qk_scale=qk_scale, 251 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer, sr_ratio=sr_ratios[2]) 252 | for i in range(depth[2])]) 253 | self.blocks4 = nn.ModuleList([ 254 | LGLBlock( 255 | dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio[3], qkv_bias=qkv_bias, qk_scale=qk_scale, 256 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer, sr_ratio=sr_ratios[3]) 257 | for i in range(depth[3])]) 258 | self.norm = nn.BatchNorm2d(embed_dim[-1]) 259 | 260 | # Representation layer 261 | if representation_size: 262 | self.num_features = representation_size 263 | self.pre_logits = nn.Sequential(OrderedDict([ 264 | ('fc', nn.Linear(embed_dim, representation_size)), 265 | ('act', nn.Tanh()) 266 | ])) 267 | else: 268 | self.pre_logits = nn.Identity() 269 | 270 | # Classifier head 271 | self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() 272 | 273 | self.apply(self._init_weights) 274 | 275 | def _init_weights(self, m): 276 | if isinstance(m, nn.Linear): 277 | trunc_normal_(m.weight, std=.02) 278 | if isinstance(m, nn.Linear) and m.bias is not None: 279 | nn.init.constant_(m.bias, 0) 280 | elif isinstance(m, nn.LayerNorm): 281 | nn.init.constant_(m.bias, 0) 282 | nn.init.constant_(m.weight, 1.0) 283 | 284 | @torch.jit.ignore 285 | def no_weight_decay(self): 286 | return {'pos_embed', 'cls_token'} 287 | 288 | def get_classifier(self): 289 | return self.head 290 | 291 | def reset_classifier(self, num_classes, global_pool=''): 292 | self.num_classes = num_classes 293 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 294 | 295 | def forward_features(self, x): 296 | x = self.patch_embed1(x) 297 | x = self.pos_drop(x) 298 | for blk in self.blocks1: 299 | x = blk(x) 300 | x = self.patch_embed2(x) 301 | for blk in self.blocks2: 302 | x = blk(x) 303 | x = self.patch_embed3(x) 304 | for blk in self.blocks3: 305 | x = blk(x) 306 | x = self.patch_embed4(x) 307 | for blk in self.blocks4: 308 | x = blk(x) 309 | x = self.norm(x) 310 | x = self.pre_logits(x) 311 | return x 312 | 313 | def forward(self, x): 314 | x = self.forward_features(x) 315 | x = x.flatten(2).mean(-1) 316 | x = self.head(x) 317 | return x 318 | 319 | 320 | @register_model 321 | def edgevit_xxs(pretrained=True, **kwargs): 322 | model = EdgeVit( 323 | depth=[1, 1, 3, 2], 324 | embed_dim=[36, 72, 144, 288], head_dim=36, mlp_ratio=[4]*4, qkv_bias=True, 325 | norm_layer=partial(nn.LayerNorm, eps=1e-6), sr_ratios=[4,2,2,1], **kwargs) 326 | model.default_cfg = _cfg() 327 | return model 328 | 329 | @register_model 330 | def edgevit_xs(pretrained=True, **kwargs): 331 | model = EdgeVit( 332 | depth=[1, 1, 3, 1], 333 | embed_dim=[48, 96, 240, 384], head_dim=48, mlp_ratio=[4]*4, qkv_bias=True, 334 | norm_layer=partial(nn.LayerNorm, eps=1e-6), sr_ratios=[4,2,2,1], **kwargs) 335 | model.default_cfg = _cfg() 336 | return model 337 | 338 | 339 | @register_model 340 | def edgevit_s(pretrained=True, **kwargs): 341 | model = EdgeVit( 342 | depth=[1, 2, 5, 3], 343 | embed_dim=[48, 96, 240, 384], head_dim=48, mlp_ratio=[4]*4, qkv_bias=True, 344 | norm_layer=partial(nn.LayerNorm, eps=1e-6), sr_ratios=[4,2,2,1], **kwargs) 345 | model.default_cfg = _cfg() 346 | return model 347 | -------------------------------------------------------------------------------- /src/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | 10 | import torch 11 | 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | 15 | from losses import DistillationLoss 16 | import utils 17 | 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 22 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 23 | set_training_mode=True): 24 | model.train(set_training_mode) 25 | metric_logger = utils.MetricLogger(delimiter=" ") 26 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | print_freq = 10 29 | 30 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 31 | samples = samples.to(device, non_blocking=True) 32 | targets = targets.to(device, non_blocking=True) 33 | 34 | if mixup_fn is not None: 35 | samples, targets = mixup_fn(samples, targets) 36 | 37 | with torch.cuda.amp.autocast(): 38 | outputs = model(samples) 39 | loss = criterion(samples, outputs, targets) 40 | 41 | loss_value = loss.item() 42 | 43 | if not math.isfinite(loss_value): 44 | print("Loss is {}, stopping training".format(loss_value)) 45 | sys.exit(1) 46 | 47 | optimizer.zero_grad() 48 | 49 | # this attribute is added by timm on one optimizer (adahessian) 50 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 51 | loss_scaler(loss, optimizer, clip_grad=max_norm, 52 | parameters=model.parameters(), create_graph=is_second_order) 53 | 54 | torch.cuda.synchronize() 55 | if model_ema is not None: 56 | model_ema.update(model) 57 | 58 | metric_logger.update(loss=loss_value) 59 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 60 | # gather the stats from all processes 61 | metric_logger.synchronize_between_processes() 62 | print("Averaged stats:", metric_logger) 63 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 64 | 65 | 66 | @torch.no_grad() 67 | def evaluate(data_loader, model, device): 68 | criterion = torch.nn.CrossEntropyLoss() 69 | 70 | metric_logger = utils.MetricLogger(delimiter=" ") 71 | header = 'Test:' 72 | 73 | # switch to evaluation mode 74 | model.eval() 75 | 76 | for images, target in metric_logger.log_every(data_loader, 10, header): 77 | images = images.to(device, non_blocking=True) 78 | target = target.to(device, non_blocking=True) 79 | 80 | # compute output 81 | with torch.cuda.amp.autocast(): 82 | output = model(images) 83 | loss = criterion(output, target) 84 | 85 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 86 | 87 | batch_size = images.shape[0] 88 | metric_logger.update(loss=loss.item()) 89 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 90 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 91 | # gather the stats from all processes 92 | metric_logger.synchronize_between_processes() 93 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 94 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 95 | 96 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 97 | -------------------------------------------------------------------------------- /src/hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from models import * 4 | from cait_models import * 5 | from resmlp_models import * 6 | #from patchconvnet_models import * 7 | 8 | dependencies = ["torch", "torchvision", "timm"] 9 | -------------------------------------------------------------------------------- /src/imagenet_h5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | import json 4 | import numpy as np 5 | import os 6 | import random 7 | import re 8 | import torch 9 | import torch.utils.data 10 | from PIL import Image 11 | from torchvision import transforms as transforms_tv 12 | 13 | import h5py, io, time 14 | import os.path as osp 15 | 16 | 17 | class Imagenet(torch.utils.data.Dataset): 18 | """ImageNet dataset.""" 19 | 20 | def __init__(self, root_path, train=True, transform=None): 21 | if train: 22 | self.mode = 'train' 23 | else: 24 | self.mode = 'val' 25 | self.data_path = root_path 26 | 27 | # logger.info("Constructing ImageNet {}...".format(mode)) 28 | self._construct_imdb_h5() 29 | self.transform = transform 30 | 31 | def safe_record_loader(self, raw_frame, attempts=10, retry_delay=1): 32 | for j in range(attempts): 33 | try: 34 | img = Image.open(io.BytesIO(raw_frame)).convert('RGB') 35 | return img 36 | except OSError as e: 37 | print(f'Attempt {j}/{attempts}: failed to load\n{e}', flush=True) 38 | if j == attempts - 1: 39 | raise 40 | time.sleep(retry_delay) 41 | 42 | def _construct_imdb_h5(self): 43 | # def __init__(self, h5_path, transform=None): 44 | # self.h5_fp = h5_path 45 | self.h5_fp = os.path.join(self.data_path, self.mode+'.h5') 46 | # logger.info("{} data path: {}".format(self.mode, self.h5_fp)) 47 | 48 | assert osp.isfile(self.h5_fp), "File not found: {}".format(self.h5_fp) 49 | self.h5_file = None 50 | h5_file = h5py.File(self.h5_fp, 'r') 51 | self._imdb = [] 52 | # labels = list(h5_file.keys()) 53 | labels = sorted(list(h5_file.keys())) 54 | for key, value in h5_file.items(): 55 | target = labels.index(key) 56 | for img_name in value.keys(): 57 | self._imdb.append({'image_name': img_name, 'class_name': key, 'class': target}) 58 | self.num_videos = len(self._imdb) 59 | # logger.info("Number of images: {}".format(len(self._imdb))) 60 | # logger.info("Number of classes: {}".format(len(labels))) 61 | 62 | def __load_h5__(self, index): 63 | try: 64 | # Load the image 65 | if self.h5_file is None: 66 | self.h5_file = h5py.File(self.h5_fp, 'r') 67 | record = self._imdb[index] 68 | raw_frame = self.h5_file[record['class_name']][record['image_name']][()] 69 | img = self.safe_record_loader(raw_frame) 70 | return img 71 | except Exception: 72 | return None 73 | 74 | def __getitem__(self, index): 75 | im = self.__load_h5__(index) 76 | if self.transform is not None: 77 | im = self.transform(im) 78 | # Retrieve the label 79 | label = self._imdb[index]["class"] 80 | return im, label 81 | 82 | def __len__(self): 83 | return len(self._imdb) 84 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | 11 | from pathlib import Path 12 | 13 | from timm.data import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.scheduler import create_scheduler 17 | from timm.optim import create_optimizer 18 | from timm.utils import NativeScaler, get_state_dict, ModelEma 19 | 20 | from datasets import build_dataset 21 | from engine import train_one_epoch, evaluate 22 | from losses import DistillationLoss 23 | from samplers import RASampler 24 | import utils 25 | 26 | 27 | import edgevit 28 | 29 | def get_args_parser(): 30 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 31 | parser.add_argument('--batch-size', default=64, type=int) 32 | parser.add_argument('--epochs', default=300, type=int) 33 | 34 | # Model parameters 35 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 36 | help='Name of model to train') 37 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 38 | 39 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 40 | help='Dropout rate (default: 0.)') 41 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 42 | help='Drop path rate (default: 0.1)') 43 | 44 | parser.add_argument('--model-ema', action='store_true') 45 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 46 | parser.set_defaults(model_ema=True) 47 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 48 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 49 | 50 | # Optimizer parameters 51 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 52 | help='Optimizer (default: "adamw"') 53 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 54 | help='Optimizer Epsilon (default: 1e-8)') 55 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 56 | help='Optimizer Betas (default: None, use opt default)') 57 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 58 | help='Clip gradient norm (default: None, no clipping)') 59 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 60 | help='SGD momentum (default: 0.9)') 61 | parser.add_argument('--weight-decay', type=float, default=0.05, 62 | help='weight decay (default: 0.05)') 63 | # Learning rate schedule parameters 64 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 65 | help='LR scheduler (default: "cosine"') 66 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 67 | help='learning rate (default: 5e-4)') 68 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 69 | help='learning rate noise on/off epoch percentages') 70 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 71 | help='learning rate noise limit percent (default: 0.67)') 72 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 73 | help='learning rate noise std-dev (default: 1.0)') 74 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 75 | help='warmup learning rate (default: 1e-6)') 76 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 78 | 79 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 80 | help='epoch interval to decay LR') 81 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 82 | help='epochs to warmup LR, if scheduler supports') 83 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 84 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 85 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 86 | help='patience epochs for Plateau LR scheduler (default: 10') 87 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 88 | help='LR decay rate (default: 0.1)') 89 | 90 | # Augmentation parameters 91 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 92 | help='Color jitter factor (default: 0.4)') 93 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 94 | help='Use AutoAugment policy. "v0" or "original". " + \ 95 | "(default: rand-m9-mstd0.5-inc1)'), 96 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 97 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 98 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 99 | 100 | parser.add_argument('--repeated-aug', action='store_true') 101 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 102 | parser.set_defaults(repeated_aug=True) 103 | 104 | # * Random Erase params 105 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 106 | help='Random erase prob (default: 0.25)') 107 | parser.add_argument('--remode', type=str, default='pixel', 108 | help='Random erase mode (default: "pixel")') 109 | parser.add_argument('--recount', type=int, default=1, 110 | help='Random erase count (default: 1)') 111 | parser.add_argument('--resplit', action='store_true', default=False, 112 | help='Do not random erase first (clean) augmentation split') 113 | 114 | # * Mixup params 115 | parser.add_argument('--mixup', type=float, default=0.8, 116 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 117 | parser.add_argument('--cutmix', type=float, default=1.0, 118 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 119 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 120 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 121 | parser.add_argument('--mixup-prob', type=float, default=1.0, 122 | help='Probability of performing mixup or cutmix when either/both is enabled') 123 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 124 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 125 | parser.add_argument('--mixup-mode', type=str, default='batch', 126 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 127 | 128 | # Distillation parameters 129 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 130 | help='Name of teacher model to train (default: "regnety_160"') 131 | parser.add_argument('--teacher-path', type=str, default='') 132 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 133 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 134 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 135 | 136 | # * Finetuning params 137 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 138 | 139 | # Dataset parameters 140 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 141 | help='dataset path') 142 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19', 'IMNETH5'], 143 | type=str, help='Image Net dataset path') 144 | parser.add_argument('--inat-category', default='name', 145 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 146 | type=str, help='semantic granularity') 147 | 148 | parser.add_argument('--output_dir', default='', 149 | help='path where to save, empty for no saving') 150 | parser.add_argument('--device', default='cuda', 151 | help='device to use for training / testing') 152 | parser.add_argument('--seed', default=0, type=int) 153 | parser.add_argument('--resume', default='', help='resume from checkpoint') 154 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 155 | help='start epoch') 156 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 157 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 158 | parser.add_argument('--num_workers', default=10, type=int) 159 | parser.add_argument('--pin-mem', action='store_true', 160 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 161 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 162 | help='') 163 | parser.set_defaults(pin_mem=True) 164 | 165 | # distributed training parameters 166 | parser.add_argument('--world_size', default=1, type=int, 167 | help='number of distributed processes') 168 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 169 | return parser 170 | 171 | 172 | def main(args): 173 | utils.init_distributed_mode(args) 174 | 175 | print(args) 176 | 177 | if args.distillation_type != 'none' and args.finetune and not args.eval: 178 | raise NotImplementedError("Finetuning with distillation not yet supported") 179 | 180 | device = torch.device(args.device) 181 | 182 | # fix the seed for reproducibility 183 | seed = args.seed + utils.get_rank() 184 | torch.manual_seed(seed) 185 | np.random.seed(seed) 186 | # random.seed(seed) 187 | 188 | cudnn.benchmark = True 189 | 190 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 191 | dataset_val, _ = build_dataset(is_train=False, args=args) 192 | 193 | if True: # args.distributed: 194 | num_tasks = utils.get_world_size() 195 | global_rank = utils.get_rank() 196 | if args.repeated_aug: 197 | sampler_train = RASampler( 198 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 199 | ) 200 | else: 201 | sampler_train = torch.utils.data.DistributedSampler( 202 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 203 | ) 204 | if args.dist_eval: 205 | if len(dataset_val) % num_tasks != 0: 206 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 207 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 208 | 'equal num of samples per-process.') 209 | sampler_val = torch.utils.data.DistributedSampler( 210 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 211 | else: 212 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 213 | else: 214 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 215 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 216 | 217 | data_loader_train = torch.utils.data.DataLoader( 218 | dataset_train, sampler=sampler_train, 219 | batch_size=args.batch_size, 220 | num_workers=args.num_workers, 221 | pin_memory=args.pin_mem, 222 | drop_last=True, 223 | ) 224 | 225 | data_loader_val = torch.utils.data.DataLoader( 226 | dataset_val, sampler=sampler_val, 227 | batch_size=int(1.5 * args.batch_size), 228 | num_workers=args.num_workers, 229 | pin_memory=args.pin_mem, 230 | drop_last=False 231 | ) 232 | 233 | mixup_fn = None 234 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 235 | if mixup_active: 236 | mixup_fn = Mixup( 237 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 238 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 239 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 240 | 241 | print(f"Creating model: {args.model}") 242 | model = create_model( 243 | args.model, 244 | pretrained=False, 245 | num_classes=args.nb_classes, 246 | drop_rate=args.drop, 247 | drop_path_rate=args.drop_path, 248 | drop_block_rate=None, 249 | ) 250 | 251 | if args.finetune: 252 | if args.finetune.startswith('https'): 253 | checkpoint = torch.hub.load_state_dict_from_url( 254 | args.finetune, map_location='cpu', check_hash=True) 255 | else: 256 | checkpoint = torch.load(args.finetune, map_location='cpu') 257 | 258 | checkpoint_model = checkpoint['model'] 259 | state_dict = model.state_dict() 260 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 261 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 262 | print(f"Removing key {k} from pretrained checkpoint") 263 | del checkpoint_model[k] 264 | 265 | # interpolate position embedding 266 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 267 | embedding_size = pos_embed_checkpoint.shape[-1] 268 | num_patches = model.patch_embed.num_patches 269 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 270 | # height (== width) for the checkpoint position embedding 271 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 272 | # height (== width) for the new position embedding 273 | new_size = int(num_patches ** 0.5) 274 | # class_token and dist_token are kept unchanged 275 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 276 | # only the position tokens are interpolated 277 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 278 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 279 | pos_tokens = torch.nn.functional.interpolate( 280 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 281 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 282 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 283 | checkpoint_model['pos_embed'] = new_pos_embed 284 | 285 | model.load_state_dict(checkpoint_model, strict=False) 286 | 287 | model.to(device) 288 | 289 | model_ema = None 290 | if args.model_ema: 291 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 292 | model_ema = ModelEma( 293 | model, 294 | decay=args.model_ema_decay, 295 | device='cpu' if args.model_ema_force_cpu else '', 296 | resume='') 297 | 298 | model_without_ddp = model 299 | if args.distributed: 300 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 301 | model_without_ddp = model.module 302 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 303 | print('number of params:', n_parameters) 304 | 305 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 306 | args.lr = linear_scaled_lr 307 | optimizer = create_optimizer(args, model_without_ddp) 308 | loss_scaler = NativeScaler() 309 | 310 | lr_scheduler, _ = create_scheduler(args, optimizer) 311 | 312 | criterion = LabelSmoothingCrossEntropy() 313 | 314 | if mixup_active: 315 | # smoothing is handled with mixup label transform 316 | criterion = SoftTargetCrossEntropy() 317 | elif args.smoothing: 318 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 319 | else: 320 | criterion = torch.nn.CrossEntropyLoss() 321 | 322 | teacher_model = None 323 | if args.distillation_type != 'none': 324 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 325 | print(f"Creating teacher model: {args.teacher_model}") 326 | teacher_model = create_model( 327 | args.teacher_model, 328 | pretrained=False, 329 | num_classes=args.nb_classes, 330 | global_pool='avg', 331 | ) 332 | if args.teacher_path.startswith('https'): 333 | checkpoint = torch.hub.load_state_dict_from_url( 334 | args.teacher_path, map_location='cpu', check_hash=True) 335 | else: 336 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 337 | teacher_model.load_state_dict(checkpoint['model']) 338 | teacher_model.to(device) 339 | teacher_model.eval() 340 | 341 | # wrap the criterion in our custom DistillationLoss, which 342 | # just dispatches to the original criterion if args.distillation_type is 'none' 343 | criterion = DistillationLoss( 344 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 345 | ) 346 | 347 | output_dir = Path(args.output_dir) 348 | if args.resume: 349 | if args.resume.startswith('https'): 350 | checkpoint = torch.hub.load_state_dict_from_url( 351 | args.resume, map_location='cpu', check_hash=True) 352 | else: 353 | checkpoint = torch.load(args.resume, map_location='cpu') 354 | model_without_ddp.load_state_dict(checkpoint['model']) 355 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 356 | optimizer.load_state_dict(checkpoint['optimizer']) 357 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 358 | args.start_epoch = checkpoint['epoch'] + 1 359 | if args.model_ema: 360 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 361 | if 'scaler' in checkpoint: 362 | loss_scaler.load_state_dict(checkpoint['scaler']) 363 | 364 | if args.eval: 365 | test_stats = evaluate(data_loader_val, model, device) 366 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 367 | return 368 | 369 | print(f"Start training for {args.epochs} epochs") 370 | start_time = time.time() 371 | max_accuracy = 0.0 372 | for epoch in range(args.start_epoch, args.epochs): 373 | if args.distributed: 374 | data_loader_train.sampler.set_epoch(epoch) 375 | 376 | train_stats = train_one_epoch( 377 | model, criterion, data_loader_train, 378 | optimizer, device, epoch, loss_scaler, 379 | args.clip_grad, model_ema, mixup_fn, 380 | set_training_mode=args.finetune == '' # keep in eval mode during finetuning 381 | ) 382 | 383 | lr_scheduler.step(epoch) 384 | if args.output_dir: 385 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 386 | for checkpoint_path in checkpoint_paths: 387 | utils.save_on_master({ 388 | 'model': model_without_ddp.state_dict(), 389 | 'optimizer': optimizer.state_dict(), 390 | 'lr_scheduler': lr_scheduler.state_dict(), 391 | 'epoch': epoch, 392 | 'model_ema': get_state_dict(model_ema), 393 | 'scaler': loss_scaler.state_dict(), 394 | 'args': args, 395 | }, checkpoint_path) 396 | 397 | 398 | test_stats = evaluate(data_loader_val, model, device) 399 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 400 | 401 | if max_accuracy < test_stats["acc1"]: 402 | max_accuracy = test_stats["acc1"] 403 | if args.output_dir: 404 | checkpoint_paths = [output_dir / 'best_checkpoint.pth'] 405 | for checkpoint_path in checkpoint_paths: 406 | utils.save_on_master({ 407 | 'model': model_without_ddp.state_dict(), 408 | 'optimizer': optimizer.state_dict(), 409 | 'lr_scheduler': lr_scheduler.state_dict(), 410 | 'epoch': epoch, 411 | 'model_ema': get_state_dict(model_ema), 412 | 'scaler': loss_scaler.state_dict(), 413 | 'args': args, 414 | }, checkpoint_path) 415 | 416 | print(f'Max accuracy: {max_accuracy:.2f}%') 417 | 418 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 419 | **{f'test_{k}': v for k, v in test_stats.items()}, 420 | 'epoch': epoch, 421 | 'n_parameters': n_parameters} 422 | 423 | 424 | 425 | 426 | if args.output_dir and utils.is_main_process(): 427 | with (output_dir / "log.txt").open("a") as f: 428 | f.write(json.dumps(log_stats) + "\n") 429 | 430 | total_time = time.time() - start_time 431 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 432 | print('Training time {}'.format(total_time_str)) 433 | 434 | 435 | if __name__ == '__main__': 436 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 437 | args = parser.parse_args() 438 | if args.output_dir: 439 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 440 | main(args) 441 | -------------------------------------------------------------------------------- /src/samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | --------------------------------------------------------------------------------