├── LICENSE ├── README.md ├── benchmark.py ├── distributed_train.sh ├── docker └── Dockerfile ├── models ├── __init__.py ├── convnext.py └── inceptionnext.py ├── scripts ├── convnext_variants │ ├── train_convnext_tiny_k3_ema.sh │ ├── train_convnext_tiny_k3_par1_16_ema.sh │ ├── train_convnext_tiny_k3_par1_2_ema.sh │ ├── train_convnext_tiny_k3_par1_4_ema.sh │ ├── train_convnext_tiny_k3_par1_8_ema.sh │ ├── train_convnext_tiny_k3_par3_8_ema.sh │ └── train_convnext_tiny_k5_ema.sh └── inceptionnext │ ├── finetune_inceptionnext_base.sh │ ├── train_inceptionnext_atto.sh │ ├── train_inceptionnext_base.sh │ ├── train_inceptionnext_small.sh │ └── train_inceptionnext_tiny.sh ├── train.py ├── utils.py └── validate.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [InceptionNeXt: When Inception Meets ConvNeXt](https://arxiv.org/abs/2303.16900) (CVPR 2024) 2 | 3 |

4 | 5 | 6 | 7 | 8 |

9 | 10 | This is a PyTorch implementation of InceptionNeXt proposed by our paper "[InceptionNeXt: When Inception Meets ConvNeXt](https://arxiv.org/abs/2303.16900)". Many thanks to [Ross Wightman](https://github.com/rwightman), InceptionNeXt is integrated into [timm](https://github.com/huggingface/pytorch-image-models). 11 | 12 | ![InceptionNeXt](https://user-images.githubusercontent.com/15921929/228630174-1d31ac66-174b-4014-9f6a-b7e6d46af958.jpeg) 13 | **TLDR**: To speed up ConvNeXt, we build InceptionNeXt by decomposing the large kernel dpethweise convolution with Inception style. **Our InceptionNeXt-T enjoys both ResNet-50’s speed and ConvNeXt-T’s accuracy.** 14 | 15 | 16 | ## Requirements 17 | Our models are trained and tested in the environment of PyTorch 1.13, NVIDIA CUDA 11.7.1 and timm 0.6.11 (`pip install timm==0.6.11`). If you use docker, check [Dockerfile](docker/Dockerfile) that we used. 18 | 19 | 20 | Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). 21 | 22 | ``` 23 | │imagenet/ 24 | ├──train/ 25 | │ ├── n01440764 26 | │ │ ├── n01440764_10026.JPEG 27 | │ │ ├── n01440764_10027.JPEG 28 | │ │ ├── ...... 29 | │ ├── ...... 30 | ├──val/ 31 | │ ├── n01440764 32 | │ │ ├── ILSVRC2012_val_00000293.JPEG 33 | │ │ ├── ILSVRC2012_val_00002138.JPEG 34 | │ │ ├── ...... 35 | │ ├── ...... 36 | ``` 37 | 38 | 39 | ## Models 40 | ### InceptionNeXt trained on ImageNet-1K 41 | | Model | Resolution | Params | MACs | Train throughput | Infer. throughput | Top1 Acc | 42 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | 43 | | mobilenetv2_140 | 224 | 6.1M | 0.60G | 1001 | 5190 | 74.7 | 44 | | efficientnet_b0 | 224 | 5.3M | 0.40G | 954 | 5502 | 77.1 | 45 | | ghostnet_130 | 224 | 7.3M | 0.24G | 946 | 7451 | 75.7 | 46 | | convnext_atto | 224 | 3.7M | 0.55G | 835 | 4539 | 75.7 | 47 | | [inceptionnext_atto](https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_atto.pth) | 224 | 4.2M | 0.51G | 2661 | 9876 | 75.3 | 48 | | resnet50 | 224 | 26M | 4.1G | 969 | 3149 | 78.4 | 49 | | convnext_tiny | 224 | 29M | 4.5G | 575 | 2413 | 82.1 | 50 | | [inceptionnext_tiny](https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth) | 224 | 28M | 4.2G | 901 | 2900 | 82.3 | 51 | | [inceptionnext_small](https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth) | 224 | 49M | 8.4G | 521 | 1750 | 83.5 | 52 | | [inceptionnext_base](https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth) | 224 | 87M | 14.9G | 375 | 1244 | 84.0 | 53 | | [inceptionnext_base_384](https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth) | 384 | 87M | 43.6G | 139 | 428 | 85.2 | 54 | 55 | ### ConvNeXt variants trained on ImageNet-1K 56 | | Model | Resolution | Params | MACs | Train throughput | Infer. throughput | Top1 Acc | 57 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | 58 | | resnet50 | 224 | 26M | 4.1G | 969 | 3149 | 78.4 | - | 59 | | convnext_tiny | 224 | 29M | 4.5G | 575 | 2413 | 82.1 | - | 60 | | [convnext_tiny_k5](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k5.pth) | 224 | 29M | 4.4G | 675 | 2704 | 82.0 | 61 | | [convnext_tiny_k3](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3.pth) | 224 | 28M | 4.4G | 798 | 2802 | 81.5 | 62 | | [convnext_tiny_k3_par1_2](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_2.pth) | 224 | 28M | 4.4G | 818 | 2740 | 81.4 | 63 | | [convnext_tiny_k3_par3_8](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par3_8.pth) | 224 | 28M | 4.4G | 847 | 2762 | 81.4 | 64 | | [convnext_tiny_k3_par1_4](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_4.pth) | 224 | 28M | 4.4G | 871 | 2808 | 81.3 | 65 | | [convnext_tiny_k3_par1_8](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_8.pth) | 224 | 28M | 4.4G | 901 | 2833 | 80.8 | 66 | | [convnext_tiny_k3_par1_16](https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_16.pth) | 224 | 28M | 4.4G | 916 | 2846 | 80.1 | 67 | 68 | The throughputs are measured on an A100 with full precisioni and batch size of 128. See [Benchmarking throughput](#benchmarking-throughput). 69 | 70 | #### Usage 71 | We also provide a Colab notebook which run the steps to perform inference with InceptionNeXt: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1-CAPm6FNKYRbe_lAPxIBxsIH4xowgfg8?usp=sharing) 72 | 73 | 74 | ## Validation 75 | 76 | To evaluate our CAFormer-S18 models, run: 77 | 78 | ```bash 79 | MODEL=inceptionnext_tiny 80 | python3 validate.py /path/to/imagenet --model $MODEL -b 128 \ 81 | --pretrained 82 | ``` 83 | 84 | ## Benchmarking throughput 85 | On the environment described above, we benchmark throughputs on an A100 with batch size of 128. The beter results of "Channel First" and "Channel Last" memory layouts are reported. 86 | 87 | For Channel First: 88 | ```bash 89 | MODEL=inceptionnext_tiny # convnext_tiny 90 | python3 benchmark.py /path/to/imagenet --model $MODEL 91 | ``` 92 | 93 | For Channel Last: 94 | ```bash 95 | MODEL=inceptionnext_tiny # convnext_tiny 96 | python3 benchmark.py /path/to/imagenet --model $MODEL --channel-last 97 | ``` 98 | 99 | ## Train 100 | We use batch size of 4096 by default and we show how to train models with 8 GPUs. For multi-node training, adjust `--grad-accum-steps` according to your situations. 101 | 102 | 103 | ```bash 104 | DATA_PATH=/path/to/imagenet 105 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 106 | 107 | 108 | ALL_BATCH_SIZE=4096 109 | NUM_GPU=8 110 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 111 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 112 | 113 | 114 | MODEL=inceptionnext_tiny # inceptionnext_small, inceptionnext_base 115 | DROP_PATH=0.1 # 0.3, 0.4 116 | 117 | 118 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 119 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 120 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 121 | --drop-path $DROP_PATH 122 | ``` 123 | Training (fine-tuning) scripts of other models are shown in [scripts](/scripts/). 124 | 125 | 126 | ## Bibtex 127 | ``` 128 | @inproceedings{yu2024inceptionnext, 129 | title={Inceptionnext: When inception meets convnext}, 130 | author={Yu, Weihao and Zhou, Pan and Yan, Shuicheng and Wang, Xinchao}, 131 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 132 | pages={5672--5683}, 133 | year={2024} 134 | } 135 | ``` 136 | 137 | ## Acknowledgment 138 | Weihao Yu would like to thank TRC program and GCP research credits for the support of partial computational resources. Our implementation is based on [pytorch-image-models](https://github.com/huggingface/pytorch-image-models), [poolformer](https://github.com/sail-sg/poolformer), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt) and [metaformer](https://github.com/sail-sg/metaformer). 139 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copied from: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/benchmark.py 3 | """ Model Benchmark Script 4 | 5 | An inference and train step benchmark script for timm models. 6 | 7 | Hacked together by Ross Wightman (https://github.com/rwightman) 8 | """ 9 | import argparse 10 | import csv 11 | import json 12 | import logging 13 | import time 14 | from collections import OrderedDict 15 | from contextlib import suppress 16 | from functools import partial 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | 22 | from timm.data import resolve_data_config 23 | from timm.models import create_model, is_model, list_models, set_fast_norm 24 | from timm.optim import create_optimizer_v2 25 | from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry 26 | 27 | import models 28 | 29 | has_apex = False 30 | try: 31 | from apex import amp 32 | has_apex = True 33 | except ImportError: 34 | pass 35 | 36 | has_native_amp = False 37 | try: 38 | if getattr(torch.cuda.amp, 'autocast') is not None: 39 | has_native_amp = True 40 | except AttributeError: 41 | pass 42 | 43 | try: 44 | from deepspeed.profiling.flops_profiler import get_model_profile 45 | has_deepspeed_profiling = True 46 | except ImportError as e: 47 | has_deepspeed_profiling = False 48 | 49 | try: 50 | from fvcore.nn import FlopCountAnalysis, flop_count_str, ActivationCountAnalysis 51 | has_fvcore_profiling = True 52 | except ImportError as e: 53 | FlopCountAnalysis = None 54 | has_fvcore_profiling = False 55 | 56 | try: 57 | from functorch.compile import memory_efficient_fusion 58 | has_functorch = True 59 | except ImportError as e: 60 | has_functorch = False 61 | 62 | 63 | torch.backends.cudnn.benchmark = True 64 | _logger = logging.getLogger('validate') 65 | 66 | 67 | parser = argparse.ArgumentParser(description='PyTorch Benchmark') 68 | 69 | # benchmark specific args 70 | parser.add_argument('--model-list', metavar='NAME', default='', 71 | help='txt file based list of model names to benchmark') 72 | parser.add_argument('--bench', default='both', type=str, 73 | help="Benchmark mode. One of 'inference', 'train', 'both'. Defaults to 'both'") 74 | parser.add_argument('--detail', action='store_true', default=False, 75 | help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False') 76 | parser.add_argument('--no-retry', action='store_true', default=False, 77 | help='Do not decay batch size and retry on error.') 78 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 79 | help='Output csv file for validation results (summary)') 80 | parser.add_argument('--num-warm-iter', default=10, type=int, 81 | metavar='N', help='Number of warmup iterations (default: 10)') 82 | parser.add_argument('--num-bench-iter', default=40, type=int, 83 | metavar='N', help='Number of benchmark iterations (default: 40)') 84 | 85 | # common inference / train args 86 | parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', 87 | help='model architecture (default: resnet50)') 88 | parser.add_argument('-b', '--batch-size', default=128, type=int, 89 | metavar='N', help='mini-batch size (default: 128)') 90 | parser.add_argument('--img-size', default=None, type=int, 91 | metavar='N', help='Input image dimension, uses model default if empty') 92 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 93 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 94 | parser.add_argument('--use-train-size', action='store_true', default=False, 95 | help='Run inference at train size, not test-input-size if it exists.') 96 | parser.add_argument('--num-classes', type=int, default=None, 97 | help='Number classes in dataset') 98 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 99 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 100 | parser.add_argument('--channels-last', action='store_true', default=False, 101 | help='Use channels_last memory layout') 102 | parser.add_argument('--grad-checkpointing', action='store_true', default=False, 103 | help='Enable gradient checkpointing through model blocks/stages') 104 | parser.add_argument('--amp', action='store_true', default=False, 105 | help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') 106 | parser.add_argument('--precision', default='float32', type=str, 107 | help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') 108 | parser.add_argument('--fuser', default='', type=str, 109 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 110 | scripting_group = parser.add_mutually_exclusive_group() 111 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', 112 | help='convert model torchscript for inference') 113 | scripting_group.add_argument('--aot-autograd', default=False, action='store_true', 114 | help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") 115 | scripting_group.add_argument('--fast-norm', default=False, action='store_true', 116 | help='enable experimental fast-norm') 117 | 118 | # train optimizer parameters 119 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 120 | help='Optimizer (default: "sgd"') 121 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 122 | help='Optimizer Epsilon (default: None, use opt default)') 123 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 124 | help='Optimizer Betas (default: None, use opt default)') 125 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 126 | help='Optimizer momentum (default: 0.9)') 127 | parser.add_argument('--weight-decay', type=float, default=0.0001, 128 | help='weight decay (default: 0.0001)') 129 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 130 | help='Clip gradient norm (default: None, no clipping)') 131 | parser.add_argument('--clip-mode', type=str, default='norm', 132 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 133 | 134 | 135 | # model regularization / loss params that impact model or loss fn 136 | parser.add_argument('--smoothing', type=float, default=0.1, 137 | help='Label smoothing (default: 0.1)') 138 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 139 | help='Dropout rate (default: 0.)') 140 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 141 | help='Drop path rate (default: None)') 142 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 143 | help='Drop block rate (default: None)') 144 | 145 | 146 | def timestamp(sync=False): 147 | return time.perf_counter() 148 | 149 | 150 | def cuda_timestamp(sync=False, device=None): 151 | if sync: 152 | torch.cuda.synchronize(device=device) 153 | return time.perf_counter() 154 | 155 | 156 | def count_params(model: nn.Module): 157 | return sum([m.numel() for m in model.parameters()]) 158 | 159 | 160 | def resolve_precision(precision: str): 161 | assert precision in ('amp', 'float16', 'bfloat16', 'float32') 162 | use_amp = False 163 | model_dtype = torch.float32 164 | data_dtype = torch.float32 165 | if precision == 'amp': 166 | use_amp = True 167 | elif precision == 'float16': 168 | model_dtype = torch.float16 169 | data_dtype = torch.float16 170 | elif precision == 'bfloat16': 171 | model_dtype = torch.bfloat16 172 | data_dtype = torch.bfloat16 173 | return use_amp, model_dtype, data_dtype 174 | 175 | 176 | def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): 177 | _, macs, _ = get_model_profile( 178 | model=model, 179 | input_shape=(batch_size,) + input_size, # input shape/resolution 180 | print_profile=detailed, # prints the model graph with the measured profile attached to each module 181 | detailed=detailed, # print the detailed profile 182 | warm_up=10, # the number of warm-ups before measuring the time of each module 183 | as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) 184 | output_file=None, # path to the output file. If None, the profiler prints to stdout. 185 | ignore_modules=None) # the list of modules to ignore in the profiling 186 | return macs, 0 # no activation count in DS 187 | 188 | 189 | def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False, force_cpu=False): 190 | if force_cpu: 191 | model = model.to('cpu') 192 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 193 | example_input = torch.ones((batch_size,) + input_size, device=device, dtype=dtype) 194 | fca = FlopCountAnalysis(model, example_input) 195 | aca = ActivationCountAnalysis(model, example_input) 196 | if detailed: 197 | fcs = flop_count_str(fca) 198 | print(fcs) 199 | return fca.total(), aca.total() 200 | 201 | 202 | class BenchmarkRunner: 203 | def __init__( 204 | self, 205 | model_name, 206 | detail=False, 207 | device='cuda', 208 | torchscript=False, 209 | aot_autograd=False, 210 | precision='float32', 211 | fuser='', 212 | num_warm_iter=10, 213 | num_bench_iter=50, 214 | use_train_size=False, 215 | **kwargs 216 | ): 217 | self.model_name = model_name 218 | self.detail = detail 219 | self.device = device 220 | self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) 221 | self.channels_last = kwargs.pop('channels_last', False) 222 | self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress 223 | 224 | if fuser: 225 | set_jit_fuser(fuser) 226 | self.model = create_model( 227 | model_name, 228 | num_classes=kwargs.pop('num_classes', None), 229 | in_chans=3, 230 | global_pool=kwargs.pop('gp', 'fast'), 231 | scriptable=torchscript, 232 | drop_rate=kwargs.pop('drop', 0.), 233 | drop_path_rate=kwargs.pop('drop_path', None), 234 | drop_block_rate=kwargs.pop('drop_block', None), 235 | ) 236 | self.model.to( 237 | device=self.device, 238 | dtype=self.model_dtype, 239 | memory_format=torch.channels_last if self.channels_last else None) 240 | self.num_classes = self.model.num_classes 241 | self.param_count = count_params(self.model) 242 | _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) 243 | 244 | data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) 245 | self.scripted = False 246 | if torchscript: 247 | self.model = torch.jit.script(self.model) 248 | self.scripted = True 249 | self.input_size = data_config['input_size'] 250 | self.batch_size = kwargs.pop('batch_size', 256) 251 | 252 | if aot_autograd: 253 | assert has_functorch, "functorch is needed for --aot-autograd" 254 | self.model = memory_efficient_fusion(self.model) 255 | 256 | self.example_inputs = None 257 | self.num_warm_iter = num_warm_iter 258 | self.num_bench_iter = num_bench_iter 259 | self.log_freq = num_bench_iter // 5 260 | if 'cuda' in self.device: 261 | self.time_fn = partial(cuda_timestamp, device=self.device) 262 | else: 263 | self.time_fn = timestamp 264 | 265 | def _init_input(self): 266 | self.example_inputs = torch.randn( 267 | (self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype) 268 | if self.channels_last: 269 | self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last) 270 | 271 | 272 | class InferenceBenchmarkRunner(BenchmarkRunner): 273 | 274 | def __init__( 275 | self, 276 | model_name, 277 | device='cuda', 278 | torchscript=False, 279 | **kwargs 280 | ): 281 | super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) 282 | self.model.eval() 283 | 284 | def run(self): 285 | def _step(): 286 | t_step_start = self.time_fn() 287 | with self.amp_autocast(): 288 | output = self.model(self.example_inputs) 289 | t_step_end = self.time_fn(True) 290 | return t_step_end - t_step_start 291 | 292 | _logger.info( 293 | f'Running inference benchmark on {self.model_name} for {self.num_bench_iter} steps w/ ' 294 | f'input size {self.input_size} and batch size {self.batch_size}.') 295 | 296 | with torch.no_grad(): 297 | self._init_input() 298 | 299 | for _ in range(self.num_warm_iter): 300 | _step() 301 | 302 | total_step = 0. 303 | num_samples = 0 304 | t_run_start = self.time_fn() 305 | for i in range(self.num_bench_iter): 306 | delta_fwd = _step() 307 | total_step += delta_fwd 308 | num_samples += self.batch_size 309 | num_steps = i + 1 310 | if num_steps % self.log_freq == 0: 311 | _logger.info( 312 | f"Infer [{num_steps}/{self.num_bench_iter}]." 313 | f" {num_samples / total_step:0.2f} samples/sec." 314 | f" {1000 * total_step / num_steps:0.3f} ms/step.") 315 | t_run_end = self.time_fn(True) 316 | t_run_elapsed = t_run_end - t_run_start 317 | 318 | results = dict( 319 | samples_per_sec=round(num_samples / t_run_elapsed, 2), 320 | step_time=round(1000 * total_step / self.num_bench_iter, 3), 321 | batch_size=self.batch_size, 322 | img_size=self.input_size[-1], 323 | param_count=round(self.param_count / 1e6, 2), 324 | ) 325 | 326 | retries = 0 if self.scripted else 2 # skip profiling if model is scripted 327 | while retries: 328 | retries -= 1 329 | try: 330 | if has_deepspeed_profiling: 331 | macs, _ = profile_deepspeed(self.model, self.input_size) 332 | results['gmacs'] = round(macs / 1e9, 2) 333 | elif has_fvcore_profiling: 334 | macs, activations = profile_fvcore(self.model, self.input_size, force_cpu=not retries) 335 | results['gmacs'] = round(macs / 1e9, 2) 336 | results['macts'] = round(activations / 1e6, 2) 337 | except RuntimeError as e: 338 | pass 339 | 340 | _logger.info( 341 | f"Inference benchmark of {self.model_name} done. " 342 | f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step") 343 | 344 | return results 345 | 346 | 347 | class TrainBenchmarkRunner(BenchmarkRunner): 348 | 349 | def __init__( 350 | self, 351 | model_name, 352 | device='cuda', 353 | torchscript=False, 354 | **kwargs 355 | ): 356 | super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs) 357 | self.model.train() 358 | 359 | self.loss = nn.CrossEntropyLoss().to(self.device) 360 | self.target_shape = tuple() 361 | 362 | self.optimizer = create_optimizer_v2( 363 | self.model, 364 | opt=kwargs.pop('opt', 'sgd'), 365 | lr=kwargs.pop('lr', 1e-4)) 366 | 367 | if kwargs.pop('grad_checkpointing', False): 368 | self.model.set_grad_checkpointing() 369 | 370 | def _gen_target(self, batch_size): 371 | return torch.empty( 372 | (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) 373 | 374 | def run(self): 375 | def _step(detail=False): 376 | self.optimizer.zero_grad() # can this be ignored? 377 | t_start = self.time_fn() 378 | t_fwd_end = t_start 379 | t_bwd_end = t_start 380 | with self.amp_autocast(): 381 | output = self.model(self.example_inputs) 382 | if isinstance(output, tuple): 383 | output = output[0] 384 | if detail: 385 | t_fwd_end = self.time_fn(True) 386 | target = self._gen_target(output.shape[0]) 387 | self.loss(output, target).backward() 388 | if detail: 389 | t_bwd_end = self.time_fn(True) 390 | self.optimizer.step() 391 | t_end = self.time_fn(True) 392 | if detail: 393 | delta_fwd = t_fwd_end - t_start 394 | delta_bwd = t_bwd_end - t_fwd_end 395 | delta_opt = t_end - t_bwd_end 396 | return delta_fwd, delta_bwd, delta_opt 397 | else: 398 | delta_step = t_end - t_start 399 | return delta_step 400 | 401 | _logger.info( 402 | f'Running train benchmark on {self.model_name} for {self.num_bench_iter} steps w/ ' 403 | f'input size {self.input_size} and batch size {self.batch_size}.') 404 | 405 | self._init_input() 406 | 407 | for _ in range(self.num_warm_iter): 408 | _step() 409 | 410 | t_run_start = self.time_fn() 411 | if self.detail: 412 | total_fwd = 0. 413 | total_bwd = 0. 414 | total_opt = 0. 415 | num_samples = 0 416 | for i in range(self.num_bench_iter): 417 | delta_fwd, delta_bwd, delta_opt = _step(True) 418 | num_samples += self.batch_size 419 | total_fwd += delta_fwd 420 | total_bwd += delta_bwd 421 | total_opt += delta_opt 422 | num_steps = (i + 1) 423 | if num_steps % self.log_freq == 0: 424 | total_step = total_fwd + total_bwd + total_opt 425 | _logger.info( 426 | f"Train [{num_steps}/{self.num_bench_iter}]." 427 | f" {num_samples / total_step:0.2f} samples/sec." 428 | f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd," 429 | f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd," 430 | f" {1000 * total_opt / num_steps:0.3f} ms/step opt." 431 | ) 432 | total_step = total_fwd + total_bwd + total_opt 433 | t_run_elapsed = self.time_fn() - t_run_start 434 | results = dict( 435 | samples_per_sec=round(num_samples / t_run_elapsed, 2), 436 | step_time=round(1000 * total_step / self.num_bench_iter, 3), 437 | fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3), 438 | bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3), 439 | opt_time=round(1000 * total_opt / self.num_bench_iter, 3), 440 | batch_size=self.batch_size, 441 | img_size=self.input_size[-1], 442 | param_count=round(self.param_count / 1e6, 2), 443 | ) 444 | else: 445 | total_step = 0. 446 | num_samples = 0 447 | for i in range(self.num_bench_iter): 448 | delta_step = _step(False) 449 | num_samples += self.batch_size 450 | total_step += delta_step 451 | num_steps = (i + 1) 452 | if num_steps % self.log_freq == 0: 453 | _logger.info( 454 | f"Train [{num_steps}/{self.num_bench_iter}]." 455 | f" {num_samples / total_step:0.2f} samples/sec." 456 | f" {1000 * total_step / num_steps:0.3f} ms/step.") 457 | t_run_elapsed = self.time_fn() - t_run_start 458 | results = dict( 459 | samples_per_sec=round(num_samples / t_run_elapsed, 2), 460 | step_time=round(1000 * total_step / self.num_bench_iter, 3), 461 | batch_size=self.batch_size, 462 | img_size=self.input_size[-1], 463 | param_count=round(self.param_count / 1e6, 2), 464 | ) 465 | 466 | _logger.info( 467 | f"Train benchmark of {self.model_name} done. " 468 | f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample") 469 | 470 | return results 471 | 472 | 473 | class ProfileRunner(BenchmarkRunner): 474 | 475 | def __init__(self, model_name, device='cuda', profiler='', **kwargs): 476 | super().__init__(model_name=model_name, device=device, **kwargs) 477 | if not profiler: 478 | if has_deepspeed_profiling: 479 | profiler = 'deepspeed' 480 | elif has_fvcore_profiling: 481 | profiler = 'fvcore' 482 | assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work." 483 | self.profiler = profiler 484 | self.model.eval() 485 | 486 | def run(self): 487 | _logger.info( 488 | f'Running profiler on {self.model_name} w/ ' 489 | f'input size {self.input_size} and batch size {self.batch_size}.') 490 | 491 | macs = 0 492 | activations = 0 493 | if self.profiler == 'deepspeed': 494 | macs, _ = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True) 495 | elif self.profiler == 'fvcore': 496 | macs, activations = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True) 497 | 498 | results = dict( 499 | gmacs=round(macs / 1e9, 2), 500 | macts=round(activations / 1e6, 2), 501 | batch_size=self.batch_size, 502 | img_size=self.input_size[-1], 503 | param_count=round(self.param_count / 1e6, 2), 504 | ) 505 | 506 | _logger.info( 507 | f"Profile of {self.model_name} done. " 508 | f"{results['gmacs']:.2f} GMACs, {results['param_count']:.2f} M params.") 509 | 510 | return results 511 | 512 | 513 | def _try_run( 514 | model_name, 515 | bench_fn, 516 | bench_kwargs, 517 | initial_batch_size, 518 | no_batch_size_retry=False 519 | ): 520 | batch_size = initial_batch_size 521 | results = dict() 522 | error_str = 'Unknown' 523 | while batch_size: 524 | try: 525 | torch.cuda.empty_cache() 526 | bench = bench_fn(model_name=model_name, batch_size=batch_size, **bench_kwargs) 527 | results = bench.run() 528 | return results 529 | except RuntimeError as e: 530 | error_str = str(e) 531 | _logger.error(f'"{error_str}" while running benchmark.') 532 | if not check_batch_size_retry(error_str): 533 | _logger.error(f'Unrecoverable error encountered while benchmarking {model_name}, skipping.') 534 | break 535 | if no_batch_size_retry: 536 | break 537 | batch_size = decay_batch_step(batch_size) 538 | _logger.warning(f'Reducing batch size to {batch_size} for retry.') 539 | results['error'] = error_str 540 | return results 541 | 542 | 543 | def benchmark(args): 544 | if args.amp: 545 | _logger.warning("Overriding precision to 'amp' since --amp flag set.") 546 | args.precision = 'amp' 547 | _logger.info(f'Benchmarking in {args.precision} precision. ' 548 | f'{"NHWC" if args.channels_last else "NCHW"} layout. ' 549 | f'torchscript {"enabled" if args.torchscript else "disabled"}') 550 | 551 | bench_kwargs = vars(args).copy() 552 | bench_kwargs.pop('amp') 553 | model = bench_kwargs.pop('model') 554 | batch_size = bench_kwargs.pop('batch_size') 555 | 556 | bench_fns = (InferenceBenchmarkRunner,) 557 | prefixes = ('infer',) 558 | if args.bench == 'both': 559 | bench_fns = ( 560 | InferenceBenchmarkRunner, 561 | TrainBenchmarkRunner 562 | ) 563 | prefixes = ('infer', 'train') 564 | elif args.bench == 'train': 565 | bench_fns = TrainBenchmarkRunner, 566 | prefixes = 'train', 567 | elif args.bench.startswith('profile'): 568 | # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore 569 | if 'deepspeed' in args.bench: 570 | assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter" 571 | bench_kwargs['profiler'] = 'deepspeed' 572 | elif 'fvcore' in args.bench: 573 | assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter" 574 | bench_kwargs['profiler'] = 'fvcore' 575 | bench_fns = ProfileRunner, 576 | batch_size = 1 577 | 578 | model_results = OrderedDict(model=model) 579 | for prefix, bench_fn in zip(prefixes, bench_fns): 580 | run_results = _try_run( 581 | model, 582 | bench_fn, 583 | bench_kwargs=bench_kwargs, 584 | initial_batch_size=batch_size, 585 | no_batch_size_retry=args.no_retry, 586 | ) 587 | if prefix and 'error' not in run_results: 588 | run_results = {'_'.join([prefix, k]): v for k, v in run_results.items()} 589 | model_results.update(run_results) 590 | if 'error' in run_results: 591 | break 592 | if 'error' not in model_results: 593 | param_count = model_results.pop('infer_param_count', model_results.pop('train_param_count', 0)) 594 | model_results.setdefault('param_count', param_count) 595 | model_results.pop('train_param_count', 0) 596 | return model_results 597 | 598 | 599 | def main(): 600 | setup_default_logging() 601 | args = parser.parse_args() 602 | model_cfgs = [] 603 | model_names = [] 604 | 605 | if args.fast_norm: 606 | set_fast_norm() 607 | 608 | if args.model_list: 609 | args.model = '' 610 | with open(args.model_list) as f: 611 | model_names = [line.rstrip() for line in f] 612 | model_cfgs = [(n, None) for n in model_names] 613 | elif args.model == 'all': 614 | # validate all models in a list of names with pretrained checkpoints 615 | args.pretrained = True 616 | model_names = list_models(pretrained=True, exclude_filters=['*in21k']) 617 | model_cfgs = [(n, None) for n in model_names] 618 | elif not is_model(args.model): 619 | # model name doesn't exist, try as wildcard filter 620 | model_names = list_models(args.model) 621 | model_cfgs = [(n, None) for n in model_names] 622 | 623 | if len(model_cfgs): 624 | results_file = args.results_file or './benchmark.csv' 625 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 626 | results = [] 627 | try: 628 | for m, _ in model_cfgs: 629 | if not m: 630 | continue 631 | args.model = m 632 | r = benchmark(args) 633 | if r: 634 | results.append(r) 635 | time.sleep(10) 636 | except KeyboardInterrupt as e: 637 | pass 638 | sort_key = 'infer_samples_per_sec' 639 | if 'train' in args.bench: 640 | sort_key = 'train_samples_per_sec' 641 | elif 'profile' in args.bench: 642 | sort_key = 'infer_gmacs' 643 | results = filter(lambda x: sort_key in x, results) 644 | results = sorted(results, key=lambda x: x[sort_key], reverse=True) 645 | if len(results): 646 | write_results(results_file, results) 647 | else: 648 | results = benchmark(args) 649 | 650 | # output results in JSON to stdout w/ delimiter for runner script 651 | print(f'--result\n{json.dumps(results, indent=4)}') 652 | 653 | 654 | def write_results(results_file, results): 655 | with open(results_file, mode='w') as cf: 656 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 657 | dw.writeheader() 658 | for r in results: 659 | dw.writerow(r) 660 | cf.flush() 661 | 662 | 663 | if __name__ == '__main__': 664 | main() -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:22.08-py3 2 | FROM ${FROM_IMAGE_NAME} 3 | 4 | RUN pip install timm==0.6.11 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # from .poolformer import * 2 | from .inceptionnext import * 3 | from .convnext import * -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from ConvNeXt official repo: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py 3 | """ 4 | 5 | # Copyright (c) Meta Platforms, Inc. and affiliates. 6 | 7 | # All rights reserved. 8 | 9 | # This source code is licensed under the license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from timm.models.layers import trunc_normal_, DropPath 17 | from timm.models.registry import register_model 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | from functools import partial 20 | 21 | 22 | class PartialConv2d(nn.Module): 23 | r""" 24 | Conduct convolution on partial channels. 25 | """ 26 | def __init__(self, in_channels, out_channels, kernel_size, 27 | conv_ratio=1.0, 28 | stride=1, padding=0, dilation=1, groups=1, bias=True, **kwargs, 29 | ): 30 | super().__init__() 31 | in_chs = int(in_channels * conv_ratio) 32 | out_chs = int(out_channels * conv_ratio) 33 | gps = int(groups * conv_ratio) or 1 # groups should be at least 1 34 | self.conv = nn.Conv2d(in_chs, out_chs, 35 | kernel_size=kernel_size, 36 | stride=stride, padding=padding, dilation=dilation, 37 | groups=gps, bias=bias, 38 | **kwargs, 39 | ) 40 | self.split_indices = (in_channels - in_chs, in_chs) 41 | 42 | def forward(self, x): 43 | identity, conv = torch.split(x, self.split_indices, dim=1) 44 | return torch.cat( 45 | (identity, self.conv(conv)), 46 | dim=1, 47 | ) 48 | 49 | 50 | class Block(nn.Module): 51 | r""" ConvNeXt Block. There are two equivalent implementations: 52 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 53 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 54 | We use (2) as we find it slightly faster in PyTorch 55 | 56 | Args: 57 | dim (int): Number of input channels. 58 | drop_path (float): Stochastic depth rate. Default: 0.0 59 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 60 | """ 61 | def __init__(self, dim, kernel_size=7, 62 | drop_path=0., layer_scale_init_value=1e-6, 63 | conv_fn=nn.Conv2d, 64 | ): 65 | super().__init__() 66 | self.dwconv = conv_fn(dim, dim, kernel_size=kernel_size, padding=kernel_size//2, groups=dim) # depthwise conv 67 | self.norm = LayerNorm(dim, eps=1e-6) 68 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 69 | self.act = nn.GELU() 70 | self.pwconv2 = nn.Linear(4 * dim, dim) 71 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 72 | requires_grad=True) if layer_scale_init_value > 0 else None 73 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 74 | 75 | def forward(self, x): 76 | input = x 77 | x = self.dwconv(x) 78 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 79 | x = self.norm(x) 80 | x = self.pwconv1(x) 81 | x = self.act(x) 82 | x = self.pwconv2(x) 83 | if self.gamma is not None: 84 | x = self.gamma * x 85 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 86 | 87 | x = input + self.drop_path(x) 88 | return x 89 | 90 | 91 | class ConvNeXt(nn.Module): 92 | r""" ConvNeXt 93 | A PyTorch impl of : `A ConvNet for the 2020s` - 94 | https://arxiv.org/pdf/2201.03545.pdf 95 | 96 | Args: 97 | in_chans (int): Number of input image channels. Default: 3 98 | num_classes (int): Number of classes for classification head. Default: 1000 99 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 100 | dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] 101 | drop_path_rate (float): Stochastic depth rate. Default: 0. 102 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 103 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 104 | """ 105 | def __init__(self, in_chans=3, num_classes=1000, 106 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 107 | layer_scale_init_value=1e-6, head_init_scale=1., 108 | kernel_sizes=7, conv_fns=nn.Conv2d, 109 | **kwargs, 110 | ): 111 | super().__init__() 112 | 113 | num_stages = len(depths) 114 | self.num_stages = num_stages 115 | 116 | if not isinstance(kernel_sizes, (list, tuple)): 117 | kernel_sizes = [kernel_sizes] * num_stages 118 | if not isinstance(conv_fns, (list, tuple)): 119 | conv_fns = [conv_fns] * num_stages 120 | 121 | self.num_classes = num_classes 122 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 123 | stem = nn.Sequential( 124 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 125 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 126 | ) 127 | self.downsample_layers.append(stem) 128 | for i in range(self.num_stages - 1): 129 | downsample_layer = nn.Sequential( 130 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 131 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 132 | ) 133 | self.downsample_layers.append(downsample_layer) 134 | 135 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 136 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 137 | cur = 0 138 | for i in range(self.num_stages): 139 | stage = nn.Sequential( 140 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 141 | kernel_size=kernel_sizes[i], 142 | layer_scale_init_value=layer_scale_init_value, 143 | conv_fn=conv_fns[i], 144 | ) for j in range(depths[i])] 145 | ) 146 | self.stages.append(stage) 147 | cur += depths[i] 148 | 149 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 150 | self.head = nn.Linear(dims[-1], num_classes) 151 | 152 | self.apply(self._init_weights) 153 | self.head.weight.data.mul_(head_init_scale) 154 | self.head.bias.data.mul_(head_init_scale) 155 | 156 | def _init_weights(self, m): 157 | if isinstance(m, (nn.Conv2d, nn.Linear)): 158 | trunc_normal_(m.weight, std=.02) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | def forward_features(self, x): 162 | for i in range(self.num_stages): 163 | x = self.downsample_layers[i](x) 164 | x = self.stages[i](x) 165 | return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C) 166 | 167 | def forward(self, x): 168 | x = self.forward_features(x) 169 | x = self.head(x) 170 | return x 171 | 172 | 173 | class LayerNorm(nn.Module): 174 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 175 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 176 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 177 | with shape (batch_size, channels, height, width). 178 | """ 179 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 180 | super().__init__() 181 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 182 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 183 | self.eps = eps 184 | self.data_format = data_format 185 | if self.data_format not in ["channels_last", "channels_first"]: 186 | raise NotImplementedError 187 | self.normalized_shape = (normalized_shape, ) 188 | 189 | def forward(self, x): 190 | if self.data_format == "channels_last": 191 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 192 | elif self.data_format == "channels_first": 193 | u = x.mean(1, keepdim=True) 194 | s = (x - u).pow(2).mean(1, keepdim=True) 195 | x = (x - u) / torch.sqrt(s + self.eps) 196 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 197 | return x 198 | 199 | 200 | def _cfg(url='', **kwargs): 201 | return { 202 | 'url': url, 203 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 204 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 205 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 206 | 'first_conv': 'stem.0', 'classifier': 'head.fc', 207 | **kwargs 208 | } 209 | 210 | 211 | model_urls = { 212 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 213 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 214 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 215 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 216 | "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", 217 | "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", 218 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 219 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 220 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 221 | 222 | # add by this InceptionNeXt repo 223 | "convnext_tiny_k5_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k5_1k_224_ema.pth", 224 | "convnext_tiny_k3_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_1k_224_ema.pth", 225 | "convnext_tiny_k3_par1_2_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_2_1k_224_ema.pth", 226 | "convnext_tiny_k3_par3_8_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par3_8_1k_224_ema.pth", 227 | "convnext_tiny_k3_par1_4_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_4_1k_224_ema.pth", 228 | "convnext_tiny_k3_par1_8_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_8_1k_224_ema.pth", 229 | "convnext_tiny_k3_par1_16_1k": "https://github.com/sail-sg/inceptionnext/releases/download/model/convnext_tiny_k3_par1_16_1k_224_ema.pth", 230 | 231 | } 232 | 233 | 234 | @register_model 235 | def convnext_tiny(pretrained=False,in_22k=False, **kwargs): 236 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 237 | model.default_cfg = _cfg() 238 | if pretrained: 239 | url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] 240 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 241 | model.load_state_dict(checkpoint["model"]) 242 | return model 243 | 244 | @register_model 245 | def convnext_tiny_k5(pretrained=False,in_22k=False, **kwargs): 246 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 247 | kernel_sizes=5, 248 | **kwargs) 249 | assert not in_22k, "22k pre-trained model not available" 250 | model.default_cfg = _cfg() 251 | if pretrained: 252 | url = model_urls['convnext_tiny_k5_1k'] 253 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 254 | model.load_state_dict(checkpoint) 255 | return model 256 | 257 | @register_model 258 | def convnext_tiny_k3(pretrained=False,in_22k=False, **kwargs): 259 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 260 | kernel_sizes=3, 261 | **kwargs) 262 | assert not in_22k, "22k pre-trained model not available" 263 | model.default_cfg = _cfg() 264 | if pretrained: 265 | url = model_urls['convnext_tiny_k3_1k'] 266 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 267 | model.load_state_dict(checkpoint) 268 | return model 269 | 270 | @register_model 271 | def convnext_tiny_k3_par1_2(pretrained=False,in_22k=False, **kwargs): 272 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 273 | kernel_sizes=3, 274 | conv_fns=partial(PartialConv2d, conv_ratio=0.5), 275 | **kwargs) 276 | assert not in_22k, "22k pre-trained model not available" 277 | model.default_cfg = _cfg() 278 | if pretrained: 279 | url = model_urls['convnext_tiny_k3_par1_2_1k'] 280 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 281 | model.load_state_dict(checkpoint) 282 | return model 283 | 284 | @register_model 285 | def convnext_tiny_k3_par3_8(pretrained=False,in_22k=False, **kwargs): 286 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 287 | kernel_sizes=3, 288 | conv_fns=partial(PartialConv2d, conv_ratio=3/8), 289 | **kwargs) 290 | assert not in_22k, "22k pre-trained model not available" 291 | model.default_cfg = _cfg() 292 | if pretrained: 293 | url = model_urls['convnext_tiny_k3_par3_8_1k'] 294 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 295 | model.load_state_dict(checkpoint) 296 | return model 297 | 298 | @register_model 299 | def convnext_tiny_k3_par1_4(pretrained=False,in_22k=False, **kwargs): 300 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 301 | kernel_sizes=3, 302 | conv_fns=partial(PartialConv2d, conv_ratio=0.25), 303 | **kwargs) 304 | assert not in_22k, "22k pre-trained model not available" 305 | model.default_cfg = _cfg() 306 | if pretrained: 307 | url = model_urls['convnext_tiny_k3_par1_4_1k'] 308 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 309 | model.load_state_dict(checkpoint) 310 | return model 311 | 312 | @register_model 313 | def convnext_tiny_k3_par1_8(pretrained=False,in_22k=False, **kwargs): 314 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 315 | kernel_sizes=3, 316 | conv_fns=partial(PartialConv2d, conv_ratio=0.125), 317 | **kwargs) 318 | assert not in_22k, "22k pre-trained model not available" 319 | model.default_cfg = _cfg() 320 | if pretrained: 321 | url = model_urls['convnext_tiny_k3_par1_8_1k'] 322 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 323 | model.load_state_dict(checkpoint) 324 | return model 325 | 326 | @register_model 327 | def convnext_tiny_k3_par1_16(pretrained=False,in_22k=False, **kwargs): 328 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], 329 | kernel_sizes=3, 330 | conv_fns=partial(PartialConv2d, conv_ratio=1/16), 331 | **kwargs) 332 | assert not in_22k, "22k pre-trained model not available" 333 | model.default_cfg = _cfg() 334 | if pretrained: 335 | url = model_urls['convnext_tiny_k3_par1_16_1k'] 336 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 337 | model.load_state_dict(checkpoint) 338 | return model 339 | 340 | @register_model 341 | def convnext_small(pretrained=False,in_22k=False, **kwargs): 342 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 343 | model.default_cfg = _cfg() 344 | if pretrained: 345 | url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] 346 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 347 | model.load_state_dict(checkpoint["model"]) 348 | return model 349 | 350 | @register_model 351 | def convnext_base(pretrained=False, in_22k=False, **kwargs): 352 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 353 | model.default_cfg = _cfg() 354 | if pretrained: 355 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 356 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 357 | model.load_state_dict(checkpoint["model"]) 358 | return model 359 | 360 | @register_model 361 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 362 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 363 | model.default_cfg = _cfg() 364 | if pretrained: 365 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 366 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 367 | model.load_state_dict(checkpoint["model"]) 368 | return model 369 | 370 | @register_model 371 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 372 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 373 | model.default_cfg = _cfg() 374 | if pretrained: 375 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 376 | url = model_urls['convnext_xlarge_22k'] 377 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 378 | model.load_state_dict(checkpoint["model"]) 379 | return model -------------------------------------------------------------------------------- /models/inceptionnext.py: -------------------------------------------------------------------------------- 1 | """ 2 | InceptionNeXt implementation, paper: https://arxiv.org/abs/2303.16900 3 | 4 | Some code is borrowed from timm: https://github.com/huggingface/pytorch-image-models 5 | """ 6 | 7 | from functools import partial 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.models.helpers import checkpoint_seq 14 | from timm.models.layers import trunc_normal_, DropPath 15 | from timm.models.registry import register_model 16 | from timm.models.layers.helpers import to_2tuple 17 | 18 | 19 | class InceptionDWConv2d(nn.Module): 20 | """ Inception depthweise convolution 21 | """ 22 | def __init__(self, in_channels, square_kernel_size=3, band_kernel_size=11, branch_ratio=0.125): 23 | super().__init__() 24 | 25 | gc = int(in_channels * branch_ratio) # channel numbers of a convolution branch 26 | self.dwconv_hw = nn.Conv2d(gc, gc, square_kernel_size, padding=square_kernel_size//2, groups=gc) 27 | self.dwconv_w = nn.Conv2d(gc, gc, kernel_size=(1, band_kernel_size), padding=(0, band_kernel_size//2), groups=gc) 28 | self.dwconv_h = nn.Conv2d(gc, gc, kernel_size=(band_kernel_size, 1), padding=(band_kernel_size//2, 0), groups=gc) 29 | self.split_indexes = (in_channels - 3 * gc, gc, gc, gc) 30 | 31 | def forward(self, x): 32 | x_id, x_hw, x_w, x_h = torch.split(x, self.split_indexes, dim=1) 33 | return torch.cat( 34 | (x_id, self.dwconv_hw(x_hw), self.dwconv_w(x_w), self.dwconv_h(x_h)), 35 | dim=1, 36 | ) 37 | 38 | 39 | class ConvMlp(nn.Module): 40 | """ MLP using 1x1 convs that keeps spatial dims 41 | copied from timm: https://github.com/huggingface/pytorch-image-models/blob/v0.6.11/timm/models/layers/mlp.py 42 | """ 43 | def __init__( 44 | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, 45 | norm_layer=None, bias=True, drop=0.): 46 | super().__init__() 47 | out_features = out_features or in_features 48 | hidden_features = hidden_features or in_features 49 | bias = to_2tuple(bias) 50 | 51 | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) 52 | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() 53 | self.act = act_layer() 54 | self.drop = nn.Dropout(drop) 55 | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) 56 | 57 | def forward(self, x): 58 | x = self.fc1(x) 59 | x = self.norm(x) 60 | x = self.act(x) 61 | x = self.drop(x) 62 | x = self.fc2(x) 63 | return x 64 | 65 | 66 | class MlpHead(nn.Module): 67 | """ MLP classification head 68 | """ 69 | def __init__(self, dim, num_classes=1000, mlp_ratio=3, act_layer=nn.GELU, 70 | norm_layer=partial(nn.LayerNorm, eps=1e-6), drop=0., bias=True): 71 | super().__init__() 72 | hidden_features = int(mlp_ratio * dim) 73 | self.fc1 = nn.Linear(dim, hidden_features, bias=bias) 74 | self.act = act_layer() 75 | self.norm = norm_layer(hidden_features) 76 | self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) 77 | self.drop = nn.Dropout(drop) 78 | 79 | def forward(self, x): 80 | x = x.mean((2, 3)) # global average pooling 81 | x = self.fc1(x) 82 | x = self.act(x) 83 | x = self.norm(x) 84 | x = self.drop(x) 85 | x = self.fc2(x) 86 | return x 87 | 88 | 89 | class MetaNeXtBlock(nn.Module): 90 | """ MetaNeXtBlock Block 91 | Args: 92 | dim (int): Number of input channels. 93 | drop_path (float): Stochastic depth rate. Default: 0.0 94 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6. 95 | """ 96 | 97 | def __init__( 98 | self, 99 | dim, 100 | token_mixer=nn.Identity, 101 | norm_layer=nn.BatchNorm2d, 102 | mlp_layer=ConvMlp, 103 | mlp_ratio=4, 104 | act_layer=nn.GELU, 105 | ls_init_value=1e-6, 106 | drop_path=0., 107 | 108 | ): 109 | super().__init__() 110 | self.token_mixer = token_mixer(dim) 111 | self.norm = norm_layer(dim) 112 | self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=act_layer) 113 | self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value else None 114 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 115 | 116 | def forward(self, x): 117 | shortcut = x 118 | x = self.token_mixer(x) 119 | x = self.norm(x) 120 | x = self.mlp(x) 121 | if self.gamma is not None: 122 | x = x.mul(self.gamma.reshape(1, -1, 1, 1)) 123 | x = self.drop_path(x) + shortcut 124 | return x 125 | 126 | 127 | class MetaNeXtStage(nn.Module): 128 | def __init__( 129 | self, 130 | in_chs, 131 | out_chs, 132 | ds_stride=2, 133 | depth=2, 134 | drop_path_rates=None, 135 | ls_init_value=1.0, 136 | token_mixer=nn.Identity, 137 | act_layer=nn.GELU, 138 | norm_layer=None, 139 | mlp_ratio=4, 140 | ): 141 | super().__init__() 142 | self.grad_checkpointing = False 143 | if ds_stride > 1: 144 | self.downsample = nn.Sequential( 145 | norm_layer(in_chs), 146 | nn.Conv2d(in_chs, out_chs, kernel_size=ds_stride, stride=ds_stride), 147 | ) 148 | else: 149 | self.downsample = nn.Identity() 150 | 151 | drop_path_rates = drop_path_rates or [0.] * depth 152 | stage_blocks = [] 153 | for i in range(depth): 154 | stage_blocks.append(MetaNeXtBlock( 155 | dim=out_chs, 156 | drop_path=drop_path_rates[i], 157 | ls_init_value=ls_init_value, 158 | token_mixer=token_mixer, 159 | act_layer=act_layer, 160 | norm_layer=norm_layer, 161 | mlp_ratio=mlp_ratio, 162 | )) 163 | in_chs = out_chs 164 | self.blocks = nn.Sequential(*stage_blocks) 165 | 166 | def forward(self, x): 167 | x = self.downsample(x) 168 | if self.grad_checkpointing and not torch.jit.is_scripting(): 169 | x = checkpoint_seq(self.blocks, x) 170 | else: 171 | x = self.blocks(x) 172 | return x 173 | 174 | 175 | class MetaNeXt(nn.Module): 176 | r""" MetaNeXt 177 | A PyTorch impl of : `InceptionNeXt: When Inception Meets ConvNeXt` - https://arxiv.org/pdf/2203.xxxxx.pdf 178 | 179 | Args: 180 | in_chans (int): Number of input image channels. Default: 3 181 | num_classes (int): Number of classes for classification head. Default: 1000 182 | depths (tuple(int)): Number of blocks at each stage. Default: (3, 3, 9, 3) 183 | dims (tuple(int)): Feature dimension at each stage. Default: (96, 192, 384, 768) 184 | token_mixers: Token mixer function. Default: nn.Identity 185 | norm_layer: Normalziation layer. Default: nn.BatchNorm2d 186 | act_layer: Activation function for MLP. Default: nn.GELU 187 | mlp_ratios (int or tuple(int)): MLP ratios. Default: (4, 4, 4, 3) 188 | head_fn: classifier head 189 | drop_rate (float): Head dropout rate 190 | drop_path_rate (float): Stochastic depth rate. Default: 0. 191 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6. 192 | """ 193 | 194 | def __init__( 195 | self, 196 | in_chans=3, 197 | num_classes=1000, 198 | depths=(3, 3, 9, 3), 199 | dims=(96, 192, 384, 768), 200 | token_mixers=nn.Identity, 201 | norm_layer=nn.BatchNorm2d, 202 | act_layer=nn.GELU, 203 | mlp_ratios=(4, 4, 4, 3), 204 | head_fn=MlpHead, 205 | drop_rate=0., 206 | drop_path_rate=0., 207 | ls_init_value=1e-6, 208 | **kwargs, 209 | ): 210 | super().__init__() 211 | 212 | num_stage = len(depths) 213 | if not isinstance(token_mixers, (list, tuple)): 214 | token_mixers = [token_mixers] * num_stage 215 | if not isinstance(mlp_ratios, (list, tuple)): 216 | mlp_ratios = [mlp_ratios] * num_stage 217 | 218 | 219 | self.num_classes = num_classes 220 | self.drop_rate = drop_rate 221 | self.stem = nn.Sequential( 222 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 223 | norm_layer(dims[0]) 224 | ) 225 | 226 | self.stages = nn.Sequential() 227 | dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] 228 | stages = [] 229 | prev_chs = dims[0] 230 | # feature resolution stages, each consisting of multiple residual blocks 231 | for i in range(num_stage): 232 | out_chs = dims[i] 233 | stages.append(MetaNeXtStage( 234 | prev_chs, 235 | out_chs, 236 | ds_stride=2 if i > 0 else 1, 237 | depth=depths[i], 238 | drop_path_rates=dp_rates[i], 239 | ls_init_value=ls_init_value, 240 | act_layer=act_layer, 241 | token_mixer=token_mixers[i], 242 | norm_layer=norm_layer, 243 | mlp_ratio=mlp_ratios[i], 244 | )) 245 | prev_chs = out_chs 246 | self.stages = nn.Sequential(*stages) 247 | self.num_features = prev_chs 248 | self.head = head_fn(self.num_features, num_classes, drop=drop_rate) 249 | self.apply(self._init_weights) 250 | 251 | @torch.jit.ignore 252 | def set_grad_checkpointing(self, enable=True): 253 | for s in self.stages: 254 | s.grad_checkpointing = enable 255 | 256 | @torch.jit.ignore 257 | def no_weight_decay(self): 258 | return {'norm'} 259 | 260 | 261 | def forward_features(self, x): 262 | x = self.stem(x) 263 | x = self.stages(x) 264 | return x 265 | 266 | def forward_head(self, x): 267 | x = self.head(x) 268 | return x 269 | 270 | def forward(self, x): 271 | x = self.forward_features(x) 272 | x = self.forward_head(x) 273 | return x 274 | 275 | 276 | def _init_weights(self, m): 277 | if isinstance(m, (nn.Conv2d, nn.Linear)): 278 | trunc_normal_(m.weight, std=.02) 279 | if m.bias is not None: 280 | nn.init.constant_(m.bias, 0) 281 | 282 | 283 | 284 | def _cfg(url='', **kwargs): 285 | return { 286 | 'url': url, 287 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 288 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 289 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 290 | 'first_conv': 'stem.0', 'classifier': 'head.fc', 291 | **kwargs 292 | } 293 | 294 | 295 | default_cfgs = dict( 296 | inceptionnext_atto=_cfg( 297 | url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_atto.pth', 298 | ), 299 | inceptionnext_tiny=_cfg( 300 | url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_tiny.pth', 301 | ), 302 | inceptionnext_small=_cfg( 303 | url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_small.pth', 304 | ), 305 | inceptionnext_base=_cfg( 306 | url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base.pth', 307 | ), 308 | inceptionnext_base_384=_cfg( 309 | url='https://github.com/sail-sg/inceptionnext/releases/download/model/inceptionnext_base_384.pth', 310 | input_size=(3, 384, 384), crop_pct=1.0, 311 | ), 312 | ) 313 | 314 | @register_model 315 | def inceptionnext_atto(pretrained=False, **kwargs): 316 | model = MetaNeXt(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), 317 | token_mixers=partial(InceptionDWConv2d, band_kernel_size=9, branch_ratio=0.25), 318 | **kwargs 319 | ) 320 | model.default_cfg = default_cfgs['inceptionnext_atto'] 321 | if pretrained: 322 | state_dict = torch.hub.load_state_dict_from_url( 323 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 324 | model.load_state_dict(state_dict) 325 | return model 326 | 327 | @register_model 328 | def inceptionnext_tiny(pretrained=False, **kwargs): 329 | model = MetaNeXt(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), 330 | token_mixers=InceptionDWConv2d, 331 | **kwargs 332 | ) 333 | model.default_cfg = default_cfgs['inceptionnext_tiny'] 334 | if pretrained: 335 | state_dict = torch.hub.load_state_dict_from_url( 336 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 337 | model.load_state_dict(state_dict) 338 | return model 339 | 340 | @register_model 341 | def inceptionnext_small(pretrained=False, **kwargs): 342 | model = MetaNeXt(depths=(3, 3, 27, 3), dims=(96, 192, 384, 768), 343 | token_mixers=InceptionDWConv2d, 344 | **kwargs 345 | ) 346 | model.default_cfg = default_cfgs['inceptionnext_small'] 347 | if pretrained: 348 | state_dict = torch.hub.load_state_dict_from_url( 349 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 350 | model.load_state_dict(state_dict) 351 | return model 352 | 353 | @register_model 354 | def inceptionnext_base(pretrained=False, **kwargs): 355 | model = MetaNeXt(depths=(3, 3, 27, 3), dims=(128, 256, 512, 1024), 356 | token_mixers=InceptionDWConv2d, 357 | **kwargs 358 | ) 359 | model.default_cfg = default_cfgs['inceptionnext_base'] 360 | if pretrained: 361 | state_dict = torch.hub.load_state_dict_from_url( 362 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 363 | model.load_state_dict(state_dict) 364 | return model 365 | 366 | @register_model 367 | def inceptionnext_base_384(pretrained=False, **kwargs): 368 | model = MetaNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], 369 | mlp_ratios=[4, 4, 4, 3], 370 | token_mixers=InceptionDWConv2d, 371 | **kwargs 372 | ) 373 | model.default_cfg = default_cfgs['inceptionnext_base_384'] 374 | if pretrained: 375 | state_dict = torch.hub.load_state_dict_from_url( 376 | url=model.default_cfg['url'], map_location="cpu", check_hash=True) 377 | model.load_state_dict(state_dict) 378 | return model 379 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_par1_16_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3_par1_16 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_par1_2_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3_par1_2 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_par1_4_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3_par1_4 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_par1_8_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3_par1_8 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k3_par3_8_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k3_par3_8 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/convnext_variants/train_convnext_tiny_k5_ema.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny_k5 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH \ 19 | --model-ema --model-ema-decay 0.9999 20 | -------------------------------------------------------------------------------- /scripts/inceptionnext/finetune_inceptionnext_base.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | INIT_CKPT=/path/to/trained/model.pth 4 | 5 | 6 | ALL_BATCH_SIZE=1024 7 | NUM_GPU=8 8 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 9 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 10 | 11 | 12 | MODEL=inceptionnext_base_384 13 | DROP_PATH=0.7 14 | DROP=0.5 15 | 16 | 17 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 18 | --model $MODEL --img-size 384 --epochs 30 --opt adamw --lr 5e-5 --sched None \ 19 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 20 | --initial-checkpoint $INIT_CKPT \ 21 | --mixup 0 --cutmix 0 \ 22 | --model-ema --model-ema-decay 0.9999 \ 23 | --drop-path $DROP_PATH --drop $DROP -------------------------------------------------------------------------------- /scripts/inceptionnext/train_inceptionnext_atto.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=1280 6 | NUM_GPU=4 7 | GRAD_ACCUM_STEPS=1 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_atto 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL_NAME --opt adamw --lr 1e-3 --warmup-epochs 5 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DATA_PATH \ 19 | --aa rand-m5-inc1-mstd101 \ 20 | --color-jitter 0 \ 21 | --epochs 600 \ 22 | --mixup 0.2 \ 23 | --reprob 0.1 \ 24 | --warmup-lr 5e-7 --min-lr 5e-7 \ 25 | --no-cj -------------------------------------------------------------------------------- /scripts/inceptionnext/train_inceptionnext_base.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny 12 | DROP_PATH=0.4 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH -------------------------------------------------------------------------------- /scripts/inceptionnext/train_inceptionnext_small.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_small 12 | DROP_PATH=0.3 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH -------------------------------------------------------------------------------- /scripts/inceptionnext/train_inceptionnext_tiny.sh: -------------------------------------------------------------------------------- 1 | DATA_PATH=/path/to/imagenet 2 | CODE_PATH=/path/to/code/inceptionnext # modify code path here 3 | 4 | 5 | ALL_BATCH_SIZE=4096 6 | NUM_GPU=8 7 | GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. 8 | let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS 9 | 10 | 11 | MODEL=inceptionnext_tiny 12 | DROP_PATH=0.1 13 | 14 | 15 | cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ 16 | --model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ 17 | -b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ 18 | --drop-path $DROP_PATH -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | r""" 3 | This script is mostly copied from https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/train.py 4 | and make some modifications: 5 | 1) enable the gradient accumulation (`--grad-accum-steps`) 6 | 2) add `--head-dropout` for ConvFormer and CAFormer with MLP head 7 | 3) Set some default values of hyper-parameters following DeiT: 8 | -j 8 \ 9 | --opt adamw \ 10 | --epochs 300 \ 11 | --sched cosine \ 12 | --warmup-epochs 5 \ 13 | --warmup-lr 1e-6 \ 14 | --min-lr 1e-5 \ 15 | --weight-decay 0.05 \ 16 | --smoothing 0.1 \ 17 | --aa rand-m9-mstd0.5-inc1 \ 18 | --mixup 0.8 \ 19 | --cutmix 1.0 \ 20 | --remode pixel \ 21 | --reprob 0.25 \ 22 | """ 23 | 24 | """ ImageNet Training Script 25 | 26 | This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet 27 | training results with some of the latest networks and training techniques. It favours canonical PyTorch 28 | and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed 29 | and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. 30 | 31 | This script was started from an early version of the PyTorch ImageNet example 32 | (https://github.com/pytorch/examples/tree/master/imagenet) 33 | 34 | NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples 35 | (https://github.com/NVIDIA/apex/tree/master/examples/imagenet) 36 | 37 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 38 | """ 39 | import argparse 40 | import logging 41 | import os 42 | import time 43 | from collections import OrderedDict 44 | from contextlib import suppress 45 | from datetime import datetime 46 | 47 | import torch 48 | import torch.nn as nn 49 | import torchvision.utils 50 | import yaml 51 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 52 | 53 | from timm import utils 54 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 55 | from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ 56 | LabelSmoothingCrossEntropy 57 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ 58 | convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm 59 | from timm.optim import create_optimizer_v2, optimizer_kwargs 60 | from timm.scheduler import create_scheduler 61 | # from timm.utils import ApexScaler, NativeScaler 62 | from utils import ApexScalerAccum as ApexScaler 63 | from utils import NativeScalerAccum as NativeScaler 64 | 65 | import models 66 | 67 | try: 68 | from apex import amp 69 | from apex.parallel import DistributedDataParallel as ApexDDP 70 | from apex.parallel import convert_syncbn_model 71 | has_apex = True 72 | except ImportError: 73 | has_apex = False 74 | 75 | has_native_amp = False 76 | try: 77 | if getattr(torch.cuda.amp, 'autocast') is not None: 78 | has_native_amp = True 79 | except AttributeError: 80 | pass 81 | 82 | try: 83 | import wandb 84 | has_wandb = True 85 | except ImportError: 86 | has_wandb = False 87 | 88 | try: 89 | from functorch.compile import memory_efficient_fusion 90 | has_functorch = True 91 | except ImportError as e: 92 | has_functorch = False 93 | 94 | 95 | torch.backends.cudnn.benchmark = True 96 | _logger = logging.getLogger('train') 97 | 98 | # The first arg parser parses out only the --config argument, this argument is used to 99 | # load a yaml file containing key-values that override the defaults for the main parser below 100 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 101 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 102 | help='YAML config file specifying default arguments') 103 | 104 | 105 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 106 | 107 | # Dataset parameters 108 | group = parser.add_argument_group('Dataset parameters') 109 | # Keep this argument outside of the dataset group because it is positional. 110 | parser.add_argument('data_dir', metavar='DIR', 111 | help='path to dataset') 112 | group.add_argument('--dataset', '-d', metavar='NAME', default='', 113 | help='dataset type (default: ImageFolder/ImageTar if empty)') 114 | group.add_argument('--train-split', metavar='NAME', default='train', 115 | help='dataset train split (default: train)') 116 | group.add_argument('--val-split', metavar='NAME', default='validation', 117 | help='dataset validation split (default: validation)') 118 | group.add_argument('--dataset-download', action='store_true', default=False, 119 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 120 | group.add_argument('--class-map', default='', type=str, metavar='FILENAME', 121 | help='path to class to idx mapping file (default: "")') 122 | 123 | # Model parameters 124 | group = parser.add_argument_group('Model parameters') 125 | group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', 126 | help='Name of model to train (default: "resnet50"') 127 | group.add_argument('--pretrained', action='store_true', default=False, 128 | help='Start with pretrained version of specified network (if avail)') 129 | group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 130 | help='Initialize model from this checkpoint (default: none)') 131 | group.add_argument('--resume', default='', type=str, metavar='PATH', 132 | help='Resume full model and optimizer state from checkpoint (default: none)') 133 | group.add_argument('--no-resume-opt', action='store_true', default=False, 134 | help='prevent resume of optimizer state when resuming model') 135 | group.add_argument('--num-classes', type=int, default=None, metavar='N', 136 | help='number of label classes (Model default if None)') 137 | group.add_argument('--gp', default=None, type=str, metavar='POOL', 138 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 139 | group.add_argument('--img-size', type=int, default=None, metavar='N', 140 | help='Image patch size (default: None => model default)') 141 | group.add_argument('--input-size', default=None, nargs=3, type=int, 142 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 143 | group.add_argument('--crop-pct', default=None, type=float, 144 | metavar='N', help='Input image center crop percent (for validation only)') 145 | group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 146 | help='Override mean pixel value of dataset') 147 | group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 148 | help='Override std deviation of dataset') 149 | group.add_argument('--interpolation', default='', type=str, metavar='NAME', 150 | help='Image resize interpolation type (overrides model)') 151 | group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 152 | help='Input batch size for training (default: 128)') 153 | group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 154 | help='Validation batch size override (default: None)') 155 | group.add_argument('--channels-last', action='store_true', default=False, 156 | help='Use channels_last memory layout') 157 | scripting_group = group.add_mutually_exclusive_group() 158 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', 159 | help='torch.jit.script the full model') 160 | scripting_group.add_argument('--aot-autograd', default=False, action='store_true', 161 | help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") 162 | group.add_argument('--fuser', default='', type=str, 163 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 164 | group.add_argument('--fast-norm', default=False, action='store_true', 165 | help='enable experimental fast-norm') 166 | group.add_argument('--grad-checkpointing', action='store_true', default=False, 167 | help='Enable gradient checkpointing through model blocks/stages') 168 | 169 | # Optimizer parameters 170 | group = parser.add_argument_group('Optimizer parameters') 171 | group.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 172 | help='Optimizer (default: "adamw"') 173 | group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 174 | help='Optimizer Epsilon (default: None, use opt default)') 175 | group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 176 | help='Optimizer Betas (default: None, use opt default)') 177 | group.add_argument('--momentum', type=float, default=0.9, metavar='M', 178 | help='Optimizer momentum (default: 0.9)') 179 | group.add_argument('--weight-decay', type=float, default=0.05, 180 | help='weight decay (default: 0.05)') 181 | group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 182 | help='Clip gradient norm (default: None, no clipping)') 183 | group.add_argument('--clip-mode', type=str, default='norm', 184 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 185 | group.add_argument('--layer-decay', type=float, default=None, 186 | help='layer-wise learning rate decay (default: None)') 187 | 188 | # Learning rate schedule parameters 189 | group = parser.add_argument_group('Learning rate schedule parameters') 190 | group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 191 | help='LR scheduler (default: "cosine"') 192 | group.add_argument('--lr', type=float, default=0.05, metavar='LR', 193 | help='learning rate (default: 0.05)') 194 | group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 195 | help='learning rate noise on/off epoch percentages') 196 | group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 197 | help='learning rate noise limit percent (default: 0.67)') 198 | group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 199 | help='learning rate noise std-dev (default: 1.0)') 200 | group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 201 | help='learning rate cycle len multiplier (default: 1.0)') 202 | group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 203 | help='amount to decay each learning rate cycle (default: 0.5)') 204 | group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 205 | help='learning rate cycle limit, cycles enabled if > 1') 206 | group.add_argument('--lr-k-decay', type=float, default=1.0, 207 | help='learning rate k-decay for cosine/poly (default: 1.0)') 208 | group.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 209 | help='warmup learning rate (default: 1e-6)') 210 | group.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 211 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 212 | group.add_argument('--epochs', type=int, default=300, metavar='N', 213 | help='number of epochs to train (default: 300)') 214 | parser.add_argument('--grad-accum-steps', default=1, type=int, 215 | help='gradient accumulation steps') 216 | group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 217 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 218 | group.add_argument('--start-epoch', default=None, type=int, metavar='N', 219 | help='manual epoch number (useful on restarts)') 220 | group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", 221 | help='list of decay epoch indices for multistep lr. must be increasing') 222 | group.add_argument('--decay-epochs', type=float, default=100, metavar='N', 223 | help='epoch interval to decay LR') 224 | group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 225 | help='epochs to warmup LR, if scheduler supports') 226 | group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 227 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 228 | group.add_argument('--patience-epochs', type=int, default=10, metavar='N', 229 | help='patience epochs for Plateau LR scheduler (default: 10') 230 | group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 231 | help='LR decay rate (default: 0.1)') 232 | 233 | # Augmentation & regularization parameters 234 | group = parser.add_argument_group('Augmentation and regularization parameters') 235 | group.add_argument('--no-aug', action='store_true', default=False, 236 | help='Disable all training augmentation, override other train aug args') 237 | group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 238 | help='Random resize scale (default: 0.08 1.0)') 239 | group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 240 | help='Random resize aspect ratio (default: 0.75 1.33)') 241 | group.add_argument('--hflip', type=float, default=0.5, 242 | help='Horizontal flip training aug probability') 243 | group.add_argument('--vflip', type=float, default=0., 244 | help='Vertical flip training aug probability') 245 | group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 246 | help='Color jitter factor (default: 0.4)') 247 | group.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 248 | help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'), 249 | group.add_argument('--aug-repeats', type=float, default=0, 250 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 251 | group.add_argument('--aug-splits', type=int, default=0, 252 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 253 | group.add_argument('--jsd-loss', action='store_true', default=False, 254 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 255 | group.add_argument('--bce-loss', action='store_true', default=False, 256 | help='Enable BCE loss w/ Mixup/CutMix use.') 257 | group.add_argument('--bce-target-thresh', type=float, default=None, 258 | help='Threshold for binarizing softened BCE targets (default: None, disabled)') 259 | group.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 260 | help='Random erase prob (default: 0.25)') 261 | group.add_argument('--remode', type=str, default='pixel', 262 | help='Random erase mode (default: "pixel")') 263 | group.add_argument('--recount', type=int, default=1, 264 | help='Random erase count (default: 1)') 265 | group.add_argument('--resplit', action='store_true', default=False, 266 | help='Do not random erase first (clean) augmentation split') 267 | group.add_argument('--mixup', type=float, default=0.8, 268 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 269 | group.add_argument('--cutmix', type=float, default=1.0, 270 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 271 | group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 272 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 273 | group.add_argument('--mixup-prob', type=float, default=1.0, 274 | help='Probability of performing mixup or cutmix when either/both is enabled') 275 | group.add_argument('--mixup-switch-prob', type=float, default=0.5, 276 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 277 | group.add_argument('--mixup-mode', type=str, default='batch', 278 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 279 | group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 280 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 281 | group.add_argument('--smoothing', type=float, default=0.1, 282 | help='Label smoothing (default: 0.1)') 283 | group.add_argument('--train-interpolation', type=str, default='random', 284 | help='Training interpolation (random, bilinear, bicubic default: "random")') 285 | group.add_argument('--drop', type=float, default=0.0, metavar='PCT', 286 | help='Dropout rate (default: 0.)') 287 | group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 288 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 289 | group.add_argument('--drop-path', type=float, default=None, metavar='PCT', 290 | help='Drop path rate (default: None)') 291 | group.add_argument('--drop-block', type=float, default=None, metavar='PCT', 292 | help='Drop block rate (default: None)') 293 | group.add_argument('--head-dropout', type=float, default=0.0, metavar='PCT', 294 | help='dropout rate for classifier (default: 0.0)') 295 | 296 | # Batch norm parameters (only works with gen_efficientnet based models currently) 297 | group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') 298 | group.add_argument('--bn-momentum', type=float, default=None, 299 | help='BatchNorm momentum override (if not None)') 300 | group.add_argument('--bn-eps', type=float, default=None, 301 | help='BatchNorm epsilon override (if not None)') 302 | group.add_argument('--sync-bn', action='store_true', 303 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 304 | group.add_argument('--dist-bn', type=str, default='reduce', 305 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 306 | group.add_argument('--split-bn', action='store_true', 307 | help='Enable separate BN layers per augmentation split.') 308 | 309 | # Model Exponential Moving Average 310 | group = parser.add_argument_group('Model exponential moving average parameters') 311 | group.add_argument('--model-ema', action='store_true', default=False, 312 | help='Enable tracking moving average of model weights') 313 | group.add_argument('--model-ema-force-cpu', action='store_true', default=False, 314 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 315 | group.add_argument('--model-ema-decay', type=float, default=0.9998, 316 | help='decay factor for model weights moving average (default: 0.9998)') 317 | 318 | # Misc 319 | group = parser.add_argument_group('Miscellaneous parameters') 320 | group.add_argument('--seed', type=int, default=42, metavar='S', 321 | help='random seed (default: 42)') 322 | group.add_argument('--worker-seeding', type=str, default='all', 323 | help='worker seed mode (default: all)') 324 | group.add_argument('--log-interval', type=int, default=50, metavar='N', 325 | help='how many batches to wait before logging training status') 326 | group.add_argument('--recovery-interval', type=int, default=0, metavar='N', 327 | help='how many batches to wait before writing recovery checkpoint') 328 | group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 329 | help='number of checkpoints to keep (default: 10)') 330 | group.add_argument('-j', '--workers', type=int, default=8, metavar='N', 331 | help='how many training processes to use (default: 8)') 332 | group.add_argument('--save-images', action='store_true', default=False, 333 | help='save images of input bathes every log interval for debugging') 334 | group.add_argument('--amp', action='store_true', default=False, 335 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 336 | group.add_argument('--apex-amp', action='store_true', default=False, 337 | help='Use NVIDIA Apex AMP mixed precision') 338 | group.add_argument('--native-amp', action='store_true', default=False, 339 | help='Use Native Torch AMP mixed precision') 340 | group.add_argument('--no-ddp-bb', action='store_true', default=False, 341 | help='Force broadcast buffers for native DDP to off.') 342 | group.add_argument('--pin-mem', action='store_true', default=False, 343 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 344 | group.add_argument('--no-prefetcher', action='store_true', default=False, 345 | help='disable fast prefetcher') 346 | group.add_argument('--output', default='', type=str, metavar='PATH', 347 | help='path to output folder (default: none, current dir)') 348 | group.add_argument('--experiment', default='', type=str, metavar='NAME', 349 | help='name of train experiment, name of sub-folder for output') 350 | group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 351 | help='Best metric (default: "top1"') 352 | group.add_argument('--tta', type=int, default=0, metavar='N', 353 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 354 | group.add_argument("--local_rank", default=0, type=int) 355 | group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 356 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 357 | group.add_argument('--log-wandb', action='store_true', default=False, 358 | help='log training and validation metrics to wandb') 359 | 360 | 361 | def _parse_args(): 362 | # Do we have a config file to parse? 363 | args_config, remaining = config_parser.parse_known_args() 364 | if args_config.config: 365 | with open(args_config.config, 'r') as f: 366 | cfg = yaml.safe_load(f) 367 | parser.set_defaults(**cfg) 368 | 369 | # The main arg parser parses the rest of the args, the usual 370 | # defaults will have been overridden if config file specified. 371 | args = parser.parse_args(remaining) 372 | 373 | # Cache the args as a text string to save them in the output dir later 374 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 375 | return args, args_text 376 | 377 | 378 | def main(): 379 | utils.setup_default_logging() 380 | args, args_text = _parse_args() 381 | 382 | args.prefetcher = not args.no_prefetcher 383 | args.distributed = False 384 | if 'WORLD_SIZE' in os.environ: 385 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 386 | args.device = 'cuda:0' 387 | args.world_size = 1 388 | args.rank = 0 # global rank 389 | if args.distributed: 390 | if 'LOCAL_RANK' in os.environ: 391 | args.local_rank = int(os.getenv('LOCAL_RANK')) 392 | args.device = 'cuda:%d' % args.local_rank 393 | torch.cuda.set_device(args.local_rank) 394 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 395 | args.world_size = torch.distributed.get_world_size() 396 | args.rank = torch.distributed.get_rank() 397 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 398 | % (args.rank, args.world_size)) 399 | else: 400 | _logger.info('Training with a single process on 1 GPUs.') 401 | assert args.rank >= 0 402 | 403 | if args.rank == 0 and args.log_wandb: 404 | if has_wandb: 405 | wandb.init(project=args.experiment, config=args) 406 | else: 407 | _logger.warning("You've requested to log metrics to wandb but package not found. " 408 | "Metrics not being logged to wandb, try `pip install wandb`") 409 | 410 | # resolve AMP arguments based on PyTorch / Apex availability 411 | use_amp = None 412 | if args.amp: 413 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 414 | if has_native_amp: 415 | args.native_amp = True 416 | elif has_apex: 417 | args.apex_amp = True 418 | if args.apex_amp and has_apex: 419 | use_amp = 'apex' 420 | elif args.native_amp and has_native_amp: 421 | use_amp = 'native' 422 | elif args.apex_amp or args.native_amp: 423 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 424 | "Install NVIDA apex or upgrade to PyTorch 1.6") 425 | 426 | utils.random_seed(args.seed, args.rank) 427 | 428 | if args.fuser: 429 | utils.set_jit_fuser(args.fuser) 430 | if args.fast_norm: 431 | set_fast_norm() 432 | 433 | create_model_args = dict( 434 | model_name=args.model, 435 | pretrained=args.pretrained, 436 | num_classes=args.num_classes, 437 | drop_rate=args.drop, 438 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 439 | drop_path_rate=args.drop_path, 440 | drop_block_rate=args.drop_block, 441 | global_pool=args.gp, 442 | bn_momentum=args.bn_momentum, 443 | bn_eps=args.bn_eps, 444 | scriptable=args.torchscript, 445 | checkpoint_path=args.initial_checkpoint 446 | ) 447 | 448 | if 'convformer' in args.model or 'caformer' in args.model: 449 | create_model_args.update(head_dropout=args.head_dropout) 450 | 451 | model = create_model(**create_model_args) 452 | 453 | if args.num_classes is None: 454 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 455 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 456 | 457 | if args.grad_checkpointing: 458 | model.set_grad_checkpointing(enable=True) 459 | 460 | if args.local_rank == 0: 461 | _logger.info( 462 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') 463 | 464 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 465 | 466 | # setup augmentation batch splits for contrastive loss or split bn 467 | num_aug_splits = 0 468 | if args.aug_splits > 0: 469 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 470 | num_aug_splits = args.aug_splits 471 | 472 | # enable split bn (separate bn stats per batch-portion) 473 | if args.split_bn: 474 | assert num_aug_splits > 1 or args.resplit 475 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 476 | 477 | # move model to GPU, enable channels last layout if set 478 | model.cuda() 479 | if args.channels_last: 480 | model = model.to(memory_format=torch.channels_last) 481 | 482 | # setup synchronized BatchNorm for distributed training 483 | if args.distributed and args.sync_bn: 484 | args.dist_bn = '' # disable dist_bn when sync BN active 485 | assert not args.split_bn 486 | if has_apex and use_amp == 'apex': 487 | # Apex SyncBN used with Apex AMP 488 | # WARNING this won't currently work with models using BatchNormAct2d 489 | model = convert_syncbn_model(model) 490 | else: 491 | model = convert_sync_batchnorm(model) 492 | if args.local_rank == 0: 493 | _logger.info( 494 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 495 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 496 | 497 | if args.torchscript: 498 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 499 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 500 | model = torch.jit.script(model) 501 | if args.aot_autograd: 502 | assert has_functorch, "functorch is needed for --aot-autograd" 503 | model = memory_efficient_fusion(model) 504 | 505 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) 506 | 507 | # setup automatic mixed-precision (AMP) loss scaling and op casting 508 | amp_autocast = suppress # do nothing 509 | loss_scaler = None 510 | if use_amp == 'apex': 511 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 512 | loss_scaler = ApexScaler() 513 | if args.local_rank == 0: 514 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 515 | elif use_amp == 'native': 516 | amp_autocast = torch.cuda.amp.autocast 517 | loss_scaler = NativeScaler() 518 | if args.local_rank == 0: 519 | _logger.info('Using native Torch AMP. Training in mixed precision.') 520 | else: 521 | if args.local_rank == 0: 522 | _logger.info('AMP not enabled. Training in float32.') 523 | 524 | 525 | # optionally resume from a checkpoint 526 | resume_epoch = None 527 | if args.resume: 528 | resume_epoch = resume_checkpoint( 529 | model, args.resume, 530 | optimizer=None if args.no_resume_opt else optimizer, 531 | loss_scaler=None if args.no_resume_opt else loss_scaler, 532 | log_info=args.local_rank == 0) 533 | 534 | # setup exponential moving average of model weights, SWA could be used here too 535 | model_ema = None 536 | if args.model_ema: 537 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper 538 | model_ema = utils.ModelEmaV2( 539 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) 540 | if args.resume: 541 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 542 | 543 | # setup distributed training 544 | if args.distributed: 545 | if has_apex and use_amp == 'apex': 546 | # Apex DDP preferred unless native amp is activated 547 | if args.local_rank == 0: 548 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 549 | model = ApexDDP(model, delay_allreduce=True) 550 | else: 551 | if args.local_rank == 0: 552 | _logger.info("Using native Torch DistributedDataParallel.") 553 | model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) 554 | # NOTE: EMA model does not need to be wrapped by DDP 555 | 556 | # setup learning rate schedule and starting epoch 557 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 558 | start_epoch = 0 559 | if args.start_epoch is not None: 560 | # a specified start_epoch will always override the resume epoch 561 | start_epoch = args.start_epoch 562 | elif resume_epoch is not None: 563 | start_epoch = resume_epoch 564 | if lr_scheduler is not None and start_epoch > 0: 565 | lr_scheduler.step(start_epoch) 566 | 567 | if args.local_rank == 0: 568 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 569 | 570 | # create the train and eval datasets 571 | dataset_train = create_dataset( 572 | args.dataset, root=args.data_dir, split=args.train_split, is_training=True, 573 | class_map=args.class_map, 574 | download=args.dataset_download, 575 | batch_size=args.batch_size, 576 | repeats=args.epoch_repeats) 577 | dataset_eval = create_dataset( 578 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False, 579 | class_map=args.class_map, 580 | download=args.dataset_download, 581 | batch_size=args.batch_size) 582 | 583 | total_batch_size = args.batch_size * args.grad_accum_steps * args.world_size 584 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 585 | if args.local_rank == 0: 586 | _logger.info('Total batch size: {}'.format(total_batch_size)) 587 | 588 | # setup mixup / cutmix 589 | collate_fn = None 590 | mixup_fn = None 591 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 592 | if mixup_active: 593 | mixup_args = dict( 594 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 595 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 596 | label_smoothing=args.smoothing, num_classes=args.num_classes) 597 | if args.prefetcher: 598 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 599 | collate_fn = FastCollateMixup(**mixup_args) 600 | else: 601 | mixup_fn = Mixup(**mixup_args) 602 | 603 | # wrap dataset in AugMix helper 604 | if num_aug_splits > 1: 605 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 606 | 607 | # create data loaders w/ augmentation pipeiine 608 | train_interpolation = args.train_interpolation 609 | if args.no_aug or not train_interpolation: 610 | train_interpolation = data_config['interpolation'] 611 | loader_train = create_loader( 612 | dataset_train, 613 | input_size=data_config['input_size'], 614 | batch_size=args.batch_size, 615 | is_training=True, 616 | use_prefetcher=args.prefetcher, 617 | no_aug=args.no_aug, 618 | re_prob=args.reprob, 619 | re_mode=args.remode, 620 | re_count=args.recount, 621 | re_split=args.resplit, 622 | scale=args.scale, 623 | ratio=args.ratio, 624 | hflip=args.hflip, 625 | vflip=args.vflip, 626 | color_jitter=args.color_jitter, 627 | auto_augment=args.aa, 628 | num_aug_repeats=args.aug_repeats, 629 | num_aug_splits=num_aug_splits, 630 | interpolation=train_interpolation, 631 | mean=data_config['mean'], 632 | std=data_config['std'], 633 | num_workers=args.workers, 634 | distributed=args.distributed, 635 | collate_fn=collate_fn, 636 | pin_memory=args.pin_mem, 637 | use_multi_epochs_loader=args.use_multi_epochs_loader, 638 | worker_seeding=args.worker_seeding, 639 | ) 640 | 641 | loader_eval = create_loader( 642 | dataset_eval, 643 | input_size=data_config['input_size'], 644 | batch_size=args.validation_batch_size or args.batch_size, 645 | is_training=False, 646 | use_prefetcher=args.prefetcher, 647 | interpolation=data_config['interpolation'], 648 | mean=data_config['mean'], 649 | std=data_config['std'], 650 | num_workers=args.workers, 651 | distributed=args.distributed, 652 | crop_pct=data_config['crop_pct'], 653 | pin_memory=args.pin_mem, 654 | ) 655 | 656 | # setup loss function 657 | if args.jsd_loss: 658 | assert num_aug_splits > 1 # JSD only valid with aug splits set 659 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) 660 | elif mixup_active: 661 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 662 | if args.bce_loss: 663 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) 664 | else: 665 | train_loss_fn = SoftTargetCrossEntropy() 666 | elif args.smoothing: 667 | if args.bce_loss: 668 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) 669 | else: 670 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 671 | else: 672 | train_loss_fn = nn.CrossEntropyLoss() 673 | train_loss_fn = train_loss_fn.cuda() 674 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 675 | 676 | # setup checkpoint saver and eval metric tracking 677 | eval_metric = args.eval_metric 678 | best_metric = None 679 | best_epoch = None 680 | saver = None 681 | output_dir = None 682 | if args.rank == 0: 683 | if args.experiment: 684 | exp_name = args.experiment 685 | else: 686 | exp_name = '-'.join([ 687 | datetime.now().strftime("%Y%m%d-%H%M%S"), 688 | safe_model_name(args.model), 689 | str(data_config['input_size'][-1]) 690 | ]) 691 | output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) 692 | decreasing = True if eval_metric == 'loss' else False 693 | saver = utils.CheckpointSaver( 694 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 695 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) 696 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 697 | f.write(args_text) 698 | 699 | try: 700 | for epoch in range(start_epoch, num_epochs): 701 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 702 | loader_train.sampler.set_epoch(epoch) 703 | 704 | train_metrics = train_one_epoch( 705 | epoch, model, loader_train, optimizer, train_loss_fn, args, 706 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 707 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, 708 | grad_accum_steps=args.grad_accum_steps, num_training_steps_per_epoch=num_training_steps_per_epoch 709 | ) 710 | 711 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 712 | if args.local_rank == 0: 713 | _logger.info("Distributing BatchNorm running means and vars") 714 | utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 715 | 716 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 717 | 718 | if model_ema is not None and not args.model_ema_force_cpu: 719 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 720 | utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 721 | ema_eval_metrics = validate( 722 | model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') 723 | eval_metrics = ema_eval_metrics 724 | 725 | if lr_scheduler is not None: 726 | # step LR for next epoch 727 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 728 | 729 | if output_dir is not None: 730 | utils.update_summary( 731 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), 732 | write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) 733 | 734 | if saver is not None: 735 | # save proper checkpoint with eval metric 736 | save_metric = eval_metrics[eval_metric] 737 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 738 | 739 | except KeyboardInterrupt: 740 | pass 741 | if best_metric is not None: 742 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 743 | 744 | 745 | def train_one_epoch( 746 | epoch, model, loader, optimizer, loss_fn, args, 747 | lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, 748 | loss_scaler=None, model_ema=None, mixup_fn=None, 749 | grad_accum_steps=1, num_training_steps_per_epoch=None): 750 | 751 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 752 | if args.prefetcher and loader.mixup_enabled: 753 | loader.mixup_enabled = False 754 | elif mixup_fn is not None: 755 | mixup_fn.mixup_enabled = False 756 | 757 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 758 | batch_time_m = utils.AverageMeter() 759 | data_time_m = utils.AverageMeter() 760 | losses_m = utils.AverageMeter() 761 | 762 | model.train() 763 | optimizer.zero_grad() 764 | 765 | end = time.time() 766 | last_idx = len(loader) - 1 767 | num_updates = epoch * len(loader) 768 | for batch_idx, (input, target) in enumerate(loader): 769 | step = batch_idx // grad_accum_steps 770 | if step >= num_training_steps_per_epoch: 771 | continue 772 | # last_batch = batch_idx == last_idx 773 | last_batch = ((batch_idx + 1) // grad_accum_steps) == num_training_steps_per_epoch 774 | data_time_m.update(time.time() - end) 775 | if not args.prefetcher: 776 | input, target = input.cuda(), target.cuda() 777 | if mixup_fn is not None: 778 | input, target = mixup_fn(input, target) 779 | if args.channels_last: 780 | input = input.contiguous(memory_format=torch.channels_last) 781 | 782 | with amp_autocast(): 783 | output = model(input) 784 | loss = loss_fn(output, target) 785 | 786 | if not args.distributed: 787 | losses_m.update(loss.item(), input.size(0)) 788 | 789 | 790 | update_grad = (batch_idx + 1) % grad_accum_steps == 0 791 | loss_update = loss / grad_accum_steps 792 | if loss_scaler is not None: 793 | loss_scaler( 794 | loss_update, optimizer, 795 | clip_grad=args.clip_grad, clip_mode=args.clip_mode, 796 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), 797 | create_graph=second_order, update_grad=update_grad) 798 | else: 799 | loss_update.backward(create_graph=second_order) 800 | if update_grad: 801 | if args.clip_grad is not None: 802 | utils.dispatch_clip_grad( 803 | model_parameters(model, exclude_head='agc' in args.clip_mode), 804 | value=args.clip_grad, mode=args.clip_mode) 805 | optimizer.step() 806 | 807 | if update_grad: 808 | optimizer.zero_grad() 809 | if model_ema is not None: 810 | model_ema.update(model) 811 | 812 | torch.cuda.synchronize() 813 | num_updates += 1 814 | batch_time_m.update(time.time() - end) 815 | if last_batch or batch_idx % args.log_interval == 0: 816 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 817 | lr = sum(lrl) / len(lrl) 818 | 819 | if args.distributed: 820 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 821 | losses_m.update(reduced_loss.item(), input.size(0)) 822 | 823 | if args.local_rank == 0: 824 | _logger.info( 825 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 826 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 827 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 828 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 829 | 'LR: {lr:.3e} ' 830 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 831 | epoch, 832 | batch_idx, len(loader), 833 | 100. * batch_idx / last_idx, 834 | loss=losses_m, 835 | batch_time=batch_time_m, 836 | rate=input.size(0) * args.world_size / batch_time_m.val, 837 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 838 | lr=lr, 839 | data_time=data_time_m)) 840 | 841 | if args.save_images and output_dir: 842 | torchvision.utils.save_image( 843 | input, 844 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 845 | padding=0, 846 | normalize=True) 847 | 848 | if saver is not None and args.recovery_interval and ( 849 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 850 | saver.save_recovery(epoch, batch_idx=batch_idx) 851 | 852 | if lr_scheduler is not None: 853 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 854 | 855 | end = time.time() 856 | # end for 857 | 858 | if hasattr(optimizer, 'sync_lookahead'): 859 | optimizer.sync_lookahead() 860 | 861 | return OrderedDict([('loss', losses_m.avg)]) 862 | 863 | 864 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 865 | batch_time_m = utils.AverageMeter() 866 | losses_m = utils.AverageMeter() 867 | top1_m = utils.AverageMeter() 868 | top5_m = utils.AverageMeter() 869 | 870 | model.eval() 871 | 872 | end = time.time() 873 | last_idx = len(loader) - 1 874 | with torch.no_grad(): 875 | for batch_idx, (input, target) in enumerate(loader): 876 | last_batch = batch_idx == last_idx 877 | if not args.prefetcher: 878 | input = input.cuda() 879 | target = target.cuda() 880 | if args.channels_last: 881 | input = input.contiguous(memory_format=torch.channels_last) 882 | 883 | with amp_autocast(): 884 | output = model(input) 885 | if isinstance(output, (tuple, list)): 886 | output = output[0] 887 | 888 | # augmentation reduction 889 | reduce_factor = args.tta 890 | if reduce_factor > 1: 891 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 892 | target = target[0:target.size(0):reduce_factor] 893 | 894 | loss = loss_fn(output, target) 895 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 896 | 897 | if args.distributed: 898 | reduced_loss = utils.reduce_tensor(loss.data, args.world_size) 899 | acc1 = utils.reduce_tensor(acc1, args.world_size) 900 | acc5 = utils.reduce_tensor(acc5, args.world_size) 901 | else: 902 | reduced_loss = loss.data 903 | 904 | torch.cuda.synchronize() 905 | 906 | losses_m.update(reduced_loss.item(), input.size(0)) 907 | top1_m.update(acc1.item(), output.size(0)) 908 | top5_m.update(acc5.item(), output.size(0)) 909 | 910 | batch_time_m.update(time.time() - end) 911 | end = time.time() 912 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 913 | log_name = 'Test' + log_suffix 914 | _logger.info( 915 | '{0}: [{1:>4d}/{2}] ' 916 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 917 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 918 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 919 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 920 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 921 | loss=losses_m, top1=top1_m, top5=top5_m)) 922 | 923 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 924 | 925 | return metrics 926 | 927 | 928 | if __name__ == '__main__': 929 | main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Modifed form timm and swin repo. 2 | 3 | """ CUDA / AMP utils 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import torch 8 | 9 | try: 10 | from apex import amp 11 | has_apex = True 12 | except ImportError: 13 | amp = None 14 | has_apex = False 15 | 16 | from timm.utils.clip_grad import dispatch_clip_grad 17 | 18 | 19 | class ApexScalerAccum: 20 | state_dict_key = "amp" 21 | 22 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, 23 | update_grad=True): 24 | with amp.scale_loss(loss, optimizer) as scaled_loss: 25 | scaled_loss.backward(create_graph=create_graph) 26 | if update_grad: 27 | if clip_grad is not None: 28 | dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) 29 | optimizer.step() 30 | 31 | def state_dict(self): 32 | if 'state_dict' in amp.__dict__: 33 | return amp.state_dict() 34 | 35 | def load_state_dict(self, state_dict): 36 | if 'load_state_dict' in amp.__dict__: 37 | amp.load_state_dict(state_dict) 38 | 39 | 40 | class NativeScalerAccum: 41 | state_dict_key = "amp_scaler" 42 | 43 | def __init__(self): 44 | self._scaler = torch.cuda.amp.GradScaler() 45 | 46 | def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, 47 | update_grad=True): 48 | self._scaler.scale(loss).backward(create_graph=create_graph) 49 | if update_grad: 50 | if clip_grad is not None: 51 | assert parameters is not None 52 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 53 | dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) 54 | self._scaler.step(optimizer) 55 | self._scaler.update() 56 | 57 | def state_dict(self): 58 | return self._scaler.state_dict() 59 | 60 | def load_state_dict(self, state_dict): 61 | self._scaler.load_state_dict(state_dict) 62 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # This script is mostly copied from https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/validate.py 3 | """ ImageNet Validation Script 4 | 5 | This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained 6 | models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes 7 | canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. 8 | 9 | Hacked together by Ross Wightman (https://github.com/rwightman) 10 | """ 11 | import argparse 12 | import os 13 | import csv 14 | import glob 15 | import json 16 | import time 17 | import logging 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | from collections import OrderedDict 22 | from contextlib import suppress 23 | 24 | from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm 25 | from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet 26 | from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ 27 | decay_batch_step, check_batch_size_retry 28 | 29 | import models 30 | 31 | has_apex = False 32 | try: 33 | from apex import amp 34 | has_apex = True 35 | except ImportError: 36 | pass 37 | 38 | has_native_amp = False 39 | try: 40 | if getattr(torch.cuda.amp, 'autocast') is not None: 41 | has_native_amp = True 42 | except AttributeError: 43 | pass 44 | 45 | try: 46 | from functorch.compile import memory_efficient_fusion 47 | has_functorch = True 48 | except ImportError as e: 49 | has_functorch = False 50 | 51 | torch.backends.cudnn.benchmark = True 52 | _logger = logging.getLogger('validate') 53 | 54 | 55 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 56 | parser.add_argument('data', metavar='DIR', 57 | help='path to dataset') 58 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 59 | help='dataset type (default: ImageFolder/ImageTar if empty)') 60 | parser.add_argument('--split', metavar='NAME', default='validation', 61 | help='dataset split (default: validation)') 62 | parser.add_argument('--dataset-download', action='store_true', default=False, 63 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 64 | parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', 65 | help='model architecture (default: dpn92)') 66 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 67 | help='number of data loading workers (default: 2)') 68 | parser.add_argument('-b', '--batch-size', default=256, type=int, 69 | metavar='N', help='mini-batch size (default: 256)') 70 | parser.add_argument('--img-size', default=None, type=int, 71 | metavar='N', help='Input image dimension, uses model default if empty') 72 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 73 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 74 | parser.add_argument('--use-train-size', action='store_true', default=False, 75 | help='force use of train input size, even when test size is specified in pretrained cfg') 76 | parser.add_argument('--crop-pct', default=None, type=float, 77 | metavar='N', help='Input image center crop pct') 78 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 79 | help='Override mean pixel value of dataset') 80 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 81 | help='Override std deviation of of dataset') 82 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 83 | help='Image resize interpolation type (overrides model)') 84 | parser.add_argument('--num-classes', type=int, default=None, 85 | help='Number classes in dataset') 86 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 87 | help='path to class to idx mapping file (default: "")') 88 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 89 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 90 | parser.add_argument('--log-freq', default=10, type=int, 91 | metavar='N', help='batch logging frequency (default: 10)') 92 | parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', 93 | help='path to latest checkpoint (default: none)') 94 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 95 | help='use pre-trained model') 96 | parser.add_argument('--num-gpu', type=int, default=1, 97 | help='Number of GPUS to use') 98 | parser.add_argument('--test-pool', dest='test_pool', action='store_true', 99 | help='enable test time pool') 100 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 101 | help='disable fast prefetcher') 102 | parser.add_argument('--pin-mem', action='store_true', default=False, 103 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 104 | parser.add_argument('--channels-last', action='store_true', default=False, 105 | help='Use channels_last memory layout') 106 | parser.add_argument('--amp', action='store_true', default=False, 107 | help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') 108 | parser.add_argument('--apex-amp', action='store_true', default=False, 109 | help='Use NVIDIA Apex AMP mixed precision') 110 | parser.add_argument('--native-amp', action='store_true', default=False, 111 | help='Use Native Torch AMP mixed precision') 112 | parser.add_argument('--tf-preprocessing', action='store_true', default=False, 113 | help='Use Tensorflow preprocessing pipeline (require CPU TF installed') 114 | parser.add_argument('--use-ema', dest='use_ema', action='store_true', 115 | help='use ema version of weights if present') 116 | scripting_group = parser.add_mutually_exclusive_group() 117 | scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', 118 | help='torch.jit.script the full model') 119 | scripting_group.add_argument('--aot-autograd', default=False, action='store_true', 120 | help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") 121 | parser.add_argument('--fuser', default='', type=str, 122 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 123 | parser.add_argument('--fast-norm', default=False, action='store_true', 124 | help='enable experimental fast-norm') 125 | parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', 126 | help='Output csv file for validation results (summary)') 127 | parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', 128 | help='Real labels JSON file for imagenet evaluation') 129 | parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', 130 | help='Valid label indices txt file for validation of partial label space') 131 | parser.add_argument('--retry', default=False, action='store_true', 132 | help='Enable batch size decay & retry for single model validation') 133 | 134 | 135 | def validate(args): 136 | # might as well try to validate something 137 | args.pretrained = args.pretrained or not args.checkpoint 138 | args.prefetcher = not args.no_prefetcher 139 | amp_autocast = suppress # do nothing 140 | if args.amp: 141 | if has_native_amp: 142 | args.native_amp = True 143 | elif has_apex: 144 | args.apex_amp = True 145 | else: 146 | _logger.warning("Neither APEX or Native Torch AMP is available.") 147 | assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." 148 | if args.native_amp: 149 | amp_autocast = torch.cuda.amp.autocast 150 | _logger.info('Validating in mixed precision with native PyTorch AMP.') 151 | elif args.apex_amp: 152 | _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') 153 | else: 154 | _logger.info('Validating in float32. AMP not enabled.') 155 | 156 | if args.fuser: 157 | set_jit_fuser(args.fuser) 158 | if args.fast_norm: 159 | set_fast_norm() 160 | 161 | # create model 162 | model = create_model( 163 | args.model, 164 | pretrained=args.pretrained, 165 | num_classes=args.num_classes, 166 | in_chans=3, 167 | global_pool=args.gp, 168 | scriptable=args.torchscript) 169 | if args.num_classes is None: 170 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 171 | args.num_classes = model.num_classes 172 | 173 | if args.checkpoint: 174 | load_checkpoint(model, args.checkpoint, args.use_ema) 175 | 176 | param_count = sum([m.numel() for m in model.parameters()]) 177 | _logger.info('Model %s created, param count: %d' % (args.model, param_count)) 178 | 179 | data_config = resolve_data_config( 180 | vars(args), 181 | model=model, 182 | use_test_size=not args.use_train_size, 183 | verbose=True 184 | ) 185 | test_time_pool = False 186 | if args.test_pool: 187 | model, test_time_pool = apply_test_time_pool(model, data_config) 188 | 189 | if args.torchscript: 190 | torch.jit.optimized_execution(True) 191 | model = torch.jit.script(model) 192 | if args.aot_autograd: 193 | assert has_functorch, "functorch is needed for --aot-autograd" 194 | model = memory_efficient_fusion(model) 195 | 196 | model = model.cuda() 197 | if args.apex_amp: 198 | model = amp.initialize(model, opt_level='O1') 199 | 200 | if args.channels_last: 201 | model = model.to(memory_format=torch.channels_last) 202 | 203 | if args.num_gpu > 1: 204 | model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) 205 | 206 | criterion = nn.CrossEntropyLoss().cuda() 207 | 208 | dataset = create_dataset( 209 | root=args.data, name=args.dataset, split=args.split, 210 | download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) 211 | 212 | if args.valid_labels: 213 | with open(args.valid_labels, 'r') as f: 214 | valid_labels = {int(line.rstrip()) for line in f} 215 | valid_labels = [i in valid_labels for i in range(args.num_classes)] 216 | else: 217 | valid_labels = None 218 | 219 | if args.real_labels: 220 | real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) 221 | else: 222 | real_labels = None 223 | 224 | crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] 225 | loader = create_loader( 226 | dataset, 227 | input_size=data_config['input_size'], 228 | batch_size=args.batch_size, 229 | use_prefetcher=args.prefetcher, 230 | interpolation=data_config['interpolation'], 231 | mean=data_config['mean'], 232 | std=data_config['std'], 233 | num_workers=args.workers, 234 | crop_pct=crop_pct, 235 | pin_memory=args.pin_mem, 236 | tf_preprocessing=args.tf_preprocessing) 237 | 238 | batch_time = AverageMeter() 239 | losses = AverageMeter() 240 | top1 = AverageMeter() 241 | top5 = AverageMeter() 242 | 243 | model.eval() 244 | with torch.no_grad(): 245 | # warmup, reduce variability of first batch time, especially for comparing torchscript vs non 246 | input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() 247 | if args.channels_last: 248 | input = input.contiguous(memory_format=torch.channels_last) 249 | with amp_autocast(): 250 | model(input) 251 | 252 | end = time.time() 253 | for batch_idx, (input, target) in enumerate(loader): 254 | if args.no_prefetcher: 255 | target = target.cuda() 256 | input = input.cuda() 257 | if args.channels_last: 258 | input = input.contiguous(memory_format=torch.channels_last) 259 | 260 | # compute output 261 | with amp_autocast(): 262 | output = model(input) 263 | 264 | if valid_labels is not None: 265 | output = output[:, valid_labels] 266 | loss = criterion(output, target) 267 | 268 | if real_labels is not None: 269 | real_labels.add_result(output) 270 | 271 | # measure accuracy and record loss 272 | acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) 273 | losses.update(loss.item(), input.size(0)) 274 | top1.update(acc1.item(), input.size(0)) 275 | top5.update(acc5.item(), input.size(0)) 276 | 277 | # measure elapsed time 278 | batch_time.update(time.time() - end) 279 | end = time.time() 280 | 281 | if batch_idx % args.log_freq == 0: 282 | _logger.info( 283 | 'Test: [{0:>4d}/{1}] ' 284 | 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 285 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 286 | 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 287 | 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( 288 | batch_idx, len(loader), batch_time=batch_time, 289 | rate_avg=input.size(0) / batch_time.avg, 290 | loss=losses, top1=top1, top5=top5)) 291 | 292 | if real_labels is not None: 293 | # real labels mode replaces topk values at the end 294 | top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) 295 | else: 296 | top1a, top5a = top1.avg, top5.avg 297 | results = OrderedDict( 298 | model=args.model, 299 | top1=round(top1a, 4), top1_err=round(100 - top1a, 4), 300 | top5=round(top5a, 4), top5_err=round(100 - top5a, 4), 301 | param_count=round(param_count / 1e6, 2), 302 | img_size=data_config['input_size'][-1], 303 | crop_pct=crop_pct, 304 | interpolation=data_config['interpolation']) 305 | 306 | _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( 307 | results['top1'], results['top1_err'], results['top5'], results['top5_err'])) 308 | 309 | return results 310 | 311 | 312 | def _try_run(args, initial_batch_size): 313 | batch_size = initial_batch_size 314 | results = OrderedDict() 315 | error_str = 'Unknown' 316 | while batch_size: 317 | args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case 318 | try: 319 | torch.cuda.empty_cache() 320 | results = validate(args) 321 | return results 322 | except RuntimeError as e: 323 | error_str = str(e) 324 | _logger.error(f'"{error_str}" while running validation.') 325 | if not check_batch_size_retry(error_str): 326 | break 327 | batch_size = decay_batch_step(batch_size) 328 | _logger.warning(f'Reducing batch size to {batch_size} for retry.') 329 | results['error'] = error_str 330 | _logger.error(f'{args.model} failed to validate ({error_str}).') 331 | return results 332 | 333 | 334 | def main(): 335 | setup_default_logging() 336 | args = parser.parse_args() 337 | model_cfgs = [] 338 | model_names = [] 339 | if os.path.isdir(args.checkpoint): 340 | # validate all checkpoints in a path with same model 341 | checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') 342 | checkpoints += glob.glob(args.checkpoint + '/*.pth') 343 | model_names = list_models(args.model) 344 | model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] 345 | else: 346 | if args.model == 'all': 347 | # validate all models in a list of names with pretrained checkpoints 348 | args.pretrained = True 349 | model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) 350 | model_cfgs = [(n, '') for n in model_names] 351 | elif not is_model(args.model): 352 | # model name doesn't exist, try as wildcard filter 353 | model_names = list_models(args.model) 354 | model_cfgs = [(n, '') for n in model_names] 355 | 356 | if not model_cfgs and os.path.isfile(args.model): 357 | with open(args.model) as f: 358 | model_names = [line.rstrip() for line in f] 359 | model_cfgs = [(n, None) for n in model_names if n] 360 | 361 | if len(model_cfgs): 362 | results_file = args.results_file or './results-all.csv' 363 | _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) 364 | results = [] 365 | try: 366 | initial_batch_size = args.batch_size 367 | for m, c in model_cfgs: 368 | args.model = m 369 | args.checkpoint = c 370 | r = _try_run(args, initial_batch_size) 371 | if 'error' in r: 372 | continue 373 | if args.checkpoint: 374 | r['checkpoint'] = args.checkpoint 375 | results.append(r) 376 | except KeyboardInterrupt as e: 377 | pass 378 | results = sorted(results, key=lambda x: x['top1'], reverse=True) 379 | if len(results): 380 | write_results(results_file, results) 381 | else: 382 | if args.retry: 383 | results = _try_run(args, args.batch_size) 384 | else: 385 | results = validate(args) 386 | # output results in JSON to stdout w/ delimiter for runner script 387 | print(f'--result\n{json.dumps(results, indent=4)}') 388 | 389 | 390 | def write_results(results_file, results): 391 | with open(results_file, mode='w') as cf: 392 | dw = csv.DictWriter(cf, fieldnames=results[0].keys()) 393 | dw.writeheader() 394 | for r in results: 395 | dw.writerow(r) 396 | cf.flush() 397 | 398 | 399 | if __name__ == '__main__': 400 | main() --------------------------------------------------------------------------------