├── .gitignore
├── .idea
├── .gitignore
├── EFM-Net-Pytorch.iml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── config
├── __init__.py
└── configs.py
├── datasets_builder.py
├── main.py
├── models
├── __init__.py
├── cores.py
└── miner.py
├── processor
├── __init__.py
├── engine.py
├── optim_factory.py
└── run_with_submitit.py
└── util
├── __init__.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/EFM-Net-Pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images
2 |
3 | By Yonghao Yi (*Student Member, IEEE*) , Yanan You (*Member, IEEE*) , Chao Li, and Wenli Zhou.
4 |
5 | Special thanks to [Haopeng Zhang](https://orcid.org/0000-0003-1981-8307) for providing the dataset source file for [FGSCR-42](https://www.mdpi.com/2072-4292/13/4/747).
6 |
7 | ## Introduction
8 |
9 | This code provides an reference version for the implementation of the *IEEE-TGRS* paper “EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images”. The projects are still under construction.
10 |
11 |
12 |
13 | ## How to run
14 |
15 | ### Prepare the datasets
16 |
17 | **If you want to use the training and testing sets divided in this paper, please click on the following link to download.**
18 |
19 | Download the public benchmark datasets and unzip them in your own path.
20 |
21 | - [FGSC-23](https://drive.google.com/file/d/1DFir2wvzVnMYqleqgvxLoN760hYZe3TW/view?usp=sharing)
22 |
23 | - [FGSCR-42](https://drive.google.com/file/d/1o8QzGA3wEhobGFZ-Hbey0GCgCNIdnEmf/view?usp=sharing)
24 |
25 | - [Aircraft-16](https://drive.google.com/file/d/1n0aoB0FJIvrA5xpC8AfeXKZKqCFu2Gca/view?usp=sharing)
26 |
27 | Move the directory with `train` and `test` sub-directories to the `./datasets/` directory and rename it with database name, such as
28 |
29 | ```
30 | mv FGSC23 /path/to/EFM-Net-Pytorch/datasets/FGSC-23
31 | ```
32 |
33 | Make sure your datasets are correctly split into the training set and test set. The training set should be placed in the directory named “train” while test set named “test”.
34 |
35 | **The source files download links of those public datasets can be found in https://github.com/JACYI/Dataset-for-Remote-Sensing.**
36 |
37 |
38 |
39 | ### Download the pre-trained model
40 |
41 | The feature extraction model is based on ConvNeXt-Base. Please download the pre-trained parameters file:
42 |
43 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
44 |
45 | and move the pre-trained model to `./pretrained` directory.
46 |
47 |
48 |
49 | ### Build the running environment
50 |
51 | 1. **Create the virtual environment:**
52 |
53 | ```shell
54 | conda create -n efmnet python=3.8
55 | source activate efmnet
56 | ```
57 |
58 | 2. **requirements**
59 |
60 | - python=3.8
61 | - pytorch>=1.8.1, torchvision>=0.9.1
62 | - timm=0.3.2
63 | - opencv-python
64 | - scikit-learn
65 | - tensorboardX
66 | - termcolor
67 |
68 |
69 |
70 | ### For Training
71 |
72 | 1. **Run on a single GPU**
73 |
74 | ```shell
75 | python main.py --exp_name test --attentions 16 --epochs 100 --dataset FGSC-23 --output_dir logs/test --novel_loss 0.5
76 | ```
77 |
78 | 2. **Run on multiple GPUs** (2 GPUs for example)
79 |
80 | ```shell
81 | python -m torch.distributed.launch --master_port 12345 \
82 | --nproc_per_node=2 main.py \
83 | --exp_name test --attentions 16 \
84 | --epochs 120 --dataset FGSC-23 \
85 | --output_dir logs/test --novel_loss 0.5
86 | ```
87 |
88 |
89 |
90 | ## Citation
91 |
92 | More details need to be added.
93 |
94 | ```
95 | @ARTICLE{10097708,
96 | author={Yi, Yonghao and You, Yanan and Li, Chao and Zhou, Wenli},
97 | journal={IEEE Transactions on Geoscience and Remote Sensing},
98 | title={EFM-Net: An Essential Feature Mining Network for Target Fine-Grained Classification in Optical Remote Sensing Images},
99 | year={2023},
100 | volume={},
101 | number={},
102 | pages={1-1},
103 | doi={10.1109/TGRS.2023.3265669}
104 | }
105 | ```
106 |
107 |
108 |
109 | ## Corresponding author information
110 |
111 | Yanan You, Beijing, China,
112 |
113 | Phone: (+86) 15201148169
114 |
115 | Email: youyanan@bupt.edu.cn
116 |
117 |
118 |
119 | ## To do
120 |
121 | 1. Support more feature extraction models;
122 | 2. Provide more interfaces for modifying parameters.
123 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/config/__init__.py
--------------------------------------------------------------------------------
/config/configs.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import argparse
6 | import time
7 |
8 |
9 | def str2bool(v):
10 | """
11 | Converts string to bool type; enables command line
12 | arguments in the format of '--arg1 true --arg2 false'
13 | """
14 | if isinstance(v, bool):
15 | return v
16 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
17 | return True
18 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
19 | return False
20 | else:
21 | raise argparse.ArgumentTypeError('Boolean value expected.')
22 |
23 |
24 | def get_args_parser():
25 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script for image classification', add_help=False)
26 | # Main config
27 | parser.add_argument('--exp_name', default='TGRS-MODEL', type=str)
28 | parser.add_argument('--batch_size', default=16, type=int,
29 | help='Per GPU batch size')
30 | parser.add_argument('--epochs', default=300, type=int, help='Number of training iterations.')
31 | parser.add_argument('--attentions', default=16, type=int,
32 | help='The number of channels in the attention module (kappa in the text)')
33 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
34 | help='learning rate (default: 1e-3)')
35 | parser.add_argument('--novel_loss', type=float, default=0.,
36 | help='SCE loss in the paper. 0. equals to cross-entropy loss')
37 | parser.add_argument('--update_freq', default=30, type=int, help='Gradient accumulation steps')
38 |
39 | # model parameters
40 | parser.add_argument('--model', default='convnext_base', type=str, metavar='MODEL',
41 | help='Name of feature extraction model to train')
42 | parser.add_argument('--drop_path', type=float, default=0.2, metavar='PCT',
43 | help='Drop path rate (default: 0.0)')
44 | parser.add_argument('--input_size', default=224, type=int,
45 | help='image input size')
46 | parser.add_argument('--layer_scale_init_value', default=1e-6, type=float,
47 | help="Layer scale initial values")
48 | parser.add_argument('--use_mha', default=True, type=str2bool)
49 | parser.add_argument('--use_ref', default=True, type=str2bool)
50 |
51 | # EMA related parameters
52 | parser.add_argument('--model_ema', type=str2bool, default=False)
53 | parser.add_argument('--model_ema_decay', type=float, default=0.9999, help='')
54 | parser.add_argument('--model_ema_force_cpu', type=str2bool, default=False, help='')
55 | parser.add_argument('--model_ema_eval', type=str2bool, default=False, help='Using ema to eval during training.')
56 |
57 | # optimization parameters
58 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
59 | help='Optimizer (default: "adamw"')
60 | parser.add_argument('--opt_eps', default=1e-8, type=float, metavar='EPSILON',
61 | help='Optimizer Epsilon (default: 1e-8)')
62 | parser.add_argument('--opt_betas', default=None, type=float, nargs='+', metavar='BETA',
63 | help='Optimizer Betas (default: None, use opt default)')
64 | parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
65 | help='Clip gradient norm (default: None, no clipping)')
66 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
67 | help='SGD momentum (default: 0.9)')
68 | parser.add_argument('--weight_decay', type=float, default=1e-8,
69 | help='weight decay (default: 0.05)')
70 | parser.add_argument('--weight_decay_end', type=float, default=None, help="""Final value of the
71 | weight decay. We use a cosine schedule for WD and using a larger decay by
72 | the end of training improves performance for ViTs.""")
73 |
74 | # learning parameters
75 | parser.add_argument('--layer_decay', type=float, default=0.8)
76 | parser.add_argument('--min_lr', type=float, default=1e-6, metavar='LR',
77 | help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
78 | parser.add_argument('--warmup_epochs', type=int, default=0, metavar='N',
79 | help='epochs to warmup LR, if scheduler supports')
80 | parser.add_argument('--warmup_steps', type=int, default=-1, metavar='N',
81 | help='num of steps to warmup LR, will overload warmup_epochs if set > 0')
82 |
83 | # augmentation parameters
84 | # parser.add_argument('--theta_detail', type=tuple, default=(0.4, 0.6),
85 | # help='Proportion of regional boundaries of detail images')
86 | # parser.add_argument('--theta_mask', type=tuple, default=(0.2, 0.5),
87 | # help='Proportion of regional boundaries of mask images')
88 | parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
89 | help='Color jitter factor (default: 0.4)')
90 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
91 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)'),
92 | parser.add_argument('--smoothing', type=float, default=0,
93 | help='Label smoothing (default: 0.1)')
94 | parser.add_argument('--train_interpolation', type=str, default='bicubic',
95 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
96 |
97 | # evaluation parameters
98 | parser.add_argument('--crop_pct', type=float, default=None)
99 |
100 | # random Erase params
101 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
102 | help='Random erase prob (default: 0.25)')
103 | parser.add_argument('--remode', type=str, default='pixel',
104 | help='Random erase mode (default: "pixel")')
105 | parser.add_argument('--recount', type=int, default=1,
106 | help='Random erase count (default: 1)')
107 | parser.add_argument('--resplit', type=str2bool, default=False,
108 | help='Do not random erase first (clean) augmentation split')
109 |
110 | # mixup params
111 | parser.add_argument('--mixup', type=float, default=0.,
112 | help='mixup alpha, mixup enabled if > 0.')
113 | parser.add_argument('--cutmix', type=float, default=0.,
114 | help='cutmix alpha, cutmix enabled if > 0.')
115 | parser.add_argument('--cutmix_minmax', type=float, nargs='+', default=None,
116 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
117 | parser.add_argument('--mixup_prob', type=float, default=1.0,
118 | help='Probability of performing mixup or cutmix when either/both is enabled')
119 | parser.add_argument('--mixup_switch_prob', type=float, default=0.5,
120 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
121 | parser.add_argument('--mixup_mode', type=str, default='batch',
122 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
123 |
124 | # finetuning params
125 | parser.add_argument('--finetune', default='pretrained/convnext_base_22k_224.pth',
126 | help='finetune from checkpoint')
127 | parser.add_argument('--head_init_scale', default=0.001, type=float,
128 | help='classifier head initial scale, typically adjusted in fine-tuning')
129 | parser.add_argument('--model_key', default='model|module', type=str,
130 | help='which key to load from saved state dict, usually model or model_ema')
131 | parser.add_argument('--model_prefix', default='', type=str)
132 |
133 | # dataset parameters
134 | parser.add_argument('--data_path', default='./datasets/', type=str,
135 | help='dataset path')
136 | parser.add_argument('--dataset', default='FGSC-23')
137 | parser.add_argument('--eval_data_path', default=None, type=str,
138 | help='dataset path for evaluation')
139 | parser.add_argument('--nb_classes', default=23, type=int,
140 | help='number of the classification types')
141 | parser.add_argument('--imagenet_default_mean_and_std', type=str2bool, default=True)
142 | parser.add_argument('--data_set', default='image_folder', choices=['CIFAR', 'IMNET', 'image_folder'],
143 | type=str, help='ImageNet dataset path')
144 | parser.add_argument('--output_dir', default='result/',
145 | help='path where to save, empty for no saving')
146 | parser.add_argument('--log_dir', default=None,
147 | help='path where to tensorboard log')
148 | parser.add_argument('--device', default='cuda',
149 | help='device to use for training / testing')
150 | parser.add_argument('--seed', default=0, type=int)
151 |
152 | # training parameters
153 | parser.add_argument('--resume', default='',
154 | help='resume from checkpoint')
155 | parser.add_argument('--auto_resume', type=str2bool, default=True)
156 | parser.add_argument('--save_ckpt', type=str2bool, default=True)
157 | parser.add_argument('--save_ckpt_freq', default=10, type=int)
158 | parser.add_argument('--save_ckpt_num', default=3, type=int)
159 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
160 | help='start epoch')
161 | parser.add_argument('--eval', type=str2bool, default=False,
162 | help='Perform evaluation only')
163 | parser.add_argument('--dist_eval', type=str2bool, default=True,
164 | help='Enabling distributed evaluation')
165 | parser.add_argument('--disable_eval', type=str2bool, default=False,
166 | help='Disabling evaluation during training')
167 | parser.add_argument('--num_workers', default=10, type=int)
168 | parser.add_argument('--pin_mem', type=str2bool, default=True,
169 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
170 |
171 | # distributed training parameters
172 | parser.add_argument('--world_size', default=1, type=int,
173 | help='number of distributed processes')
174 | parser.add_argument('--local_rank', default=-1, type=int)
175 | parser.add_argument('--dist_on_itp', type=str2bool, default=False)
176 | parser.add_argument('--dist_url', default='env://',
177 | help='url used to set up distributed training')
178 |
179 | parser.add_argument('--use_amp', type=str2bool, default=False,
180 | help="Use PyTorch's AMP (Automatic Mixed Precision) or not")
181 |
182 | # weights and Biases arguments
183 | parser.add_argument('--enable_wandb', type=str2bool, default=False,
184 | help="enable logging to Weights and Biases")
185 | parser.add_argument('--project', default='A-Net', type=str,
186 | help="The name of the W&B project where you're sending the new run.")
187 | parser.add_argument('--wandb_ckpt', type=str2bool, default=False,
188 | help="Save model checkpoints as W&B Artifacts.")
189 | parser.add_argument('--tag', default=time.strftime("%m-%d-%H-%M-%S", time.localtime()), help='tag of experiment')
190 | return parser
191 |
192 | if __name__ == "__main__":
193 | parser = argparse.ArgumentParser('Test for parser.', parents=[get_args_parser()])
194 | args = parser.parse_args()
--------------------------------------------------------------------------------
/datasets_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import os
6 | from torchvision import datasets, transforms
7 | from timm.data.constants import \
8 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
9 | from timm.data import create_transform
10 |
11 | def build_dataset(is_train, args):
12 | transform = build_transform(is_train, args)
13 | print('{}.trans:{}'.format(is_train, transform))
14 | print("Transform = ")
15 | if isinstance(transform, tuple):
16 | for trans in transform:
17 | print(" - - - - - - - - - - ")
18 | for t in trans.transforms:
19 | print(t)
20 | else:
21 | for t in transform.transforms:
22 | print(t)
23 | print("---------------------------")
24 |
25 | if args.data_set == 'CIFAR':
26 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
27 | nb_classes = 100
28 | elif args.data_set == 'IMNET':
29 | print("reading from datapath", args.data_path)
30 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
31 | dataset = datasets.ImageFolder(root, transform=transform)
32 | nb_classes = 1000
33 | elif args.data_set == "image_folder":
34 | root = os.path.join(args.data_path, args.dataset, 'train' if is_train else 'test')
35 |
36 | dataset = datasets.ImageFolder(root, transform=transform)
37 | args.nb_classes = len(dataset.class_to_idx)
38 | nb_classes = args.nb_classes
39 | else:
40 | raise NotImplementedError()
41 | print("Number of the class = %d" % nb_classes)
42 |
43 | return dataset, nb_classes
44 |
45 |
46 | def build_transform(is_train, args):
47 | resize_im = args.input_size > 32
48 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
49 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
50 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
51 |
52 | if is_train:
53 | # this should always dispatch to transforms_imagenet_train
54 | transform = create_transform(
55 | input_size=args.input_size,
56 | is_training=True,
57 | color_jitter=args.color_jitter,
58 | auto_augment=args.aa,
59 | interpolation=args.train_interpolation,
60 | re_prob=args.reprob,
61 | re_mode=args.remode,
62 | re_count=args.recount,
63 | mean=mean,
64 | std=std,
65 | )
66 | if not resize_im:
67 | transform.transforms[0] = transforms.RandomCrop(
68 | args.input_size, padding=4)
69 | return transform
70 |
71 | t = []
72 | if resize_im:
73 | # warping (no cropping) when evaluated at 384 or larger
74 | if args.input_size >= 384:
75 | t.append(
76 | transforms.Resize((args.input_size, args.input_size),
77 | interpolation=transforms.InterpolationMode.BICUBIC),
78 | )
79 | print(f"Warping {args.input_size} size input images...")
80 | else:
81 | if args.crop_pct is None:
82 | args.crop_pct = 224 / 256
83 | size = int(args.input_size / args.crop_pct)
84 | t.append(
85 | # to maintain same ratio w.r.t. 224 images
86 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),
87 | )
88 | t.append(transforms.CenterCrop(args.input_size))
89 |
90 | t.append(transforms.ToTensor())
91 | t.append(transforms.Normalize(mean, std))
92 | return transforms.Compose(t)
93 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import argparse
6 | import datetime
7 | import numpy as np
8 | import time
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import json
12 | import os
13 | from pathlib import Path
14 | from timm.data.mixup import Mixup
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.utils import ModelEma
17 | from config.configs import get_args_parser
18 | from models.miner import convnext_base
19 | from datasets_builder import build_dataset
20 | from processor.engine import train_one_epoch, evaluate
21 | from processor.optim_factory import create_optimizer, LayerDecayValueAssigner
22 | from util.utils import NativeScalerWithGradNormCount as NativeScaler
23 | from util import utils
24 | from util.utils import create_logger, SoftCrossEntropyLoss
25 | import torch.distributed as dist
26 | import warnings
27 | warnings.filterwarnings('ignore')
28 |
29 |
30 | def main(args):
31 | # ---------------------- prepare running --------------------------------
32 | # distributed mode
33 | utils.init_distributed_mode(args)
34 | device = torch.device(args.device)
35 | num_tasks = utils.get_world_size()
36 | global_rank = utils.get_rank()
37 | # logging file
38 | if args.eval is False:
39 | logger = create_logger(output_dir=args.output_dir, dist_rank=dist.get_rank())
40 | else:
41 | logger = None
42 | # fix the seed for reproducibility
43 | seed = args.seed + utils.get_rank()
44 | torch.manual_seed(seed)
45 | np.random.seed(seed)
46 | cudnn.benchmark = True
47 |
48 | # ------------------------- build dataset ------------------------------
49 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
50 | if args.disable_eval:
51 | args.dist_eval = False
52 | dataset_val = None
53 | else:
54 | dataset_val, _ = build_dataset(is_train=False, args=args)
55 |
56 |
57 | sampler_train = torch.utils.data.DistributedSampler(
58 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True, seed=args.seed,
59 | )
60 | # logger.info("Sampler_train = %s" % str(sampler_train))
61 | if args.dist_eval:
62 | if len(dataset_val) % num_tasks != 0:
63 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
64 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
65 | 'equal num of samples per-process.')
66 | sampler_val = torch.utils.data.DistributedSampler(
67 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
68 | else:
69 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
70 |
71 | if global_rank == 0 and args.log_dir is not None:
72 | os.makedirs(args.log_dir, exist_ok=True)
73 | log_writer = utils.TensorboardLogger(log_dir=args.log_dir)
74 | else:
75 | log_writer = None
76 |
77 | if global_rank == 0 and args.enable_wandb:
78 | wandb_logger = utils.WandbLogger(args)
79 | else:
80 | wandb_logger = None
81 |
82 | data_loader_train = torch.utils.data.DataLoader(
83 | dataset_train, sampler=sampler_train,
84 | batch_size=args.batch_size,
85 | num_workers=args.num_workers,
86 | pin_memory=args.pin_mem,
87 | drop_last=True,
88 | )
89 |
90 | if dataset_val is not None:
91 | data_loader_val = torch.utils.data.DataLoader(
92 | dataset_val, sampler=sampler_val,
93 | batch_size=int(1.5 * args.batch_size),
94 | num_workers=args.num_workers,
95 | pin_memory=args.pin_mem,
96 | drop_last=False
97 | )
98 | else:
99 | data_loader_val = None
100 |
101 | # ------------------------- mixup setting ------------------------------
102 | mixup_fn = None
103 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
104 | if mixup_active:
105 | logger.info("Mixup is activated!")
106 | mixup_fn = Mixup(
107 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
108 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
109 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
110 |
111 | # ---------------------- initialize the model ------------------------------
112 | if args.model == 'convnext_base':
113 | model = convnext_base(
114 | pretrained=False,
115 | M=args.attentions,
116 | num_classes=args.nb_classes,
117 | drop_path_rate=args.drop_path,
118 | layer_scale_init_value=args.layer_scale_init_value,
119 | head_init_scale=args.head_init_scale,
120 | use_mha=args.use_mha,
121 | use_ref=args.use_ref,
122 | )
123 | checkpoint = torch.load(args.finetune, map_location='cpu')
124 | checkpoint_model = None
125 | for model_key in args.model_key.split('|'):
126 | if model_key in checkpoint:
127 | checkpoint_model = checkpoint[model_key]
128 | logger.info("Load state_dict by model_key = %s" % model_key)
129 | break
130 | if checkpoint_model is None:
131 | checkpoint_model = checkpoint
132 | state_dict = model.state_dict()
133 | for k in ['head.weight', 'head.bias']:
134 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
135 | del checkpoint_model[k]
136 | utils.load_state_dict(model, checkpoint_model, prefix=args.model_prefix)
137 | else:
138 | ValueError("Unsupported model: %s" % args.model)
139 |
140 | model.to(device)
141 | model_ema = None
142 | if args.model_ema:
143 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
144 | model_ema = ModelEma(
145 | model,
146 | decay=args.model_ema_decay,
147 | device='cpu' if args.model_ema_force_cpu else '',
148 | resume='')
149 | logger.info("Using EMA with decay = %.8f" % args.model_ema_decay)
150 |
151 | model_without_ddp = model
152 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
153 |
154 | total_batch_size = args.batch_size * args.update_freq * utils.get_world_size()
155 | num_training_steps_per_epoch = len(dataset_train) // total_batch_size
156 |
157 | if args.layer_decay < 1.0 or args.layer_decay > 1.0:
158 | num_layers = 12 # convnext layers divided into 12 parts, each with a different decayed lr value.
159 | if args.model in ['convnext_small', 'convnext_base', 'convnext_large', 'convnext_xlarge']:
160 | assigner = LayerDecayValueAssigner(list(args.layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2)))
161 |
162 | # model distributed
163 | if args.distributed:
164 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], broadcast_buffers=False, find_unused_parameters=True)
165 | model_without_ddp = model.module
166 |
167 | # -------------------- initialize the optimizer ------------------------------
168 | optimizer = create_optimizer(
169 | args, model_without_ddp, skip_list=None,
170 | get_num_layer=assigner.get_layer_id if assigner is not None else None,
171 | get_layer_scale=assigner.get_scale if assigner is not None else None)
172 |
173 | loss_scaler = NativeScaler() # if args.use_amp is False, this won't be used
174 | # schedule
175 | logger.info("Use Cosine LR scheduler")
176 | lr_schedule_values = utils.cosine_scheduler(
177 | args.lr, args.min_lr, args.epochs, num_training_steps_per_epoch,
178 | warmup_epochs=args.warmup_epochs, warmup_steps=args.warmup_steps,
179 | )
180 | # weight decay
181 | if args.weight_decay_end is None:
182 | args.weight_decay_end = args.weight_decay
183 | wd_schedule_values = utils.cosine_scheduler(
184 | args.weight_decay, args.weight_decay_end, args.epochs, num_training_steps_per_epoch)
185 |
186 | if mixup_fn is not None:
187 | # smoothing is handled with mixup label transform
188 | criterion = SoftTargetCrossEntropy()
189 | elif args.smoothing > 0.:
190 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
191 | else:
192 | criterion = SoftCrossEntropyLoss(gama=args.novel_loss)
193 |
194 | logger.info("Criterion: %s" % str(criterion))
195 |
196 | utils.auto_load_model(
197 | args=args, model=model, model_without_ddp=model_without_ddp,
198 | optimizer=optimizer, loss_scaler=loss_scaler, model_ema=model_ema)
199 | # for evaluation
200 | if args.eval:
201 | logger.info(f"Eval only mode")
202 | ckpt = torch.load("checkpoint/" + args.dataset + ".pth")["model"]
203 | model_dict = model.state_dict()
204 | ckpt = {k: v for k, v in ckpt.items() if k in model_dict}
205 | model.load_state_dict(ckpt)
206 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq)
207 | logger.info(f"Accuracy of the network on {len(dataset_val)} test images: {test_stats['acc1']:.5f}%")
208 | return
209 |
210 | # ------------------------- training stage ------------------------------
211 | max_accuracy = 0.0
212 | if args.model_ema and args.model_ema_eval:
213 | max_accuracy_ema = 0.0
214 | logger.info("Start training")
215 | start_time = time.time()
216 | for epoch in range(args.start_epoch, args.epochs):
217 | if args.distributed:
218 | data_loader_train.sampler.set_epoch(epoch)
219 | if log_writer is not None:
220 | log_writer.set_step(epoch * num_training_steps_per_epoch * args.update_freq)
221 | if wandb_logger:
222 | wandb_logger.set_steps()
223 | # training
224 | train_stats = train_one_epoch(
225 | model, criterion, data_loader_train, optimizer,
226 | device, epoch, loss_scaler, args.clip_grad, model_ema, mixup_fn,
227 | log_writer=log_writer, wandb_logger=wandb_logger, start_steps=epoch * num_training_steps_per_epoch,
228 | lr_schedule_values=lr_schedule_values, wd_schedule_values=wd_schedule_values,
229 | num_training_steps_per_epoch=num_training_steps_per_epoch, update_freq=args.update_freq,
230 | use_amp=args.use_amp, logger=logger,
231 | )
232 | # save params
233 | if args.output_dir and args.save_ckpt:
234 | if (epoch + 1) % args.save_ckpt_freq == 0 or epoch + 1 == args.epochs:
235 | utils.save_model(
236 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
237 | loss_scaler=loss_scaler, epoch=epoch, model_ema=model_ema)
238 |
239 | # evaluate
240 | if data_loader_val is not None:
241 | test_stats = evaluate(data_loader_val, model, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq)
242 | logger.info(f"test accuracy : {test_stats['acc1']:.1f}%")
243 | if max_accuracy < test_stats["acc1"]:
244 | max_accuracy = test_stats["acc1"]
245 | if wandb_logger is not None:
246 | wandb.run.summary["Best Accuracy"] = max_accuracy
247 | wandb.run.summary["Best Epoch"] = epoch
248 | if args.output_dir and args.save_ckpt:
249 | utils.save_model(
250 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
251 | loss_scaler=loss_scaler, epoch="best", model_ema=model_ema)
252 | acc1 = test_stats["acc1"]
253 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
254 | logger.info(f'Max accuracy: {max_accuracy:.2f}%')
255 |
256 | if log_writer is not None:
257 | log_writer.update(test_acc1=test_stats['acc1'], head="perf", step=epoch)
258 | log_writer.update(test_acc5=test_stats['acc5'], head="perf", step=epoch)
259 | log_writer.update(test_loss=test_stats['loss'], head="perf", step=epoch)
260 |
261 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
262 | **{f'test_{k}': v for k, v in test_stats.items()},
263 | 'epoch': epoch,
264 | 'n_parameters': n_parameters}
265 |
266 | # repeat testing routines for EMA, if ema eval is turned on
267 | if args.model_ema and args.model_ema_eval:
268 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, use_amp=args.use_amp, logger=logger, update_freq=args.update_freq)
269 | # logger.info(f"Accuracy of the model EMA on {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%")
270 | if max_accuracy_ema < test_stats_ema["acc1"]:
271 | max_accuracy_ema = test_stats_ema["acc1"]
272 | if args.output_dir and args.save_ckpt:
273 | utils.save_model(
274 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
275 | loss_scaler=loss_scaler, epoch="best-ema", model_ema=model_ema)
276 | logger.info(f'Max EMA accuracy: {max_accuracy_ema:.2f}%')
277 | if log_writer is not None:
278 | log_writer.update(test_acc1_ema=test_stats_ema['acc1'], head="perf", step=epoch)
279 | log_stats.update({**{f'test_{k}_ema': v for k, v in test_stats_ema.items()}})
280 | else:
281 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
282 | 'epoch': epoch,
283 | 'n_parameters': n_parameters}
284 | # update logger info
285 | if args.output_dir and utils.is_main_process():
286 | if log_writer is not None:
287 | log_writer.flush()
288 | with open(os.path.join(args.output_dir, "record.txt"), mode="a", encoding="utf-8") as f:
289 | f.write(json.dumps(log_stats) + "\n")
290 |
291 | if wandb_logger:
292 | wandb_logger.log_epoch_metrics(log_stats)
293 |
294 | if wandb_logger and args.wandb_ckpt and args.save_ckpt and args.output_dir:
295 | wandb_logger.log_checkpoints()
296 |
297 | # ------------------------- finished ------------------------------
298 | total_time = time.time() - start_time
299 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
300 | logger('Training time {}'.format(total_time_str))
301 |
302 |
303 | if __name__ == '__main__':
304 |
305 | parser = argparse.ArgumentParser('ConvNeXt training and evaluation script', parents=[get_args_parser()])
306 | args = parser.parse_args()
307 |
308 | args.output_dir = os.path.join(args.output_dir, '%s_%s' % (args.dataset, args.tag))
309 |
310 | if args.output_dir:
311 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
312 |
313 | with open(os.path.join(args.output_dir, 'config.txt'), 'w') as f:
314 | argsDict = args.__dict__
315 | f.writelines('------------------ start ------------------' + '\n')
316 | for eachArg, value in argsDict.items():
317 | f.writelines(eachArg + ' : ' + str(value) + '\n')
318 | f.writelines('------------------- end -------------------')
319 |
320 | main(args)
321 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/models/__init__.py
--------------------------------------------------------------------------------
/models/cores.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | class BasicConv2d(nn.Module):
11 | def __init__(self, in_channels, out_channels, **kwargs):
12 | super(BasicConv2d, self).__init__()
13 | self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
14 | self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | x = self.bn(x)
19 | x = F.relu(x, inplace=True)
20 | return x
21 |
22 |
23 | # Attention network
24 | class Attentions(nn.Module):
25 | def __init__(self, channel_size=256):
26 | super(Attentions, self).__init__()
27 | self.A1_c = ChannelGate(channel_size)
28 | self.A2_c = ChannelGate(channel_size)
29 | self.A3_c = ChannelGate(channel_size)
30 |
31 | def forward(self, inputs):
32 | F1, F2, F3 = inputs
33 | # Global Average Pooling to a vector
34 | A1_channel = self.A1_c(F1)
35 | A2_channel = self.A2_c(F2)
36 | A3_channel = self.A3_c(F3)
37 |
38 | # bottom to top
39 | A2_channel = (A2_channel + A1_channel) / 2
40 | A3_channel = (A3_channel + A2_channel) / 2
41 |
42 | # channel pooling
43 | # A1 = F1 * A1_channel
44 | # A2 = F2 * A2_channel
45 | A3 = F3 * A3_channel
46 |
47 | return A3
48 |
49 |
50 | # channel attention
51 | class ChannelGate(nn.Module):
52 | def __init__(self, out_channels):
53 | super(ChannelGate, self).__init__()
54 | self.conv1 = nn.Conv2d(out_channels, out_channels//4, kernel_size=1, stride=1, padding=0)
55 | self.conv2 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1, padding=0)
56 | self.relu = nn.ReLU(inplace=True)
57 | self.sigmoid = nn.Sigmoid()
58 |
59 | def forward(self, x):
60 | x = nn.AdaptiveAvgPool2d(output_size=1)(x)
61 | x = self.conv1(x)
62 | x = self.relu(x)
63 | x = self.conv2(x)
64 | x = self.sigmoid(x)
65 | # x = F.relu(self.conv1(x), inplace=True)
66 | # x = torch.sigmoid(self.conv2(x))
67 | return x
68 |
69 |
70 | class Flatten(nn.Module):
71 | def __init__(self):
72 | super(Flatten, self).__init__()
73 |
74 | def forward(self, x):
75 | return x.view(x.size(0), -1)
76 |
77 |
78 | class BasicConv(nn.Module):
79 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
80 | bn=True, bias=False):
81 | super(BasicConv, self).__init__()
82 | self.out_channels = out_planes
83 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
84 | dilation=dilation, groups=groups, bias=bias)
85 | # self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
86 | self.relu = nn.ReLU(inplace=True) if relu else None
87 |
88 | def forward(self, x):
89 | x = self.conv(x)
90 | # if self.bn is not None:
91 | # x = self.bn(x)
92 | if self.relu is not None:
93 | x = self.relu(x)
94 | return x
95 |
96 |
97 | # construct the top of pyramid layer
98 | class SimpleFPA(nn.Module):
99 | def __init__(self, in_planes, out_planes):
100 | super(SimpleFPA, self).__init__()
101 |
102 | self.channels_cond = in_planes
103 | # Master branch
104 | self.conv_master = BasicConv(in_planes, out_planes, kernel_size=1, stride=1)
105 |
106 | # Global pooling branch
107 | self.conv_gpb = BasicConv(in_planes, out_planes, kernel_size=1, stride=1)
108 |
109 | def forward(self, x):
110 | # Master branch
111 | x_master = self.conv_master(x)
112 |
113 | # Global pooling branch
114 | x_gpb = nn.AvgPool2d(x.shape[2:])(x).view(x.shape[0], self.channels_cond, 1, 1)
115 | x_gpb = self.conv_gpb(x_gpb)
116 |
117 | out = x_master + x_gpb
118 |
119 | return out
120 |
121 |
122 | # Feature Pyramid Network
123 | class PyramidFeatures(nn.Module):
124 | def __init__(self, B1_size, B2_size, B3_size, feature_size=256):
125 | super(PyramidFeatures, self).__init__()
126 |
127 | # feature_size represents the number of the channels
128 | self.P3_1 = SimpleFPA(B3_size, feature_size)
129 | self.P3_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
130 |
131 | self.P2_1 = nn.Conv2d(B2_size, feature_size, kernel_size=1, stride=1, padding=0)
132 | self.P2_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
133 |
134 | self.P1_1 = nn.Conv2d(B1_size, feature_size, kernel_size=1, stride=1, padding=0)
135 | self.P1_2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, stride=1, padding=1)
136 |
137 | def forward(self, inputs):
138 | B1, B2, B3 = inputs
139 |
140 | P3_x = self.P3_1(B3)
141 | P3_upsampled_x = F.interpolate(P3_x, scale_factor=2)
142 | P3_x = self.P3_2(P3_x)
143 |
144 | P2_x = self.P2_1(B2)
145 | P2_x = P3_upsampled_x + P2_x
146 | P2_upsampled_x = F.interpolate(P2_x, scale_factor=2)
147 | P2_x = self.P2_2(P2_x)
148 |
149 | P1_x = self.P1_1(B1)
150 | P1_x = P1_x + P2_upsampled_x
151 | P1_x = self.P1_2(P1_x)
152 |
153 | return [P1_x, P2_x, P3_x]
154 |
155 |
156 | class MHAM(nn.Module):
157 | def __init__(self, fpn_sizes, M, num_features, use_mha=True):
158 | super(MHAM, self).__init__()
159 | self.use_mha = use_mha
160 | if self.use_mha:
161 | self.fpn = PyramidFeatures(fpn_sizes[1], fpn_sizes[2], fpn_sizes[3], feature_size=M)
162 | # channel attention
163 | self.ca = Attentions(channel_size=M)
164 | self.conv = BasicConv(in_planes=M, out_planes=M, kernel_size=1)
165 | else:
166 | self.fpa = SimpleFPA(num_features, M)
167 |
168 | def forward(self, input1, input2, input3):
169 | if not self.use_mha:
170 | return self.fpa(input3)
171 | x = self.fpn([input1, input2, input3])
172 | x = self.ca(x)
173 | x = self.conv(x)
174 | return x
175 |
176 |
177 | # bilinear polymerization pooling
178 | class BPP(nn.Module):
179 | def __init__(self, epsilon):
180 | super(BPP, self).__init__()
181 | self.epsilon = epsilon
182 |
183 | def forward(self, features1, features2):
184 | # unify the size of width and height
185 | B, C, H, W = features1.size()
186 | _, M, AH, AW = features2.size()
187 |
188 | # match size
189 | if AH != H or AW != W:
190 | features2 = F.upsample_bilinear(features2, size=(H, W))
191 |
192 | # essential_matrix: (B, M, C) -> (B, M * C)
193 | essential_matrix = (torch.einsum('imjk,injk->imn', (features2, features1)) / float(H * W)).view(B, -1)
194 | # nornalize
195 | essential_matrix = torch.sign(essential_matrix) * torch.sqrt(torch.abs(essential_matrix) + self.epsilon)
196 | essential_matrix = F.normalize(essential_matrix, dim=-1)
197 |
198 | return essential_matrix
199 |
--------------------------------------------------------------------------------
/models/miner.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from timm.models.layers import trunc_normal_, DropPath
10 | from timm.models.registry import register_model
11 | from models.cores import MHAM, BPP
12 |
13 |
14 | class ConvNeXt(nn.Module):
15 | def __init__(self, in_chans=3, num_classes=1000,
16 | depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], M=16, drop_path_rate=0.2,
17 | layer_scale_init_value=1e-6, head_init_scale=0.001, use_mha=True, use_ref=True
18 | ):
19 | super().__init__()
20 |
21 | self.epsilon = 1e-12
22 | self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
23 | stem = nn.Sequential(
24 | nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
25 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
26 | )
27 | self.downsample_layers.append(stem)
28 | for i in range(3):
29 | downsample_layer = nn.Sequential(
30 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
31 | nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
32 | )
33 | self.downsample_layers.append(downsample_layer)
34 |
35 | self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
36 | dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
37 | cur = 0
38 | for i in range(4):
39 | stage = nn.Sequential(
40 | *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
41 | layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
42 | )
43 | self.stages.append(stage)
44 | cur += depths[i]
45 | self.num_classes = num_classes
46 | self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
47 |
48 | self.use_mha = use_mha # is network using MHAM module?
49 | self.use_ref = use_ref # is network using Refiner?
50 |
51 | self.hidden_num = 256
52 | self.num_features = dims[-1] # output feature size
53 | self.M = M # the channel number of attention module in paper (kappa)
54 | fpn_sizes = [self.stages[0][-1].pwconv2.out_features,
55 | self.stages[1][-1].pwconv2.out_features,
56 | self.stages[2][-1].pwconv2.out_features,
57 | self.stages[3][-1].pwconv2.out_features]
58 | # feature pyramid network
59 | self.mha = MHAM(fpn_sizes, M, self.num_features, use_mha)
60 | # bilinear polymerization pooling
61 | self.bpp = BPP(self.epsilon)
62 | # classification network
63 | self.head = nn.Linear(self.num_features * self.M, self.num_classes)
64 |
65 | def _init_weights(self, m):
66 | if isinstance(m, (nn.Conv2d, nn.Linear)):
67 | trunc_normal_(m.weight, std=.02)
68 | nn.init.constant_(m.bias, 0)
69 | def refiner_input(self, batch_size, attention_maps):
70 | attentions = []
71 | if self.training:
72 | # Randomly choose one of attention maps Ak
73 | for i in range(batch_size):
74 | attention_weight = torch.sqrt(attention_maps[i].sum(dim=(1, 2)).detach() + self.epsilon)
75 | attention_weight = F.normalize(attention_weight, p=1, dim=0)
76 | k_index = np.random.choice(self.M, 2, p=attention_weight.cpu().numpy())
77 | attentions.append(attention_maps[i, k_index, ...])
78 | attentions = torch.stack(attentions) # (B, 2, H, W) - one for detail image, the other for mask image
79 | else:
80 | # Object Localization Am = mean(Ak)
81 | attentions = torch.mean(attention_maps, dim=1, keepdim=True) # (B, 1, H, W)
82 |
83 | return attentions
84 |
85 |
86 | def forward(self, x):
87 | torch.autograd.set_detect_anomaly(True) # auto gradient screen
88 |
89 | # backbone (convnext)
90 | x1 = self.downsample_layers[0](x)
91 | x1 = self.stages[0](x1)
92 | # layer2 out(batchsize, 256, h/8, w/8)
93 | x2 = self.downsample_layers[1](x1)
94 | x2 = self.stages[1](x2)
95 | # layer3 out(batchsize, 512, h/16, w/16)
96 | x3 = self.downsample_layers[2](x2)
97 | x3 = self.stages[2](x3)
98 | # layer4 out(batchsize, 1024, h/32, w/32)
99 | x4 = self.downsample_layers[3](x3)
100 | intrisics = self.stages[3](x4)
101 |
102 | # multiple hierarchies attention module
103 | attention_maps = self.mha(x2, x3, intrisics)
104 | essential = self.bpp(intrisics, attention_maps)
105 | pred = self.head(essential)
106 | if self.use_ref:
107 | return pred, self.refiner_input(x.size(0), attention_maps)
108 | else:
109 | return pred, []
110 |
111 |
112 | class LayerNorm(nn.Module):
113 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
114 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
115 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
116 | with shape (batch_size, channels, height, width).
117 | """
118 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
119 | super().__init__()
120 | self.weight = nn.Parameter(torch.ones(normalized_shape))
121 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
122 | self.eps = eps
123 | self.data_format = data_format
124 | if self.data_format not in ["channels_last", "channels_first"]:
125 | raise NotImplementedError
126 | self.normalized_shape = (normalized_shape, )
127 |
128 | def forward(self, x):
129 | if self.data_format == "channels_last":
130 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
131 | elif self.data_format == "channels_first":
132 | u = x.mean(1, keepdim=True)
133 | s = (x - u).pow(2).mean(1, keepdim=True)
134 | x = (x - u) / torch.sqrt(s + self.eps)
135 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
136 | return x
137 |
138 |
139 | model_urls = {
140 | "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
141 | "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
142 | "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
143 | "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
144 | "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
145 | "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
146 | "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
147 | }
148 |
149 |
150 |
151 | @register_model
152 | def miner(pretrained=False, M=16, in_22k=False, **kwargs):
153 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], M=M, **kwargs)
154 | if pretrained:
155 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
156 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
157 | model.load_state_dict(checkpoint["model"])
158 | return model
159 |
160 |
161 | @register_model
162 | def convnext_tiny(pretrained=False, **kwargs):
163 | model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs)
164 | if pretrained:
165 | url = model_urls['convnext_tiny_1k']
166 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True)
167 | model.load_state_dict(checkpoint["model"])
168 | return model
169 |
170 |
171 | @register_model
172 | def convnext_small(pretrained=False, **kwargs):
173 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
174 | if pretrained:
175 | url = model_urls['convnext_small_1k']
176 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
177 | model.load_state_dict(checkpoint["model"])
178 | return model
179 |
180 |
181 | @register_model
182 | def convnext_base(pretrained=False, M=16, in_22k=False, **kwargs):
183 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], M=M, **kwargs)
184 | if pretrained:
185 | url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k']
186 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
187 | model.load_state_dict(checkpoint["model"])
188 | return model
189 |
190 |
191 | @register_model
192 | def convnext_large(pretrained=False, in_22k=False, **kwargs):
193 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
194 | if pretrained:
195 | url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k']
196 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
197 | model.load_state_dict(checkpoint["model"])
198 | return model
199 |
200 |
201 | @register_model
202 | def convnext_xlarge(pretrained=False, in_22k=False, **kwargs):
203 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
204 | if pretrained:
205 | assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True"
206 | url = model_urls['convnext_xlarge_22k']
207 | checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
208 | model.load_state_dict(checkpoint["model"])
209 | return model
210 |
211 |
212 | class Block(nn.Module):
213 | r""" ConvNeXt Block. There are two equivalent implementations:
214 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
215 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
216 | We use (2) as we find it slightly faster in PyTorch
217 |
218 | Args:
219 | dim (int): Number of input channels.
220 | drop_path (float): Stochastic depth rate. Default: 0.0
221 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
222 | """
223 |
224 | def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
225 | super().__init__()
226 | self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
227 | self.norm = LayerNorm(dim, eps=1e-6)
228 | self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
229 | self.act = nn.GELU()
230 | self.pwconv2 = nn.Linear(4 * dim, dim)
231 | self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
232 | requires_grad=True) if layer_scale_init_value > 0 else None
233 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
234 |
235 | def forward(self, x):
236 | input = x
237 | x = self.dwconv(x)
238 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
239 | x = self.norm(x)
240 | x = self.pwconv1(x)
241 | x = self.act(x)
242 | x = self.pwconv2(x)
243 | if self.gamma is not None:
244 | x = self.gamma * x
245 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
246 |
247 | x = input + self.drop_path(x)
248 | return x
249 |
250 |
251 |
--------------------------------------------------------------------------------
/processor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/processor/__init__.py
--------------------------------------------------------------------------------
/processor/engine.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import math
6 | import time
7 | import datetime
8 | from typing import Iterable, Optional
9 | import numpy as np
10 | import torch
11 | from timm.data import Mixup
12 | from timm.utils import accuracy, ModelEma, AverageMeter
13 | from util.utils import reduce_tensor, MetricLogger, SmoothedValue, get_detail_images, get_mask_images
14 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
15 |
16 |
17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
18 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
19 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
20 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
21 | wandb_logger=None, start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
22 | num_training_steps_per_epoch=None, update_freq=None, use_amp=False, logger=None,
23 | ):
24 | model.train(True)
25 | metric_logger = MetricLogger(delimiter=" ")
26 | metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
27 | metric_logger.add_meter('min_lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
28 | header = 'Epoch: [{}]'.format(epoch)
29 | print_freq = 10
30 |
31 | optimizer.zero_grad()
32 |
33 | # counter
34 | num_steps = len(data_loader)
35 | batch_time = AverageMeter()
36 | loss_meter = AverageMeter()
37 |
38 | start = time.time()
39 | end = time.time()
40 |
41 | for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
42 | step = data_iter_step // update_freq
43 | if step >= num_training_steps_per_epoch:
44 | continue
45 | it = start_steps + step # global training iteration
46 | # Update LR & WD for the first acc
47 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
48 | for i, param_group in enumerate(optimizer.param_groups):
49 | if lr_schedule_values is not None:
50 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
51 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
52 | param_group["weight_decay"] = wd_schedule_values[it]
53 |
54 | samples = samples.to(device, non_blocking=True)
55 | targets = targets.to(device, non_blocking=True)
56 |
57 | if mixup_fn is not None:
58 | samples, targets = mixup_fn(samples, targets)
59 |
60 | if use_amp:
61 | with torch.cuda.amp.autocast():
62 | input_pred, attention_map = model(samples)
63 | else: # full precision
64 | input_pred, attention_map = model(samples)
65 |
66 | if attention_map != []:
67 | with torch.no_grad():
68 | detail_images = get_detail_images(samples, attention_map[:, :1, :, :], theta_detail=(0.4, 0.6), padding=0.1)
69 |
70 | # detail-images forward
71 | detail_pred, _ = model(detail_images)
72 |
73 | with torch.no_grad():
74 | mask_images = get_mask_images(samples, attention_map[:, 1:, :, :], theta_mask=(0.2, 0.5))
75 |
76 | mask_pred, _ = model(mask_images)
77 |
78 | output = (input_pred + detail_pred + mask_pred)/3.
79 | loss = criterion(input_pred, targets)/3. + \
80 | criterion(detail_pred, targets)/3. + \
81 | criterion(mask_pred, targets)/3.
82 | else:
83 | loss = criterion(input_pred, targets)
84 | output = input_pred
85 |
86 | loss_value = loss.item()
87 |
88 | if not math.isfinite(loss_value): # this could trigger if using AMP
89 | logger.info("Loss is {}, stopping training".format(loss_value))
90 | assert math.isfinite(loss_value)
91 |
92 | if use_amp:
93 | # this attribute is added by timm on one optimizer (adahessian)
94 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
95 | loss /= update_freq
96 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
97 | parameters=model.parameters(), create_graph=is_second_order,
98 | update_grad=(data_iter_step + 1) % update_freq == 0)
99 | if (data_iter_step + 1) % update_freq == 0:
100 | optimizer.zero_grad()
101 | if model_ema is not None:
102 | model_ema.update(model)
103 | else: # full precision
104 | loss /= update_freq
105 | loss.backward()
106 | if (data_iter_step + 1) % update_freq == 0:
107 | optimizer.step()
108 | optimizer.zero_grad()
109 | if model_ema is not None:
110 | model_ema.update(model)
111 |
112 | torch.cuda.synchronize()
113 |
114 | if mixup_fn is None:
115 | class_acc = (output.max(-1)[-1] == targets).float().mean()
116 | else:
117 | class_acc = None
118 | metric_logger.update(loss=loss_value)
119 | metric_logger.update(class_acc=class_acc)
120 | min_lr = 10.
121 | max_lr = 0.
122 | for group in optimizer.param_groups:
123 | min_lr = min(min_lr, group["lr"])
124 | max_lr = max(max_lr, group["lr"])
125 |
126 | metric_logger.update(lr=max_lr)
127 | metric_logger.update(min_lr=min_lr)
128 | weight_decay_value = None
129 | for group in optimizer.param_groups:
130 | if group["weight_decay"] > 0:
131 | weight_decay_value = group["weight_decay"]
132 | metric_logger.update(weight_decay=weight_decay_value)
133 | if use_amp:
134 | metric_logger.update(grad_norm=grad_norm)
135 |
136 | if log_writer is not None:
137 | log_writer.update(loss=loss_value, head="loss")
138 | log_writer.update(class_acc=class_acc, head="loss")
139 | log_writer.update(lr=max_lr, head="opt")
140 | log_writer.update(min_lr=min_lr, head="opt")
141 | log_writer.update(weight_decay=weight_decay_value, head="opt")
142 | if use_amp:
143 | log_writer.update(grad_norm=grad_norm, head="opt")
144 | log_writer.set_step()
145 |
146 | if wandb_logger:
147 | wandb_logger._wandb.log({
148 | 'Rank-0 Batch Wise/train_loss': loss_value,
149 | 'Rank-0 Batch Wise/train_max_lr': max_lr,
150 | 'Rank-0 Batch Wise/train_min_lr': min_lr
151 | }, commit=False)
152 | if class_acc:
153 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_class_acc': class_acc}, commit=False)
154 | if use_amp:
155 | wandb_logger._wandb.log({'Rank-0 Batch Wise/train_grad_norm': grad_norm}, commit=False)
156 | wandb_logger._wandb.log({'Rank-0 Batch Wise/global_train_step': it})
157 |
158 | loss_meter.update(loss.item(), targets.size(0))
159 | batch_time.update(time.time() - end)
160 | end = time.time()
161 |
162 | if data_iter_step % update_freq == 0:
163 | lr = max_lr
164 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
165 | etas = batch_time.avg * (num_steps - data_iter_step)
166 | logger.info(
167 | f'Train: [{epoch}][{data_iter_step}/{num_steps}]\t'
168 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
169 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
170 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
171 | f'mem {memory_used:.0f}MB')
172 |
173 | # gather the stats from all processes
174 | metric_logger.synchronize_between_processes()
175 | train_stat = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
176 |
177 | epoch_time = time.time() - start
178 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
179 |
180 | return train_stat
181 |
182 | @torch.no_grad()
183 | def evaluate(data_loader, model, device, use_amp=False, logger=None, update_freq=1):
184 | criterion = torch.nn.CrossEntropyLoss()
185 |
186 | metric_logger = MetricLogger(delimiter=" ")
187 | header = 'Test:'
188 |
189 | # switch to evaluation mode
190 | model.eval()
191 |
192 | batch_time = AverageMeter()
193 | loss_meter = AverageMeter()
194 | acc1_meter = AverageMeter()
195 | acc5_meter = AverageMeter()
196 |
197 | # statistics
198 | record_truth = np.array([])
199 | record_pred = np.array([])
200 | record_feature = torch.tensor([])
201 |
202 | end = time.time()
203 | idx = 0
204 | with torch.no_grad():
205 | for batch in metric_logger.log_every(data_loader, 10, header):
206 | images = batch[0]
207 | target = batch[-1]
208 | label_true = target.numpy().squeeze()
209 | record_truth = np.concatenate((record_truth, label_true))
210 |
211 | images = images.to(device, non_blocking=True)
212 | target = target.to(device, non_blocking=True)
213 |
214 | # compute output
215 | if use_amp:
216 | with torch.cuda.amp.autocast():
217 | input_pred, attention_map = model(images)
218 | else:
219 | input_pred, attention_map = model(images)
220 |
221 | if attention_map != []:
222 | detail_images = get_detail_images(images, attention_map, theta_detail=0.1, padding=0.05)
223 | detail_pred, _ = model(detail_images)
224 | output = (input_pred + detail_pred)/2.
225 | loss = criterion(output, target)
226 | else:
227 | loss = criterion(input_pred, target)
228 | output = input_pred
229 |
230 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
231 |
232 | # record
233 | _, pred = torch.max(output, dim=1)
234 | pred = pred.cpu().numpy().squeeze()
235 | record_pred = np.concatenate((record_pred, pred))
236 |
237 | acc1 = reduce_tensor(acc1)
238 | acc5 = reduce_tensor(acc5)
239 | loss = reduce_tensor(loss)
240 |
241 | loss_meter.update(loss.item(), target.size(0))
242 | acc1_meter.update(acc1.item(), target.size(0))
243 | acc5_meter.update(acc5.item(), target.size(0))
244 |
245 | # measure elapsed time
246 | batch_time.update(time.time() - end)
247 | end = time.time()
248 |
249 | if idx % update_freq == 0:
250 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
251 | if logger:
252 | logger.info(
253 | f'Test: [{idx}/{len(data_loader)}]\t'
254 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
255 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
256 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
257 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
258 | f'Mem {memory_used:.0f}MB')
259 | idx += 1
260 |
261 | batch_size = images.shape[0]
262 | metric_logger.update(loss=loss.item())
263 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
264 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
265 |
266 | # gather the stats from all processes
267 | metric_logger.synchronize_between_processes()
268 | # metric
269 | acc = accuracy_score(record_truth, record_pred)
270 | precision = precision_score(record_truth, record_pred, average='weighted')
271 | recall = recall_score(record_truth, record_pred, average='weighted')
272 | f1 = f1_score(record_truth, record_pred, average='weighted')
273 | logger.info(f'[Info] acc: {acc}, precision: {precision}, recall: {recall}, f1: {f1}')
274 |
275 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
276 |
--------------------------------------------------------------------------------
/processor/optim_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import torch
6 | from torch import optim as optim
7 |
8 | from timm.optim.adafactor import Adafactor
9 | from timm.optim.adahessian import Adahessian
10 | from timm.optim.adamp import AdamP
11 | from timm.optim.lookahead import Lookahead
12 | from timm.optim.nadam import Nadam
13 | from timm.optim.novograd import NovoGrad
14 | from timm.optim.nvnovograd import NvNovoGrad
15 | from timm.optim.radam import RAdam
16 | from timm.optim.rmsprop_tf import RMSpropTF
17 | from timm.optim.sgdp import SGDP
18 |
19 | import json
20 |
21 | try:
22 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
23 | has_apex = True
24 | except ImportError:
25 | has_apex = False
26 |
27 |
28 | def get_num_layer_for_convnext(var_name):
29 | """
30 | Divide [3, 3, 27, 3] layers into 12 groups; each group is three
31 | consecutive blocks, including possible neighboring downsample layers;
32 | adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
33 | """
34 | num_max_layer = 12
35 | if var_name.startswith("downsample_layers"):
36 | stage_id = int(var_name.split('.')[1])
37 | if stage_id == 0:
38 | layer_id = 0
39 | elif stage_id == 1 or stage_id == 2:
40 | layer_id = stage_id + 1
41 | elif stage_id == 3:
42 | layer_id = 12
43 | return layer_id
44 |
45 | elif var_name.startswith("stages"):
46 | stage_id = int(var_name.split('.')[1])
47 | block_id = int(var_name.split('.')[2])
48 | if stage_id == 0 or stage_id == 1:
49 | layer_id = stage_id + 1
50 | elif stage_id == 2:
51 | layer_id = 3 + block_id // 3
52 | elif stage_id == 3:
53 | layer_id = 12
54 | return layer_id
55 | else:
56 | return num_max_layer + 1
57 |
58 | class LayerDecayValueAssigner(object):
59 | def __init__(self, values):
60 | self.values = values
61 |
62 | def get_scale(self, layer_id):
63 | return self.values[layer_id]
64 |
65 | def get_layer_id(self, var_name):
66 | return get_num_layer_for_convnext(var_name)
67 |
68 |
69 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
70 | parameter_group_names = {}
71 | parameter_group_vars = {}
72 |
73 | for name, param in model.named_parameters():
74 | if not param.requires_grad:
75 | continue # frozen weights
76 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
77 | group_name = "no_decay"
78 | this_weight_decay = 0.
79 | else:
80 | group_name = "decay"
81 | this_weight_decay = weight_decay
82 | if get_num_layer is not None:
83 | layer_id = get_num_layer(name)
84 | group_name = "layer_%d_%s" % (layer_id, group_name)
85 | else:
86 | layer_id = None
87 |
88 | if group_name not in parameter_group_names:
89 | if get_layer_scale is not None:
90 | scale = get_layer_scale(layer_id)
91 | else:
92 | scale = 1.
93 |
94 | parameter_group_names[group_name] = {
95 | "weight_decay": this_weight_decay,
96 | "params": [],
97 | "lr_scale": scale
98 | }
99 | parameter_group_vars[group_name] = {
100 | "weight_decay": this_weight_decay,
101 | "params": [],
102 | "lr_scale": scale
103 | }
104 |
105 | parameter_group_vars[group_name]["params"].append(param)
106 | parameter_group_names[group_name]["params"].append(name)
107 | # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
108 | return list(parameter_group_vars.values())
109 |
110 |
111 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
112 | opt_lower = args.opt.lower()
113 | weight_decay = args.weight_decay
114 | # if weight_decay and filter_bias_and_bn:
115 | if filter_bias_and_bn:
116 | skip = {}
117 | if skip_list is not None:
118 | skip = skip_list
119 | elif hasattr(model, 'no_weight_decay'):
120 | skip = model.no_weight_decay()
121 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
122 | weight_decay = 0.
123 | else:
124 | parameters = model.parameters()
125 |
126 | if 'fused' in opt_lower:
127 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
128 |
129 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
130 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
131 | opt_args['eps'] = args.opt_eps
132 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
133 | opt_args['betas'] = args.opt_betas
134 |
135 | opt_split = opt_lower.split('_')
136 | opt_lower = opt_split[-1]
137 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
138 | opt_args.pop('eps', None)
139 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
140 | elif opt_lower == 'momentum':
141 | opt_args.pop('eps', None)
142 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
143 | elif opt_lower == 'adam':
144 | optimizer = optim.Adam(parameters, **opt_args)
145 | elif opt_lower == 'adamw':
146 | optimizer = optim.AdamW(parameters, **opt_args)
147 | elif opt_lower == 'nadam':
148 | optimizer = Nadam(parameters, **opt_args)
149 | elif opt_lower == 'radam':
150 | optimizer = RAdam(parameters, **opt_args)
151 | elif opt_lower == 'adamp':
152 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
153 | elif opt_lower == 'sgdp':
154 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
155 | elif opt_lower == 'adadelta':
156 | optimizer = optim.Adadelta(parameters, **opt_args)
157 | elif opt_lower == 'adafactor':
158 | if not args.lr:
159 | opt_args['lr'] = None
160 | optimizer = Adafactor(parameters, **opt_args)
161 | elif opt_lower == 'adahessian':
162 | optimizer = Adahessian(parameters, **opt_args)
163 | elif opt_lower == 'rmsprop':
164 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
165 | elif opt_lower == 'rmsproptf':
166 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
167 | elif opt_lower == 'novograd':
168 | optimizer = NovoGrad(parameters, **opt_args)
169 | elif opt_lower == 'nvnovograd':
170 | optimizer = NvNovoGrad(parameters, **opt_args)
171 | elif opt_lower == 'fusedsgd':
172 | opt_args.pop('eps', None)
173 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
174 | elif opt_lower == 'fusedmomentum':
175 | opt_args.pop('eps', None)
176 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
177 | elif opt_lower == 'fusedadam':
178 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
179 | elif opt_lower == 'fusedadamw':
180 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
181 | elif opt_lower == 'fusedlamb':
182 | optimizer = FusedLAMB(parameters, **opt_args)
183 | elif opt_lower == 'fusednovograd':
184 | opt_args.setdefault('betas', (0.95, 0.98))
185 | optimizer = FusedNovoGrad(parameters, **opt_args)
186 | else:
187 | assert False and "Invalid optimizer"
188 |
189 | if len(opt_split) > 1:
190 | if opt_split[0] == 'lookahead':
191 | optimizer = Lookahead(optimizer)
192 |
193 | return optimizer
194 |
--------------------------------------------------------------------------------
/processor/run_with_submitit.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import argparse
6 | import os
7 | import uuid
8 | from pathlib import Path
9 |
10 | import main as classification
11 | import submitit
12 |
13 | def parse_args():
14 | classification_parser = classification.get_args_parser()
15 | parser = argparse.ArgumentParser("Submitit for ConvNeXt", parents=[classification_parser])
16 | parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
17 | parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
18 | parser.add_argument("--timeout", default=72, type=int, help="Duration of the job, in hours")
19 | parser.add_argument("--job_name", default="convnext", type=str, help="Job name")
20 | parser.add_argument("--job_dir", default="", type=str, help="Job directory; leave empty for default")
21 | parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit")
22 | parser.add_argument("--use_volta32", action='store_true', default=True, help="Big models? Use this")
23 | parser.add_argument('--comment', default="", type=str,
24 | help='Comment to pass to scheduler, e.g. priority message')
25 | return parser.parse_args()
26 |
27 | def get_shared_folder() -> Path:
28 | user = os.getenv("USER")
29 | if Path("/checkpoint/").is_dir():
30 | p = Path(f"/checkpoint/{user}/convnext")
31 | p.mkdir(exist_ok=True)
32 | return p
33 | raise RuntimeError("No shared folder available")
34 |
35 | def get_init_file():
36 | # Init file must not exist, but it's parent dir must exist.
37 | os.makedirs(str(get_shared_folder()), exist_ok=True)
38 | init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init"
39 | if init_file.exists():
40 | os.remove(str(init_file))
41 | return init_file
42 |
43 | class Trainer(object):
44 | def __init__(self, args):
45 | self.args = args
46 |
47 | def __call__(self):
48 | import main as classification
49 |
50 | self._setup_gpu_args()
51 | classification.main(self.args)
52 |
53 | def checkpoint(self):
54 | import os
55 | import submitit
56 |
57 | self.args.dist_url = get_init_file().as_uri()
58 | self.args.auto_resume = True
59 | print("Requeuing ", self.args)
60 | empty_trainer = type(self)(self.args)
61 | return submitit.helpers.DelayedSubmission(empty_trainer)
62 |
63 | def _setup_gpu_args(self):
64 | import submitit
65 | from pathlib import Path
66 |
67 | job_env = submitit.JobEnvironment()
68 | self.args.output_dir = Path(self.args.job_dir)
69 | self.args.gpu = job_env.local_rank
70 | self.args.rank = job_env.global_rank
71 | self.args.world_size = job_env.num_tasks
72 | print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
73 |
74 |
75 | def main():
76 | args = parse_args()
77 |
78 | if args.job_dir == "":
79 | args.job_dir = get_shared_folder() / "%j"
80 |
81 | executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
82 |
83 | num_gpus_per_node = args.ngpus
84 | nodes = args.nodes
85 | timeout_min = args.timeout * 60
86 |
87 | partition = args.partition
88 | kwargs = {}
89 | if args.use_volta32:
90 | kwargs['slurm_constraint'] = 'volta32gb'
91 | if args.comment:
92 | kwargs['slurm_comment'] = args.comment
93 |
94 | executor.update_parameters(
95 | mem_gb=40 * num_gpus_per_node,
96 | gpus_per_node=num_gpus_per_node,
97 | tasks_per_node=num_gpus_per_node, # one task per GPU
98 | cpus_per_task=10,
99 | nodes=nodes,
100 | timeout_min=timeout_min, # max is 60 * 72
101 | # Below are cluster dependent parameters
102 | slurm_partition=partition,
103 | slurm_signal_delay_s=120,
104 | **kwargs
105 | )
106 |
107 | executor.update_parameters(name=args.job_name)
108 |
109 | args.dist_url = get_init_file().as_uri()
110 | args.output_dir = args.job_dir
111 |
112 | trainer = Trainer(args)
113 | job = executor.submit(trainer)
114 |
115 | print("Submitted job_id:", job.job_id)
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JACYI/EFM-Net-Pytorch/7d0432307f99439d8e7085ee8b055f7ba11b9500/util/__init__.py
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright © Beijing University of Posts and Telecommunications,
2 | # School of Artificial Intelligence.
3 |
4 |
5 | import os
6 | import math
7 | import time
8 | import cv2
9 | from collections import defaultdict, deque
10 | import datetime
11 | import numpy as np
12 | from timm.utils import get_state_dict
13 | from pathlib import Path
14 | import torch
15 | import torch.distributed as dist
16 | from torch._six import inf
17 | import torch.nn as nn
18 | from tensorboardX import SummaryWriter
19 | from termcolor import colored
20 | import functools
21 | import logging
22 | import sys
23 | import random
24 | import torch.nn.functional as F
25 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
26 |
27 |
28 | class SoftCrossEntropyLoss(nn.Module):
29 | def __init__(self, gama=0., reduction='mean'):
30 | super(SoftCrossEntropyLoss, self).__init__()
31 | self.reduction = reduction
32 | self.gama = gama
33 | self.softmax = nn.Softmax(dim=1)
34 |
35 | def __str__(self):
36 | return "Soft-cross-entropy loss"
37 |
38 | def forward(self, outputs, targets):
39 | targets = targets.view(-1, 1)
40 | assert outputs.size(0) == targets.size(0)
41 |
42 | softed = self.softmax(outputs)
43 | loged = -torch.log(softed)
44 |
45 | y_truth = softed.gather(1, targets)
46 | y_truth_1 = torch.pow(1 - y_truth, self.gama)
47 |
48 | log_truth = loged.gather(1, targets)
49 | soft_loss = y_truth_1 * log_truth
50 | if self.reduction == 'mean':
51 | loss = torch.mean(soft_loss)
52 | elif self.reduction == 'sum':
53 | loss = torch.sum(soft_loss)
54 | else:
55 | raise ValueError("std parameter must be in ['mean', 'sum'], but got {}".format(self.std))
56 |
57 | return loss
58 |
59 |
60 | class SmoothedValue(object):
61 | """Track a series of values and provide access to smoothed values over a
62 | window or the global series average.
63 | """
64 |
65 | def __init__(self, window_size=20, fmt=None):
66 | if fmt is None:
67 | fmt = "{median:.4f} ({global_avg:.4f})"
68 | self.deque = deque(maxlen=window_size)
69 | self.total = 0.0
70 | self.count = 0
71 | self.fmt = fmt
72 |
73 | def update(self, value, n=1):
74 | self.deque.append(value)
75 | self.count += n
76 | self.total += value * n
77 |
78 | def synchronize_between_processes(self):
79 | """
80 | Warning: does not synchronize the deque!
81 | """
82 | if not is_dist_avail_and_initialized():
83 | return
84 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
85 | dist.barrier()
86 | dist.all_reduce(t)
87 | t = t.tolist()
88 | self.count = int(t[0])
89 | self.total = t[1]
90 |
91 | @property
92 | def median(self):
93 | d = torch.tensor(list(self.deque))
94 | return d.median().item()
95 |
96 | @property
97 | def avg(self):
98 | d = torch.tensor(list(self.deque), dtype=torch.float32)
99 | return d.mean().item()
100 |
101 | @property
102 | def global_avg(self):
103 | return self.total / self.count
104 |
105 | @property
106 | def max(self):
107 | return max(self.deque)
108 |
109 | @property
110 | def value(self):
111 | return self.deque[-1]
112 |
113 | def __str__(self):
114 | return self.fmt.format(
115 | median=self.median,
116 | avg=self.avg,
117 | global_avg=self.global_avg,
118 | max=self.max,
119 | value=self.value)
120 |
121 |
122 | class MetricLogger(object):
123 | def __init__(self, delimiter="\t"):
124 | self.meters = defaultdict(SmoothedValue)
125 | self.delimiter = delimiter
126 |
127 | def update(self, **kwargs):
128 | for k, v in kwargs.items():
129 | if v is None:
130 | continue
131 | if isinstance(v, torch.Tensor):
132 | v = v.item()
133 | assert isinstance(v, (float, int))
134 | self.meters[k].update(v)
135 |
136 | def __getattr__(self, attr):
137 | if attr in self.meters:
138 | return self.meters[attr]
139 | if attr in self.__dict__:
140 | return self.__dict__[attr]
141 | raise AttributeError("'{}' object has no attribute '{}'".format(
142 | type(self).__name__, attr))
143 |
144 | def __str__(self):
145 | loss_str = []
146 | for name, meter in self.meters.items():
147 | loss_str.append(
148 | "{}: {}".format(name, str(meter))
149 | )
150 | return self.delimiter.join(loss_str)
151 |
152 | def synchronize_between_processes(self):
153 | for meter in self.meters.values():
154 | meter.synchronize_between_processes()
155 |
156 | def add_meter(self, name, meter):
157 | self.meters[name] = meter
158 |
159 | def log_every(self, iterable, print_freq, header=None):
160 | i = 0
161 | if not header:
162 | header = ''
163 | start_time = time.time()
164 | end = time.time()
165 | iter_time = SmoothedValue(fmt='{avg:.4f}')
166 | data_time = SmoothedValue(fmt='{avg:.4f}')
167 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
168 | log_msg = [
169 | header,
170 | '[{0' + space_fmt + '}/{1}]',
171 | 'eta: {eta}',
172 | '{meters}',
173 | 'time: {time}',
174 | 'data: {data}'
175 | ]
176 | if torch.cuda.is_available():
177 | log_msg.append('max mem: {memory:.0f}')
178 | log_msg = self.delimiter.join(log_msg)
179 | MB = 1024.0 * 1024.0
180 | for obj in iterable:
181 | data_time.update(time.time() - end)
182 | yield obj
183 | iter_time.update(time.time() - end)
184 | if i % print_freq == 0 or i == len(iterable) - 1:
185 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
186 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
187 | if torch.cuda.is_available():
188 | print(log_msg.format(
189 | i, len(iterable), eta=eta_string,
190 | meters=str(self),
191 | time=str(iter_time), data=str(data_time),
192 | memory=torch.cuda.max_memory_allocated() / MB))
193 | else:
194 | print(log_msg.format(
195 | i, len(iterable), eta=eta_string,
196 | meters=str(self),
197 | time=str(iter_time), data=str(data_time)))
198 | i += 1
199 | end = time.time()
200 | total_time = time.time() - start_time
201 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
202 | print('{} Total time: {} ({:.4f} s / it)'.format(
203 | header, total_time_str, total_time / len(iterable)))
204 |
205 |
206 | class TensorboardLogger(object):
207 | def __init__(self, log_dir):
208 | self.writer = SummaryWriter(logdir=log_dir)
209 | self.step = 0
210 |
211 | def set_step(self, step=None):
212 | if step is not None:
213 | self.step = step
214 | else:
215 | self.step += 1
216 |
217 | def update(self, head='scalar', step=None, **kwargs):
218 | for k, v in kwargs.items():
219 | if v is None:
220 | continue
221 | if isinstance(v, torch.Tensor):
222 | v = v.item()
223 | assert isinstance(v, (float, int))
224 | self.writer.add_scalar(head + "/" + k, v, self.step if step is None else step)
225 |
226 | def flush(self):
227 | self.writer.flush()
228 |
229 |
230 | class WandbLogger(object):
231 | def __init__(self, args):
232 | self.args = args
233 |
234 | try:
235 | import wandb
236 | self._wandb = wandb
237 | except ImportError:
238 | raise ImportError(
239 | "To use the Weights and Biases Logger please install wandb."
240 | "Run `pip install wandb` to install it."
241 | )
242 |
243 | # Initialize a W&B run
244 | if self._wandb.run is None:
245 | self._wandb.init(
246 | entity="yonghao-yi",
247 | project=args.project,
248 | config=args,
249 | name=args.exp_name
250 | )
251 |
252 | def log_epoch_metrics(self, metrics, commit=True):
253 | """
254 | Log train/test metrics onto W&B.
255 | """
256 | # Log number of model parameters as W&B summary
257 | self._wandb.summary['n_parameters'] = metrics.get('n_parameters', None)
258 | metrics.pop('n_parameters', None)
259 |
260 | # Log current epoch
261 | self._wandb.log({'epoch': metrics.get('epoch')}, commit=False)
262 | metrics.pop('epoch')
263 |
264 | for k, v in metrics.items():
265 | if 'train' in k:
266 | self._wandb.log({f'Global Train/{k}': v}, commit=False)
267 | elif 'test' in k:
268 | self._wandb.log({f'Global Test/{k}': v}, commit=False)
269 |
270 | self._wandb.log({})
271 |
272 | def log_checkpoints(self):
273 | output_dir = self.args.output_dir
274 | model_artifact = self._wandb.Artifact(
275 | self._wandb.run.id + "_model", type="model"
276 | )
277 |
278 | model_artifact.add_dir(output_dir)
279 | self._wandb.log_artifact(model_artifact, aliases=["latest", "best"])
280 |
281 | def set_steps(self):
282 | # Set global training step
283 | self._wandb.define_metric('Rank-0 Batch Wise/*', step_metric='Rank-0 Batch Wise/global_train_step')
284 | # Set epoch-wise step
285 | self._wandb.define_metric('Global Train/*', step_metric='epoch')
286 | self._wandb.define_metric('Global Test/*', step_metric='epoch')
287 |
288 |
289 | def setup_for_distributed(is_master):
290 | """
291 | This function disables printing when not in master process
292 | """
293 | import builtins as __builtin__
294 | builtin_print = __builtin__.print
295 |
296 | def print(*args, **kwargs):
297 | force = kwargs.pop('force', False)
298 | if is_master or force:
299 | builtin_print(*args, **kwargs)
300 |
301 | __builtin__.print = print
302 |
303 |
304 | def is_dist_avail_and_initialized():
305 | if not dist.is_available():
306 | return False
307 | if not dist.is_initialized():
308 | return False
309 | return True
310 |
311 |
312 | def get_world_size():
313 | if not is_dist_avail_and_initialized():
314 | return 1
315 | return dist.get_world_size()
316 |
317 |
318 | def get_rank():
319 | if not is_dist_avail_and_initialized():
320 | return 0
321 | return dist.get_rank()
322 |
323 |
324 | def is_main_process():
325 | return get_rank() == 0
326 |
327 |
328 | def save_on_master(*args, **kwargs):
329 | if is_main_process():
330 | torch.save(*args, **kwargs)
331 |
332 |
333 | def init_distributed_mode(args):
334 |
335 | if args.dist_on_itp:
336 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
337 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
338 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
339 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
340 | os.environ['LOCAL_RANK'] = str(args.gpu)
341 | os.environ['RANK'] = str(args.rank)
342 | os.environ['WORLD_SIZE'] = str(args.world_size)
343 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
344 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
345 | args.rank = int(os.environ["RANK"])
346 | args.world_size = int(os.environ['WORLD_SIZE'])
347 | args.gpu = int(os.environ['LOCAL_RANK'])
348 | elif 'SLURM_PROCID' in os.environ:
349 | args.rank = int(os.environ['SLURM_PROCID'])
350 | args.gpu = args.rank % torch.cuda.device_count()
351 |
352 | os.environ['RANK'] = str(args.rank)
353 | os.environ['LOCAL_RANK'] = str(args.gpu)
354 | os.environ['WORLD_SIZE'] = str(args.world_size)
355 | else:
356 | print('Not using distributed mode')
357 | args.distributed = False
358 | return
359 |
360 | args.distributed = True
361 |
362 | torch.cuda.set_device(args.gpu)
363 | args.dist_backend = 'nccl'
364 | print('| distributed init (rank {}): {}, gpu {}'.format(
365 | args.rank, args.dist_url, args.gpu), flush=True)
366 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
367 | world_size=args.world_size, rank=args.rank)
368 | print(f"backend:{args.dist_backend}, init_method:{args.dist_url}, world_size:{args.world_size}, rank:{args.rank}")
369 | torch.distributed.barrier()
370 | setup_for_distributed(args.rank == 0)
371 |
372 |
373 | def load_state_dict(model, state_dict, prefix='', ignore_missing="relative_position_index"):
374 | missing_keys = []
375 | unexpected_keys = []
376 | error_msgs = []
377 | # copy state_dict so _load_from_state_dict can modify it
378 | metadata = getattr(state_dict, '_metadata', None)
379 | state_dict = state_dict.copy()
380 | if metadata is not None:
381 | state_dict._metadata = metadata
382 |
383 | def load(module, prefix=''):
384 | local_metadata = {} if metadata is None else metadata.get(
385 | prefix[:-1], {})
386 | module._load_from_state_dict(
387 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
388 | for name, child in module._modules.items():
389 | if child is not None:
390 | load(child, prefix + name + '.')
391 |
392 | load(model, prefix=prefix)
393 |
394 | warn_missing_keys = []
395 | ignore_missing_keys = []
396 | for key in missing_keys:
397 | keep_flag = True
398 | for ignore_key in ignore_missing.split('|'):
399 | if ignore_key in key:
400 | keep_flag = False
401 | break
402 | if keep_flag:
403 | warn_missing_keys.append(key)
404 | else:
405 | ignore_missing_keys.append(key)
406 |
407 | missing_keys = warn_missing_keys
408 |
409 | if len(missing_keys) > 0:
410 | logging.warning("Weights of {} not initialized from pretrained model: {}".format(
411 | model.__class__.__name__, missing_keys))
412 | if len(unexpected_keys) > 0:
413 | logging.warning("Weights from pretrained model not used in {}: {}".format(
414 | model.__class__.__name__, unexpected_keys))
415 | if len(ignore_missing_keys) > 0:
416 | logging.warning("Ignored weights of {} not initialized from pretrained model: {}".format(
417 | model.__class__.__name__, ignore_missing_keys))
418 | if len(error_msgs) > 0:
419 | print('\n'.join(error_msgs))
420 |
421 |
422 | class NativeScalerWithGradNormCount:
423 | state_dict_key = "amp_scaler"
424 |
425 | def __init__(self):
426 | self._scaler = torch.cuda.amp.GradScaler()
427 |
428 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
429 | self._scaler.scale(loss).backward(create_graph=create_graph)
430 | if update_grad:
431 | if clip_grad is not None:
432 | assert parameters is not None
433 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
434 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
435 | else:
436 | self._scaler.unscale_(optimizer)
437 | norm = get_grad_norm_(parameters)
438 | self._scaler.step(optimizer)
439 | self._scaler.update()
440 | else:
441 | norm = None
442 | return norm
443 |
444 | def state_dict(self):
445 | return self._scaler.state_dict()
446 |
447 | def load_state_dict(self, state_dict):
448 | self._scaler.load_state_dict(state_dict)
449 |
450 |
451 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
452 | if isinstance(parameters, torch.Tensor):
453 | parameters = [parameters]
454 | parameters = [p for p in parameters if p.grad is not None]
455 | norm_type = float(norm_type)
456 | if len(parameters) == 0:
457 | return torch.tensor(0.)
458 | device = parameters[0].grad.device
459 | if norm_type == inf:
460 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
461 | else:
462 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
463 | return total_norm
464 |
465 |
466 | def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0,
467 | start_warmup_value=0, warmup_steps=-1):
468 | warmup_schedule = np.array([])
469 | warmup_iters = warmup_epochs * niter_per_ep
470 | if warmup_steps > 0:
471 | warmup_iters = warmup_steps
472 | print("Set warmup steps = %d" % warmup_iters)
473 | if warmup_epochs > 0:
474 | warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
475 |
476 | iters = np.arange(epochs * niter_per_ep - warmup_iters)
477 | schedule = np.array(
478 | [final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * i / (len(iters)))) for i in iters])
479 |
480 | schedule = np.concatenate((warmup_schedule, schedule))
481 |
482 | assert len(schedule) == epochs * niter_per_ep
483 | return schedule
484 |
485 |
486 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
487 | output_dir = Path(args.output_dir)
488 | epoch_name = str(epoch)
489 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
490 | for checkpoint_path in checkpoint_paths:
491 | to_save = {
492 | 'model': model_without_ddp.state_dict(),
493 | 'optimizer': optimizer.state_dict(),
494 | 'epoch': epoch,
495 | 'scaler': loss_scaler.state_dict(),
496 | 'args': args,
497 | }
498 |
499 | if model_ema is not None:
500 | to_save['model_ema'] = get_state_dict(model_ema)
501 |
502 | save_on_master(to_save, checkpoint_path)
503 |
504 | if is_main_process() and isinstance(epoch, int):
505 | to_del = epoch - args.save_ckpt_num * args.save_ckpt_freq
506 | old_ckpt = output_dir / ('checkpoint-%s.pth' % to_del)
507 | if os.path.exists(old_ckpt):
508 | os.remove(old_ckpt)
509 |
510 |
511 | def auto_load_model(args, model, model_without_ddp, optimizer, loss_scaler, model_ema=None):
512 | output_dir = Path(args.output_dir)
513 | if args.auto_resume and len(args.resume) == 0:
514 | import glob
515 | all_checkpoints = glob.glob(os.path.join(output_dir, 'checkpoint-*.pth'))
516 | latest_ckpt = -1
517 | for ckpt in all_checkpoints:
518 | t = ckpt.split('-')[-1].split('.')[0]
519 | if t.isdigit():
520 | latest_ckpt = max(int(t), latest_ckpt)
521 | if latest_ckpt >= 0:
522 | args.resume = os.path.join(output_dir, 'checkpoint-%d.pth' % latest_ckpt)
523 | print("Auto resume checkpoint: %s" % args.resume)
524 |
525 | if args.resume:
526 | if args.resume.startswith('https'):
527 | checkpoint = torch.hub.load_state_dict_from_url(
528 | args.resume, map_location='cpu', check_hash=True)
529 | else:
530 | checkpoint = torch.load(args.resume, map_location='cpu')
531 | model_without_ddp.load_state_dict(checkpoint['model'])
532 | print("Resume checkpoint %s" % args.resume)
533 | if 'optimizer' in checkpoint and 'epoch' in checkpoint:
534 | optimizer.load_state_dict(checkpoint['optimizer'])
535 | if not isinstance(checkpoint['epoch'], str): # does not support resuming with 'best', 'best-ema'
536 | args.start_epoch = checkpoint['epoch'] + 1
537 | else:
538 | assert args.eval, 'Does not support resuming with checkpoint-best'
539 | if hasattr(args, 'model_ema') and args.model_ema:
540 | if 'model_ema' in checkpoint.keys():
541 | model_ema.ema.load_state_dict(checkpoint['model_ema'])
542 | else:
543 | model_ema.ema.load_state_dict(checkpoint['model'])
544 | if 'scaler' in checkpoint:
545 | loss_scaler.load_state_dict(checkpoint['scaler'])
546 | print("With optim & sched!")
547 |
548 |
549 |
550 | @functools.lru_cache()
551 | def create_logger(output_dir, dist_rank=0, name=''):
552 | # create logger
553 | logger = logging.getLogger(name)
554 | logger.setLevel(logging.INFO)
555 | logger.propagate = False
556 |
557 | # create formatter
558 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
559 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
560 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
561 |
562 | # create console handlers for master process
563 | if dist_rank == 0:
564 | console_handler = logging.StreamHandler(sys.stdout)
565 | console_handler.setLevel(logging.DEBUG)
566 | console_handler.setFormatter(
567 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
568 | logger.addHandler(console_handler)
569 |
570 | # create file handlers
571 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
572 | file_handler.setLevel(logging.DEBUG)
573 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
574 | logger.addHandler(file_handler)
575 |
576 | return logger
577 |
578 | def reduce_tensor(tensor):
579 | rt = tensor.clone()
580 | # dist.all_reduce(rt, op=dist.ReduceOp.SUM)
581 | # rt /= dist.get_world_size()
582 | return rt
583 |
584 |
585 | ##############################################
586 | # Center Loss for Attention Regularization
587 | ##############################################
588 | class CenterLoss(nn.Module):
589 | def __init__(self):
590 | super(CenterLoss, self).__init__()
591 | self.l2_loss = nn.MSELoss(reduction='sum')
592 |
593 | def forward(self, outputs, targets):
594 | return self.l2_loss(outputs, targets) / outputs.size(0)
595 |
596 |
597 | class UnNormalizer(object):
598 | def __init__(self, mean=None, std=None):
599 | if mean == None:
600 | self.mean = [0.485, 0.456, 0.406]
601 | else:
602 | self.mean = mean
603 | if std == None:
604 | self.std = [0.229, 0.224, 0.225]
605 | else:
606 | self.std = std
607 |
608 | def __call__(self, tensor):
609 | """
610 | Args:
611 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
612 | Returns:
613 | Tensor: Normalized image.
614 | """
615 | for t, m, s in zip(tensor, self.mean, self.std):
616 | t.mul_(s).add_(m)
617 | return tensor
618 |
619 |
620 | def save_input_image(input_image, save_path):
621 | unorm = UnNormalizer(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
622 | img = unorm(input_image.cpu()).numpy().copy()
623 | img = np.clip(img * 255, 0, 255).astype(np.uint8)
624 | img = np.transpose(img, [1, 2, 0])
625 | # change the channel(BGR->RGB)
626 | r, g, b = cv2.split(img)
627 | img = cv2.merge([b, g, r])
628 | cv2.imwrite(save_path, img)
629 |
630 |
631 | def get_detail_images(input_images, attentions, theta_detail=0.5, padding=0.1):
632 | batch_size, _, height, width = input_images.size() # B, C, H, W
633 | detail_images = []
634 | for i in range(batch_size):
635 | attention = attentions[i:i + 1]
636 | if isinstance(theta_detail, tuple):
637 | theta = random.uniform(*theta_detail) * attention.max()
638 | else:
639 | theta = theta_detail * attention.max()
640 |
641 | mask = F.upsample_bilinear(attention, size=(height, width)) >= theta
642 | nonzero_indices = torch.nonzero(mask[0, 0, ...])
643 | height_min = max(int((nonzero_indices[:, 0].min().item() - padding * height)), 0)
644 | height_max = min(int(nonzero_indices[:, 0].max().item() + padding * height), height)
645 | width_min = max(int(nonzero_indices[:, 1].min().item() - padding * width), 0)
646 | width_max = min(int(nonzero_indices[:, 1].max().item() + padding * width), width)
647 | detail_images.append(
648 | F.upsample_bilinear(input_images[i:i + 1, :, height_min:height_max, width_min:width_max],
649 | size=(height, width)))
650 |
651 | detail_images = torch.cat(detail_images, dim=0)
652 | return detail_images
653 |
654 |
655 | def get_mask_images(input_images, attentions, theta_mask=0.5):
656 | batch_size, _, height, weight = input_images.size()
657 | masks = []
658 | for i in range(batch_size):
659 | attention = attentions[i:i + 1]
660 | if isinstance(theta_mask, tuple):
661 | temp1 = random.uniform(*theta_mask)
662 | temp2 = attention.max()
663 | theta = temp1 * temp2
664 | else:
665 | theta = theta_mask * attention.max()
666 |
667 | masks.append(F.upsample_bilinear(attention, size=(height, weight)) < theta)
668 | masks = torch.cat(masks, dim=0)
669 | mask_images = input_images * masks.float()
670 |
671 | return mask_images
672 |
673 | # def batch_augment(images, attention_map, mode='crop', theta=0.5, padding_ratio=0.1, image_name=None):
674 | # batches, _, imgH, imgW = images.size()
675 | # if mode == 'crop':
676 | # crop_images = []
677 | # nonzero_indices = None
678 | # for batch_index in range(batches):
679 | # atten_map = attention_map[batch_index:batch_index + 1]
680 | # if isinstance(theta, tuple):
681 | # theta_c = random.uniform(*theta) * atten_map.max()
682 | # else:
683 | # theta_c = theta * atten_map.max()
684 | #
685 | # crop_mask = F.upsample_bilinear(atten_map, size=(imgH, imgW)) >= theta_c
686 | # nonzero_indices = torch.nonzero(crop_mask[0, 0, ...])
687 | # height_min = max(int((nonzero_indices[:, 0].min().item() - padding_ratio * imgH)), 0)
688 | # height_max = min(int(nonzero_indices[:, 0].max().item() + padding_ratio * imgH), imgH)
689 | # width_min = max(int(nonzero_indices[:, 1].min().item() - padding_ratio * imgW), 0)
690 | # width_max = min(int(nonzero_indices[:, 1].max().item() + padding_ratio * imgW), imgW)
691 | # crop_images.append(
692 | # F.upsample_bilinear(images[batch_index:batch_index + 1, :, height_min:height_max, width_min:width_max],
693 | # size=(imgH, imgW)))
694 | # # if image_name:
695 | # # save_input_image(images[batch_index], os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop", image_name[batch_index] + "_raw.jpg"))
696 | # # save_input_image(crop_images[batch_index][0], os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop", image_name[batch_index] + "_crop.jpg"))
697 | #
698 | # crop_images = torch.cat(crop_images, dim=0)
699 | # return crop_images
700 | #
701 | # elif mode == 'drop':
702 | # drop_masks = []
703 | # for batch_index in range(batches):
704 | # atten_map = attention_map[batch_index:batch_index + 1]
705 | # if isinstance(theta, tuple):
706 | # temp1 = random.uniform(*theta)
707 | # temp2 = atten_map.max()
708 | # theta_d = temp1 * temp2
709 | # else:
710 | # theta_d = theta * atten_map.max()
711 | #
712 | # drop_masks.append(F.upsample_bilinear(atten_map, size=(imgH, imgW)) < theta_d)
713 | # drop_masks = torch.cat(drop_masks, dim=0)
714 | # drop_images = images * drop_masks.float()
715 | # # for idx, drop_image in enumerate(drop_images):
716 | # # if image_name:
717 | # # save_input_image(drop_image,
718 | # # os.path.join("/home/yiyonghao/TGRS/A-Net/visualization/crop_drop",
719 | # # image_name[idx] + "_drop.jpg"))
720 | # return drop_images
721 | #
722 | # else:
723 | # raise ValueError('Expected mode in [\'crop\', \'drop\'], but received unsupported augmentation method %s' % mode)
724 |
--------------------------------------------------------------------------------