├── .gitignore ├── .idea ├── .gitignore ├── EFM-Net-Pytorch.iml ├── inspectionProfiles │ ├── Project_Default.xml │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── config ├── __init__.py └── configs.py ├── datasets_builder.py ├── main.py ├── models ├── __init__.py ├── cores.py └── miner.py ├── processor ├── __init__.py ├── engine.py ├── optim_factory.py └── run_with_submitit.py └── util ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/EFM-Net-Pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images 2 | 3 | By Yonghao Yi (*Student Member, IEEE*) , Yanan You (*Member, IEEE*) , Chao Li, and Wenli Zhou. 4 | 5 | Special thanks to [Haopeng Zhang](https://orcid.org/0000-0003-1981-8307) for providing the dataset source file for [FGSCR-42](https://www.mdpi.com/2072-4292/13/4/747). 6 | 7 | ## Introduction 8 | 9 | This code provides an reference version for the implementation of the *IEEE-TGRS* paper “EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images”. The projects are still under construction. 10 | 11 | 12 | 13 | ## How to run 14 | 15 | ### Prepare the datasets 16 | 17 | **If you want to use the training and testing sets divided in this paper, please click on the following link to download.** 18 | 19 | Download the public benchmark datasets and unzip them in your own path. 20 | 21 | - [FGSC-23](https://drive.google.com/file/d/1DFir2wvzVnMYqleqgvxLoN760hYZe3TW/view?usp=sharing) 22 | 23 | - [FGSCR-42](https://drive.google.com/file/d/1o8QzGA3wEhobGFZ-Hbey0GCgCNIdnEmf/view?usp=sharing) 24 | 25 | - [Aircraft-16](https://drive.google.com/file/d/1n0aoB0FJIvrA5xpC8AfeXKZKqCFu2Gca/view?usp=sharing) 26 | 27 | Move the directory with `train` and `test` sub-directories to the `./datasets/` directory and rename it with database name, such as 28 | 29 | ``` 30 | mv FGSC23 /path/to/EFM-Net-Pytorch/datasets/FGSC-23 31 | ``` 32 | 33 | Make sure your datasets are correctly split into the training set and test set. The training set should be placed in the directory named “train” while test set named “test”. 34 | 35 | **The source files download links of those public datasets can be found in https://github.com/JACYI/Dataset-for-Remote-Sensing.** 36 | 37 | 38 | 39 | ### Download the pre-trained model 40 | 41 | The feature extraction model is based on ConvNeXt-Base. Please download the pre-trained parameters file: 42 | 43 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 44 | 45 | and move the pre-trained model to `./pretrained` directory. 46 | 47 | 48 | 49 | ### Build the running environment 50 | 51 | 1. **Create the virtual environment:** 52 | 53 | ```shell 54 | conda create -n efmnet python=3.8 55 | source activate efmnet 56 | ``` 57 | 58 | 2. **requirements** 59 | 60 | - python=3.8 61 | - pytorch>=1.8.1, torchvision>=0.9.1 62 | - timm=0.3.2 63 | - opencv-python 64 | - scikit-learn 65 | - tensorboardX 66 | - termcolor 67 | 68 | 69 | 70 | ### For Training 71 | 72 | 1. **Run on a single GPU** 73 | 74 | ```shell 75 | python main.py --exp_name test --attentions 16 --epochs 100 --dataset FGSC-23 --output_dir logs/test --novel_loss 0.5 76 | ``` 77 | 78 | 2. **Run on multiple GPUs** (2 GPUs for example) 79 | 80 | ```shell 81 | python -m torch.distributed.launch --master_port 12345 \ 82 | --nproc_per_node=2 main.py \ 83 | --exp_name test --attentions 16 \ 84 | --epochs 120 --dataset FGSC-23 \ 85 | --output_dir logs/test --novel_loss 0.5 86 | ``` 87 | 88 | 89 | 90 | ## Citation 91 | 92 | More details need to be added. 93 | 94 | ``` 95 | @ARTICLE{10097708, 96 | author={Yi, Yonghao and You, Yanan and Li, Chao and Zhou, Wenli}, 97 | journal={IEEE Transactions on Geoscience and Remote Sensing}, 98 | title={EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images}, 99 | year={2023}, 100 | volume={}, 101 | number={}, 102 | pages={1-1}, 103 | doi={10.1109/TGRS.2023.3265669} 104 | } 105 | ``` 106 | 107 | 108 | 109 | ## Corresponding author information 110 | 111 | Yanan You, Beijing, China, 112 | 113 | Phone: (+86) 15201148169 114 | 115 | Email: youyanan@bupt.edu.cn 116 | 117 | 118 | 119 | ## To do 120 | 121 | 1. Support more feature extraction models; 122 | 2. Provide more interfaces for modifying parameters. 123 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/config/__init__.py -------------------------------------------------------------------------------- /config/configs.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import argparse 6 | import time 7 | 8 | 9 | def str2bool(v): 10 | """ 11 | Converts string to bool type; enables command line 12 | arguments in the format of '--arg1 true --arg2 false' 13 | """ 14 | if isinstance(v, bool): 15 | return v 16 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 17 | return True 18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 19 | return False 20 | else: 21 | raise argparse.ArgumentTypeError('Boolean value expected.') 22 | 23 | 24 | def get_args_parser(): 25 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False) 26 | # Main config 27 | parser.add_argument('--exp_name', default='TGRS-MODEL', type=str) 28 | parser.add_argument('--batch_size', default=16, type=int, 29 | help='Per GPU batch size') 30 | parser.add_argument('--epochs', default=300, type=int, help='Number of training iterations.') 31 | parser.add_argument('--attentions', default=16, type=int, 32 | help='The number of channels in the attention module (kappa in the text)') 33 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', 34 | help='learning rate (default: 1e-3)') 35 | parser.add_argument('--novel_loss', type=float, default=0., 36 | help='SCE loss in the paper. 0. equals to cross-entropy loss') 37 | parser.add_argument('--update_freq', default=30, type=int, help='Gradient accumulation steps') 38 | 39 | # model parameters 40 | parser.add_argument('--model', default='convnext_base', type=str, metavar='MODEL', 41 | help='Name of feature extraction model to train') 42 | parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT', 43 | help='Drop path rate (default: 0.0)') 44 | parser.add_argument('--input_size', default=224, type=int, 45 | help='image input size') 46 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float, 47 | help="Layer scale initial values") 48 | parser.add_argument('--use_mha', default=True, type=str2bool) 49 | parser.add_argument('--use_ref', default=True, type=str2bool) 50 | 51 | # EMA related parameters 52 | parser.add_argument('--model_ema', type=str2bool, default=False) 53 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='') 54 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='') 55 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.') 56 | 57 | # optimization parameters 58 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 59 | help='Optimizer (default: "adamw"') 60 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON', 61 | help='Optimizer Epsilon (default: 1e-8)') 62 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA', 63 | help='Optimizer Betas (default: None, use opt default)') 64 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', 65 | help='Clip gradient norm (default: None, no clipping)') 66 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 67 | help='SGD momentum (default: 0.9)') 68 | parser.add_argument('--weight_decay', type=float, default=1e-8, 69 | help='weight decay (default: 0.05)') 70 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the 71 | weight decay. We use a cosine schedule for WD and using a larger decay by 72 | the end of training improves performance for ViTs.""") 73 | 74 | # learning parameters 75 | parser.add_argument('--layer_decay', type=float, default=0.8) 76 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR', 77 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') 78 | parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N', 79 | help='epochs to warmup LR, if scheduler supports') 80 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N', 81 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0') 82 | 83 | # augmentation parameters 84 | # parser.add_argument('--theta_detail', type=tuple, default=(0.4, 0.6), 85 | # help='Proportion of regional boundaries of detail images') 86 | # parser.add_argument('--theta_mask', type=tuple, default=(0.2, 0.5), 87 | # help='Proportion of regional boundaries of mask images') 88 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', 89 | help='Color jitter factor (default: 0.4)') 90 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 91 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'), 92 | parser.add_argument('--smoothing', type=float, default=0, 93 | help='Label smoothing (default: 0.1)') 94 | parser.add_argument('--train_interpolation', type=str, default='bicubic', 95 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 96 | 97 | # evaluation parameters 98 | parser.add_argument('--crop_pct', type=float, default=None) 99 | 100 | # random Erase params 101 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 102 | help='Random erase prob (default: 0.25)') 103 | parser.add_argument('--remode', type=str, default='pixel', 104 | help='Random erase mode (default: "pixel")') 105 | parser.add_argument('--recount', type=int, default=1, 106 | help='Random erase count (default: 1)') 107 | parser.add_argument('--resplit', type=str2bool, default=False, 108 | help='Do not random erase first (clean) augmentation split') 109 | 110 | # mixup params 111 | parser.add_argument('--mixup', type=float, default=0., 112 | help='mixup alpha, mixup enabled if > 0.') 113 | parser.add_argument('--cutmix', type=float, default=0., 114 | help='cutmix alpha, cutmix enabled if > 0.') 115 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None, 116 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 117 | parser.add_argument('--mixup_prob', type=float, default=1.0, 118 | help='Probability of performing mixup or cutmix when either/both is enabled') 119 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5, 120 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 121 | parser.add_argument('--mixup_mode', type=str, default='batch', 122 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 123 | 124 | # finetuning params 125 | parser.add_argument('--finetune', default='pretrained/convnext_base_22k_224.pth', 126 | help='finetune from checkpoint') 127 | parser.add_argument('--head_init_scale', default=0.001, type=float, 128 | help='classifier head initial scale, typically adjusted in fine-tuning') 129 | parser.add_argument('--model_key', default='model|module', type=str, 130 | help='which key to load from saved state dict, usually model or model_ema') 131 | parser.add_argument('--model_prefix', default='', type=str) 132 | 133 | # dataset parameters 134 | parser.add_argument('--data_path', default='./datasets/', type=str, 135 | help='dataset path') 136 | parser.add_argument('--dataset', default='FGSC-23') 137 | parser.add_argument('--eval_data_path', default=None, type=str, 138 | help='dataset path for evaluation') 139 | parser.add_argument('--nb_classes', default=23, type=int, 140 | help='number of the classification types') 141 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True) 142 | parser.add_argument('--data_set', default='image_folder', choices=['CIFAR', 'IMNET', 'image_folder'], 143 | type=str, help='ImageNet dataset path') 144 | parser.add_argument('--output_dir', default='result/', 145 | help='path where to save, empty for no saving') 146 | parser.add_argument('--log_dir', default=None, 147 | help='path where to tensorboard log') 148 | parser.add_argument('--device', default='cuda', 149 | help='device to use for training / testing') 150 | parser.add_argument('--seed', default=0, type=int) 151 | 152 | # training parameters 153 | parser.add_argument('--resume', default='', 154 | help='resume from checkpoint') 155 | parser.add_argument('--auto_resume', type=str2bool, default=True) 156 | parser.add_argument('--save_ckpt', type=str2bool, default=True) 157 | parser.add_argument('--save_ckpt_freq', default=10, type=int) 158 | parser.add_argument('--save_ckpt_num', default=3, type=int) 159 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 160 | help='start epoch') 161 | parser.add_argument('--eval', type=str2bool, default=False, 162 | help='Perform evaluation only') 163 | parser.add_argument('--dist_eval', type=str2bool, default=True, 164 | help='Enabling distributed evaluation') 165 | parser.add_argument('--disable_eval', type=str2bool, default=False, 166 | help='Disabling evaluation during training') 167 | parser.add_argument('--num_workers', default=10, type=int) 168 | parser.add_argument('--pin_mem', type=str2bool, default=True, 169 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 170 | 171 | # distributed training parameters 172 | parser.add_argument('--world_size', default=1, type=int, 173 | help='number of distributed processes') 174 | parser.add_argument('--local_rank', default=-1, type=int) 175 | parser.add_argument('--dist_on_itp', type=str2bool, default=False) 176 | parser.add_argument('--dist_url', default='env://', 177 | help='url used to set up distributed training') 178 | 179 | parser.add_argument('--use_amp', type=str2bool, default=False, 180 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not") 181 | 182 | # weights and Biases arguments 183 | parser.add_argument('--enable_wandb', type=str2bool, default=False, 184 | help="enable logging to Weights and Biases") 185 | parser.add_argument('--project', default='A-Net', type=str, 186 | help="The name of the W&B project where you're sending the new run.") 187 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False, 188 | help="Save model checkpoints as W&B Artifacts.") 189 | parser.add_argument('--tag', default=time.strftime("%m-%d-%H-%M-%S", time.localtime()), help='tag of experiment') 190 | return parser 191 | 192 | if __name__ == "__main__": 193 | parser = argparse.ArgumentParser('Test for parser.', parents=[get_args_parser()]) 194 | args = parser.parse_args() -------------------------------------------------------------------------------- /datasets_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import os 6 | from torchvision import datasets, transforms 7 | from timm.data.constants import \ 8 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 9 | from timm.data import create_transform 10 | 11 | def build_dataset(is_train, args): 12 | transform = build_transform(is_train, args) 13 | print('{}.trans:{}'.format(is_train, transform)) 14 | print("Transform = ") 15 | if isinstance(transform, tuple): 16 | for trans in transform: 17 | print(" - - - - - - - - - - ") 18 | for t in trans.transforms: 19 | print(t) 20 | else: 21 | for t in transform.transforms: 22 | print(t) 23 | print("---------------------------") 24 | 25 | if args.data_set == 'CIFAR': 26 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True) 27 | nb_classes = 100 28 | elif args.data_set == 'IMNET': 29 | print("reading from datapath", args.data_path) 30 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 31 | dataset = datasets.ImageFolder(root, transform=transform) 32 | nb_classes = 1000 33 | elif args.data_set == "image_folder": 34 | root = os.path.join(args.data_path, args.dataset, 'train' if is_train else 'test') 35 | 36 | dataset = datasets.ImageFolder(root, transform=transform) 37 | args.nb_classes = len(dataset.class_to_idx) 38 | nb_classes = args.nb_classes 39 | else: 40 | raise NotImplementedError() 41 | print("Number of the class = %d" % nb_classes) 42 | 43 | return dataset, nb_classes 44 | 45 | 46 | def build_transform(is_train, args): 47 | resize_im = args.input_size > 32 48 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 49 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 50 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 51 | 52 | if is_train: 53 | # this should always dispatch to transforms_imagenet_train 54 | transform = create_transform( 55 | input_size=args.input_size, 56 | is_training=True, 57 | color_jitter=args.color_jitter, 58 | auto_augment=args.aa, 59 | interpolation=args.train_interpolation, 60 | re_prob=args.reprob, 61 | re_mode=args.remode, 62 | re_count=args.recount, 63 | mean=mean, 64 | std=std, 65 | ) 66 | if not resize_im: 67 | transform.transforms[0] = transforms.RandomCrop( 68 | args.input_size, padding=4) 69 | return transform 70 | 71 | t = [] 72 | if resize_im: 73 | # warping (no cropping) when evaluated at 384 or larger 74 | if args.input_size >= 384: 75 | t.append( 76 | transforms.Resize((args.input_size, args.input_size), 77 | interpolation=transforms.InterpolationMode.BICUBIC), 78 | ) 79 | print(f"Warping {args.input_size} size input images...") 80 | else: 81 | if args.crop_pct is None: 82 | args.crop_pct = 224 / 256 83 | size = int(args.input_size / args.crop_pct) 84 | t.append( 85 | # to maintain same ratio w.r.t. 224 images 86 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC), 87 | ) 88 | t.append(transforms.CenterCrop(args.input_size)) 89 | 90 | t.append(transforms.ToTensor()) 91 | t.append(transforms.Normalize(mean, std)) 92 | return transforms.Compose(t) 93 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import argparse 6 | import datetime 7 | import numpy as np 8 | import time 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import json 12 | import os 13 | from pathlib import Path 14 | from timm.data.mixup import Mixup 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.utils import ModelEma 17 | from config.configs import get_args_parser 18 | from models.miner import convnext_base 19 | from datasets_builder import build_dataset 20 | from processor.engine import train_one_epoch, evaluate 21 | from processor.optim_factory import create_optimizer, LayerDecayValueAssigner 22 | from util.utils import NativeScalerWithGradNormCount as NativeScaler 23 | from util import utils 24 | from util.utils import create_logger, SoftCrossEntropyLoss 25 | import torch.distributed as dist 26 | import warnings 27 | warnings.filterwarnings('ignore') 28 | 29 | 30 | def main(args): 31 | # ---------------------- prepare running -------------------------------- 32 | # distributed mode 33 | utils.init_distributed_mode(args) 34 | device = torch.device(args.device) 35 | num_tasks = utils.get_world_size() 36 | global_rank = utils.get_rank() 37 | # logging file 38 | if args.eval is False: 39 | logger = create_logger(output_dir=args.output_dir, dist_rank=dist.get_rank()) 40 | else: 41 | logger = None 42 | # fix the seed for reproducibility 43 | seed = args.seed + utils.get_rank() 44 | torch.manual_seed(seed) 45 | np.random.seed(seed) 46 | cudnn.benchmark = True 47 | 48 | # ------------------------- build dataset ------------------------------ 49 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 50 | if args.disable_eval: 51 | args.dist_eval = False 52 | dataset_val = None 53 | else: 54 | dataset_val, _ = build_dataset(is_train=False, args=args) 55 | 56 | 57 | sampler_train = torch.utils.data.DistributedSampler( 58 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed, 59 | ) 60 | # logger.info("Sampler_train = %s" % str(sampler_train)) 61 | if args.dist_eval: 62 | if len(dataset_val) % num_tasks != 0: 63 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 64 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 65 | 'equal num of samples per-process.') 66 | sampler_val = torch.utils.data.DistributedSampler( 67 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 68 | else: 69 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 70 | 71 | if global_rank == 0 and args.log_dir is not None: 72 | os.makedirs(args.log_dir, exist_ok=True) 73 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir) 74 | else: 75 | log_writer = None 76 | 77 | if global_rank == 0 and args.enable_wandb: 78 | wandb_logger = utils.WandbLogger(args) 79 | else: 80 | wandb_logger = None 81 | 82 | data_loader_train = torch.utils.data.DataLoader( 83 | dataset_train, sampler=sampler_train, 84 | batch_size=args.batch_size, 85 | num_workers=args.num_workers, 86 | pin_memory=args.pin_mem, 87 | drop_last=True, 88 | ) 89 | 90 | if dataset_val is not None: 91 | data_loader_val = torch.utils.data.DataLoader( 92 | dataset_val, sampler=sampler_val, 93 | batch_size=int(1.5 * args.batch_size), 94 | num_workers=args.num_workers, 95 | pin_memory=args.pin_mem, 96 | drop_last=False 97 | ) 98 | else: 99 | data_loader_val = None 100 | 101 | # ------------------------- mixup setting ------------------------------ 102 | mixup_fn = None 103 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 104 | if mixup_active: 105 | logger.info("Mixup is activated!") 106 | mixup_fn = Mixup( 107 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 108 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 109 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 110 | 111 | # ---------------------- initialize the model ------------------------------ 112 | if args.model == 'convnext_base': 113 | model = convnext_base( 114 | pretrained=False, 115 | M=args.attentions, 116 | num_classes=args.nb_classes, 117 | drop_path_rate=args.drop_path, 118 | layer_scale_init_value=args.layer_scale_init_value, 119 | head_init_scale=args.head_init_scale, 120 | use_mha=args.use_mha, 121 | use_ref=args.use_ref, 122 | ) 123 | checkpoint = torch.load(args.finetune, map_location='cpu') 124 | checkpoint_model = None 125 | for model_key in args.model_key.split('|'): 126 | if model_key in checkpoint: 127 | checkpoint_model = checkpoint[model_key] 128 | logger.info("Load state_dict by model_key = %s" % model_key) 129 | break 130 | if checkpoint_model is None: 131 | checkpoint_model = checkpoint 132 | state_dict = model.state_dict() 133 | for k in ['head.weight', 'head.bias']: 134 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 135 | del checkpoint_model[k] 136 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix) 137 | else: 138 | ValueError("Unsupported model: %s" % args.model) 139 | 140 | model.to(device) 141 | model_ema = None 142 | if args.model_ema: 143 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 144 | model_ema = ModelEma( 145 | model, 146 | decay=args.model_ema_decay, 147 | device='cpu' if args.model_ema_force_cpu else '', 148 | resume='') 149 | logger.info("Using EMA with decay = %.8f" % args.model_ema_decay) 150 | 151 | model_without_ddp = model 152 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 153 | 154 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size() 155 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size 156 | 157 | if args.layer_decay < 1.0 or args.layer_decay > 1.0: 158 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value. 159 | if args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge']: 160 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))) 161 | 162 | # model distributed 163 | if args.distributed: 164 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False, find_unused_parameters=True) 165 | model_without_ddp = model.module 166 | 167 | # -------------------- initialize the optimizer ------------------------------ 168 | optimizer = create_optimizer( 169 | args, model_without_ddp, skip_list=None, 170 | get_num_layer=assigner.get_layer_id if assigner is not None else None, 171 | get_layer_scale=assigner.get_scale if assigner is not None else None) 172 | 173 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used 174 | # schedule 175 | logger.info("Use Cosine LR scheduler") 176 | lr_schedule_values = utils.cosine_scheduler( 177 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch, 178 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps, 179 | ) 180 | # weight decay 181 | if args.weight_decay_end is None: 182 | args.weight_decay_end = args.weight_decay 183 | wd_schedule_values = utils.cosine_scheduler( 184 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch) 185 | 186 | if mixup_fn is not None: 187 | # smoothing is handled with mixup label transform 188 | criterion = SoftTargetCrossEntropy() 189 | elif args.smoothing > 0.: 190 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 191 | else: 192 | criterion = SoftCrossEntropyLoss(gama=args.novel_loss) 193 | 194 | logger.info("Criterion: %s" % str(criterion)) 195 | 196 | utils.auto_load_model( 197 | args=args, model=model, model_without_ddp=model_without_ddp, 198 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema) 199 | # for evaluation 200 | if args.eval: 201 | logger.info(f"Eval only mode") 202 | ckpt = torch.load("checkpoint/" + args.dataset + ".pth")["model"] 203 | model_dict = model.state_dict() 204 | ckpt = {k: v for k, v in ckpt.items() if k in model_dict} 205 | model.load_state_dict(ckpt) 206 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq) 207 | logger.info(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%") 208 | return 209 | 210 | # ------------------------- training stage ------------------------------ 211 | max_accuracy = 0.0 212 | if args.model_ema and args.model_ema_eval: 213 | max_accuracy_ema = 0.0 214 | logger.info("Start training") 215 | start_time = time.time() 216 | for epoch in range(args.start_epoch, args.epochs): 217 | if args.distributed: 218 | data_loader_train.sampler.set_epoch(epoch) 219 | if log_writer is not None: 220 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq) 221 | if wandb_logger: 222 | wandb_logger.set_steps() 223 | # training 224 | train_stats = train_one_epoch( 225 | model, criterion, data_loader_train, optimizer, 226 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn, 227 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch, 228 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values, 229 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq, 230 | use_amp=args.use_amp, logger=logger, 231 | ) 232 | # save params 233 | if args.output_dir and args.save_ckpt: 234 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs: 235 | utils.save_model( 236 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 237 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema) 238 | 239 | # evaluate 240 | if data_loader_val is not None: 241 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq) 242 | logger.info(f"test accuracy : {test_stats['acc1']:.1f}%") 243 | if max_accuracy < test_stats["acc1"]: 244 | max_accuracy = test_stats["acc1"] 245 | if wandb_logger is not None: 246 | wandb.run.summary["Best Accuracy"] = max_accuracy 247 | wandb.run.summary["Best Epoch"] = epoch 248 | if args.output_dir and args.save_ckpt: 249 | utils.save_model( 250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 251 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema) 252 | acc1 = test_stats["acc1"] 253 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 254 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 255 | 256 | if log_writer is not None: 257 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch) 258 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch) 259 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch) 260 | 261 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 262 | **{f'test_{k}': v for k, v in test_stats.items()}, 263 | 'epoch': epoch, 264 | 'n_parameters': n_parameters} 265 | 266 | # repeat testing routines for EMA, if ema eval is turned on 267 | if args.model_ema and args.model_ema_eval: 268 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq) 269 | # logger.info(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 270 | if max_accuracy_ema < test_stats_ema["acc1"]: 271 | max_accuracy_ema = test_stats_ema["acc1"] 272 | if args.output_dir and args.save_ckpt: 273 | utils.save_model( 274 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 275 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema) 276 | logger.info(f'Max EMA accuracy: {max_accuracy_ema:.2f}%') 277 | if log_writer is not None: 278 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch) 279 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}}) 280 | else: 281 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 282 | 'epoch': epoch, 283 | 'n_parameters': n_parameters} 284 | # update logger info 285 | if args.output_dir and utils.is_main_process(): 286 | if log_writer is not None: 287 | log_writer.flush() 288 | with open(os.path.join(args.output_dir, "record.txt"), mode="a", encoding="utf-8") as f: 289 | f.write(json.dumps(log_stats) + "\n") 290 | 291 | if wandb_logger: 292 | wandb_logger.log_epoch_metrics(log_stats) 293 | 294 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir: 295 | wandb_logger.log_checkpoints() 296 | 297 | # ------------------------- finished ------------------------------ 298 | total_time = time.time() - start_time 299 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 300 | logger('Training time {}'.format(total_time_str)) 301 | 302 | 303 | if __name__ == '__main__': 304 | 305 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()]) 306 | args = parser.parse_args() 307 | 308 | args.output_dir = os.path.join(args.output_dir, '%s_%s' % (args.dataset, args.tag)) 309 | 310 | if args.output_dir: 311 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 312 | 313 | with open(os.path.join(args.output_dir, 'config.txt'), 'w') as f: 314 | argsDict = args.__dict__ 315 | f.writelines('------------------ start ------------------' + '\n') 316 | for eachArg, value in argsDict.items(): 317 | f.writelines(eachArg + ' : ' + str(value) + '\n') 318 | f.writelines('------------------- end -------------------') 319 | 320 | main(args) 321 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/models/__init__.py -------------------------------------------------------------------------------- /models/cores.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class BasicConv2d(nn.Module): 11 | def __init__(self, in_channels, out_channels, **kwargs): 12 | super(BasicConv2d, self).__init__() 13 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) 14 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | x = F.relu(x, inplace=True) 20 | return x 21 | 22 | 23 | # Attention network 24 | class Attentions(nn.Module): 25 | def __init__(self, channel_size=256): 26 | super(Attentions, self).__init__() 27 | self.A1_c = ChannelGate(channel_size) 28 | self.A2_c = ChannelGate(channel_size) 29 | self.A3_c = ChannelGate(channel_size) 30 | 31 | def forward(self, inputs): 32 | F1, F2, F3 = inputs 33 | # Global Average Pooling to a vector 34 | A1_channel = self.A1_c(F1) 35 | A2_channel = self.A2_c(F2) 36 | A3_channel = self.A3_c(F3) 37 | 38 | # bottom to top 39 | A2_channel = (A2_channel + A1_channel) / 2 40 | A3_channel = (A3_channel + A2_channel) / 2 41 | 42 | # channel pooling 43 | # A1 = F1 * A1_channel 44 | # A2 = F2 * A2_channel 45 | A3 = F3 * A3_channel 46 | 47 | return A3 48 | 49 | 50 | # channel attention 51 | class ChannelGate(nn.Module): 52 | def __init__(self, out_channels): 53 | super(ChannelGate, self).__init__() 54 | self.conv1 = nn.Conv2d(out_channels, out_channels//4, kernel_size=1, stride=1, padding=0) 55 | self.conv2 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1, padding=0) 56 | self.relu = nn.ReLU(inplace=True) 57 | self.sigmoid = nn.Sigmoid() 58 | 59 | def forward(self, x): 60 | x = nn.AdaptiveAvgPool2d(output_size=1)(x) 61 | x = self.conv1(x) 62 | x = self.relu(x) 63 | x = self.conv2(x) 64 | x = self.sigmoid(x) 65 | # x = F.relu(self.conv1(x), inplace=True) 66 | # x = torch.sigmoid(self.conv2(x)) 67 | return x 68 | 69 | 70 | class Flatten(nn.Module): 71 | def __init__(self): 72 | super(Flatten, self).__init__() 73 | 74 | def forward(self, x): 75 | return x.view(x.size(0), -1) 76 | 77 | 78 | class BasicConv(nn.Module): 79 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 80 | bn=True, bias=False): 81 | super(BasicConv, self).__init__() 82 | self.out_channels = out_planes 83 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, 84 | dilation=dilation, groups=groups, bias=bias) 85 | # self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 86 | self.relu = nn.ReLU(inplace=True) if relu else None 87 | 88 | def forward(self, x): 89 | x = self.conv(x) 90 | # if self.bn is not None: 91 | # x = self.bn(x) 92 | if self.relu is not None: 93 | x = self.relu(x) 94 | return x 95 | 96 | 97 | # construct the top of pyramid layer 98 | class SimpleFPA(nn.Module): 99 | def __init__(self, in_planes, out_planes): 100 | super(SimpleFPA, self).__init__() 101 | 102 | self.channels_cond = in_planes 103 | # Master branch 104 | self.conv_master = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 105 | 106 | # Global pooling branch 107 | self.conv_gpb = BasicConv(in_planes, out_planes, kernel_size=1, stride=1) 108 | 109 | def forward(self, x): 110 | # Master branch 111 | x_master = self.conv_master(x) 112 | 113 | # Global pooling branch 114 | x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1) 115 | x_gpb = self.conv_gpb(x_gpb) 116 | 117 | out = x_master + x_gpb 118 | 119 | return out 120 | 121 | 122 | # Feature Pyramid Network 123 | class PyramidFeatures(nn.Module): 124 | def __init__(self, B1_size, B2_size, B3_size, feature_size=256): 125 | super(PyramidFeatures, self).__init__() 126 | 127 | # feature_size represents the number of the channels 128 | self.P3_1 = SimpleFPA(B3_size, feature_size) 129 | self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 130 | 131 | self.P2_1 = nn.Conv2d(B2_size, feature_size, kernel_size=1, stride=1, padding=0) 132 | self.P2_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 133 | 134 | self.P1_1 = nn.Conv2d(B1_size, feature_size, kernel_size=1, stride=1, padding=0) 135 | self.P1_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1) 136 | 137 | def forward(self, inputs): 138 | B1, B2, B3 = inputs 139 | 140 | P3_x = self.P3_1(B3) 141 | P3_upsampled_x = F.interpolate(P3_x, scale_factor=2) 142 | P3_x = self.P3_2(P3_x) 143 | 144 | P2_x = self.P2_1(B2) 145 | P2_x = P3_upsampled_x + P2_x 146 | P2_upsampled_x = F.interpolate(P2_x, scale_factor=2) 147 | P2_x = self.P2_2(P2_x) 148 | 149 | P1_x = self.P1_1(B1) 150 | P1_x = P1_x + P2_upsampled_x 151 | P1_x = self.P1_2(P1_x) 152 | 153 | return [P1_x, P2_x, P3_x] 154 | 155 | 156 | class MHAM(nn.Module): 157 | def __init__(self, fpn_sizes, M, num_features, use_mha=True): 158 | super(MHAM, self).__init__() 159 | self.use_mha = use_mha 160 | if self.use_mha: 161 | self.fpn = PyramidFeatures(fpn_sizes[1], fpn_sizes[2], fpn_sizes[3], feature_size=M) 162 | # channel attention 163 | self.ca = Attentions(channel_size=M) 164 | self.conv = BasicConv(in_planes=M, out_planes=M, kernel_size=1) 165 | else: 166 | self.fpa = SimpleFPA(num_features, M) 167 | 168 | def forward(self, input1, input2, input3): 169 | if not self.use_mha: 170 | return self.fpa(input3) 171 | x = self.fpn([input1, input2, input3]) 172 | x = self.ca(x) 173 | x = self.conv(x) 174 | return x 175 | 176 | 177 | # bilinear polymerization pooling 178 | class BPP(nn.Module): 179 | def __init__(self, epsilon): 180 | super(BPP, self).__init__() 181 | self.epsilon = epsilon 182 | 183 | def forward(self, features1, features2): 184 | # unify the size of width and height 185 | B, C, H, W = features1.size() 186 | _, M, AH, AW = features2.size() 187 | 188 | # match size 189 | if AH != H or AW != W: 190 | features2 = F.upsample_bilinear(features2, size=(H, W)) 191 | 192 | # essential_matrix: (B, M, C) -> (B, M * C) 193 | essential_matrix = (torch.einsum('imjk,injk->imn', (features2, features1)) / float(H * W)).view(B, -1) 194 | # nornalize 195 | essential_matrix = torch.sign(essential_matrix) * torch.sqrt(torch.abs(essential_matrix) + self.epsilon) 196 | essential_matrix = F.normalize(essential_matrix, dim=-1) 197 | 198 | return essential_matrix 199 | -------------------------------------------------------------------------------- /models/miner.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from timm.models.layers import trunc_normal_, DropPath 10 | from timm.models.registry import register_model 11 | from models.cores import MHAM, BPP 12 | 13 | 14 | class ConvNeXt(nn.Module): 15 | def __init__(self, in_chans=3, num_classes=1000, 16 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], M=16, drop_path_rate=0.2, 17 | layer_scale_init_value=1e-6, head_init_scale=0.001, use_mha=True, use_ref=True 18 | ): 19 | super().__init__() 20 | 21 | self.epsilon = 1e-12 22 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers 23 | stem = nn.Sequential( 24 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), 25 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first") 26 | ) 27 | self.downsample_layers.append(stem) 28 | for i in range(3): 29 | downsample_layer = nn.Sequential( 30 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 31 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), 32 | ) 33 | self.downsample_layers.append(downsample_layer) 34 | 35 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks 36 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 37 | cur = 0 38 | for i in range(4): 39 | stage = nn.Sequential( 40 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 41 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] 42 | ) 43 | self.stages.append(stage) 44 | cur += depths[i] 45 | self.num_classes = num_classes 46 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer 47 | 48 | self.use_mha = use_mha # is network using MHAM module? 49 | self.use_ref = use_ref # is network using Refiner? 50 | 51 | self.hidden_num = 256 52 | self.num_features = dims[-1] # output feature size 53 | self.M = M # the channel number of attention module in paper (kappa) 54 | fpn_sizes = [self.stages[0][-1].pwconv2.out_features, 55 | self.stages[1][-1].pwconv2.out_features, 56 | self.stages[2][-1].pwconv2.out_features, 57 | self.stages[3][-1].pwconv2.out_features] 58 | # feature pyramid network 59 | self.mha = MHAM(fpn_sizes, M, self.num_features, use_mha) 60 | # bilinear polymerization pooling 61 | self.bpp = BPP(self.epsilon) 62 | # classification network 63 | self.head = nn.Linear(self.num_features * self.M, self.num_classes) 64 | 65 | def _init_weights(self, m): 66 | if isinstance(m, (nn.Conv2d, nn.Linear)): 67 | trunc_normal_(m.weight, std=.02) 68 | nn.init.constant_(m.bias, 0) 69 | def refiner_input(self, batch_size, attention_maps): 70 | attentions = [] 71 | if self.training: 72 | # Randomly choose one of attention maps Ak 73 | for i in range(batch_size): 74 | attention_weight = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + self.epsilon) 75 | attention_weight = F.normalize(attention_weight, p=1, dim=0) 76 | k_index = np.random.choice(self.M, 2, p=attention_weight.cpu().numpy()) 77 | attentions.append(attention_maps[i, k_index, ...]) 78 | attentions = torch.stack(attentions) # (B, 2, H, W) - one for detail image, the other for mask image 79 | else: 80 | # Object Localization Am = mean(Ak) 81 | attentions = torch.mean(attention_maps, dim=1, keepdim=True) # (B, 1, H, W) 82 | 83 | return attentions 84 | 85 | 86 | def forward(self, x): 87 | torch.autograd.set_detect_anomaly(True) # auto gradient screen 88 | 89 | # backbone (convnext) 90 | x1 = self.downsample_layers[0](x) 91 | x1 = self.stages[0](x1) 92 | # layer2 out(batchsize, 256, h/8, w/8) 93 | x2 = self.downsample_layers[1](x1) 94 | x2 = self.stages[1](x2) 95 | # layer3 out(batchsize, 512, h/16, w/16) 96 | x3 = self.downsample_layers[2](x2) 97 | x3 = self.stages[2](x3) 98 | # layer4 out(batchsize, 1024, h/32, w/32) 99 | x4 = self.downsample_layers[3](x3) 100 | intrisics = self.stages[3](x4) 101 | 102 | # multiple hierarchies attention module 103 | attention_maps = self.mha(x2, x3, intrisics) 104 | essential = self.bpp(intrisics, attention_maps) 105 | pred = self.head(essential) 106 | if self.use_ref: 107 | return pred, self.refiner_input(x.size(0), attention_maps) 108 | else: 109 | return pred, [] 110 | 111 | 112 | class LayerNorm(nn.Module): 113 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 114 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 115 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 116 | with shape (batch_size, channels, height, width). 117 | """ 118 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 119 | super().__init__() 120 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 121 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 122 | self.eps = eps 123 | self.data_format = data_format 124 | if self.data_format not in ["channels_last", "channels_first"]: 125 | raise NotImplementedError 126 | self.normalized_shape = (normalized_shape, ) 127 | 128 | def forward(self, x): 129 | if self.data_format == "channels_last": 130 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 131 | elif self.data_format == "channels_first": 132 | u = x.mean(1, keepdim=True) 133 | s = (x - u).pow(2).mean(1, keepdim=True) 134 | x = (x - u) / torch.sqrt(s + self.eps) 135 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 136 | return x 137 | 138 | 139 | model_urls = { 140 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", 141 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", 142 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", 143 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", 144 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", 145 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", 146 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", 147 | } 148 | 149 | 150 | 151 | @register_model 152 | def miner(pretrained=False, M=16, in_22k=False, **kwargs): 153 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], M=M, **kwargs) 154 | if pretrained: 155 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 156 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 157 | model.load_state_dict(checkpoint["model"]) 158 | return model 159 | 160 | 161 | @register_model 162 | def convnext_tiny(pretrained=False, **kwargs): 163 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) 164 | if pretrained: 165 | url = model_urls['convnext_tiny_1k'] 166 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) 167 | model.load_state_dict(checkpoint["model"]) 168 | return model 169 | 170 | 171 | @register_model 172 | def convnext_small(pretrained=False, **kwargs): 173 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 174 | if pretrained: 175 | url = model_urls['convnext_small_1k'] 176 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 177 | model.load_state_dict(checkpoint["model"]) 178 | return model 179 | 180 | 181 | @register_model 182 | def convnext_base(pretrained=False, M=16, in_22k=False, **kwargs): 183 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], M=M, **kwargs) 184 | if pretrained: 185 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] 186 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 187 | model.load_state_dict(checkpoint["model"]) 188 | return model 189 | 190 | 191 | @register_model 192 | def convnext_large(pretrained=False, in_22k=False, **kwargs): 193 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 194 | if pretrained: 195 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] 196 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 197 | model.load_state_dict(checkpoint["model"]) 198 | return model 199 | 200 | 201 | @register_model 202 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs): 203 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) 204 | if pretrained: 205 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" 206 | url = model_urls['convnext_xlarge_22k'] 207 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") 208 | model.load_state_dict(checkpoint["model"]) 209 | return model 210 | 211 | 212 | class Block(nn.Module): 213 | r""" ConvNeXt Block. There are two equivalent implementations: 214 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 215 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 216 | We use (2) as we find it slightly faster in PyTorch 217 | 218 | Args: 219 | dim (int): Number of input channels. 220 | drop_path (float): Stochastic depth rate. Default: 0.0 221 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 222 | """ 223 | 224 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): 225 | super().__init__() 226 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 227 | self.norm = LayerNorm(dim, eps=1e-6) 228 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers 229 | self.act = nn.GELU() 230 | self.pwconv2 = nn.Linear(4 * dim, dim) 231 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 232 | requires_grad=True) if layer_scale_init_value > 0 else None 233 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 234 | 235 | def forward(self, x): 236 | input = x 237 | x = self.dwconv(x) 238 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 239 | x = self.norm(x) 240 | x = self.pwconv1(x) 241 | x = self.act(x) 242 | x = self.pwconv2(x) 243 | if self.gamma is not None: 244 | x = self.gamma * x 245 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 246 | 247 | x = input + self.drop_path(x) 248 | return x 249 | 250 | 251 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/processor/__init__.py -------------------------------------------------------------------------------- /processor/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import math 6 | import time 7 | import datetime 8 | from typing import Iterable, Optional 9 | import numpy as np 10 | import torch 11 | from timm.data import Mixup 12 | from timm.utils import accuracy, ModelEma, AverageMeter 13 | from util.utils import reduce_tensor, MetricLogger, SmoothedValue, get_detail_images, get_mask_images 14 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 15 | 16 | 17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 21 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 22 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False, logger=None, 23 | ): 24 | model.train(True) 25 | metric_logger = MetricLogger(delimiter=" ") 26 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 27 | metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) 28 | header = 'Epoch: [{}]'.format(epoch) 29 | print_freq = 10 30 | 31 | optimizer.zero_grad() 32 | 33 | # counter 34 | num_steps = len(data_loader) 35 | batch_time = AverageMeter() 36 | loss_meter = AverageMeter() 37 | 38 | start = time.time() 39 | end = time.time() 40 | 41 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 42 | step = data_iter_step // update_freq 43 | if step >= num_training_steps_per_epoch: 44 | continue 45 | it = start_steps + step # global training iteration 46 | # Update LR & WD for the first acc 47 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 48 | for i, param_group in enumerate(optimizer.param_groups): 49 | if lr_schedule_values is not None: 50 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 51 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 52 | param_group["weight_decay"] = wd_schedule_values[it] 53 | 54 | samples = samples.to(device, non_blocking=True) 55 | targets = targets.to(device, non_blocking=True) 56 | 57 | if mixup_fn is not None: 58 | samples, targets = mixup_fn(samples, targets) 59 | 60 | if use_amp: 61 | with torch.cuda.amp.autocast(): 62 | input_pred, attention_map = model(samples) 63 | else: # full precision 64 | input_pred, attention_map = model(samples) 65 | 66 | if attention_map != []: 67 | with torch.no_grad(): 68 | detail_images = get_detail_images(samples, attention_map[:, :1, :, :], theta_detail=(0.4, 0.6), padding=0.1) 69 | 70 | # detail-images forward 71 | detail_pred, _ = model(detail_images) 72 | 73 | with torch.no_grad(): 74 | mask_images = get_mask_images(samples, attention_map[:, 1:, :, :], theta_mask=(0.2, 0.5)) 75 | 76 | mask_pred, _ = model(mask_images) 77 | 78 | output = (input_pred + detail_pred + mask_pred)/3. 79 | loss = criterion(input_pred, targets)/3. + \ 80 | criterion(detail_pred, targets)/3. + \ 81 | criterion(mask_pred, targets)/3. 82 | else: 83 | loss = criterion(input_pred, targets) 84 | output = input_pred 85 | 86 | loss_value = loss.item() 87 | 88 | if not math.isfinite(loss_value): # this could trigger if using AMP 89 | logger.info("Loss is {}, stopping training".format(loss_value)) 90 | assert math.isfinite(loss_value) 91 | 92 | if use_amp: 93 | # this attribute is added by timm on one optimizer (adahessian) 94 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 95 | loss /= update_freq 96 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 97 | parameters=model.parameters(), create_graph=is_second_order, 98 | update_grad=(data_iter_step + 1) % update_freq == 0) 99 | if (data_iter_step + 1) % update_freq == 0: 100 | optimizer.zero_grad() 101 | if model_ema is not None: 102 | model_ema.update(model) 103 | else: # full precision 104 | loss /= update_freq 105 | loss.backward() 106 | if (data_iter_step + 1) % update_freq == 0: 107 | optimizer.step() 108 | optimizer.zero_grad() 109 | if model_ema is not None: 110 | model_ema.update(model) 111 | 112 | torch.cuda.synchronize() 113 | 114 | if mixup_fn is None: 115 | class_acc = (output.max(-1)[-1] == targets).float().mean() 116 | else: 117 | class_acc = None 118 | metric_logger.update(loss=loss_value) 119 | metric_logger.update(class_acc=class_acc) 120 | min_lr = 10. 121 | max_lr = 0. 122 | for group in optimizer.param_groups: 123 | min_lr = min(min_lr, group["lr"]) 124 | max_lr = max(max_lr, group["lr"]) 125 | 126 | metric_logger.update(lr=max_lr) 127 | metric_logger.update(min_lr=min_lr) 128 | weight_decay_value = None 129 | for group in optimizer.param_groups: 130 | if group["weight_decay"] > 0: 131 | weight_decay_value = group["weight_decay"] 132 | metric_logger.update(weight_decay=weight_decay_value) 133 | if use_amp: 134 | metric_logger.update(grad_norm=grad_norm) 135 | 136 | if log_writer is not None: 137 | log_writer.update(loss=loss_value, head="loss") 138 | log_writer.update(class_acc=class_acc, head="loss") 139 | log_writer.update(lr=max_lr, head="opt") 140 | log_writer.update(min_lr=min_lr, head="opt") 141 | log_writer.update(weight_decay=weight_decay_value, head="opt") 142 | if use_amp: 143 | log_writer.update(grad_norm=grad_norm, head="opt") 144 | log_writer.set_step() 145 | 146 | if wandb_logger: 147 | wandb_logger._wandb.log({ 148 | 'Rank-0 Batch Wise/train_loss': loss_value, 149 | 'Rank-0 Batch Wise/train_max_lr': max_lr, 150 | 'Rank-0 Batch Wise/train_min_lr': min_lr 151 | }, commit=False) 152 | if class_acc: 153 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False) 154 | if use_amp: 155 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False) 156 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it}) 157 | 158 | loss_meter.update(loss.item(), targets.size(0)) 159 | batch_time.update(time.time() - end) 160 | end = time.time() 161 | 162 | if data_iter_step % update_freq == 0: 163 | lr = max_lr 164 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 165 | etas = batch_time.avg * (num_steps - data_iter_step) 166 | logger.info( 167 | f'Train: [{epoch}][{data_iter_step}/{num_steps}]\t' 168 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 169 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 170 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 171 | f'mem {memory_used:.0f}MB') 172 | 173 | # gather the stats from all processes 174 | metric_logger.synchronize_between_processes() 175 | train_stat = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 176 | 177 | epoch_time = time.time() - start 178 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 179 | 180 | return train_stat 181 | 182 | @torch.no_grad() 183 | def evaluate(data_loader, model, device, use_amp=False, logger=None, update_freq=1): 184 | criterion = torch.nn.CrossEntropyLoss() 185 | 186 | metric_logger = MetricLogger(delimiter=" ") 187 | header = 'Test:' 188 | 189 | # switch to evaluation mode 190 | model.eval() 191 | 192 | batch_time = AverageMeter() 193 | loss_meter = AverageMeter() 194 | acc1_meter = AverageMeter() 195 | acc5_meter = AverageMeter() 196 | 197 | # statistics 198 | record_truth = np.array([]) 199 | record_pred = np.array([]) 200 | record_feature = torch.tensor([]) 201 | 202 | end = time.time() 203 | idx = 0 204 | with torch.no_grad(): 205 | for batch in metric_logger.log_every(data_loader, 10, header): 206 | images = batch[0] 207 | target = batch[-1] 208 | label_true = target.numpy().squeeze() 209 | record_truth = np.concatenate((record_truth, label_true)) 210 | 211 | images = images.to(device, non_blocking=True) 212 | target = target.to(device, non_blocking=True) 213 | 214 | # compute output 215 | if use_amp: 216 | with torch.cuda.amp.autocast(): 217 | input_pred, attention_map = model(images) 218 | else: 219 | input_pred, attention_map = model(images) 220 | 221 | if attention_map != []: 222 | detail_images = get_detail_images(images, attention_map, theta_detail=0.1, padding=0.05) 223 | detail_pred, _ = model(detail_images) 224 | output = (input_pred + detail_pred)/2. 225 | loss = criterion(output, target) 226 | else: 227 | loss = criterion(input_pred, target) 228 | output = input_pred 229 | 230 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 231 | 232 | # record 233 | _, pred = torch.max(output, dim=1) 234 | pred = pred.cpu().numpy().squeeze() 235 | record_pred = np.concatenate((record_pred, pred)) 236 | 237 | acc1 = reduce_tensor(acc1) 238 | acc5 = reduce_tensor(acc5) 239 | loss = reduce_tensor(loss) 240 | 241 | loss_meter.update(loss.item(), target.size(0)) 242 | acc1_meter.update(acc1.item(), target.size(0)) 243 | acc5_meter.update(acc5.item(), target.size(0)) 244 | 245 | # measure elapsed time 246 | batch_time.update(time.time() - end) 247 | end = time.time() 248 | 249 | if idx % update_freq == 0: 250 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 251 | if logger: 252 | logger.info( 253 | f'Test: [{idx}/{len(data_loader)}]\t' 254 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 255 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 256 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 257 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 258 | f'Mem {memory_used:.0f}MB') 259 | idx += 1 260 | 261 | batch_size = images.shape[0] 262 | metric_logger.update(loss=loss.item()) 263 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 264 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 265 | 266 | # gather the stats from all processes 267 | metric_logger.synchronize_between_processes() 268 | # metric 269 | acc = accuracy_score(record_truth, record_pred) 270 | precision = precision_score(record_truth, record_pred, average='weighted') 271 | recall = recall_score(record_truth, record_pred, average='weighted') 272 | f1 = f1_score(record_truth, record_pred, average='weighted') 273 | logger.info(f'[Info] acc: {acc}, precision: {precision}, recall: {recall}, f1: {f1}') 274 | 275 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 276 | -------------------------------------------------------------------------------- /processor/optim_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import torch 6 | from torch import optim as optim 7 | 8 | from timm.optim.adafactor import Adafactor 9 | from timm.optim.adahessian import Adahessian 10 | from timm.optim.adamp import AdamP 11 | from timm.optim.lookahead import Lookahead 12 | from timm.optim.nadam import Nadam 13 | from timm.optim.novograd import NovoGrad 14 | from timm.optim.nvnovograd import NvNovoGrad 15 | from timm.optim.radam import RAdam 16 | from timm.optim.rmsprop_tf import RMSpropTF 17 | from timm.optim.sgdp import SGDP 18 | 19 | import json 20 | 21 | try: 22 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 23 | has_apex = True 24 | except ImportError: 25 | has_apex = False 26 | 27 | 28 | def get_num_layer_for_convnext(var_name): 29 | """ 30 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three 31 | consecutive blocks, including possible neighboring downsample layers; 32 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py 33 | """ 34 | num_max_layer = 12 35 | if var_name.startswith("downsample_layers"): 36 | stage_id = int(var_name.split('.')[1]) 37 | if stage_id == 0: 38 | layer_id = 0 39 | elif stage_id == 1 or stage_id == 2: 40 | layer_id = stage_id + 1 41 | elif stage_id == 3: 42 | layer_id = 12 43 | return layer_id 44 | 45 | elif var_name.startswith("stages"): 46 | stage_id = int(var_name.split('.')[1]) 47 | block_id = int(var_name.split('.')[2]) 48 | if stage_id == 0 or stage_id == 1: 49 | layer_id = stage_id + 1 50 | elif stage_id == 2: 51 | layer_id = 3 + block_id // 3 52 | elif stage_id == 3: 53 | layer_id = 12 54 | return layer_id 55 | else: 56 | return num_max_layer + 1 57 | 58 | class LayerDecayValueAssigner(object): 59 | def __init__(self, values): 60 | self.values = values 61 | 62 | def get_scale(self, layer_id): 63 | return self.values[layer_id] 64 | 65 | def get_layer_id(self, var_name): 66 | return get_num_layer_for_convnext(var_name) 67 | 68 | 69 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 70 | parameter_group_names = {} 71 | parameter_group_vars = {} 72 | 73 | for name, param in model.named_parameters(): 74 | if not param.requires_grad: 75 | continue # frozen weights 76 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: 77 | group_name = "no_decay" 78 | this_weight_decay = 0. 79 | else: 80 | group_name = "decay" 81 | this_weight_decay = weight_decay 82 | if get_num_layer is not None: 83 | layer_id = get_num_layer(name) 84 | group_name = "layer_%d_%s" % (layer_id, group_name) 85 | else: 86 | layer_id = None 87 | 88 | if group_name not in parameter_group_names: 89 | if get_layer_scale is not None: 90 | scale = get_layer_scale(layer_id) 91 | else: 92 | scale = 1. 93 | 94 | parameter_group_names[group_name] = { 95 | "weight_decay": this_weight_decay, 96 | "params": [], 97 | "lr_scale": scale 98 | } 99 | parameter_group_vars[group_name] = { 100 | "weight_decay": this_weight_decay, 101 | "params": [], 102 | "lr_scale": scale 103 | } 104 | 105 | parameter_group_vars[group_name]["params"].append(param) 106 | parameter_group_names[group_name]["params"].append(name) 107 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 108 | return list(parameter_group_vars.values()) 109 | 110 | 111 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 112 | opt_lower = args.opt.lower() 113 | weight_decay = args.weight_decay 114 | # if weight_decay and filter_bias_and_bn: 115 | if filter_bias_and_bn: 116 | skip = {} 117 | if skip_list is not None: 118 | skip = skip_list 119 | elif hasattr(model, 'no_weight_decay'): 120 | skip = model.no_weight_decay() 121 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 122 | weight_decay = 0. 123 | else: 124 | parameters = model.parameters() 125 | 126 | if 'fused' in opt_lower: 127 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 128 | 129 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 130 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 131 | opt_args['eps'] = args.opt_eps 132 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 133 | opt_args['betas'] = args.opt_betas 134 | 135 | opt_split = opt_lower.split('_') 136 | opt_lower = opt_split[-1] 137 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 138 | opt_args.pop('eps', None) 139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 140 | elif opt_lower == 'momentum': 141 | opt_args.pop('eps', None) 142 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 143 | elif opt_lower == 'adam': 144 | optimizer = optim.Adam(parameters, **opt_args) 145 | elif opt_lower == 'adamw': 146 | optimizer = optim.AdamW(parameters, **opt_args) 147 | elif opt_lower == 'nadam': 148 | optimizer = Nadam(parameters, **opt_args) 149 | elif opt_lower == 'radam': 150 | optimizer = RAdam(parameters, **opt_args) 151 | elif opt_lower == 'adamp': 152 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 153 | elif opt_lower == 'sgdp': 154 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 155 | elif opt_lower == 'adadelta': 156 | optimizer = optim.Adadelta(parameters, **opt_args) 157 | elif opt_lower == 'adafactor': 158 | if not args.lr: 159 | opt_args['lr'] = None 160 | optimizer = Adafactor(parameters, **opt_args) 161 | elif opt_lower == 'adahessian': 162 | optimizer = Adahessian(parameters, **opt_args) 163 | elif opt_lower == 'rmsprop': 164 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 165 | elif opt_lower == 'rmsproptf': 166 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 167 | elif opt_lower == 'novograd': 168 | optimizer = NovoGrad(parameters, **opt_args) 169 | elif opt_lower == 'nvnovograd': 170 | optimizer = NvNovoGrad(parameters, **opt_args) 171 | elif opt_lower == 'fusedsgd': 172 | opt_args.pop('eps', None) 173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 174 | elif opt_lower == 'fusedmomentum': 175 | opt_args.pop('eps', None) 176 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 177 | elif opt_lower == 'fusedadam': 178 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 179 | elif opt_lower == 'fusedadamw': 180 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 181 | elif opt_lower == 'fusedlamb': 182 | optimizer = FusedLAMB(parameters, **opt_args) 183 | elif opt_lower == 'fusednovograd': 184 | opt_args.setdefault('betas', (0.95, 0.98)) 185 | optimizer = FusedNovoGrad(parameters, **opt_args) 186 | else: 187 | assert False and "Invalid optimizer" 188 | 189 | if len(opt_split) > 1: 190 | if opt_split[0] == 'lookahead': 191 | optimizer = Lookahead(optimizer) 192 | 193 | return optimizer 194 | -------------------------------------------------------------------------------- /processor/run_with_submitit.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import argparse 6 | import os 7 | import uuid 8 | from pathlib import Path 9 | 10 | import main as classification 11 | import submitit 12 | 13 | def parse_args(): 14 | classification_parser = classification.get_args_parser() 15 | parser = argparse.ArgumentParser("Submitit for ConvNeXt", parents=[classification_parser]) 16 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node") 17 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request") 18 | parser.add_argument("--timeout", default=72, type=int, help="Duration of the job, in hours") 19 | parser.add_argument("--job_name", default="convnext", type=str, help="Job name") 20 | parser.add_argument("--job_dir", default="", type=str, help="Job directory; leave empty for default") 21 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") 22 | parser.add_argument("--use_volta32", action='store_true', default=True, help="Big models? Use this") 23 | parser.add_argument('--comment', default="", type=str, 24 | help='Comment to pass to scheduler, e.g. priority message') 25 | return parser.parse_args() 26 | 27 | def get_shared_folder() -> Path: 28 | user = os.getenv("USER") 29 | if Path("/checkpoint/").is_dir(): 30 | p = Path(f"/checkpoint/{user}/convnext") 31 | p.mkdir(exist_ok=True) 32 | return p 33 | raise RuntimeError("No shared folder available") 34 | 35 | def get_init_file(): 36 | # Init file must not exist, but it's parent dir must exist. 37 | os.makedirs(str(get_shared_folder()), exist_ok=True) 38 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" 39 | if init_file.exists(): 40 | os.remove(str(init_file)) 41 | return init_file 42 | 43 | class Trainer(object): 44 | def __init__(self, args): 45 | self.args = args 46 | 47 | def __call__(self): 48 | import main as classification 49 | 50 | self._setup_gpu_args() 51 | classification.main(self.args) 52 | 53 | def checkpoint(self): 54 | import os 55 | import submitit 56 | 57 | self.args.dist_url = get_init_file().as_uri() 58 | self.args.auto_resume = True 59 | print("Requeuing ", self.args) 60 | empty_trainer = type(self)(self.args) 61 | return submitit.helpers.DelayedSubmission(empty_trainer) 62 | 63 | def _setup_gpu_args(self): 64 | import submitit 65 | from pathlib import Path 66 | 67 | job_env = submitit.JobEnvironment() 68 | self.args.output_dir = Path(self.args.job_dir) 69 | self.args.gpu = job_env.local_rank 70 | self.args.rank = job_env.global_rank 71 | self.args.world_size = job_env.num_tasks 72 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") 73 | 74 | 75 | def main(): 76 | args = parse_args() 77 | 78 | if args.job_dir == "": 79 | args.job_dir = get_shared_folder() / "%j" 80 | 81 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) 82 | 83 | num_gpus_per_node = args.ngpus 84 | nodes = args.nodes 85 | timeout_min = args.timeout * 60 86 | 87 | partition = args.partition 88 | kwargs = {} 89 | if args.use_volta32: 90 | kwargs['slurm_constraint'] = 'volta32gb' 91 | if args.comment: 92 | kwargs['slurm_comment'] = args.comment 93 | 94 | executor.update_parameters( 95 | mem_gb=40 * num_gpus_per_node, 96 | gpus_per_node=num_gpus_per_node, 97 | tasks_per_node=num_gpus_per_node, # one task per GPU 98 | cpus_per_task=10, 99 | nodes=nodes, 100 | timeout_min=timeout_min, # max is 60 * 72 101 | # Below are cluster dependent parameters 102 | slurm_partition=partition, 103 | slurm_signal_delay_s=120, 104 | **kwargs 105 | ) 106 | 107 | executor.update_parameters(name=args.job_name) 108 | 109 | args.dist_url = get_init_file().as_uri() 110 | args.output_dir = args.job_dir 111 | 112 | trainer = Trainer(args) 113 | job = executor.submit(trainer) 114 | 115 | print("Submitted job_id:", job.job_id) 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/util/__init__.py -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright © Beijing University of Posts and Telecommunications, 2 | # School of Artificial Intelligence. 3 | 4 | 5 | import os 6 | import math 7 | import time 8 | import cv2 9 | from collections import defaultdict, deque 10 | import datetime 11 | import numpy as np 12 | from timm.utils import get_state_dict 13 | from pathlib import Path 14 | import torch 15 | import torch.distributed as dist 16 | from torch._six import inf 17 | import torch.nn as nn 18 | from tensorboardX import SummaryWriter 19 | from termcolor import colored 20 | import functools 21 | import logging 22 | import sys 23 | import random 24 | import torch.nn.functional as F 25 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 26 | 27 | 28 | class SoftCrossEntropyLoss(nn.Module): 29 | def __init__(self, gama=0., reduction='mean'): 30 | super(SoftCrossEntropyLoss, self).__init__() 31 | self.reduction = reduction 32 | self.gama = gama 33 | self.softmax = nn.Softmax(dim=1) 34 | 35 | def __str__(self): 36 | return "Soft-cross-entropy loss" 37 | 38 | def forward(self, outputs, targets): 39 | targets = targets.view(-1, 1) 40 | assert outputs.size(0) == targets.size(0) 41 | 42 | softed = self.softmax(outputs) 43 | loged = -torch.log(softed) 44 | 45 | y_truth = softed.gather(1, targets) 46 | y_truth_1 = torch.pow(1 - y_truth, self.gama) 47 | 48 | log_truth = loged.gather(1, targets) 49 | soft_loss = y_truth_1 * log_truth 50 | if self.reduction == 'mean': 51 | loss = torch.mean(soft_loss) 52 | elif self.reduction == 'sum': 53 | loss = torch.sum(soft_loss) 54 | else: 55 | raise ValueError("std parameter must be in ['mean', 'sum'], but got {}".format(self.std)) 56 | 57 | return loss 58 | 59 | 60 | class SmoothedValue(object): 61 | """Track a series of values and provide access to smoothed values over a 62 | window or the global series average. 63 | """ 64 | 65 | def __init__(self, window_size=20, fmt=None): 66 | if fmt is None: 67 | fmt = "{median:.4f} ({global_avg:.4f})" 68 | self.deque = deque(maxlen=window_size) 69 | self.total = 0.0 70 | self.count = 0 71 | self.fmt = fmt 72 | 73 | def update(self, value, n=1): 74 | self.deque.append(value) 75 | self.count += n 76 | self.total += value * n 77 | 78 | def synchronize_between_processes(self): 79 | """ 80 | Warning: does not synchronize the deque! 81 | """ 82 | if not is_dist_avail_and_initialized(): 83 | return 84 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 85 | dist.barrier() 86 | dist.all_reduce(t) 87 | t = t.tolist() 88 | self.count = int(t[0]) 89 | self.total = t[1] 90 | 91 | @property 92 | def median(self): 93 | d = torch.tensor(list(self.deque)) 94 | return d.median().item() 95 | 96 | @property 97 | def avg(self): 98 | d = torch.tensor(list(self.deque), dtype=torch.float32) 99 | return d.mean().item() 100 | 101 | @property 102 | def global_avg(self): 103 | return self.total / self.count 104 | 105 | @property 106 | def max(self): 107 | return max(self.deque) 108 | 109 | @property 110 | def value(self): 111 | return self.deque[-1] 112 | 113 | def __str__(self): 114 | return self.fmt.format( 115 | median=self.median, 116 | avg=self.avg, 117 | global_avg=self.global_avg, 118 | max=self.max, 119 | value=self.value) 120 | 121 | 122 | class MetricLogger(object): 123 | def __init__(self, delimiter="\t"): 124 | self.meters = defaultdict(SmoothedValue) 125 | self.delimiter = delimiter 126 | 127 | def update(self, **kwargs): 128 | for k, v in kwargs.items(): 129 | if v is None: 130 | continue 131 | if isinstance(v, torch.Tensor): 132 | v = v.item() 133 | assert isinstance(v, (float, int)) 134 | self.meters[k].update(v) 135 | 136 | def __getattr__(self, attr): 137 | if attr in self.meters: 138 | return self.meters[attr] 139 | if attr in self.__dict__: 140 | return self.__dict__[attr] 141 | raise AttributeError("'{}' object has no attribute '{}'".format( 142 | type(self).__name__, attr)) 143 | 144 | def __str__(self): 145 | loss_str = [] 146 | for name, meter in self.meters.items(): 147 | loss_str.append( 148 | "{}: {}".format(name, str(meter)) 149 | ) 150 | return self.delimiter.join(loss_str) 151 | 152 | def synchronize_between_processes(self): 153 | for meter in self.meters.values(): 154 | meter.synchronize_between_processes() 155 | 156 | def add_meter(self, name, meter): 157 | self.meters[name] = meter 158 | 159 | def log_every(self, iterable, print_freq, header=None): 160 | i = 0 161 | if not header: 162 | header = '' 163 | start_time = time.time() 164 | end = time.time() 165 | iter_time = SmoothedValue(fmt='{avg:.4f}') 166 | data_time = SmoothedValue(fmt='{avg:.4f}') 167 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 168 | log_msg = [ 169 | header, 170 | '[{0' + space_fmt + '}/{1}]', 171 | 'eta: {eta}', 172 | '{meters}', 173 | 'time: {time}', 174 | 'data: {data}' 175 | ] 176 | if torch.cuda.is_available(): 177 | log_msg.append('max mem: {memory:.0f}') 178 | log_msg = self.delimiter.join(log_msg) 179 | MB = 1024.0 * 1024.0 180 | for obj in iterable: 181 | data_time.update(time.time() - end) 182 | yield obj 183 | iter_time.update(time.time() - end) 184 | if i % print_freq == 0 or i == len(iterable) - 1: 185 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 186 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 187 | if torch.cuda.is_available(): 188 | print(log_msg.format( 189 | i, len(iterable), eta=eta_string, 190 | meters=str(self), 191 | time=str(iter_time), data=str(data_time), 192 | memory=torch.cuda.max_memory_allocated() / MB)) 193 | else: 194 | print(log_msg.format( 195 | i, len(iterable), eta=eta_string, 196 | meters=str(self), 197 | time=str(iter_time), data=str(data_time))) 198 | i += 1 199 | end = time.time() 200 | total_time = time.time() - start_time 201 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 202 | print('{} Total time: {} ({:.4f} s / it)'.format( 203 | header, total_time_str, total_time / len(iterable))) 204 | 205 | 206 | class TensorboardLogger(object): 207 | def __init__(self, log_dir): 208 | self.writer = SummaryWriter(logdir=log_dir) 209 | self.step = 0 210 | 211 | def set_step(self, step=None): 212 | if step is not None: 213 | self.step = step 214 | else: 215 | self.step += 1 216 | 217 | def update(self, head='scalar', step=None, **kwargs): 218 | for k, v in kwargs.items(): 219 | if v is None: 220 | continue 221 | if isinstance(v, torch.Tensor): 222 | v = v.item() 223 | assert isinstance(v, (float, int)) 224 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step) 225 | 226 | def flush(self): 227 | self.writer.flush() 228 | 229 | 230 | class WandbLogger(object): 231 | def __init__(self, args): 232 | self.args = args 233 | 234 | try: 235 | import wandb 236 | self._wandb = wandb 237 | except ImportError: 238 | raise ImportError( 239 | "To use the Weights and Biases Logger please install wandb." 240 | "Run `pip install wandb` to install it." 241 | ) 242 | 243 | # Initialize a W&B run 244 | if self._wandb.run is None: 245 | self._wandb.init( 246 | entity="yonghao-yi", 247 | project=args.project, 248 | config=args, 249 | name=args.exp_name 250 | ) 251 | 252 | def log_epoch_metrics(self, metrics, commit=True): 253 | """ 254 | Log train/test metrics onto W&B. 255 | """ 256 | # Log number of model parameters as W&B summary 257 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None) 258 | metrics.pop('n_parameters', None) 259 | 260 | # Log current epoch 261 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False) 262 | metrics.pop('epoch') 263 | 264 | for k, v in metrics.items(): 265 | if 'train' in k: 266 | self._wandb.log({f'Global Train/{k}': v}, commit=False) 267 | elif 'test' in k: 268 | self._wandb.log({f'Global Test/{k}': v}, commit=False) 269 | 270 | self._wandb.log({}) 271 | 272 | def log_checkpoints(self): 273 | output_dir = self.args.output_dir 274 | model_artifact = self._wandb.Artifact( 275 | self._wandb.run.id + "_model", type="model" 276 | ) 277 | 278 | model_artifact.add_dir(output_dir) 279 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"]) 280 | 281 | def set_steps(self): 282 | # Set global training step 283 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step') 284 | # Set epoch-wise step 285 | self._wandb.define_metric('Global Train/*', step_metric='epoch') 286 | self._wandb.define_metric('Global Test/*', step_metric='epoch') 287 | 288 | 289 | def setup_for_distributed(is_master): 290 | """ 291 | This function disables printing when not in master process 292 | """ 293 | import builtins as __builtin__ 294 | builtin_print = __builtin__.print 295 | 296 | def print(*args, **kwargs): 297 | force = kwargs.pop('force', False) 298 | if is_master or force: 299 | builtin_print(*args, **kwargs) 300 | 301 | __builtin__.print = print 302 | 303 | 304 | def is_dist_avail_and_initialized(): 305 | if not dist.is_available(): 306 | return False 307 | if not dist.is_initialized(): 308 | return False 309 | return True 310 | 311 | 312 | def get_world_size(): 313 | if not is_dist_avail_and_initialized(): 314 | return 1 315 | return dist.get_world_size() 316 | 317 | 318 | def get_rank(): 319 | if not is_dist_avail_and_initialized(): 320 | return 0 321 | return dist.get_rank() 322 | 323 | 324 | def is_main_process(): 325 | return get_rank() == 0 326 | 327 | 328 | def save_on_master(*args, **kwargs): 329 | if is_main_process(): 330 | torch.save(*args, **kwargs) 331 | 332 | 333 | def init_distributed_mode(args): 334 | 335 | if args.dist_on_itp: 336 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 337 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 338 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 339 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 340 | os.environ['LOCAL_RANK'] = str(args.gpu) 341 | os.environ['RANK'] = str(args.rank) 342 | os.environ['WORLD_SIZE'] = str(args.world_size) 343 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 344 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 345 | args.rank = int(os.environ["RANK"]) 346 | args.world_size = int(os.environ['WORLD_SIZE']) 347 | args.gpu = int(os.environ['LOCAL_RANK']) 348 | elif 'SLURM_PROCID' in os.environ: 349 | args.rank = int(os.environ['SLURM_PROCID']) 350 | args.gpu = args.rank % torch.cuda.device_count() 351 | 352 | os.environ['RANK'] = str(args.rank) 353 | os.environ['LOCAL_RANK'] = str(args.gpu) 354 | os.environ['WORLD_SIZE'] = str(args.world_size) 355 | else: 356 | print('Not using distributed mode') 357 | args.distributed = False 358 | return 359 | 360 | args.distributed = True 361 | 362 | torch.cuda.set_device(args.gpu) 363 | args.dist_backend = 'nccl' 364 | print('| distributed init (rank {}): {}, gpu {}'.format( 365 | args.rank, args.dist_url, args.gpu), flush=True) 366 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 367 | world_size=args.world_size, rank=args.rank) 368 | print(f"backend:{args.dist_backend}, init_method:{args.dist_url}, world_size:{args.world_size}, rank:{args.rank}") 369 | torch.distributed.barrier() 370 | setup_for_distributed(args.rank == 0) 371 | 372 | 373 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"): 374 | missing_keys = [] 375 | unexpected_keys = [] 376 | error_msgs = [] 377 | # copy state_dict so _load_from_state_dict can modify it 378 | metadata = getattr(state_dict, '_metadata', None) 379 | state_dict = state_dict.copy() 380 | if metadata is not None: 381 | state_dict._metadata = metadata 382 | 383 | def load(module, prefix=''): 384 | local_metadata = {} if metadata is None else metadata.get( 385 | prefix[:-1], {}) 386 | module._load_from_state_dict( 387 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) 388 | for name, child in module._modules.items(): 389 | if child is not None: 390 | load(child, prefix + name + '.') 391 | 392 | load(model, prefix=prefix) 393 | 394 | warn_missing_keys = [] 395 | ignore_missing_keys = [] 396 | for key in missing_keys: 397 | keep_flag = True 398 | for ignore_key in ignore_missing.split('|'): 399 | if ignore_key in key: 400 | keep_flag = False 401 | break 402 | if keep_flag: 403 | warn_missing_keys.append(key) 404 | else: 405 | ignore_missing_keys.append(key) 406 | 407 | missing_keys = warn_missing_keys 408 | 409 | if len(missing_keys) > 0: 410 | logging.warning("Weights of {} not initialized from pretrained model: {}".format( 411 | model.__class__.__name__, missing_keys)) 412 | if len(unexpected_keys) > 0: 413 | logging.warning("Weights from pretrained model not used in {}: {}".format( 414 | model.__class__.__name__, unexpected_keys)) 415 | if len(ignore_missing_keys) > 0: 416 | logging.warning("Ignored weights of {} not initialized from pretrained model: {}".format( 417 | model.__class__.__name__, ignore_missing_keys)) 418 | if len(error_msgs) > 0: 419 | print('\n'.join(error_msgs)) 420 | 421 | 422 | class NativeScalerWithGradNormCount: 423 | state_dict_key = "amp_scaler" 424 | 425 | def __init__(self): 426 | self._scaler = torch.cuda.amp.GradScaler() 427 | 428 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 429 | self._scaler.scale(loss).backward(create_graph=create_graph) 430 | if update_grad: 431 | if clip_grad is not None: 432 | assert parameters is not None 433 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 434 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 435 | else: 436 | self._scaler.unscale_(optimizer) 437 | norm = get_grad_norm_(parameters) 438 | self._scaler.step(optimizer) 439 | self._scaler.update() 440 | else: 441 | norm = None 442 | return norm 443 | 444 | def state_dict(self): 445 | return self._scaler.state_dict() 446 | 447 | def load_state_dict(self, state_dict): 448 | self._scaler.load_state_dict(state_dict) 449 | 450 | 451 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 452 | if isinstance(parameters, torch.Tensor): 453 | parameters = [parameters] 454 | parameters = [p for p in parameters if p.grad is not None] 455 | norm_type = float(norm_type) 456 | if len(parameters) == 0: 457 | return torch.tensor(0.) 458 | device = parameters[0].grad.device 459 | if norm_type == inf: 460 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 461 | else: 462 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 463 | return total_norm 464 | 465 | 466 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, 467 | start_warmup_value=0, warmup_steps=-1): 468 | warmup_schedule = np.array([]) 469 | warmup_iters = warmup_epochs * niter_per_ep 470 | if warmup_steps > 0: 471 | warmup_iters = warmup_steps 472 | print("Set warmup steps = %d" % warmup_iters) 473 | if warmup_epochs > 0: 474 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) 475 | 476 | iters = np.arange(epochs * niter_per_ep - warmup_iters) 477 | schedule = np.array( 478 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters]) 479 | 480 | schedule = np.concatenate((warmup_schedule, schedule)) 481 | 482 | assert len(schedule) == epochs * niter_per_ep 483 | return schedule 484 | 485 | 486 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 487 | output_dir = Path(args.output_dir) 488 | epoch_name = str(epoch) 489 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 490 | for checkpoint_path in checkpoint_paths: 491 | to_save = { 492 | 'model': model_without_ddp.state_dict(), 493 | 'optimizer': optimizer.state_dict(), 494 | 'epoch': epoch, 495 | 'scaler': loss_scaler.state_dict(), 496 | 'args': args, 497 | } 498 | 499 | if model_ema is not None: 500 | to_save['model_ema'] = get_state_dict(model_ema) 501 | 502 | save_on_master(to_save, checkpoint_path) 503 | 504 | if is_main_process() and isinstance(epoch, int): 505 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq 506 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del) 507 | if os.path.exists(old_ckpt): 508 | os.remove(old_ckpt) 509 | 510 | 511 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None): 512 | output_dir = Path(args.output_dir) 513 | if args.auto_resume and len(args.resume) == 0: 514 | import glob 515 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth')) 516 | latest_ckpt = -1 517 | for ckpt in all_checkpoints: 518 | t = ckpt.split('-')[-1].split('.')[0] 519 | if t.isdigit(): 520 | latest_ckpt = max(int(t), latest_ckpt) 521 | if latest_ckpt >= 0: 522 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt) 523 | print("Auto resume checkpoint: %s" % args.resume) 524 | 525 | if args.resume: 526 | if args.resume.startswith('https'): 527 | checkpoint = torch.hub.load_state_dict_from_url( 528 | args.resume, map_location='cpu', check_hash=True) 529 | else: 530 | checkpoint = torch.load(args.resume, map_location='cpu') 531 | model_without_ddp.load_state_dict(checkpoint['model']) 532 | print("Resume checkpoint %s" % args.resume) 533 | if 'optimizer' in checkpoint and 'epoch' in checkpoint: 534 | optimizer.load_state_dict(checkpoint['optimizer']) 535 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema' 536 | args.start_epoch = checkpoint['epoch'] + 1 537 | else: 538 | assert args.eval, 'Does not support resuming with checkpoint-best' 539 | if hasattr(args, 'model_ema') and args.model_ema: 540 | if 'model_ema' in checkpoint.keys(): 541 | model_ema.ema.load_state_dict(checkpoint['model_ema']) 542 | else: 543 | model_ema.ema.load_state_dict(checkpoint['model']) 544 | if 'scaler' in checkpoint: 545 | loss_scaler.load_state_dict(checkpoint['scaler']) 546 | print("With optim & sched!") 547 | 548 | 549 | 550 | @functools.lru_cache() 551 | def create_logger(output_dir, dist_rank=0, name=''): 552 | # create logger 553 | logger = logging.getLogger(name) 554 | logger.setLevel(logging.INFO) 555 | logger.propagate = False 556 | 557 | # create formatter 558 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 559 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 560 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 561 | 562 | # create console handlers for master process 563 | if dist_rank == 0: 564 | console_handler = logging.StreamHandler(sys.stdout) 565 | console_handler.setLevel(logging.DEBUG) 566 | console_handler.setFormatter( 567 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 568 | logger.addHandler(console_handler) 569 | 570 | # create file handlers 571 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 572 | file_handler.setLevel(logging.DEBUG) 573 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 574 | logger.addHandler(file_handler) 575 | 576 | return logger 577 | 578 | def reduce_tensor(tensor): 579 | rt = tensor.clone() 580 | # dist.all_reduce(rt, op=dist.ReduceOp.SUM) 581 | # rt /= dist.get_world_size() 582 | return rt 583 | 584 | 585 | ############################################## 586 | # Center Loss for Attention Regularization 587 | ############################################## 588 | class CenterLoss(nn.Module): 589 | def __init__(self): 590 | super(CenterLoss, self).__init__() 591 | self.l2_loss = nn.MSELoss(reduction='sum') 592 | 593 | def forward(self, outputs, targets): 594 | return self.l2_loss(outputs, targets) / outputs.size(0) 595 | 596 | 597 | class UnNormalizer(object): 598 | def __init__(self, mean=None, std=None): 599 | if mean == None: 600 | self.mean = [0.485, 0.456, 0.406] 601 | else: 602 | self.mean = mean 603 | if std == None: 604 | self.std = [0.229, 0.224, 0.225] 605 | else: 606 | self.std = std 607 | 608 | def __call__(self, tensor): 609 | """ 610 | Args: 611 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 612 | Returns: 613 | Tensor: Normalized image. 614 | """ 615 | for t, m, s in zip(tensor, self.mean, self.std): 616 | t.mul_(s).add_(m) 617 | return tensor 618 | 619 | 620 | def save_input_image(input_image, save_path): 621 | unorm = UnNormalizer(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) 622 | img = unorm(input_image.cpu()).numpy().copy() 623 | img = np.clip(img * 255, 0, 255).astype(np.uint8) 624 | img = np.transpose(img, [1, 2, 0]) 625 | # change the channel(BGR->RGB) 626 | r, g, b = cv2.split(img) 627 | img = cv2.merge([b, g, r]) 628 | cv2.imwrite(save_path, img) 629 | 630 | 631 | def get_detail_images(input_images, attentions, theta_detail=0.5, padding=0.1): 632 | batch_size, _, height, width = input_images.size() # B, C, H, W 633 | detail_images = [] 634 | for i in range(batch_size): 635 | attention = attentions[i:i + 1] 636 | if isinstance(theta_detail, tuple): 637 | theta = random.uniform(*theta_detail) * attention.max() 638 | else: 639 | theta = theta_detail * attention.max() 640 | 641 | mask = F.upsample_bilinear(attention, size=(height, width)) >= theta 642 | nonzero_indices = torch.nonzero(mask[0, 0, ...]) 643 | height_min = max(int((nonzero_indices[:, 0].min().item() - padding * height)), 0) 644 | height_max = min(int(nonzero_indices[:, 0].max().item() + padding * height), height) 645 | width_min = max(int(nonzero_indices[:, 1].min().item() - padding * width), 0) 646 | width_max = min(int(nonzero_indices[:, 1].max().item() + padding * width), width) 647 | detail_images.append( 648 | F.upsample_bilinear(input_images[i:i + 1, :, height_min:height_max, width_min:width_max], 649 | size=(height, width))) 650 | 651 | detail_images = torch.cat(detail_images, dim=0) 652 | return detail_images 653 | 654 | 655 | def get_mask_images(input_images, attentions, theta_mask=0.5): 656 | batch_size, _, height, weight = input_images.size() 657 | masks = [] 658 | for i in range(batch_size): 659 | attention = attentions[i:i + 1] 660 | if isinstance(theta_mask, tuple): 661 | temp1 = random.uniform(*theta_mask) 662 | temp2 = attention.max() 663 | theta = temp1 * temp2 664 | else: 665 | theta = theta_mask * attention.max() 666 | 667 | masks.append(F.upsample_bilinear(attention, size=(height, weight)) < theta) 668 | masks = torch.cat(masks, dim=0) 669 | mask_images = input_images * masks.float() 670 | 671 | return mask_images 672 | 673 | # def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1, image_name=None): 674 | # batches, _, imgH, imgW = images.size() 675 | # if mode == 'crop': 676 | # crop_images = [] 677 | # nonzero_indices = None 678 | # for batch_index in range(batches): 679 | # atten_map = attention_map[batch_index:batch_index + 1] 680 | # if isinstance(theta, tuple): 681 | # theta_c = random.uniform(*theta) * atten_map.max() 682 | # else: 683 | # theta_c = theta * atten_map.max() 684 | # 685 | # crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c 686 | # nonzero_indices = torch.nonzero(crop_mask[0, 0, ...]) 687 | # height_min = max(int((nonzero_indices[:, 0].min().item() - padding_ratio * imgH)), 0) 688 | # height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH) 689 | # width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0) 690 | # width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW) 691 | # crop_images.append( 692 | # F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max], 693 | # size=(imgH, imgW))) 694 | # # if image_name: 695 | # # save_input_image(images[batch_index], os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop", image_name[batch_index] + "_raw.jpg")) 696 | # # save_input_image(crop_images[batch_index][0], os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop", image_name[batch_index] + "_crop.jpg")) 697 | # 698 | # crop_images = torch.cat(crop_images, dim=0) 699 | # return crop_images 700 | # 701 | # elif mode == 'drop': 702 | # drop_masks = [] 703 | # for batch_index in range(batches): 704 | # atten_map = attention_map[batch_index:batch_index + 1] 705 | # if isinstance(theta, tuple): 706 | # temp1 = random.uniform(*theta) 707 | # temp2 = atten_map.max() 708 | # theta_d = temp1 * temp2 709 | # else: 710 | # theta_d = theta * atten_map.max() 711 | # 712 | # drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d) 713 | # drop_masks = torch.cat(drop_masks, dim=0) 714 | # drop_images = images * drop_masks.float() 715 | # # for idx, drop_image in enumerate(drop_images): 716 | # # if image_name: 717 | # # save_input_image(drop_image, 718 | # # os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop", 719 | # # image_name[idx] + "_drop.jpg")) 720 | # return drop_images 721 | # 722 | # else: 723 | # raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode) 724 | --------------------------------------------------------------------------------