├── LICENSE ├── README.md ├── datasets.py ├── engine.py ├── images ├── EdgeNext.png ├── Segmentation.png ├── madds_vs_top_1.png └── table_2.png ├── main.py ├── models ├── conv_encoder.py ├── edgenext.py ├── edgenext_bn_hs.py ├── layers.py ├── model.py └── sdta_encoder.py ├── optim_factory.py ├── requirements.txt ├── sampler.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Muhammad Maaz 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EdgeNeXt 2 | ### **EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications [CADL'22, ECCVW]** 3 | 4 | [Muhammad Maaz](https://scholar.google.com/citations?user=vTy9Te8AAAAJ&hl=en&authuser=1&oi=sra), 5 | [Abdelrahman Shaker](https://scholar.google.com/citations?hl=en&user=eEz4Wu4AAAAJ), 6 | [Hisham Cholakkal](https://scholar.google.com/citations?hl=en&user=bZ3YBRcAAAAJ), 7 | [Salman Khan](https://salman-h-khan.github.io), 8 | [Syed Waqas Zamir](https://www.waqaszamir.com), 9 | [Rao Muhammad Anwer](https://scholar.google.com/citations?hl=en&authuser=1&user=_KlvMVoAAAAJ) 10 | and [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en) 11 | 12 | [![Website](https://img.shields.io/badge/Project-Website-87CEEB)](https://mmaaz60.github.io/EdgeNeXt) 13 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2206.10589) 14 | [![video](https://img.shields.io/badge/Video-Presentation-F9D371)](https://www.youtube.com/watch?v=Oh-ooHlx58o) 15 | [![slides](https://img.shields.io/badge/Presentation-Slides-B762C1)](https://mbzuaiac-my.sharepoint.com/:b:/g/personal/muhammad_maaz_mbzuai_ac_ae/EaFA4bSPEMBNlJuHMbKDD3UBHmwXrmpijSRqZITk2l1-wQ?e=b7ruLV) 16 | 17 | ## :rocket: News 18 | * **(Jul 26, 2023):** [SwiftFormer](https://github.com/Amshaker/SwiftFormer) is accepted at ICCV 2023 :fire::fire::fire:. 19 | * **(Mar 28, 2023):** [SwiftFormer](https://github.com/Amshaker/SwiftFormer) is released :fire::fire::fire:. Our small variant achieves 78.5% top-1 ImageNet-1K accuracy with only 0.8 ms latency on iPhone 14. 20 | * **(Aug 10, 2022):** EdgeNeXt-B ImageNet-21K pretrained model is released. It achieves 83.31% top-1 ImageNet-1K accuracy. The weights are available at [EdgeNeXt-B-IN21K](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth). 21 | * **(Oct 23, 2022):** EdgeNeXt is presented at [International Workshop on Computational Aspects of Deep Learning at ECCV 2022](https://ailb-web.ing.unimore.it/cadl2022) in a *full length oral presentation*. 22 | * **(Jul 28, 2022):** EdgeNeXt-B model is released. It achieves 82.5% top-1 ImageNet-1K accuracy with 18.51M parameters and 3.84G MAdds. 23 | Further, using USI (https://arxiv.org/abs/2204.03475) training recipe, the same model achieves 83.7% accuracy. 24 | 25 | * **(Jun 28, 2022):** EdgeNeXt-S model trained using USI (https://arxiv.org/abs/2204.03475) is released. 26 | It achieves 81.1% top-1 ImageNet-1K accuracy with only 5.59M parameters and 1.26G MAdds. 27 | 28 | * **(Jun 22, 2022):** Training and evaluation code along with pre-trained models are released. 29 | 30 |
31 | 32 | ![main figure](images/EdgeNext.png) 33 | > **Abstract:** *In the pursuit of achieving ever-increasing accuracy, large and complex neural networks are usually developed. Such models demand high computational resources and therefore cannot be deployed on edge devices. It is of great interest to build resource-efficient general purpose networks due to their usefulness in several application areas. In this work, we strive to effectively combine the strengths of both CNN and Transformer models and propose a new efficient hybrid architecture EdgeNeXt. Specifically in EdgeNeXt, we introduce split depth-wise transpose attention (SDTA) encoder that splits input tensors into multiple channel groups and utilizes depth-wise convolution along with self-attention across channel dimensions to implicitly increase the receptive field and encode multi-scale features. Our extensive experiments on classification, detection and segmentation tasks, reveal the merits of the proposed approach, outperforming state-of-the-art methods with comparatively lower compute requirements. Our EdgeNeXt model with 1.3M parameters achieves 71.2\% top-1 accuracy on ImageNet-1K, outperforming MobileViT with an absolute gain of 2.2\% with 28\% reduction in FLOPs. Further, our EdgeNeXt model with 5.6M parameters achieves 79.4\% top-1 accuracy on ImageNet-1K.* 34 |
35 | 36 | ## Model Zoo 37 | 38 | | Name |Acc@1 | #Params | MAdds | Model | 39 | |---|:---:|:---:| :---:|:---:| 40 | | edgenext_base_usi | 83.68 | 18.51M | 3.84G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth) 41 | | edgenext_base_IN21K | 83.31 | 18.51M | 3.84G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.21/edgenext_base_IN21K.pth) 42 | | edgenext_base | 82.47 | 18.51M | 3.84G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base.pth) 43 | | edgenext_small_usi | 81.07 | 5.59M | 1.26G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth) 44 | | edgenext_small | 79.41 | 5.59M | 1.26G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth) 45 | | edgenext_x_small | 74.96 | 2.34M | 538M | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth) 46 | | edgenext_xx_small | 71.23 | 1.33M | 261M | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth) 47 | | edgenext_small_bn_hs | 78.39 | 5.58M | 1.25G | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small_bn_hs.pth) 48 | | edgenext_x_small_bn_hs | 74.87 | 2.34M | 536M | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small_bn_hs.pth) 49 | | edgenext_xx_small_bn_hs | 70.33 | 1.33M | 260M | [model](https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small_bn_hs.pth) 50 | 51 |
52 | 53 | ## Comparison with SOTA ViTs and Hybrid Architectures 54 | ![results](images/madds_vs_top_1.png) 55 | 56 |
57 | 58 | ## Comparison with Previous SOTA [MobileViT (ICLR-2022)](https://arxiv.org/abs/2110.02178) 59 | ![results](images/table_2.png) 60 | 61 |
62 | 63 | ## Qualitative Results (Segmentation) 64 | ![results](images/Segmentation.png) 65 | 66 | ## Installation 67 | 1. Create conda environment 68 | ```shell 69 | conda create --name edgenext python=3.8 70 | conda activate edgenext 71 | ``` 72 | 2. Install PyTorch and torchvision 73 | ```shell 74 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 75 | ``` 76 | 3. Install other dependencies 77 | ```shell 78 | pip install -r requirements.txt 79 | ``` 80 | 81 |
82 | 83 | ## Dataset Preparation 84 | Download the [ImageNet-1K](http://image-net.org/) classification dataset and structure the data as follows: 85 | ``` 86 | /path/to/imagenet-1k/ 87 | train/ 88 | class1/ 89 | img1.jpeg 90 | class2/ 91 | img2.jpeg 92 | val/ 93 | class1/ 94 | img3.jpeg 95 | class2/ 96 | img4.jpeg 97 | ``` 98 | 99 |
100 | 101 | ## Evaluation 102 | Download the pretrained weights and run the following command for evaluation on ImageNet-1K dataset. 103 | 104 | ```shell 105 | wget https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth 106 | python main.py --model edgenext_small --eval True --batch_size 16 --data_path --output_dir --resume edgenext_small.pth 107 | ``` 108 | This should give, 109 | ```text 110 | Acc@1 79.412 Acc@5 94.512 loss 0.881 111 | ``` 112 | 113 | ##### Note: For evaluating the USI model, please set `usi_eval True`. 114 | 115 |
116 | 117 | ## Training 118 | 119 | On a single machine with 8 GPUs, run the following command to train EdgeNeXt-S model. 120 | 121 | ```shell 122 | python -m torch.distributed.launch --nproc_per_node=8 main.py \ 123 | --model edgenext_small --drop_path 0.1 \ 124 | --batch_size 256 --lr 6e-3 --update_freq 2 \ 125 | --model_ema true --model_ema_eval true \ 126 | --data_path
\ 127 | --output_dir \ 128 | --use_amp True --multi_scale_sampler 129 | ``` 130 |
131 | 132 | ## Citation 133 | If you use our work, please consider citing: 134 | ```bibtex 135 | @inproceedings{Maaz2022EdgeNeXt, 136 | title={EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications}, 137 | author={Muhammad Maaz and Abdelrahman Shaker and Hisham Cholakkal and Salman Khan and Syed Waqas Zamir and Rao Muhammad Anwer and Fahad Shahbaz Khan}, 138 | booktitle={International Workshop on Computational Aspects of Deep Learning at 17th European Conference on Computer Vision (CADL2022)}, 139 | year={2022}, 140 | organization={Springer} 141 | } 142 | ``` 143 | 144 |
145 | 146 | ## Contact 147 | Should you have any question, please create an issue on this repository or contact at muhammad.maaz@mbzuai.ac.ae & abdelrahman.youssief@mbzuai.ac.ae 148 | 149 |
150 | 151 | ## References 152 | Our code is based on [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) repository. 153 | We thank them for releasing their code. 154 | 155 | ## Our Related Works 156 | - SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications. [Paper](https://arxiv.org/abs/2303.15446) | [Code](https://github.com/Amshaker/SwiftFormer). 157 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torchvision import datasets, transforms 3 | 4 | from timm.data.constants import \ 5 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 6 | from timm.data import create_transform 7 | from sampler import MultiScaleImageFolder 8 | 9 | 10 | def build_dataset(is_train, args): 11 | transform = build_transform(is_train, args) 12 | 13 | print("Transform = ") 14 | if isinstance(transform, tuple): 15 | for trans in transform: 16 | print(" - - - - - - - - - - ") 17 | for t in trans.transforms: 18 | print(t) 19 | else: 20 | for t in transform.transforms: 21 | print(t) 22 | print("---------------------------") 23 | 24 | if args.data_set == 'CIFAR': 25 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 26 | nb_classes = 100 27 | elif args.data_set == 'IMNET': 28 | print("reading from datapath", args.data_path) 29 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 30 | if is_train and args.multi_scale_sampler: 31 | dataset = MultiScaleImageFolder(root, args) 32 | else: 33 | dataset = datasets.ImageFolder(root, transform=transform) 34 | nb_classes = 1000 35 | elif args.data_set == "image_folder": 36 | root = args.data_path if is_train else args.eval_data_path 37 | dataset = datasets.ImageFolder(root, transform=transform) 38 | nb_classes = args.nb_classes 39 | assert len(dataset.class_to_idx) == nb_classes 40 | else: 41 | raise NotImplementedError() 42 | print("Number of the class = %d" % nb_classes) 43 | 44 | return dataset, nb_classes 45 | 46 | 47 | def build_transform(is_train, args): 48 | resize_im = args.input_size > 32 49 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 50 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 51 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 52 | 53 | if is_train: 54 | # This should always dispatch to transforms_imagenet_train 55 | transform = create_transform( 56 | input_size=args.input_size, 57 | is_training=True, 58 | color_jitter=args.color_jitter if args.color_jitter > 0 else None, 59 | auto_augment=args.aa, 60 | interpolation=args.train_interpolation, 61 | re_prob=args.reprob, 62 | re_mode=args.remode, 63 | re_count=args.recount, 64 | mean=mean, 65 | std=std, 66 | ) 67 | if args.three_aug: # --aa should not be "" to use this as it actually overrides the auto-augment 68 | print(f"Using 3-Augments instead of Rand Augment") 69 | cur_augs = transform.transforms 70 | three_aug = transforms.RandomChoice([transforms.Grayscale(num_output_channels=3), 71 | transforms.RandomSolarize(threshold=192.0), 72 | transforms.GaussianBlur(kernel_size=(5, 9))]) 73 | final_transforms = cur_augs[0:2] + [three_aug] + cur_augs[2:] 74 | transform = transforms.Compose(final_transforms) 75 | if not resize_im: 76 | transform.transforms[0] = transforms.RandomCrop( 77 | args.input_size, padding=4) 78 | return transform 79 | 80 | t = [] 81 | if resize_im: 82 | # Warping (no cropping) when evaluated at 384 or larger 83 | if args.input_size >= 384: 84 | t.append( 85 | transforms.Resize((args.input_size, args.input_size), 86 | interpolation=transforms.InterpolationMode.BICUBIC), 87 | ) 88 | print(f"Warping {args.input_size} size input images...") 89 | else: 90 | if args.crop_pct is None: 91 | args.crop_pct = 224 / 256 92 | size = int(args.input_size / args.crop_pct) 93 | t.append( 94 | # To maintain same ratio w.r.t. 224 images 95 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 96 | ) 97 | t.append(transforms.CenterCrop(args.input_size)) 98 | 99 | t.append(transforms.ToTensor()) 100 | t.append(transforms.Normalize(mean, std)) 101 | return transforms.Compose(t) 102 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterable, Optional 3 | import torch 4 | from timm.data import Mixup 5 | from timm.utils import accuracy, ModelEma 6 | 7 | import utils 8 | 9 | 10 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 11 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 12 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 13 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 14 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 15 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False): 16 | model.train(True) 17 | metric_logger = utils.MetricLogger(delimiter=" ") 18 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 19 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 20 | header = 'Epoch: [{}]'.format(epoch) 21 | print_freq = 10 22 | 23 | optimizer.zero_grad() 24 | 25 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 26 | step = data_iter_step // update_freq 27 | if step >= num_training_steps_per_epoch: 28 | continue 29 | it = start_steps + step # Global training iteration 30 | # Update LR & WD for the first acc 31 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 32 | for i, param_group in enumerate(optimizer.param_groups): 33 | if lr_schedule_values is not None: 34 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 35 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 36 | param_group["weight_decay"] = wd_schedule_values[it] 37 | 38 | samples = samples.to(device, non_blocking=True) 39 | targets = targets.to(device, non_blocking=True) 40 | 41 | if mixup_fn is not None: 42 | samples, targets = mixup_fn(samples, targets) 43 | 44 | if use_amp: 45 | with torch.cuda.amp.autocast(): 46 | output = model(samples) 47 | loss = criterion(output, targets) 48 | else: # Full precision 49 | output = model(samples) 50 | loss = criterion(output, targets) 51 | 52 | loss_value = loss.item() 53 | 54 | if not math.isfinite(loss_value): # This could trigger if using AMP 55 | print("Loss is {}, stopping training".format(loss_value)) 56 | assert math.isfinite(loss_value) 57 | 58 | if use_amp: 59 | # This attribute is added by timm on one optimizer (adahessian) 60 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 61 | loss /= update_freq 62 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 63 | parameters=model.parameters(), create_graph=is_second_order, 64 | update_grad=(data_iter_step + 1) % update_freq == 0) 65 | if (data_iter_step + 1) % update_freq == 0: 66 | optimizer.zero_grad() 67 | if model_ema is not None: 68 | model_ema.update(model) 69 | else: # Full precision 70 | loss /= update_freq 71 | loss.backward() 72 | if (data_iter_step + 1) % update_freq == 0: 73 | optimizer.step() 74 | optimizer.zero_grad() 75 | if model_ema is not None: 76 | model_ema.update(model) 77 | 78 | torch.cuda.synchronize() 79 | 80 | if mixup_fn is None: 81 | class_acc = (output.max(-1)[-1] == targets).float().mean() 82 | else: 83 | class_acc = None 84 | metric_logger.update(loss=loss_value) 85 | metric_logger.update(class_acc=class_acc) 86 | min_lr = 10. 87 | max_lr = 0. 88 | for group in optimizer.param_groups: 89 | min_lr = min(min_lr, group["lr"]) 90 | max_lr = max(max_lr, group["lr"]) 91 | 92 | metric_logger.update(lr=max_lr) 93 | metric_logger.update(min_lr=min_lr) 94 | weight_decay_value = None 95 | for group in optimizer.param_groups: 96 | if group["weight_decay"] > 0: 97 | weight_decay_value = group["weight_decay"] 98 | metric_logger.update(weight_decay=weight_decay_value) 99 | if use_amp: 100 | metric_logger.update(grad_norm=grad_norm) 101 | 102 | if log_writer is not None: 103 | log_writer.update(loss=loss_value, head="loss") 104 | log_writer.update(class_acc=class_acc, head="loss") 105 | log_writer.update(lr=max_lr, head="opt") 106 | log_writer.update(min_lr=min_lr, head="opt") 107 | log_writer.update(weight_decay=weight_decay_value, head="opt") 108 | if use_amp: 109 | log_writer.update(grad_norm=grad_norm, head="opt") 110 | log_writer.set_step() 111 | 112 | if wandb_logger: 113 | wandb_logger._wandb.log({ 114 | 'Rank-0 Batch Wise/train_loss': loss_value, 115 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 116 | 'Rank-0 Batch Wise/train_min_lr': min_lr 117 | }, commit=False) 118 | if class_acc: 119 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 120 | if use_amp: 121 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 122 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 123 | 124 | # Gather the stats from all processes 125 | metric_logger.synchronize_between_processes() 126 | print("Averaged stats:", metric_logger) 127 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 128 | 129 | 130 | @torch.no_grad() 131 | def evaluate(data_loader, model, device, use_amp=False): 132 | criterion = torch.nn.CrossEntropyLoss() 133 | 134 | metric_logger = utils.MetricLogger(delimiter=" ") 135 | header = 'Test:' 136 | 137 | # Switch to evaluation mode 138 | model.eval() 139 | for batch in metric_logger.log_every(data_loader, 10, header): 140 | images = batch[0] 141 | target = batch[-1] 142 | 143 | images = images.to(device, non_blocking=True) 144 | target = target.to(device, non_blocking=True) 145 | 146 | # Compute output 147 | if use_amp: 148 | with torch.cuda.amp.autocast(): 149 | output = model(images) 150 | loss = criterion(output, target) 151 | else: 152 | output = model(images) 153 | loss = criterion(output, target) 154 | 155 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 156 | 157 | batch_size = images.shape[0] 158 | metric_logger.update(loss=loss.item()) 159 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 160 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 161 | # Gather the stats from all processes 162 | metric_logger.synchronize_between_processes() 163 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 164 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 165 | 166 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 167 | -------------------------------------------------------------------------------- /images/EdgeNext.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmaaz60/EdgeNeXt/388450404fdd6d6e097f1cf94f5afcbbdc352f08/images/EdgeNext.png -------------------------------------------------------------------------------- /images/Segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmaaz60/EdgeNeXt/388450404fdd6d6e097f1cf94f5afcbbdc352f08/images/Segmentation.png -------------------------------------------------------------------------------- /images/madds_vs_top_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmaaz60/EdgeNeXt/388450404fdd6d6e097f1cf94f5afcbbdc352f08/images/madds_vs_top_1.png -------------------------------------------------------------------------------- /images/table_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmaaz60/EdgeNeXt/388450404fdd6d6e097f1cf94f5afcbbdc352f08/images/table_2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import numpy as np 4 | import time 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import json 8 | import os 9 | 10 | from pathlib import Path 11 | 12 | from timm.data.mixup import Mixup 13 | from timm.models import create_model 14 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 15 | from timm.utils import ModelEma 16 | from optim_factory import create_optimizer 17 | 18 | from datasets import build_dataset 19 | from engine import train_one_epoch, evaluate 20 | 21 | from utils import NativeScalerWithGradNormCount as NativeScaler 22 | import utils 23 | import models.model 24 | 25 | from sampler import MultiScaleSamplerDDP 26 | from fvcore.nn import FlopCountAnalysis 27 | 28 | 29 | def str2bool(v): 30 | """ 31 | Converts string to bool type; enables command line 32 | arguments in the format of '--arg1 true --arg2 false' 33 | """ 34 | if isinstance(v, bool): 35 | return v 36 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 37 | return True 38 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 39 | return False 40 | else: 41 | raise argparse.ArgumentTypeError('Boolean value expected.') 42 | 43 | 44 | def get_args_parser(): 45 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 46 | parser.add_argument('--batch_size', default=256, type=int, 47 | help='Per GPU batch size') 48 | parser.add_argument('--epochs', default=300, type=int) 49 | parser.add_argument('--update_freq', default=2, type=int, 50 | help='gradient accumulation steps') 51 | 52 | # Model parameters 53 | parser.add_argument('--model', default='edgenext_small', type=str, metavar='MODEL', 54 | help='Name of model to train') 55 | parser.add_argument('--drop_path', type=float, default=0.1, metavar='PCT', 56 | help='Drop path rate (default: 0.0)') 57 | parser.add_argument('--input_size', default=256, type=int, 58 | help='image input size') 59 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 60 | help="Layer scale initial values") 61 | 62 | # EMA related parameters 63 | parser.add_argument('--model_ema', type=str2bool, default=False) 64 | parser.add_argument('--model_ema_decay', type=float, default=0.9995, help='') # TODO: MobileViT is using 0.9995 65 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 66 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 67 | 68 | # Optimization parameters 69 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', help='Optimizer (default: "adamw"') 70 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 71 | help='Optimizer Epsilon (default: 1e-8)') 72 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 73 | help='Optimizer Betas (default: None, use opt default)') 74 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 75 | help='Clip gradient norm (default: None, no clipping)') 76 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 77 | help='SGD momentum (default: 0.9)') 78 | parser.add_argument('--weight_decay', type=float, default=0.05, 79 | help='weight decay (default: 0.05)') 80 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 81 | weight decay. We use a cosine schedule for WD and using a larger decay by 82 | the end of training improves performance for ViTs.""") 83 | 84 | parser.add_argument('--lr', type=float, default=6e-3, metavar='LR', 85 | help='learning rate (default: 6e-3), with total batch size 4096') 86 | parser.add_argument('--layer_decay', type=float, default=1.0) 87 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 88 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 89 | parser.add_argument('--warmup_epochs', type=int, default=20, metavar='N', 90 | help='epochs to warmup LR, if scheduler supports') 91 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 92 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 93 | parser.add_argument('--warmup_start_lr', type=float, default=0, metavar='LR', 94 | help='Starting LR for warmup (default 0)') 95 | 96 | # Augmentation parameters 97 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 98 | help='Color jitter factor (default: 0.4)') 99 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 100 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 101 | parser.add_argument('--smoothing', type=float, default=0.1, 102 | help='Label smoothing (default: 0.1)') 103 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 104 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 105 | 106 | # Evaluation parameters 107 | parser.add_argument('--crop_pct', type=float, default=None) 108 | 109 | # * Random Erase params 110 | parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT', 111 | help='Random erase prob (default: 0.0)') 112 | parser.add_argument('--remode', type=str, default='pixel', 113 | help='Random erase mode (default: "pixel")') 114 | parser.add_argument('--recount', type=int, default=1, 115 | help='Random erase count (default: 1)') 116 | parser.add_argument('--resplit', type=str2bool, default=False, 117 | help='Do not random erase first (clean) augmentation split') 118 | 119 | # Mixup params 120 | parser.add_argument('--mixup', type=float, default=0.0, 121 | help='mixup alpha, mixup enabled if > 0.') 122 | parser.add_argument('--cutmix', type=float, default=0.0, 123 | help='cutmix alpha, cutmix enabled if > 0.') 124 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 125 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 126 | parser.add_argument('--mixup_prob', type=float, default=0.0, 127 | help='Probability of performing mixup or cutmix when either/both is enabled') 128 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 129 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 130 | parser.add_argument('--mixup_mode', type=str, default='batch', 131 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 132 | 133 | # Dataset parameters 134 | parser.add_argument('--data_path', default='datasets/imagenet_full', type=str, 135 | help='dataset path (path to full imagenet)') 136 | parser.add_argument('--eval_data_path', default=None, type=str, 137 | help='dataset path for evaluation') 138 | parser.add_argument('--nb_classes', default=1000, type=int, 139 | help='number of the classification types') 140 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 141 | parser.add_argument('--data_set', default='IMNET', choices=['IMNET', 'image_folder'], 142 | type=str, help='ImageNet dataset path') 143 | parser.add_argument('--output_dir', default='', 144 | help='path where to save, empty for no saving') 145 | parser.add_argument('--log_dir', default=None, 146 | help='path where to tensorboard log') 147 | parser.add_argument('--device', default='cuda', 148 | help='device to use for training / testing') 149 | parser.add_argument('--seed', default=0, type=int) 150 | 151 | parser.add_argument('--resume', default='', 152 | help='resume from checkpoint') 153 | parser.add_argument('--finetune', default='', 154 | help='finetune the model') 155 | parser.add_argument('--auto_resume', type=str2bool, default=True) 156 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 157 | parser.add_argument('--save_ckpt_freq', default=1, type=int) 158 | parser.add_argument('--save_ckpt_num', default=3, type=int) 159 | 160 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 161 | help='start epoch') 162 | parser.add_argument('--eval', type=str2bool, default=False, 163 | help='Perform evaluation only') 164 | parser.add_argument('--dist_eval', type=str2bool, default=True, 165 | help='Enabling distributed evaluation') 166 | parser.add_argument('--disable_eval', type=str2bool, default=False, 167 | help='Disabling evaluation during training') 168 | parser.add_argument('--num_workers', default=10, type=int) 169 | parser.add_argument('--pin_mem', type=str2bool, default=True, 170 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 171 | 172 | # Distributed training parameters 173 | parser.add_argument('--world_size', default=1, type=int, 174 | help='number of distributed processes') 175 | parser.add_argument('--local_rank', default=-1, type=int) 176 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 177 | parser.add_argument('--dist_url', default='env://', 178 | help='url used to set up distributed training') 179 | 180 | parser.add_argument('--use_amp', type=str2bool, default=True, 181 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 182 | 183 | # Weights and Biases arguments 184 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 185 | help="enable logging to Weights and Biases") 186 | parser.add_argument('--project', default='edgenext', type=str, 187 | help="The name of the W&B project where you're sending the new run.") 188 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 189 | help="Save model checkpoints as W&B Artifacts.") 190 | parser.add_argument("--multi_scale_sampler", action="store_true", help="Either to use multi-scale sampler or not.") 191 | parser.add_argument('--min_crop_size_w', default=160, type=int) 192 | parser.add_argument('--max_crop_size_w', default=320, type=int) 193 | parser.add_argument('--min_crop_size_h', default=160, type=int) 194 | parser.add_argument('--max_crop_size_h', default=320, type=int) 195 | parser.add_argument("--find_unused_params", action="store_true", 196 | help="Set this flag to enable unused parameters finding in DistributedDataParallel()") 197 | parser.add_argument("--three_aug", action="store_true", 198 | help="Either to use three augments proposed by DeiT-III") 199 | parser.add_argument('--classifier_dropout', default=0.0, type=float) 200 | parser.add_argument('--usi_eval', type=str2bool, default=False, 201 | help="Enable it when testing USI model.") 202 | 203 | return parser 204 | 205 | 206 | def main(args): 207 | utils.init_distributed_mode(args) 208 | print(args) 209 | device = torch.device(args.device) 210 | 211 | # Eval/USI_eval configurations 212 | if args.eval: 213 | if args.usi_eval: 214 | args.crop_pct = 0.95 215 | model_state_dict_name = 'state_dict' 216 | else: 217 | model_state_dict_name = 'model_ema' 218 | else: 219 | model_state_dict_name = 'model' 220 | 221 | # Fix the seed for reproducibility 222 | seed = args.seed + utils.get_rank() 223 | torch.manual_seed(seed) 224 | np.random.seed(seed) 225 | cudnn.benchmark = True 226 | 227 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 228 | if args.disable_eval: 229 | args.dist_eval = False 230 | dataset_val = None 231 | else: 232 | dataset_val, _ = build_dataset(is_train=False, args=args) 233 | 234 | num_tasks = utils.get_world_size() 235 | global_rank = utils.get_rank() 236 | if args.multi_scale_sampler: 237 | sampler_train = MultiScaleSamplerDDP(base_im_w=args.input_size, base_im_h=args.input_size, 238 | base_batch_size=args.batch_size, n_data_samples=len(dataset_train), 239 | is_training=True, distributed=args.distributed, 240 | min_crop_size_w=args.min_crop_size_w, max_crop_size_w=args.max_crop_size_w, 241 | min_crop_size_h=args.min_crop_size_h, max_crop_size_h=args.max_crop_size_h) 242 | else: 243 | sampler_train = torch.utils.data.DistributedSampler( 244 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 245 | ) 246 | print("Sampler_train = %s" % str(sampler_train)) 247 | if args.dist_eval: 248 | if len(dataset_val) % num_tasks != 0: 249 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 250 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 251 | 'equal num of samples per-process.') 252 | sampler_val = torch.utils.data.DistributedSampler( 253 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 254 | else: 255 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 256 | 257 | if global_rank == 0 and args.log_dir is not None: 258 | os.makedirs(args.log_dir, exist_ok=True) 259 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 260 | else: 261 | log_writer = None 262 | 263 | if global_rank == 0 and args.enable_wandb: 264 | wandb_logger = utils.WandbLogger(args) 265 | else: 266 | wandb_logger = None 267 | 268 | if args.multi_scale_sampler: 269 | data_loader_train = torch.utils.data.DataLoader( 270 | dataset_train, batch_sampler=sampler_train, 271 | batch_size=1, 272 | num_workers=args.num_workers, 273 | pin_memory=args.pin_mem, 274 | ) 275 | else: 276 | data_loader_train = torch.utils.data.DataLoader( 277 | dataset_train, sampler=sampler_train, 278 | batch_size=args.batch_size, 279 | num_workers=args.num_workers, 280 | pin_memory=args.pin_mem, 281 | drop_last=True, 282 | ) 283 | 284 | if dataset_val is not None: 285 | data_loader_val = torch.utils.data.DataLoader( 286 | dataset_val, sampler=sampler_val, 287 | batch_size=int(1.5 * args.batch_size), 288 | num_workers=args.num_workers, 289 | pin_memory=args.pin_mem, 290 | drop_last=False 291 | ) 292 | else: 293 | data_loader_val = None 294 | 295 | mixup_fn = None 296 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 297 | if mixup_active: 298 | print("Mixup is activated!") 299 | mixup_fn = Mixup( 300 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 301 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 302 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 303 | 304 | model = create_model( 305 | args.model, 306 | pretrained=False, 307 | num_classes=args.nb_classes, 308 | drop_path_rate=args.drop_path, 309 | layer_scale_init_value=args.layer_scale_init_value, 310 | head_init_scale=1.0, 311 | input_res=args.input_size, 312 | classifier_dropout=args.classifier_dropout, 313 | ) 314 | if args.finetune: 315 | checkpoint = torch.load(args.finetune, map_location="cpu") 316 | state_dict = checkpoint[model_state_dict_name] 317 | utils.load_state_dict(model, state_dict) 318 | model.to(device) 319 | 320 | model_ema = None 321 | if args.model_ema: 322 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 323 | model_ema = ModelEma( 324 | model, 325 | decay=args.model_ema_decay, 326 | device='cpu' if args.model_ema_force_cpu else '', 327 | resume='') 328 | print("Using EMA with decay = %.8f" % args.model_ema_decay) 329 | 330 | model_without_ddp = model 331 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 332 | 333 | print("Model = %s" % str(model_without_ddp)) 334 | print('number of params:', n_parameters) 335 | 336 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 337 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 338 | print("LR = %.8f" % args.lr) 339 | print("Batch size = %d" % total_batch_size) 340 | print("Update frequent = %d" % args.update_freq) 341 | print("Number of training examples = %d" % len(dataset_train)) 342 | print("Number of training training per epoch = %d" % num_training_steps_per_epoch) 343 | 344 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 345 | # Layer decay not supported 346 | raise NotImplementedError 347 | else: 348 | assigner = None 349 | 350 | if assigner is not None: 351 | print("Assigned values = %s" % str(assigner.values)) 352 | 353 | if args.distributed: 354 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], 355 | find_unused_parameters=args.find_unused_params) 356 | model_without_ddp = model.module 357 | 358 | optimizer = create_optimizer( 359 | args, model_without_ddp, skip_list=None, 360 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 361 | get_layer_scale=assigner.get_scale if assigner is not None else None) 362 | 363 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 364 | 365 | print("Use Cosine LR scheduler") 366 | lr_schedule_values = utils.cosine_scheduler( 367 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 368 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 369 | start_warmup_value=args.warmup_start_lr 370 | ) 371 | 372 | if args.weight_decay_end is None: 373 | args.weight_decay_end = args.weight_decay 374 | wd_schedule_values = utils.cosine_scheduler( 375 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 376 | print("Max WD = %.7f, Min WD = %.7f" % (max(wd_schedule_values), min(wd_schedule_values))) 377 | 378 | if mixup_fn is not None: 379 | # smoothing is handled with mixup label transform 380 | criterion = SoftTargetCrossEntropy() 381 | elif args.smoothing > 0.: 382 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 383 | else: 384 | criterion = torch.nn.CrossEntropyLoss() 385 | 386 | print("criterion = %s" % str(criterion)) 387 | 388 | utils.auto_load_model( 389 | args=args, model=model, model_without_ddp=model_without_ddp, 390 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema, state_dict_name=model_state_dict_name) 391 | 392 | if args.eval: 393 | print(f"Eval only mode") 394 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 395 | print(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 396 | return 397 | 398 | max_accuracy = 0.0 399 | if args.model_ema and args.model_ema_eval: 400 | max_accuracy_ema = 0.0 401 | 402 | def count_parameters(model): 403 | total_trainable_params = 0 404 | for name, parameter in model.named_parameters(): 405 | if not parameter.requires_grad: 406 | continue 407 | params = parameter.numel() 408 | total_trainable_params += params 409 | return total_trainable_params 410 | 411 | total_params = count_parameters(model) 412 | # fvcore to calculate MAdds 413 | input_res = (3, args.input_size, args.input_size) 414 | input = torch.ones(()).new_empty((1, *input_res), dtype=next(model.parameters()).dtype, 415 | device=next(model.parameters()).device) 416 | flops = FlopCountAnalysis(model, input) 417 | model_flops = flops.total() 418 | print(f"Total Trainable Params: {round(total_params * 1e-6, 2)} M") 419 | print(f"MAdds: {round(model_flops * 1e-6, 2)} M") 420 | 421 | print("Start training for %d epochs" % args.epochs) 422 | start_time = time.time() 423 | for epoch in range(args.start_epoch, args.epochs): 424 | if args.multi_scale_sampler: 425 | data_loader_train.batch_sampler.set_epoch(epoch) 426 | elif args.distributed: 427 | data_loader_train.sampler.set_epoch(epoch) 428 | if log_writer is not None: 429 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 430 | if wandb_logger: 431 | wandb_logger.set_steps() 432 | train_stats = train_one_epoch( 433 | model, criterion, data_loader_train, optimizer, 434 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 435 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 436 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 437 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 438 | use_amp=args.use_amp 439 | ) 440 | if args.output_dir and args.save_ckpt: 441 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 442 | utils.save_model( 443 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 444 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 445 | if data_loader_val is not None: 446 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp) 447 | print(f"Accuracy of the model on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 448 | if max_accuracy < test_stats["acc1"]: 449 | max_accuracy = test_stats["acc1"] 450 | if args.output_dir and args.save_ckpt: 451 | utils.save_model( 452 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 453 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 454 | print(f'Max accuracy: {max_accuracy:.2f}%') 455 | 456 | if log_writer is not None: 457 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 458 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 459 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 460 | 461 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 462 | **{f'test_{k}': v for k, v in test_stats.items()}, 463 | 'epoch': epoch, 464 | 'n_parameters': n_parameters} 465 | 466 | # Repeat testing routines for EMA, if ema eval is turned on 467 | if args.model_ema and args.model_ema_eval: 468 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp) 469 | print(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 470 | if max_accuracy_ema < test_stats_ema["acc1"]: 471 | max_accuracy_ema = test_stats_ema["acc1"] 472 | if args.output_dir and args.save_ckpt: 473 | utils.save_model( 474 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 475 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 476 | print(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 477 | if log_writer is not None: 478 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 479 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 480 | else: 481 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 482 | 'epoch': epoch, 483 | 'n_parameters': n_parameters} 484 | 485 | if args.output_dir and utils.is_main_process(): 486 | if log_writer is not None: 487 | log_writer.flush() 488 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 489 | f.write(json.dumps(log_stats) + "\n") 490 | 491 | if wandb_logger: 492 | wandb_logger.log_epoch_metrics(log_stats) 493 | 494 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 495 | wandb_logger.log_checkpoints() 496 | 497 | total_time = time.time() - start_time 498 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 499 | print('Training time {}'.format(total_time_str)) 500 | 501 | 502 | if __name__ == '__main__': 503 | parser = argparse.ArgumentParser('EdgeNeXt training and evaluation script', parents=[get_args_parser()]) 504 | args = parser.parse_args() 505 | if args.output_dir: 506 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 507 | main(args) 508 | -------------------------------------------------------------------------------- /models/conv_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import DropPath 4 | from .layers import LayerNorm 5 | 6 | 7 | class ConvEncoder(nn.Module): 8 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7): 9 | super().__init__() 10 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim) 11 | self.norm = LayerNorm(dim, eps=1e-6) 12 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim) 13 | self.act = nn.GELU() 14 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim) 15 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), 16 | requires_grad=True) if layer_scale_init_value > 0 else None 17 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 18 | 19 | def forward(self, x): 20 | input = x 21 | x = self.dwconv(x) 22 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 23 | x = self.norm(x) 24 | x = self.pwconv1(x) 25 | x = self.act(x) 26 | x = self.pwconv2(x) 27 | if self.gamma is not None: 28 | x = self.gamma * x 29 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 30 | 31 | x = input + self.drop_path(x) 32 | return x 33 | 34 | 35 | class ConvEncoderBNHS(nn.Module): 36 | """ 37 | Conv. Encoder with Batch Norm and Hard-Swish Activation 38 | """ 39 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, kernel_size=7): 40 | super().__init__() 41 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, bias=False) 42 | self.norm = nn.BatchNorm2d(dim) 43 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim) 44 | self.act = nn.Hardswish() 45 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim) 46 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones(dim), 47 | requires_grad=True) if layer_scale_init_value > 0 else None 48 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 49 | 50 | def forward(self, x): 51 | input = x 52 | x = self.dwconv(x) 53 | x = self.norm(x) 54 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 55 | x = self.pwconv1(x) 56 | x = self.act(x) 57 | x = self.pwconv2(x) 58 | if self.gamma is not None: 59 | x = self.gamma * x 60 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 61 | 62 | x = input + self.drop_path(x) 63 | return x 64 | -------------------------------------------------------------------------------- /models/edgenext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import trunc_normal_ 4 | from .layers import LayerNorm, PositionalEncodingFourier 5 | from .sdta_encoder import SDTAEncoder 6 | from .conv_encoder import ConvEncoder 7 | 8 | 9 | class EdgeNeXt(nn.Module): 10 | def __init__(self, in_chans=3, num_classes=1000, 11 | depths=[3, 3, 9, 3], dims=[24, 48, 88, 168], 12 | global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA'], 13 | drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4, 14 | kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False], 15 | use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], **kwargs): 16 | super().__init__() 17 | for g in global_block_type: 18 | assert g in ['None', 'SDTA'] 19 | if use_pos_embd_global: 20 | self.pos_embd = PositionalEncodingFourier(dim=dims[0]) 21 | else: 22 | self.pos_embd = None 23 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 24 | stem = nn.Sequential( 25 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 26 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 27 | ) 28 | self.downsample_layers.append(stem) 29 | for i in range(3): 30 | downsample_layer = nn.Sequential( 31 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 32 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), 33 | ) 34 | self.downsample_layers.append(downsample_layer) 35 | 36 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 37 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 38 | cur = 0 39 | for i in range(4): 40 | stage_blocks = [] 41 | for j in range(depths[i]): 42 | if j > depths[i] - global_block[i] - 1: 43 | if global_block_type[i] == 'SDTA': 44 | stage_blocks.append(SDTAEncoder(dim=dims[i], drop_path=dp_rates[cur + j], 45 | expan_ratio=expan_ratio, scales=d2_scales[i], 46 | use_pos_emb=use_pos_embd_xca[i], num_heads=heads[i])) 47 | else: 48 | raise NotImplementedError 49 | else: 50 | stage_blocks.append(ConvEncoder(dim=dims[i], drop_path=dp_rates[cur + j], 51 | layer_scale_init_value=layer_scale_init_value, 52 | expan_ratio=expan_ratio, kernel_size=kernel_sizes[i])) 53 | 54 | self.stages.append(nn.Sequential(*stage_blocks)) 55 | cur += depths[i] 56 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # Final norm layer 57 | self.head = nn.Linear(dims[-1], num_classes) 58 | 59 | self.apply(self._init_weights) 60 | self.head_dropout = nn.Dropout(kwargs["classifier_dropout"]) 61 | self.head.weight.data.mul_(head_init_scale) 62 | self.head.bias.data.mul_(head_init_scale) 63 | 64 | def _init_weights(self, m): # TODO: MobileViT is using 'kaiming_normal' for initializing conv layers 65 | if isinstance(m, (nn.Conv2d, nn.Linear)): 66 | trunc_normal_(m.weight, std=.02) 67 | if m.bias is not None: 68 | nn.init.constant_(m.bias, 0) 69 | elif isinstance(m, (LayerNorm, nn.LayerNorm)): 70 | nn.init.constant_(m.bias, 0) 71 | nn.init.constant_(m.weight, 1.0) 72 | 73 | def forward_features(self, x): 74 | x = self.downsample_layers[0](x) 75 | x = self.stages[0](x) 76 | if self.pos_embd: 77 | B, C, H, W = x.shape 78 | x = x + self.pos_embd(B, H, W) 79 | for i in range(1, 4): 80 | x = self.downsample_layers[i](x) 81 | x = self.stages[i](x) 82 | 83 | return self.norm(x.mean([-2, -1])) # Global average pooling, (N, C, H, W) -> (N, C) 84 | 85 | def forward(self, x): 86 | x = self.forward_features(x) 87 | x = self.head(self.head_dropout(x)) 88 | return x 89 | -------------------------------------------------------------------------------- /models/edgenext_bn_hs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import trunc_normal_ 4 | from .layers import LayerNorm, PositionalEncodingFourier 5 | from .sdta_encoder import SDTAEncoderBNHS 6 | from .conv_encoder import ConvEncoderBNHS 7 | 8 | 9 | class EdgeNeXtBNHS(nn.Module): 10 | def __init__(self, in_chans=3, num_classes=1000, 11 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 12 | global_block=[0, 0, 0, 3], global_block_type=['None', 'None', 'None', 'SDTA_BN_HS'], 13 | drop_path_rate=0., layer_scale_init_value=1e-6, head_init_scale=1., expan_ratio=4, 14 | kernel_sizes=[7, 7, 7, 7], heads=[8, 8, 8, 8], use_pos_embd_xca=[False, False, False, False], 15 | use_pos_embd_global=False, d2_scales=[2, 3, 4, 5], **kwargs): 16 | super().__init__() 17 | for g in global_block_type: 18 | assert g in ['None', 'SDTA_BN_HS'] 19 | 20 | if use_pos_embd_global: 21 | self.pos_embd = PositionalEncodingFourier(dim=dims[0]) 22 | else: 23 | self.pos_embd = None 24 | 25 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 26 | stem = nn.Sequential( 27 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=False), 28 | nn.BatchNorm2d(dims[0]) 29 | ) 30 | self.downsample_layers.append(stem) 31 | for i in range(3): 32 | downsample_layer = nn.Sequential( 33 | nn.BatchNorm2d(dims[i]), 34 | nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2, bias=False), 35 | ) 36 | self.downsample_layers.append(downsample_layer) 37 | 38 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 39 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 40 | cur = 0 41 | for i in range(4): 42 | stage_blocks = [] 43 | for j in range(depths[i]): 44 | if j > depths[i] - global_block[i] - 1: 45 | if global_block_type[i] == 'SDTA_BN_HS': 46 | stage_blocks.append(SDTAEncoderBNHS(dim=dims[i], drop_path=dp_rates[cur + j], 47 | expan_ratio=expan_ratio, scales=d2_scales[i], 48 | use_pos_emb=use_pos_embd_xca[i], 49 | num_heads=heads[i])) 50 | else: 51 | raise NotImplementedError 52 | else: 53 | stage_blocks.append(ConvEncoderBNHS(dim=dims[i], drop_path=dp_rates[cur + j], 54 | layer_scale_init_value=layer_scale_init_value, 55 | expan_ratio=expan_ratio, kernel_size=kernel_sizes[i])) 56 | 57 | self.stages.append(nn.Sequential(*stage_blocks)) 58 | cur += depths[i] 59 | self.norm = nn.BatchNorm2d(dims[-1]) 60 | self.head = nn.Linear(dims[-1], num_classes) 61 | 62 | self.apply(self._init_weights) 63 | self.head_dropout = nn.Dropout(kwargs["classifier_dropout"]) 64 | self.head.weight.data.mul_(head_init_scale) 65 | self.head.bias.data.mul_(head_init_scale) 66 | 67 | def _init_weights(self, m): # TODO: MobileViT is using 'kaiming_normal' for initializing conv layers 68 | if isinstance(m, (nn.Conv2d, nn.Linear)): 69 | trunc_normal_(m.weight, std=.02) 70 | if m.bias is not None: 71 | nn.init.constant_(m.bias, 0) 72 | elif isinstance(m, (LayerNorm, nn.LayerNorm)): 73 | nn.init.constant_(m.bias, 0) 74 | nn.init.constant_(m.weight, 1.0) 75 | 76 | def forward_features(self, x): 77 | x = self.downsample_layers[0](x) 78 | x = self.stages[0](x) 79 | if self.pos_embd: 80 | B, C, H, W = x.shape 81 | x = x + self.pos_embd(B, H, W) 82 | for i in range(1, 4): 83 | x = self.downsample_layers[i](x) 84 | x = self.stages[i](x) 85 | return self.norm(x).mean([-2, -1]) 86 | 87 | def forward(self, x): 88 | x = self.forward_features(x) 89 | x = self.head(self.head_dropout(x)) 90 | return x 91 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import math 5 | 6 | 7 | class LayerNorm(nn.Module): 8 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 9 | super().__init__() 10 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 11 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 12 | self.eps = eps 13 | self.data_format = data_format 14 | if self.data_format not in ["channels_last", "channels_first"]: 15 | raise NotImplementedError 16 | self.normalized_shape = (normalized_shape,) 17 | 18 | def forward(self, x): 19 | if self.data_format == "channels_last": 20 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 21 | elif self.data_format == "channels_first": 22 | u = x.mean(1, keepdim=True) 23 | s = (x - u).pow(2).mean(1, keepdim=True) 24 | x = (x - u) / torch.sqrt(s + self.eps) 25 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 26 | return x 27 | 28 | 29 | class PositionalEncodingFourier(nn.Module): 30 | def __init__(self, hidden_dim=32, dim=768, temperature=10000): 31 | super().__init__() 32 | self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) 33 | self.scale = 2 * math.pi 34 | self.temperature = temperature 35 | self.hidden_dim = hidden_dim 36 | self.dim = dim 37 | 38 | def forward(self, B, H, W): 39 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 40 | not_mask = ~mask 41 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 42 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 43 | eps = 1e-6 44 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 45 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 46 | 47 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 48 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) 49 | 50 | pos_x = x_embed[:, :, :, None] / dim_t 51 | pos_y = y_embed[:, :, :, None] / dim_t 52 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 53 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 55 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 56 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 57 | pos = self.token_projection(pos) 58 | 59 | return pos 60 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from .edgenext import EdgeNeXt 2 | from .edgenext_bn_hs import EdgeNeXtBNHS 3 | from timm.models.registry import register_model 4 | 5 | """ 6 | -- Main Models 7 | XX-Small -> 1.3M 8 | X-Small -> 2.3M 9 | Small -> 5.6M 10 | """ 11 | 12 | 13 | @register_model 14 | def edgenext_xx_small(pretrained=False, **kwargs): 15 | # 1.33M & 260.58M @ 256 resolution 16 | # 71.23% Top-1 accuracy 17 | # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler 18 | # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS 19 | # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS 20 | model = EdgeNeXt(depths=[2, 2, 6, 2], dims=[24, 48, 88, 168], expan_ratio=4, 21 | global_block=[0, 1, 1, 1], 22 | global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], 23 | use_pos_embd_xca=[False, True, False, False], 24 | kernel_sizes=[3, 5, 7, 9], 25 | heads=[4, 4, 4, 4], 26 | d2_scales=[2, 2, 3, 4], 27 | **kwargs) 28 | 29 | return model 30 | 31 | 32 | @register_model 33 | def edgenext_x_small(pretrained=False, **kwargs): 34 | # 2.34M & 538.0M @ 256 resolution 35 | # 75.00% Top-1 accuracy 36 | # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler 37 | # Jetson FPS=31.61 versus 28.49 for MobileViT_XS 38 | # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS 39 | model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[32, 64, 100, 192], expan_ratio=4, 40 | global_block=[0, 1, 1, 1], 41 | global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], 42 | use_pos_embd_xca=[False, True, False, False], 43 | kernel_sizes=[3, 5, 7, 9], 44 | heads=[4, 4, 4, 4], 45 | d2_scales=[2, 2, 3, 4], 46 | **kwargs) 47 | 48 | return model 49 | 50 | 51 | @register_model 52 | def edgenext_small(pretrained=False, **kwargs): 53 | # 5.59M & 1260.59M @ 256 resolution 54 | # 79.43% Top-1 accuracy 55 | # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler 56 | # Jetson FPS=20.47 versus 18.86 for MobileViT_S 57 | # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S 58 | model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4, 59 | global_block=[0, 1, 1, 1], 60 | global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], 61 | use_pos_embd_xca=[False, True, False, False], 62 | kernel_sizes=[3, 5, 7, 9], 63 | d2_scales=[2, 2, 3, 4], 64 | **kwargs) 65 | 66 | return model 67 | 68 | 69 | @register_model 70 | def edgenext_base(pretrained=False, **kwargs): 71 | # 18.51M & 3840.93M @ 256 resolution 72 | # 82.5% (normal) 83.7% (USI) Top-1 accuracy 73 | # AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler 74 | # Jetson FPS=xx.xx versus xx.xx for MobileViT_S 75 | # For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx 76 | model = EdgeNeXt(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], expan_ratio=4, 77 | global_block=[0, 1, 1, 1], 78 | global_block_type=['None', 'SDTA', 'SDTA', 'SDTA'], 79 | use_pos_embd_xca=[False, True, False, False], 80 | kernel_sizes=[3, 5, 7, 9], 81 | d2_scales=[2, 2, 3, 4], 82 | **kwargs) 83 | 84 | return model 85 | 86 | 87 | """ 88 | Using BN & HSwish instead of LN & GeLU 89 | """ 90 | 91 | 92 | @register_model 93 | def edgenext_xx_small_bn_hs(pretrained=False, **kwargs): 94 | # 1.33M & 259.53M @ 256 resolution 95 | # 70.33% Top-1 accuracy 96 | # For A100: FPS @ BS=1: 219.66 & @ BS=256: 10359.98 97 | model = EdgeNeXtBNHS(depths=[2, 2, 6, 2], dims=[24, 48, 88, 168], expan_ratio=4, 98 | global_block=[0, 1, 1, 1], 99 | global_block_type=['None', 'SDTA_BN_HS', 'SDTA_BN_HS', 'SDTA_BN_HS'], 100 | use_pos_embd_xca=[False, True, False, False], 101 | kernel_sizes=[3, 5, 7, 9], 102 | heads=[4, 4, 4, 4], 103 | d2_scales=[2, 2, 3, 4], 104 | **kwargs) 105 | 106 | return model 107 | 108 | 109 | @register_model 110 | def edgenext_x_small_bn_hs(pretrained=False, **kwargs): 111 | # 2.34M & 535.84M @ 256 resolution 112 | # 74.87% Top-1 accuracy 113 | # For A100: FPS @ BS=1: 179.25 & @ BS=256: 6059.59 114 | model = EdgeNeXtBNHS(depths=[3, 3, 9, 3], dims=[32, 64, 100, 192], expan_ratio=4, 115 | global_block=[0, 1, 1, 1], 116 | global_block_type=['None', 'SDTA_BN_HS', 'SDTA_BN_HS', 'SDTA_BN_HS'], 117 | use_pos_embd_xca=[False, True, False, False], 118 | kernel_sizes=[3, 5, 7, 9], 119 | heads=[4, 4, 4, 4], 120 | d2_scales=[2, 2, 3, 4], 121 | **kwargs) 122 | 123 | return model 124 | 125 | 126 | @register_model 127 | def edgenext_small_bn_hs(pretrained=False, **kwargs): 128 | # 5.58M & 1257.28M @ 256 resolution 129 | # 78.39% Top-1 accuracy 130 | # For A100: FPS @ BS=1: 174.68 & @ BS=256: 3808.19 131 | model = EdgeNeXtBNHS(depths=[3, 3, 9, 3], dims=[48, 96, 160, 304], expan_ratio=4, 132 | global_block=[0, 1, 1, 1], 133 | global_block_type=['None', 'SDTA_BN_HS', 'SDTA_BN_HS', 'SDTA_BN_HS'], 134 | use_pos_embd_xca=[False, True, False, False], 135 | kernel_sizes=[3, 5, 7, 9], 136 | d2_scales=[2, 2, 3, 4], 137 | **kwargs) 138 | 139 | return model 140 | -------------------------------------------------------------------------------- /models/sdta_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import DropPath 4 | from .layers import LayerNorm, PositionalEncodingFourier 5 | import math 6 | 7 | 8 | class SDTAEncoder(nn.Module): 9 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, 10 | use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., drop=0., scales=1): 11 | super().__init__() 12 | width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales))) 13 | self.width = width 14 | if scales == 1: 15 | self.nums = 1 16 | else: 17 | self.nums = scales - 1 18 | convs = [] 19 | for i in range(self.nums): 20 | convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, groups=width)) 21 | self.convs = nn.ModuleList(convs) 22 | 23 | self.pos_embd = None 24 | if use_pos_emb: 25 | self.pos_embd = PositionalEncodingFourier(dim=dim) 26 | self.norm_xca = LayerNorm(dim, eps=1e-6) 27 | self.gamma_xca = nn.Parameter(layer_scale_init_value * torch.ones(dim), 28 | requires_grad=True) if layer_scale_init_value > 0 else None 29 | self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 30 | 31 | self.norm = LayerNorm(dim, eps=1e-6) 32 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim) # pointwise/1x1 convs, implemented with linear layers 33 | self.act = nn.GELU() # TODO: MobileViT is using 'swish' 34 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim) 35 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 36 | requires_grad=True) if layer_scale_init_value > 0 else None 37 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 38 | 39 | def forward(self, x): 40 | input = x 41 | 42 | spx = torch.split(x, self.width, 1) 43 | for i in range(self.nums): 44 | if i == 0: 45 | sp = spx[i] 46 | else: 47 | sp = sp + spx[i] 48 | sp = self.convs[i](sp) 49 | if i == 0: 50 | out = sp 51 | else: 52 | out = torch.cat((out, sp), 1) 53 | x = torch.cat((out, spx[self.nums]), 1) 54 | # XCA 55 | B, C, H, W = x.shape 56 | x = x.reshape(B, C, H * W).permute(0, 2, 1) 57 | if self.pos_embd: 58 | pos_encoding = self.pos_embd(B, H, W).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 59 | x = x + pos_encoding 60 | x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) 61 | x = x.reshape(B, H, W, C) 62 | 63 | # Inverted Bottleneck 64 | x = self.norm(x) 65 | x = self.pwconv1(x) 66 | x = self.act(x) 67 | x = self.pwconv2(x) 68 | if self.gamma is not None: 69 | x = self.gamma * x 70 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 71 | 72 | x = input + self.drop_path(x) 73 | 74 | return x 75 | 76 | 77 | class SDTAEncoderBNHS(nn.Module): 78 | """ 79 | SDTA Encoder with Batch Norm and Hard-Swish Activation 80 | """ 81 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, expan_ratio=4, 82 | use_pos_emb=True, num_heads=8, qkv_bias=True, attn_drop=0., drop=0., scales=1): 83 | super().__init__() 84 | width = max(int(math.ceil(dim / scales)), int(math.floor(dim // scales))) 85 | self.width = width 86 | if scales == 1: 87 | self.nums = 1 88 | else: 89 | self.nums = scales - 1 90 | convs = [] 91 | for i in range(self.nums): 92 | convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, groups=width)) 93 | self.convs = nn.ModuleList(convs) 94 | 95 | self.pos_embd = None 96 | if use_pos_emb: 97 | self.pos_embd = PositionalEncodingFourier(dim=dim) 98 | self.norm_xca = nn.BatchNorm2d(dim) 99 | self.gamma_xca = nn.Parameter(layer_scale_init_value * torch.ones(dim), 100 | requires_grad=True) if layer_scale_init_value > 0 else None 101 | self.xca = XCA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 102 | 103 | self.norm = nn.BatchNorm2d(dim) 104 | self.pwconv1 = nn.Linear(dim, expan_ratio * dim) # pointwise/1x1 convs, implemented with linear layers 105 | self.act = nn.Hardswish() # TODO: MobileViT is using 'swish' 106 | self.pwconv2 = nn.Linear(expan_ratio * dim, dim) 107 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 108 | requires_grad=True) if layer_scale_init_value > 0 else None 109 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 110 | 111 | def forward(self, x): 112 | input = x 113 | 114 | spx = torch.split(x, self.width, 1) 115 | for i in range(self.nums): 116 | if i == 0: 117 | sp = spx[i] 118 | else: 119 | sp = sp + spx[i] 120 | sp = self.convs[i](sp) 121 | if i == 0: 122 | out = sp 123 | else: 124 | out = torch.cat((out, sp), 1) 125 | x = torch.cat((out, spx[self.nums]), 1) 126 | # XCA 127 | x = self.norm_xca(x) 128 | B, C, H, W = x.shape 129 | x = x.reshape(B, C, H * W).permute(0, 2, 1) 130 | if self.pos_embd: 131 | pos_encoding = self.pos_embd(B, H, W).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 132 | x = x + pos_encoding 133 | x = x + self.drop_path(self.gamma_xca * self.xca(x)) 134 | x = x.reshape(B, H, W, C).permute(0, 3, 1, 2) 135 | 136 | # Inverted Bottleneck 137 | x = self.norm(x) 138 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 139 | x = self.pwconv1(x) 140 | x = self.act(x) 141 | x = self.pwconv2(x) 142 | if self.gamma is not None: 143 | x = self.gamma * x 144 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 145 | 146 | x = input + self.drop_path(x) 147 | 148 | return x 149 | 150 | 151 | class XCA(nn.Module): 152 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 153 | super().__init__() 154 | self.num_heads = num_heads 155 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 156 | 157 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 158 | self.attn_drop = nn.Dropout(attn_drop) 159 | self.proj = nn.Linear(dim, dim) 160 | self.proj_drop = nn.Dropout(proj_drop) 161 | 162 | def forward(self, x): 163 | B, N, C = x.shape 164 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 165 | qkv = qkv.permute(2, 0, 3, 1, 4) 166 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 167 | 168 | q = q.transpose(-2, -1) 169 | k = k.transpose(-2, -1) 170 | v = v.transpose(-2, -1) 171 | 172 | q = torch.nn.functional.normalize(q, dim=-1) 173 | k = torch.nn.functional.normalize(k, dim=-1) 174 | 175 | attn = (q @ k.transpose(-2, -1)) * self.temperature 176 | # ------------------- 177 | attn = attn.softmax(dim=-1) 178 | attn = self.attn_drop(attn) 179 | 180 | x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) 181 | # ------------------ 182 | x = self.proj(x) 183 | x = self.proj_drop(x) 184 | 185 | return x 186 | 187 | @torch.jit.ignore 188 | def no_weight_decay(self): 189 | return {'temperature'} 190 | -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | from timm.optim.adafactor import Adafactor 4 | from timm.optim.adahessian import Adahessian 5 | from timm.optim.adamp import AdamP 6 | from timm.optim.lookahead import Lookahead 7 | from timm.optim.nadam import Nadam 8 | from timm.optim.novograd import NovoGrad 9 | from timm.optim.nvnovograd import NvNovoGrad 10 | from timm.optim.radam import RAdam 11 | from timm.optim.rmsprop_tf import RMSpropTF 12 | from timm.optim.sgdp import SGDP 13 | 14 | import json 15 | 16 | try: 17 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 18 | 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_convnext(var_name): 25 | """ 26 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 27 | consecutive blocks, including possible neighboring downsample layers; 28 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 29 | """ 30 | num_max_layer = 12 31 | if var_name.startswith("downsample_layers"): 32 | stage_id = int(var_name.split('.')[1]) 33 | if stage_id == 0: 34 | layer_id = 0 35 | elif stage_id == 1 or stage_id == 2: 36 | layer_id = stage_id + 1 37 | elif stage_id == 3: 38 | layer_id = 12 39 | return layer_id 40 | 41 | elif var_name.startswith("stages"): 42 | stage_id = int(var_name.split('.')[1]) 43 | block_id = int(var_name.split('.')[2]) 44 | if stage_id == 0 or stage_id == 1: 45 | layer_id = stage_id + 1 46 | elif stage_id == 2: 47 | layer_id = 3 + block_id // 3 48 | elif stage_id == 3: 49 | layer_id = 12 50 | return layer_id 51 | else: 52 | return num_max_layer + 1 53 | 54 | 55 | class LayerDecayValueAssigner(object): 56 | def __init__(self, values): 57 | self.values = values 58 | 59 | def get_scale(self, layer_id): 60 | return self.values[layer_id] 61 | 62 | def get_layer_id(self, var_name): 63 | return get_num_layer_for_convnext(var_name) 64 | 65 | 66 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 67 | parameter_group_names = {} 68 | parameter_group_vars = {} 69 | 70 | for name, param in model.named_parameters(): 71 | if not param.requires_grad: 72 | continue # frozen weights 73 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 74 | group_name = "no_decay" 75 | this_weight_decay = 0. 76 | else: 77 | group_name = "decay" 78 | this_weight_decay = weight_decay 79 | if get_num_layer is not None: 80 | layer_id = get_num_layer(name) 81 | group_name = "layer_%d_%s" % (layer_id, group_name) 82 | else: 83 | layer_id = None 84 | 85 | if group_name not in parameter_group_names: 86 | if get_layer_scale is not None: 87 | scale = get_layer_scale(layer_id) 88 | else: 89 | scale = 1. 90 | 91 | parameter_group_names[group_name] = { 92 | "weight_decay": this_weight_decay, 93 | "params": [], 94 | "lr_scale": scale 95 | } 96 | parameter_group_vars[group_name] = { 97 | "weight_decay": this_weight_decay, 98 | "params": [], 99 | "lr_scale": scale 100 | } 101 | 102 | parameter_group_vars[group_name]["params"].append(param) 103 | parameter_group_names[group_name]["params"].append(name) 104 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 105 | return list(parameter_group_vars.values()) 106 | 107 | 108 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 109 | opt_lower = args.opt.lower() 110 | weight_decay = args.weight_decay 111 | # if weight_decay and filter_bias_and_bn: 112 | if filter_bias_and_bn: 113 | skip = {} 114 | if skip_list is not None: 115 | skip = skip_list 116 | elif hasattr(model, 'no_weight_decay'): 117 | skip = model.no_weight_decay() 118 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 119 | weight_decay = 0. 120 | else: 121 | parameters = model.parameters() 122 | 123 | if 'fused' in opt_lower: 124 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 125 | 126 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 127 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 128 | opt_args['eps'] = args.opt_eps 129 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 130 | opt_args['betas'] = args.opt_betas 131 | 132 | opt_split = opt_lower.split('_') 133 | opt_lower = opt_split[-1] 134 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 135 | opt_args.pop('eps', None) 136 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 137 | elif opt_lower == 'momentum': 138 | opt_args.pop('eps', None) 139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 140 | elif opt_lower == 'adam': 141 | optimizer = optim.Adam(parameters, **opt_args) 142 | elif opt_lower == 'adamw': 143 | optimizer = optim.AdamW(parameters, **opt_args) 144 | elif opt_lower == 'nadam': 145 | optimizer = Nadam(parameters, **opt_args) 146 | elif opt_lower == 'radam': 147 | optimizer = RAdam(parameters, **opt_args) 148 | elif opt_lower == 'adamp': 149 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 150 | elif opt_lower == 'sgdp': 151 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 152 | elif opt_lower == 'adadelta': 153 | optimizer = optim.Adadelta(parameters, **opt_args) 154 | elif opt_lower == 'adafactor': 155 | if not args.lr: 156 | opt_args['lr'] = None 157 | optimizer = Adafactor(parameters, **opt_args) 158 | elif opt_lower == 'adahessian': 159 | optimizer = Adahessian(parameters, **opt_args) 160 | elif opt_lower == 'rmsprop': 161 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 162 | elif opt_lower == 'rmsproptf': 163 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 164 | elif opt_lower == 'novograd': 165 | optimizer = NovoGrad(parameters, **opt_args) 166 | elif opt_lower == 'nvnovograd': 167 | optimizer = NvNovoGrad(parameters, **opt_args) 168 | elif opt_lower == 'fusedsgd': 169 | opt_args.pop('eps', None) 170 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 171 | elif opt_lower == 'fusedmomentum': 172 | opt_args.pop('eps', None) 173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 174 | elif opt_lower == 'fusedadam': 175 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 176 | elif opt_lower == 'fusedadamw': 177 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 178 | elif opt_lower == 'fusedlamb': 179 | optimizer = FusedLAMB(parameters, **opt_args) 180 | elif opt_lower == 'fusednovograd': 181 | opt_args.setdefault('betas', (0.95, 0.98)) 182 | optimizer = FusedNovoGrad(parameters, **opt_args) 183 | else: 184 | assert False and "Invalid optimizer" 185 | 186 | if len(opt_split) > 1: 187 | if opt_split[0] == 'lookahead': 188 | optimizer = Lookahead(optimizer) 189 | 190 | return optimizer 191 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | tensorboardX==2.2 3 | six==1.16.0 4 | fvcore==0.1.5.post20220414 5 | protobuf==3.20.* -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | import torch.distributed as dist 3 | import math 4 | import random 5 | import numpy as np 6 | from torchvision.datasets import ImageFolder 7 | from timm.data import create_transform 8 | from timm.data.constants import \ 9 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 10 | from torchvision import transforms 11 | from typing import Tuple 12 | from typing import Optional, Union 13 | 14 | 15 | class MultiScaleSamplerDDP(Sampler): 16 | def __init__(self, base_im_w: int, base_im_h: int, base_batch_size: int, n_data_samples: int, 17 | min_crop_size_w: int = 160, max_crop_size_w: int = 320, 18 | min_crop_size_h: int = 160, max_crop_size_h: int = 320, 19 | n_scales: int = 5, is_training: bool = True, distributed=True) -> None: 20 | # min. and max. spatial dimensions 21 | min_im_w, max_im_w = min_crop_size_w, max_crop_size_w 22 | min_im_h, max_im_h = min_crop_size_h, max_crop_size_h 23 | 24 | # Get the GPU and node related information 25 | if not distributed: 26 | num_replicas = 1 27 | rank = 0 28 | else: 29 | num_replicas = dist.get_world_size() 30 | rank = dist.get_rank() 31 | 32 | # adjust the total samples to avoid batch dropping 33 | num_samples_per_replica = int(math.ceil(n_data_samples * 1.0 / num_replicas)) 34 | total_size = num_samples_per_replica * num_replicas 35 | img_indices = [idx for idx in range(n_data_samples)] 36 | img_indices += img_indices[:(total_size - n_data_samples)] 37 | assert len(img_indices) == total_size 38 | 39 | self.shuffle = True if is_training else False 40 | if is_training: 41 | self.img_batch_pairs = _image_batch_pairs(base_im_w, base_im_h, base_batch_size, num_replicas, n_scales, 32, 42 | min_im_w, max_im_w, min_im_h, max_im_h) 43 | else: 44 | self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)] 45 | 46 | self.img_indices = img_indices 47 | self.n_samples_per_replica = num_samples_per_replica 48 | self.epoch = 0 49 | self.rank = rank 50 | self.num_replicas = num_replicas 51 | self.batch_size_gpu0 = base_batch_size 52 | 53 | def __iter__(self): 54 | if self.shuffle: 55 | random.seed(self.epoch) 56 | random.shuffle(self.img_indices) 57 | random.shuffle(self.img_batch_pairs) 58 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas] 59 | else: 60 | indices_rank_i = self.img_indices[self.rank:len(self.img_indices):self.num_replicas] 61 | 62 | start_index = 0 63 | while start_index < self.n_samples_per_replica: 64 | curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs) 65 | 66 | end_index = min(start_index + curr_bsz, self.n_samples_per_replica) 67 | batch_ids = indices_rank_i[start_index:end_index] 68 | n_batch_samples = len(batch_ids) 69 | if n_batch_samples != curr_bsz: 70 | batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] 71 | start_index += curr_bsz 72 | 73 | if len(batch_ids) > 0: 74 | batch = [(curr_h, curr_w, b_id) for b_id in batch_ids] 75 | yield batch 76 | 77 | def set_epoch(self, epoch: int) -> None: 78 | self.epoch = epoch 79 | 80 | def __len__(self): 81 | return self.n_samples_per_replica 82 | 83 | 84 | def _image_batch_pairs(crop_size_w: int, 85 | crop_size_h: int, 86 | batch_size_gpu0: int, 87 | n_gpus: int, 88 | max_scales: Optional[float] = 5, 89 | check_scale_div_factor: Optional[int] = 32, 90 | min_crop_size_w: Optional[int] = 160, 91 | max_crop_size_w: Optional[int] = 320, 92 | min_crop_size_h: Optional[int] = 160, 93 | max_crop_size_h: Optional[int] = 320, 94 | *args, **kwargs) -> list: 95 | """ 96 | This function creates batch and image size pairs. For a given batch size and image size, different image sizes 97 | are generated and batch size is adjusted so that GPU memory can be utilized efficiently. 98 | 99 | :param crop_size_w: Base Image width (e.g., 224) 100 | :param crop_size_h: Base Image height (e.g., 224) 101 | :param batch_size_gpu0: Batch size on GPU 0 for base image 102 | :param n_gpus: Number of available GPUs 103 | :param max_scales: Number of scales. How many image sizes that we want to generate between min and max scale factors. 104 | :param check_scale_div_factor: Check if image scales are divisible by this factor. 105 | :param min_crop_size_w: Min. crop size along width 106 | :param max_crop_size_w: Max. crop size along width 107 | :param min_crop_size_h: Min. crop size along height 108 | :param max_crop_size_h: Max. crop size along height 109 | :param args: 110 | :param kwargs: 111 | :return: a sorted list of tuples. Each index is of the form (h, w, batch_size) 112 | """ 113 | 114 | width_dims = list(np.linspace(min_crop_size_w, max_crop_size_w, max_scales)) 115 | if crop_size_w not in width_dims: 116 | width_dims.append(crop_size_w) 117 | 118 | height_dims = list(np.linspace(min_crop_size_h, max_crop_size_h, max_scales)) 119 | if crop_size_h not in height_dims: 120 | height_dims.append(crop_size_h) 121 | 122 | image_scales = set() 123 | 124 | for h, w in zip(height_dims, width_dims): 125 | # ensure that sampled sizes are divisible by check_scale_div_factor 126 | # This is important in some cases where input undergoes a fixed number of down-sampling stages 127 | # for instance, in ImageNet training, CNNs usually have 5 downsampling stages, which downsamples the 128 | # input image of resolution 224x224 to 7x7 size 129 | h = make_divisible(h, check_scale_div_factor) 130 | w = make_divisible(w, check_scale_div_factor) 131 | image_scales.add((h, w)) 132 | 133 | image_scales = list(image_scales) 134 | 135 | img_batch_tuples = set() 136 | n_elements = crop_size_w * crop_size_h * batch_size_gpu0 137 | for (crop_h, crop_y) in image_scales: 138 | # compute the batch size for sampled image resolutions with respect to the base resolution 139 | _bsz = max(batch_size_gpu0, int(round(n_elements/(crop_h * crop_y), 2))) 140 | 141 | _bsz = make_divisible(_bsz, n_gpus) 142 | _bsz = _bsz if _bsz % 2 == 0 else _bsz - 1 # Batch size must be even 143 | img_batch_tuples.add((crop_h, crop_y, _bsz)) 144 | 145 | img_batch_tuples = list(img_batch_tuples) 146 | return sorted(img_batch_tuples) 147 | 148 | 149 | def make_divisible(v: Union[float, int], 150 | divisor: Optional[int] = 8, 151 | min_value: Optional[Union[float, int]] = None) -> Union[float, int]: 152 | """ 153 | This function is taken from the original tf repo. 154 | It ensures that all layers have a channel number that is divisible by 8 155 | It can be seen here: 156 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 157 | :param v: 158 | :param divisor: 159 | :param min_value: 160 | :return: 161 | """ 162 | if min_value is None: 163 | min_value = divisor 164 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 165 | # Make sure that round down does not go down by more than 10%. 166 | if new_v < 0.9 * v: 167 | new_v += divisor 168 | return new_v 169 | 170 | 171 | class MultiScaleImageFolder(ImageFolder): 172 | def __init__(self, root, args) -> None: 173 | self.args = args 174 | ImageFolder.__init__(self, root=root, transform=None, target_transform=None, is_valid_file=None) 175 | 176 | def get_transforms(self, size: int): 177 | imagenet_default_mean_and_std = self.args.imagenet_default_mean_and_std 178 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 179 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 180 | resize_im = size > 32 181 | transform = create_transform( 182 | input_size=size, 183 | is_training=True, 184 | color_jitter=self.args.color_jitter, 185 | auto_augment=self.args.aa, 186 | interpolation=self.args.train_interpolation, 187 | re_prob=self.args.reprob, 188 | re_mode=self.args.remode, 189 | re_count=self.args.recount, 190 | mean=mean, 191 | std=std, 192 | ) 193 | if not resize_im: 194 | transform.transforms[0] = transforms.RandomCrop(size, padding=4) 195 | 196 | return transform 197 | 198 | def __getitem__(self, batch_indexes_tup: Tuple): 199 | crop_size_h, crop_size_w, img_index = batch_indexes_tup 200 | transforms = self.get_transforms(size=int(crop_size_w)) 201 | 202 | path, target = self.samples[img_index] 203 | sample = self.loader(path) 204 | if transforms is not None: 205 | sample = transforms(sample) 206 | if self.target_transform is not None: 207 | target = self.target_transform(target) 208 | 209 | return sample, target 210 | 211 | def __len__(self): 212 | return len(self.samples) 213 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | from collections import defaultdict, deque 5 | import datetime 6 | import numpy as np 7 | from timm.utils import get_state_dict 8 | 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from torch._six import inf 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | import subprocess 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | class MetricLogger(object): 83 | def __init__(self, delimiter="\t"): 84 | self.meters = defaultdict(SmoothedValue) 85 | self.delimiter = delimiter 86 | 87 | def update(self, **kwargs): 88 | for k, v in kwargs.items(): 89 | if v is None: 90 | continue 91 | if isinstance(v, torch.Tensor): 92 | v = v.item() 93 | assert isinstance(v, (float, int)) 94 | self.meters[k].update(v) 95 | 96 | def __getattr__(self, attr): 97 | if attr in self.meters: 98 | return self.meters[attr] 99 | if attr in self.__dict__: 100 | return self.__dict__[attr] 101 | raise AttributeError("'{}' object has no attribute '{}'".format( 102 | type(self).__name__, attr)) 103 | 104 | def __str__(self): 105 | loss_str = [] 106 | for name, meter in self.meters.items(): 107 | loss_str.append( 108 | "{}: {}".format(name, str(meter)) 109 | ) 110 | return self.delimiter.join(loss_str) 111 | 112 | def synchronize_between_processes(self): 113 | for meter in self.meters.values(): 114 | meter.synchronize_between_processes() 115 | 116 | def add_meter(self, name, meter): 117 | self.meters[name] = meter 118 | 119 | def log_every(self, iterable, print_freq, header=None): 120 | i = 0 121 | if not header: 122 | header = '' 123 | start_time = time.time() 124 | end = time.time() 125 | iter_time = SmoothedValue(fmt='{avg:.4f}') 126 | data_time = SmoothedValue(fmt='{avg:.4f}') 127 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 128 | log_msg = [ 129 | header, 130 | '[{0' + space_fmt + '}/{1}]', 131 | 'eta: {eta}', 132 | '{meters}', 133 | 'time: {time}', 134 | 'data: {data}' 135 | ] 136 | if torch.cuda.is_available(): 137 | log_msg.append('max mem: {memory:.0f}') 138 | log_msg = self.delimiter.join(log_msg) 139 | MB = 1024.0 * 1024.0 140 | for obj in iterable: 141 | data_time.update(time.time() - end) 142 | yield obj 143 | iter_time.update(time.time() - end) 144 | if i % print_freq == 0 or i == len(iterable) - 1: 145 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 146 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 147 | if torch.cuda.is_available(): 148 | print(log_msg.format( 149 | i, len(iterable), eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), data=str(data_time), 152 | memory=torch.cuda.max_memory_allocated() / MB)) 153 | else: 154 | print(log_msg.format( 155 | i, len(iterable), eta=eta_string, 156 | meters=str(self), 157 | time=str(iter_time), data=str(data_time))) 158 | i += 1 159 | end = time.time() 160 | total_time = time.time() - start_time 161 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 162 | print('{} Total time: {} ({:.4f} s / it)'.format( 163 | header, total_time_str, total_time / len(iterable))) 164 | 165 | 166 | class TensorboardLogger(object): 167 | def __init__(self, log_dir): 168 | self.writer = SummaryWriter(logdir=log_dir) 169 | self.step = 0 170 | 171 | def set_step(self, step=None): 172 | if step is not None: 173 | self.step = step 174 | else: 175 | self.step += 1 176 | 177 | def update(self, head='scalar', step=None, **kwargs): 178 | for k, v in kwargs.items(): 179 | if v is None: 180 | continue 181 | if isinstance(v, torch.Tensor): 182 | v = v.item() 183 | assert isinstance(v, (float, int)) 184 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 185 | 186 | def flush(self): 187 | self.writer.flush() 188 | 189 | 190 | class WandbLogger(object): 191 | def __init__(self, args): 192 | self.args = args 193 | 194 | try: 195 | import wandb 196 | self._wandb = wandb 197 | except ImportError: 198 | raise ImportError( 199 | "To use the Weights and Biases Logger please install wandb." 200 | "Run `pip install wandb` to install it." 201 | ) 202 | 203 | # Initialize a W&B run 204 | if self._wandb.run is None: 205 | self._wandb.init( 206 | project=args.project, 207 | config=args 208 | ) 209 | 210 | def log_epoch_metrics(self, metrics, commit=True): 211 | """ 212 | Log train/test metrics onto W&B. 213 | """ 214 | # Log number of model parameters as W&B summary 215 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 216 | metrics.pop('n_parameters', None) 217 | 218 | # Log current epoch 219 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 220 | metrics.pop('epoch') 221 | 222 | for k, v in metrics.items(): 223 | if 'train' in k: 224 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 225 | elif 'test' in k: 226 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 227 | 228 | self._wandb.log({}) 229 | 230 | def log_checkpoints(self): 231 | output_dir = self.args.output_dir 232 | model_artifact = self._wandb.Artifact( 233 | self._wandb.run.id + "_model", type="model" 234 | ) 235 | 236 | model_artifact.add_dir(output_dir) 237 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 238 | 239 | def set_steps(self): 240 | # Set global training step 241 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 242 | # Set epoch-wise step 243 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 244 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 245 | 246 | 247 | def setup_for_distributed(is_master): 248 | """ 249 | This function disables printing when not in master process 250 | """ 251 | import builtins as __builtin__ 252 | builtin_print = __builtin__.print 253 | 254 | def print(*args, **kwargs): 255 | force = kwargs.pop('force', False) 256 | if is_master or force: 257 | builtin_print(*args, **kwargs) 258 | 259 | __builtin__.print = print 260 | 261 | 262 | def is_dist_avail_and_initialized(): 263 | if not dist.is_available(): 264 | return False 265 | if not dist.is_initialized(): 266 | return False 267 | return True 268 | 269 | 270 | def get_world_size(): 271 | if not is_dist_avail_and_initialized(): 272 | return 1 273 | return dist.get_world_size() 274 | 275 | 276 | def get_rank(): 277 | if not is_dist_avail_and_initialized(): 278 | return 0 279 | return dist.get_rank() 280 | 281 | 282 | def is_main_process(): 283 | return get_rank() == 0 284 | 285 | 286 | def save_on_master(*args, **kwargs): 287 | if is_main_process(): 288 | torch.save(*args, **kwargs) 289 | 290 | 291 | def init_distributed_mode(args): 292 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 293 | args.rank = int(os.environ["RANK"]) 294 | args.world_size = int(os.environ['WORLD_SIZE']) 295 | args.gpu = int(os.environ['LOCAL_RANK']) 296 | args.dist_url = 'env://' 297 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 298 | print('Using distributed mode: 1') 299 | elif 'SLURM_PROCID' in os.environ: 300 | proc_id = int(os.environ['SLURM_PROCID']) 301 | ntasks = int(os.environ['SLURM_NTASKS']) 302 | node_list = os.environ['SLURM_NODELIST'] 303 | num_gpus = torch.cuda.device_count() 304 | addr = subprocess.getoutput( 305 | 'scontrol show hostname {} | head -n1'.format(node_list)) 306 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 307 | os.environ['MASTER_ADDR'] = addr 308 | os.environ['WORLD_SIZE'] = str(ntasks) 309 | os.environ['RANK'] = str(proc_id) 310 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 311 | os.environ['LOCAL_SIZE'] = str(num_gpus) 312 | args.dist_url = 'env://' 313 | args.world_size = ntasks 314 | args.rank = proc_id 315 | args.gpu = proc_id % num_gpus 316 | print('Using distributed mode: slurm') 317 | print(f"world: {os.environ['WORLD_SIZE']}, rank:{os.environ['RANK']}," 318 | f" local_rank{os.environ['LOCAL_RANK']}, local_size{os.environ['LOCAL_SIZE']}") 319 | else: 320 | print('Not using distributed mode') 321 | args.distributed = False 322 | return 323 | 324 | args.distributed = True 325 | 326 | torch.cuda.set_device(args.gpu) 327 | args.dist_backend = 'nccl' 328 | print('| distributed init (rank {}): {}'.format( 329 | args.rank, args.dist_url), flush=True) 330 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 331 | world_size=args.world_size, rank=args.rank) 332 | torch.distributed.barrier() 333 | setup_for_distributed(args.rank == 0) 334 | 335 | 336 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 337 | missing_keys = [] 338 | unexpected_keys = [] 339 | error_msgs = [] 340 | # copy state_dict so _load_from_state_dict can modify it 341 | metadata = getattr(state_dict, '_metadata', None) 342 | state_dict = state_dict.copy() 343 | if metadata is not None: 344 | state_dict._metadata = metadata 345 | 346 | def load(module, prefix=''): 347 | local_metadata = {} if metadata is None else metadata.get( 348 | prefix[:-1], {}) 349 | module._load_from_state_dict( 350 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 351 | for name, child in module._modules.items(): 352 | if child is not None: 353 | load(child, prefix + name + '.') 354 | 355 | load(model, prefix=prefix) 356 | 357 | warn_missing_keys = [] 358 | ignore_missing_keys = [] 359 | for key in missing_keys: 360 | keep_flag = True 361 | for ignore_key in ignore_missing.split('|'): 362 | if ignore_key in key: 363 | keep_flag = False 364 | break 365 | if keep_flag: 366 | warn_missing_keys.append(key) 367 | else: 368 | ignore_missing_keys.append(key) 369 | 370 | missing_keys = warn_missing_keys 371 | 372 | if len(missing_keys) > 0: 373 | print("Weights of {} not initialized from pretrained model: {}".format( 374 | model.__class__.__name__, missing_keys)) 375 | if len(unexpected_keys) > 0: 376 | print("Weights from pretrained model not used in {}: {}".format( 377 | model.__class__.__name__, unexpected_keys)) 378 | if len(ignore_missing_keys) > 0: 379 | print("Ignored weights of {} not initialized from pretrained model: {}".format( 380 | model.__class__.__name__, ignore_missing_keys)) 381 | if len(error_msgs) > 0: 382 | print('\n'.join(error_msgs)) 383 | 384 | 385 | class NativeScalerWithGradNormCount: 386 | state_dict_key = "amp_scaler" 387 | 388 | def __init__(self): 389 | self._scaler = torch.cuda.amp.GradScaler() 390 | 391 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 392 | self._scaler.scale(loss).backward(create_graph=create_graph) 393 | if update_grad: 394 | if clip_grad is not None: 395 | assert parameters is not None 396 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 397 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 398 | else: 399 | self._scaler.unscale_(optimizer) 400 | norm = get_grad_norm_(parameters) 401 | self._scaler.step(optimizer) 402 | self._scaler.update() 403 | else: 404 | norm = None 405 | return norm 406 | 407 | def state_dict(self): 408 | return self._scaler.state_dict() 409 | 410 | def load_state_dict(self, state_dict): 411 | self._scaler.load_state_dict(state_dict) 412 | 413 | 414 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 415 | if isinstance(parameters, torch.Tensor): 416 | parameters = [parameters] 417 | parameters = [p for p in parameters if p.grad is not None] 418 | norm_type = float(norm_type) 419 | if len(parameters) == 0: 420 | return torch.tensor(0.) 421 | device = parameters[0].grad.device 422 | if norm_type == inf: 423 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 424 | else: 425 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 426 | return total_norm 427 | 428 | 429 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 430 | start_warmup_value=0, warmup_steps=-1): 431 | warmup_schedule = np.array([]) 432 | warmup_iters = warmup_epochs * niter_per_ep 433 | if warmup_steps > 0: 434 | warmup_iters = warmup_steps 435 | print("Set warmup steps = %d" % warmup_iters) 436 | if warmup_epochs > 0: 437 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 438 | 439 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 440 | schedule = np.array( 441 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 442 | 443 | schedule = np.concatenate((warmup_schedule, schedule)) 444 | 445 | assert len(schedule) == epochs * niter_per_ep 446 | return schedule 447 | 448 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 449 | output_dir = Path(args.output_dir) 450 | epoch_name = str(epoch) 451 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 452 | for checkpoint_path in checkpoint_paths: 453 | to_save = { 454 | 'model': model_without_ddp.state_dict(), 455 | 'optimizer': optimizer.state_dict(), 456 | 'epoch': epoch, 457 | 'scaler': loss_scaler.state_dict(), 458 | 'args': args, 459 | } 460 | 461 | if model_ema is not None: 462 | to_save['model_ema'] = get_state_dict(model_ema) 463 | 464 | save_on_master(to_save, checkpoint_path) 465 | 466 | if is_main_process() and isinstance(epoch, int): 467 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 468 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 469 | if os.path.exists(old_ckpt): 470 | os.remove(old_ckpt) 471 | 472 | 473 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None, state_dict_name='model'): 474 | output_dir = Path(args.output_dir) 475 | if args.auto_resume and len(args.resume) == 0: 476 | import glob 477 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 478 | latest_ckpt = -1 479 | for ckpt in all_checkpoints: 480 | t = ckpt.split('-')[-1].split('.')[0] 481 | if t.isdigit(): 482 | latest_ckpt = max(int(t), latest_ckpt) 483 | if latest_ckpt >= 0: 484 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 485 | print("Auto resume checkpoint: %s" % args.resume) 486 | 487 | if args.resume: 488 | if args.resume.startswith('https'): 489 | checkpoint = torch.hub.load_state_dict_from_url( 490 | args.resume, map_location='cpu', check_hash=True) 491 | else: 492 | checkpoint = torch.load(args.resume, map_location='cpu') 493 | model_without_ddp.load_state_dict(checkpoint[state_dict_name]) 494 | print("Resume checkpoint %s" % args.resume) 495 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 496 | optimizer.load_state_dict(checkpoint['optimizer']) 497 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 498 | args.start_epoch = checkpoint['epoch'] + 1 499 | else: 500 | assert args.eval, 'Does not support resuming with checkpoint-best' 501 | if hasattr(args, 'model_ema') and args.model_ema: 502 | if 'model_ema' in checkpoint.keys(): 503 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 504 | else: 505 | model_ema.ema.load_state_dict(checkpoint['model']) 506 | if 'scaler' in checkpoint: 507 | loss_scaler.load_state_dict(checkpoint['scaler']) 508 | print("With optim & sched!") 509 | --------------------------------------------------------------------------------