├── .gitignore ├── LICENSE ├── MODEL_ZOO.md ├── README.md ├── engine_segfinetune.py ├── inference.py ├── main_segfinetune.py ├── models ├── __init__.py ├── models_convnext.py ├── models_resnet.py ├── models_rfconvnext.py ├── models_vit.py ├── rfconv.py └── rfconvnext.py └── util ├── datasets.py ├── lr_decay.py ├── lr_sched.py ├── metric.py ├── misc.py ├── pos_embed.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | dataset/ 4 | output*/ 5 | output_dir/ 6 | output_dir*/ 7 | ckpts/ 8 | *.pth 9 | *.t7 10 | *.png 11 | *.jpg 12 | tmp*.py 13 | *.pdf 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # celery beat schedule file 92 | celerybeat-schedule 93 | 94 | # SageMath parsed files 95 | *.sage.py 96 | 97 | # Environments 98 | .env 99 | .venv 100 | env/ 101 | venv/ 102 | ENV/ 103 | env.bak/ 104 | venv.bak/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | .DS_Store 119 | 120 | .vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model ZOO for Semi-Supervised Learning on ImageNet-S 2 | 3 | [Finetuning with ViT](#1) 4 | 5 | [Finetuning with ResNet](#2) 6 | 7 | [Finetuning with RF-ConvNext](#3) 8 | 9 | 10 |
11 | 12 | ## Finetuning with ViT 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 |
MethodArchPretraining epochsPretraining modevaltestPretrainedFinetuned
MAEViT-B/161600SSL38.337.0modelmodel
MAEViT-B/161600SSL+Sup61.060.2modelmodel
SEREViT-S/16100SSL41.040.2modelmodel
SEREViT-S/16100SSL+Sup58.957.8modelmodel
65 | 66 | ### Masked Autoencoders Are Scalable Vision Learners (MAE) 67 | 68 |
69 | Command for SSL+Sup 70 | 71 | ```shell 72 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 73 | --accum_iter 1 \ 74 | --batch_size 32 \ 75 | --model vit_base_patch16 \ 76 | --finetune mae_finetuned_vit_base.pth \ 77 | --epochs 100 \ 78 | --nb_classes 920 \ 79 | --blr 1e-4 --layer_decay 0.40 \ 80 | --weight_decay 0.05 --drop_path 0.1 \ 81 | --data_path ${IMAGENETS_DIR} \ 82 | --output_dir ${OUTPATH} \ 83 | --dist_eval 84 | ``` 85 | 86 |
87 | 88 |
89 | Command for SSL 90 | 91 | ```shell 92 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 93 | --accum_iter 1 \ 94 | --batch_size 32 \ 95 | --model vit_base_patch16 \ 96 | --finetune mae_pretrain_vit_base.pth \ 97 | --epochs 100 \ 98 | --nb_classes 920 \ 99 | --blr 5e-4 --layer_decay 0.60 \ 100 | --weight_decay 0.05 --drop_path 0.1 \ 101 | --data_path ${IMAGENETS_DIR} \ 102 | --output_dir ${OUTPATH} \ 103 | --dist_eval 104 | ``` 105 | 106 |
107 | 108 | ### SERE: Exploring Feature Self-relation for Self-supervised Transformer 109 | 110 |
111 | Command for SSL+Sup 112 | 113 | ```shell 114 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 115 | --accum_iter 1 \ 116 | --batch_size 32 \ 117 | --model vit_small_patch16 \ 118 | --finetune sere_finetuned_vit_small_ep100.pth \ 119 | --epochs 100 \ 120 | --nb_classes 920 \ 121 | --blr 5e-4 --layer_decay 0.50 \ 122 | --weight_decay 0.05 --drop_path 0.1 \ 123 | --data_path ${IMAGENETS_DIR} \ 124 | --output_dir ${OUTPATH} \ 125 | --dist_eval 126 | ``` 127 | 128 |
129 | 130 |
131 | Command for SSL 132 | 133 | ```shell 134 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 135 | --accum_iter 1 \ 136 | --batch_size 32 \ 137 | --model vit_small_patch16 \ 138 | --finetune sere_pretrained_vit_small_ep100.pth \ 139 | --epochs 100 \ 140 | --nb_classes 920 \ 141 | --blr 5e-4 --layer_decay 0.50 \ 142 | --weight_decay 0.05 --drop_path 0.1 \ 143 | --data_path ${IMAGENETS_DIR} \ 144 | --output_dir ${OUTPATH} \ 145 | --dist_eval 146 | ``` 147 |
148 | 149 | 150 |
151 | 152 | ## Finetuning with ResNet 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 |
MethodArchPretraining epochsPretraining modevaltestPretrainedFinetuned
PASSResNet-50 D32100SSL21.020.3modelmodel
PASSResNet-50 D16100SSL21.620.8modelmodel
186 | 187 | `D16` means the output stride is 16 with dilation=2 in the last stage. This result is better than the results reported in the paper thanks to the new training scripts. 188 | 189 | ### Large-scale Unsupervised Semantic Segmentation (PASS) 190 |
191 | Command for SSL (ResNet-50 D32) 192 | 193 | ```shell 194 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 195 | --accum_iter 1 \ 196 | --batch_size 32 \ 197 | --model resnet50 \ 198 | --finetune pass919_pretrained.pth.tar \ 199 | --epochs 100 \ 200 | --nb_classes 920 \ 201 | --blr 5e-4 --layer_decay 0.4 \ 202 | --weight_decay 0.0005 \ 203 | --data_path ${IMAGENETS_DIR} \ 204 | --output_dir ${OUTPATH} \ 205 | --dist_eval 206 | ``` 207 |
208 | 209 |
210 | Command for SSL (ResNet-50 D16) 211 | 212 | ```shell 213 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 214 | --accum_iter 1 \ 215 | --batch_size 32 \ 216 | --model resnet50_d16 \ 217 | --finetune pass919_pretrained.pth.tar \ 218 | --epochs 100 \ 219 | --nb_classes 920 \ 220 | --blr 5e-4 --layer_decay 0.45 \ 221 | --weight_decay 0.0005 \ 222 | --data_path ${IMAGENETS_DIR} \ 223 | --output_dir ${OUTPATH} \ 224 | --dist_eval 225 | ``` 226 |
227 | 228 | 229 |
230 | 231 | ## Finetuning with RF-ConvNeXt 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 |
ArchPretraining epochsRF-Next modevaltestPretrainedSearchedFinetuned
ConvNeXt-T300-48.748.8model-model
RF-ConvNeXt-T300rfsingle50.750.5modelmodelmodel
RF-ConvNeXt-T300rfmultiple50.850.5modelmodelmodel
RF-ConvNeXt-T300rfmerge51.351.1modelmodelmodel
286 | 287 |
288 | Command for ConvNeXt-T 289 | 290 | ```shell 291 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 292 | --accum_iter 1 \ 293 | --batch_size 32 \ 294 | --model convnext_tiny \ 295 | --patch_size 4 \ 296 | --finetune convnext_tiny_1k_224_ema.pth \ 297 | --epochs 100 \ 298 | --nb_classes 920 \ 299 | --blr 2.5e-4 --layer_decay 0.6 \ 300 | --weight_decay 0.05 --drop_path 0.2 \ 301 | --data_path ${IMAGENETS_DIR} \ 302 | --output_dir ${OUTPATH} \ 303 | --dist_eval 304 | ``` 305 |
306 | 307 | Before training RF-ConvNext, 308 | please search dilation rates with the mode of rfsearch. 309 | 310 | For rfmultiple and rfsingle, please set `pretrained_rfnext` 311 | as the weights trained in rfsearch. 312 | 313 | For rfmerge, we initilize the model with weights in rfmultiple and only finetune `seg_norm`, `seg_head` and `rfconvs` whose dilate rates are changed. 314 | The othe parts of the network are freezed. 315 | Please set `pretrained_rfnext` 316 | as the weights trained in rfmutilple. 317 | 318 | **Note that this freezing operation in rfmerge may be not required for other tasks.** 319 | 320 |
321 | Command for RF-ConvNeXt-T (rfsearch) 322 | 323 | ```shell 324 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 325 | --accum_iter 1 \ 326 | --batch_size 32 \ 327 | --model rfconvnext_tiny_rfsearch \ 328 | --patch_size 4 \ 329 | --finetune convnext_tiny_1k_224_ema.pth \ 330 | --epochs 100 \ 331 | --nb_classes 920 \ 332 | --blr 2.5e-4 --layer_decay 0.6 0.9 --layer_multiplier 1.0 10.0 \ 333 | --weight_decay 0.05 --drop_path 0.2 \ 334 | --data_path ${IMAGENETS_DIR} \ 335 | --output_dir ${OUTPATH} \ 336 | --dist_eval 337 | ``` 338 |
339 | 340 |
341 | Command for RF-ConvNeXt-T (rfsingle) 342 | 343 | ```shell 344 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 345 | --accum_iter 1 \ 346 | --batch_size 32 \ 347 | --model rfconvnext_tiny_rfsingle \ 348 | --patch_size 4 \ 349 | --finetune convnext_tiny_1k_224_ema.pth \ 350 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \ 351 | --epochs 100 \ 352 | --nb_classes 920 \ 353 | --blr 2.5e-4 --layer_decay 0.6 0.9 --layer_multiplier 1.0 10.0 \ 354 | --weight_decay 0.05 --drop_path 0.2 \ 355 | --data_path ${IMAGENETS_DIR} \ 356 | --output_dir ${OUTPATH} \ 357 | --dist_eval 358 | 359 | python inference.py --model rfconvnext_tiny_rfsingle \ 360 | --patch_size 4 \ 361 | --nb_classes 920 \ 362 | --output_dir ${OUTPATH}/predictions \ 363 | --data_path ${IMAGENETS_DIR} \ 364 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \ 365 | --finetune ${OUTPATH}/checkpoint-99.pth \ 366 | --mode validation 367 | ``` 368 |
369 | 370 |
371 | Command for RF-ConvNeXt-T (rfmultiple) 372 | 373 | ```shell 374 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 375 | --accum_iter 1 \ 376 | --batch_size 32 \ 377 | --model rfconvnext_tiny_rfmultiple \ 378 | --patch_size 4 \ 379 | --finetune convnext_tiny_1k_224_ema.pth \ 380 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \ 381 | --epochs 100 \ 382 | --nb_classes 920 \ 383 | --blr 2.5e-4 --layer_decay 0.55 0.9 --layer_multiplier 1.0 10.0 \ 384 | --weight_decay 0.05 --drop_path 0.1 \ 385 | --data_path ${IMAGENETS_DIR} \ 386 | --output_dir ${OUTPATH} \ 387 | --dist_eval 388 | 389 | python inference.py --model rfconvnext_tiny_rfmultiple \ 390 | --patch_size 4 \ 391 | --nb_classes 920 \ 392 | --output_dir ${OUTPATH}/predictions \ 393 | --data_path ${IMAGENETS_DIR} \ 394 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \ 395 | --finetune ${OUTPATH}/checkpoint-99.pth \ 396 | --mode validation 397 | ``` 398 |
399 | 400 | 401 |
402 | Command for RF-ConvNeXt-T (rfmerge) 403 | 404 | ```shell 405 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 406 | --accum_iter 1 \ 407 | --batch_size 32 \ 408 | --model rfconvnext_tiny_rfmerge \ 409 | --patch_size 4 \ 410 | --pretrained_rfnext ${OUTPATH_OF_RFMULTIPLE}/checkpoint-99.pth \ 411 | --epochs 100 \ 412 | --nb_classes 920 \ 413 | --blr 2.5e-4 --layer_decay 0.55 1.0 --layer_multiplier 1.0 10.0 \ 414 | --weight_decay 0.05 --drop_path 0.2 \ 415 | --data_path ${IMAGENETS_DIR} \ 416 | --output_dir ${OUTPATH} \ 417 | --dist_eval 418 | 419 | python inference.py --model rfconvnext_tiny_rfmerge \ 420 | --patch_size 4 \ 421 | --nb_classes 920 \ 422 | --output_dir ${OUTPATH}/predictions \ 423 | --data_path ${IMAGENETS_DIR} \ 424 | --pretrained_rfnext ${OUTPATH_OF_RFMULTIPLE}/checkpoint-99.pth \ 425 | --finetune ${OUTPATH}/checkpoint-99.pth \ 426 | --mode validation 427 | ``` 428 |
429 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semi-supervised Semantic Segmentation on the ImageNet-S dataset 2 | 3 | This repo provides the code of semi-supervised training of large-scale semantic segmentation on the ImageNet-S dataset. 4 | 5 | ## About ImageNet-S 6 | Based on the ImageNet dataset, the ImageNet-S dataset has 1.2 million training images and 50k high-quality semantic segmentation annotations to 7 | support unsupervised/semi-supervised semantic segmentation on the ImageNet dataset. ImageNet-S dataset is available on [ImageNet-S](https://github.com/LUSSeg/ImageNet-S). More details about the dataset please refer to the [project page](https://LUSSeg.github.io/) or [paper link](https://arxiv.org/abs/2106.03149). 8 | 9 | 10 | 11 | ## Usage 12 | - Semi-supervised finetuning with pre-trained checkpoints 13 | ``` 14 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \ 15 | --accum_iter 1 \ 16 | --batch_size 32 \ 17 | --model vit_small_patch16 \ 18 | --finetune ${PRETRAIN_CHKPT} \ 19 | --epochs 100 \ 20 | --nb_classes 920 | 301 | 51 \ 21 | --blr 5e-4 --layer_decay 0.50 \ 22 | --weight_decay 0.05 --drop_path 0.1 \ 23 | --data_path ${IMAGENETS_DIR} \ 24 | --output_dir ${OUTPATH} \ 25 | --dist_eval 26 | ``` 27 | Note: To use one GPU for training, you can change `--nproc_per_node=8` to `--nproc_per_node=1` and change `--accum_iter 1` to `--accum_iter 8`. 28 | - Get the zip file for testing set. You can submit it to our [online server](https://lusseg.github.io/). 29 | ``` 30 | python inference.py --model vit_small_patch16 \ 31 | --nb_classes 920 | 301 | 51 \ 32 | --output_dir ${OUTPATH}/predictions \ 33 | --data_path ${IMAGENETS_DIR} \ 34 | --finetune ${OUTPATH}/checkpoint-99.pth \ 35 | --mode validation | test 36 | ``` 37 | 38 | ## Model Zoo 39 | **[Model Zoo](MODEL_ZOO.md)**: 40 | We provide a model zoo to record the trend of semi-supervised semantic segmentation on the ImageNet-S dataset. 41 | For now, this repo supports ViT, and more backbones and pretrained models will be added. 42 | Please open a pull request if you want to update your new results. 43 | 44 | Supported networks: ViT, ResNet, ConvNext, RF-ConvNext 45 | 46 | Supported pretrain: MAE, SERE, PASS 47 | 48 | ## Citation 49 | ``` 50 | @article{gao2021luss, 51 | title={Large-scale Unsupervised Semantic Segmentation}, 52 | author={Gao, Shanghua and Li, Zhong-Yu and Yang, Ming-Hsuan and Cheng, Ming-Ming and Han, Junwei and Torr, Philip}, 53 | journal={arXiv preprint arXiv:2106.03149}, 54 | year={2021} 55 | } 56 | ``` 57 | 58 | ## Acknowledgement 59 | 60 | This codebase is build based on the [MAE codebase](https://github.com/facebookresearch/mae). 61 | -------------------------------------------------------------------------------- /engine_segfinetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | import warnings 15 | from typing import Iterable 16 | 17 | import torch 18 | import torch.distributed as dist 19 | import torch.nn.functional as F 20 | from torch.distributed import ReduceOp 21 | 22 | import util.lr_sched as lr_sched 23 | import util.misc as misc 24 | from util.metric import FMeasureGPU, IoUGPU 25 | 26 | 27 | def train_one_epoch(model: torch.nn.Module, 28 | criterion: torch.nn.Module, 29 | data_loader: Iterable, 30 | optimizer: torch.optim.Optimizer, 31 | device: torch.device, 32 | epoch: int, 33 | loss_scaler, 34 | max_norm: float = 0, 35 | args=None): 36 | model.train(True) 37 | metric_logger = misc.MetricLogger(delimiter=' ') 38 | metric_logger.add_meter( 39 | 'lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 40 | header = 'Epoch: [{}]'.format(epoch) 41 | print_freq = 20 42 | 43 | accum_iter = args.accum_iter 44 | 45 | optimizer.zero_grad() 46 | 47 | for data_iter_step, (samples, targets) in enumerate( 48 | metric_logger.log_every(data_loader, print_freq, header)): 49 | 50 | # we use a per iteration (instead of per epoch) lr scheduler 51 | if data_iter_step % accum_iter == 0: 52 | lr_sched.adjust_learning_rate( 53 | optimizer, data_iter_step / len(data_loader) + epoch, args) 54 | 55 | samples = samples.to(device, non_blocking=True) 56 | targets = targets.to(device, non_blocking=True) 57 | 58 | with torch.cuda.amp.autocast(): 59 | outputs = model(samples) 60 | 61 | outputs = torch.nn.functional.interpolate(outputs, 62 | scale_factor=2, 63 | align_corners=False, 64 | mode='bilinear') 65 | targets = torch.nn.functional.interpolate( 66 | targets.unsqueeze(1), 67 | size=(outputs.shape[2], outputs.shape[3]), 68 | mode='nearest').squeeze(1) 69 | loss = criterion(outputs, targets.long()) 70 | 71 | loss_value = loss.item() 72 | 73 | if not math.isfinite(loss_value): 74 | print('Loss is {}, stopping training'.format(loss_value)) 75 | sys.exit(1) 76 | 77 | loss /= accum_iter 78 | loss_scaler(loss, 79 | optimizer, 80 | clip_grad=max_norm, 81 | parameters=model.parameters(), 82 | create_graph=False, 83 | update_grad=(data_iter_step + 1) % accum_iter == 0) 84 | if (data_iter_step + 1) % accum_iter == 0: 85 | optimizer.zero_grad() 86 | 87 | torch.cuda.synchronize() 88 | 89 | metric_logger.update(loss=loss_value) 90 | max_lr = 0. 91 | for group in optimizer.param_groups: 92 | max_lr = max(max_lr, group['lr']) 93 | 94 | metric_logger.update(lr=max_lr) 95 | 96 | # gather the stats from all processes 97 | metric_logger.synchronize_between_processes() 98 | print('Averaged stats:', metric_logger) 99 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 100 | 101 | 102 | @torch.no_grad() 103 | def evaluate(data_loader, model, device, num_classes, max_res=1000): 104 | metric_logger = misc.MetricLogger(delimiter=' ') 105 | header = 'Test:' 106 | 107 | T = torch.zeros(size=(num_classes, )).cuda() 108 | P = torch.zeros(size=(num_classes, )).cuda() 109 | TP = torch.zeros(size=(num_classes, )).cuda() 110 | IoU = torch.zeros(size=(num_classes, )).cuda() 111 | FMeasure = 0. 112 | 113 | # switch to evaluation mode 114 | model.eval() 115 | 116 | for batch in metric_logger.log_every(data_loader, 100, header): 117 | images = batch[0] 118 | target = batch[-1] 119 | images = images.to(device, non_blocking=True) 120 | target = target.to(device, non_blocking=True) 121 | 122 | # compute output 123 | with torch.no_grad(): 124 | output = model(images) 125 | 126 | # process an image with a large resolution 127 | H, W = target.shape[1], target.shape[2] 128 | if (H > W and H * W > max_res * max_res 129 | and max_res > 0): 130 | output = F.interpolate(output, (max_res, int(max_res * W / H)), 131 | mode='bilinear', 132 | align_corners=False) 133 | output = torch.argmax(output, dim=1, keepdim=True) 134 | output = F.interpolate(output.float(), (H, W), 135 | mode='nearest').long() 136 | elif (H <= W and H * W > max_res * max_res 137 | and max_res > 0): 138 | output = F.interpolate(output, (int(max_res * H / W), max_res), 139 | mode='bilinear', align_corners=False) 140 | output = torch.argmax(output, dim=1, keepdim=True) 141 | output = F.interpolate(output.float(), (H, W), 142 | mode='nearest').long() 143 | else: 144 | output = F.interpolate(output, (H, W), 145 | mode='bilinear', 146 | align_corners=False) 147 | output = torch.argmax(output, dim=1, keepdim=True) 148 | 149 | target = target.view(-1) 150 | output = output.view(-1) 151 | mask = target != 1000 152 | target = target[mask] 153 | output = output[mask] 154 | 155 | area_intersection, area_output, area_target = IoUGPU( 156 | output, target, num_classes) 157 | f_score = FMeasureGPU(output, target) 158 | 159 | T += area_output 160 | P += area_target 161 | TP += area_intersection 162 | FMeasure += f_score 163 | 164 | metric_logger.synchronize_between_processes() 165 | 166 | # gather the stats from all processes 167 | dist.barrier() 168 | dist.all_reduce(T, op=ReduceOp.SUM) 169 | dist.all_reduce(P, op=ReduceOp.SUM) 170 | dist.all_reduce(TP, op=ReduceOp.SUM) 171 | dist.all_reduce(FMeasure, op=ReduceOp.SUM) 172 | 173 | IoU = TP / (T + P - TP + 1e-10) * 100 174 | FMeasure = FMeasure / len(data_loader.dataset) 175 | 176 | mIoU = torch.mean(IoU).item() 177 | FMeasure = FMeasure.item() * 100 178 | 179 | log = {} 180 | log['mIoU'] = mIoU 181 | log['IoUs'] = IoU.tolist() 182 | log['FMeasure'] = FMeasure 183 | 184 | print('* mIoU {mIoU:.3f} FMeasure {FMeasure:.3f}'.format( 185 | mIoU=mIoU, FMeasure=FMeasure)) 186 | 187 | return log 188 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | from torchvision import datasets, transforms 10 | from tqdm import tqdm 11 | 12 | import models 13 | 14 | 15 | class SegmentationFolder(datasets.ImageFolder): 16 | def __getitem__(self, index): 17 | path = self.imgs[index][0] 18 | sample = self.loader(path) 19 | height, width = sample.size[1], sample.size[0] 20 | 21 | if self.transform is not None: 22 | sample = self.transform(sample) 23 | return sample, path, height, width 24 | 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='Inference') 28 | parser.add_argument('--nb_classes', type=int, default=50) 29 | parser.add_argument('--mode', 30 | type=str, 31 | required=True, 32 | help='validation or test', 33 | choices=['validation', 'test']) 34 | parser.add_argument('--output_dir', 35 | type=str, 36 | default=None, 37 | help='the path to save segmentation masks') 38 | parser.add_argument('--data_path', 39 | type=str, 40 | default=None, 41 | help='path to imagenetS dataset') 42 | parser.add_argument('--finetune', 43 | type=str, 44 | default=None, 45 | help='the model checkpoint file') 46 | parser.add_argument('--pretrained_rfnext', 47 | default='', 48 | help='pretrained weights for RF-Next') 49 | parser.add_argument('--model', 50 | default='vit_small_patch16', 51 | help='model architecture') 52 | parser.add_argument('--patch_size', 53 | type=int, 54 | default=4, 55 | help='For convnext/rfconvnext, the numnber of output channels is ' 56 | 'nb_classes * patch_size * patch_size.' 57 | 'https://arxiv.org/pdf/2111.06377.pdf') 58 | parser.add_argument( 59 | '--max_res', 60 | default=1000, 61 | type=int, 62 | help='Maximum resolution for evaluation. 0 for disable.') 63 | parser.add_argument('--method', 64 | default='example submission', 65 | help='Method name in method description file(.txt).') 66 | parser.add_argument('--train_data', 67 | default='null', 68 | help='Training data in method description file(.txt).') 69 | parser.add_argument( 70 | '--train_scheme', 71 | default='null', 72 | help='Training scheme in method description file(.txt), \ 73 | e.g., SSL, Sup, SSL+Sup.') 74 | parser.add_argument( 75 | '--link', 76 | default='null', 77 | help='Paper/project link in method description file(.txt).') 78 | parser.add_argument( 79 | '--description', 80 | default='null', 81 | help='Method description in method description file(.txt).') 82 | args = parser.parse_args() 83 | return args 84 | 85 | 86 | def main(): 87 | args = parse_args() 88 | 89 | # build model 90 | model = models.__dict__[args.model](args) 91 | model = model.cuda() 92 | model.eval() 93 | 94 | # load checkpoints 95 | checkpoint = torch.load(args.finetune)['model'] 96 | model.load_state_dict(checkpoint, strict=True) 97 | # build the dataloader 98 | dataset_path = os.path.join(args.data_path, args.mode) 99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 100 | std=[0.229, 0.224, 0.225]) 101 | dataset = SegmentationFolder(root=dataset_path, 102 | transform=transforms.Compose([ 103 | transforms.Resize(256), 104 | transforms.ToTensor(), 105 | normalize, 106 | ])) 107 | dataloader = torch.utils.data.DataLoader(dataset, 108 | batch_size=1, 109 | num_workers=16, 110 | pin_memory=True) 111 | 112 | output_dir = os.path.join(args.output_dir, args.mode) 113 | 114 | for images, path, height, width in tqdm(dataloader): 115 | path = path[0] 116 | cate = path.split('/')[-2] 117 | name = path.split('/')[-1].split('.')[0] 118 | if not os.path.exists(os.path.join(output_dir, cate)): 119 | os.makedirs(os.path.join(output_dir, cate)) 120 | 121 | with torch.no_grad(): 122 | H = height.item() 123 | W = width.item() 124 | 125 | output = model.forward(images.cuda()) 126 | 127 | if (H > W and H * W > args.max_res * args.max_res 128 | and args.max_res > 0): 129 | output = F.interpolate( 130 | output, (args.max_res, int(args.max_res * W / H)), 131 | mode='bilinear', 132 | align_corners=False) 133 | output = torch.argmax(output, dim=1, keepdim=True) 134 | output = F.interpolate(output.float(), (H, W), 135 | mode='nearest').long() 136 | elif (H <= W and H * W > args.max_res * args.max_res 137 | and args.max_res > 0): 138 | output = F.interpolate( 139 | output, (int(args.max_res * H / W), args.max_res), 140 | mode='bilinear', 141 | align_corners=False) 142 | output = torch.argmax(output, dim=1, keepdim=True) 143 | output = F.interpolate(output.float(), (H, W), 144 | mode='nearest').long() 145 | else: 146 | output = F.interpolate(output, (H, W), 147 | mode='bilinear', 148 | align_corners=False) 149 | output = torch.argmax(output, dim=1, keepdim=True) 150 | output = output.squeeze() 151 | 152 | res = torch.zeros(size=(output.shape[0], output.shape[1], 3)) 153 | res[:, :, 0] = output % 256 154 | res[:, :, 1] = output // 256 155 | res = res.cpu().numpy() 156 | 157 | res = Image.fromarray(res.astype(np.uint8)) 158 | res.save(os.path.join(output_dir, cate, name + '.png')) 159 | 160 | if args.mode == 'test': 161 | method = 'Method name: {}\n'.format( 162 | args.method) + \ 163 | 'Training data: {}\nTraining scheme: {}\n'.format( 164 | args.train_data, args.train_scheme) + \ 165 | 'Networks: {}\nPaper/Project link: {}\n'.format( 166 | args.model, args.link) + \ 167 | 'Method description: {}'.format( 168 | args.description) 169 | with open(os.path.join(output_dir, 'method.txt'), 'w') as f: 170 | f.write(method) 171 | 172 | # zip for submission 173 | shutil.make_archive(os.path.join(args.output_dir, args.mode), 174 | 'zip', 175 | root_dir=output_dir) 176 | 177 | 178 | if __name__ == '__main__': 179 | main() 180 | -------------------------------------------------------------------------------- /main_segfinetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | import timm 21 | import torch 22 | import torch.backends.cudnn as cudnn 23 | from timm.models.layers import trunc_normal_ 24 | 25 | import models 26 | import util.lr_decay as lrd 27 | import util.misc as misc 28 | from engine_segfinetune import evaluate, train_one_epoch 29 | from util.datasets import build_dataset 30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 31 | from util.pos_embed import interpolate_pos_embed 32 | from timm.models.convnext import checkpoint_filter_fn 33 | 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser( 37 | 'Semi-supervised fine-tuning for ' 38 | 'the semantic segmentation on the ImageNet-S dataset', 39 | add_help=False) 40 | parser.add_argument( 41 | '--batch_size', 42 | default=64, 43 | type=int, 44 | help='Batch size per GPU ' 45 | '(effective batch size is batch_size * accum_iter * # gpus') 46 | parser.add_argument('--epochs', default=50, type=int) 47 | parser.add_argument( 48 | '--accum_iter', 49 | default=1, 50 | type=int, 51 | help='Accumulate gradient iterations ' 52 | '(for increasing the effective batch size under memory constraints)') 53 | parser.add_argument('--saveckp_freq', 54 | default=20, 55 | type=int, 56 | help='Save checkpoint every x epochs.') 57 | parser.add_argument('--eval_freq', 58 | default=20, 59 | type=int, 60 | help='Evaluate the model every x epochs.') 61 | parser.add_argument( 62 | '--max_res', 63 | default=1000, 64 | type=int, 65 | help='Maximum resolution for evaluation. 0 for disable.') 66 | 67 | # Model parameters 68 | parser.add_argument('--model', 69 | default='vit_small_patch16', 70 | type=str, 71 | metavar='MODEL', 72 | help='Name of model to train') 73 | parser.add_argument('--drop_path', 74 | type=float, 75 | default=0.1, 76 | metavar='PCT', 77 | help='Drop path rate (default: 0.1)') 78 | parser.add_argument('--patch_size', 79 | type=int, 80 | default=4, 81 | help='For convnext/rfconvnext, the numnber of output channels is ' 82 | 'nb_classes * patch_size * patch_size.' 83 | 'https://arxiv.org/pdf/2111.06377.pdf') 84 | 85 | # Optimizer parameters 86 | parser.add_argument('--clip_grad', 87 | type=float, 88 | default=None, 89 | metavar='NORM', 90 | help='Clip gradient norm (default: None, no clipping)') 91 | parser.add_argument('--weight_decay', 92 | type=float, 93 | default=0.05, 94 | help='weight decay (default: 0.05)') 95 | 96 | parser.add_argument('--lr', 97 | type=float, 98 | default=None, 99 | metavar='LR', 100 | help='learning rate (absolute lr)') 101 | parser.add_argument('--blr', 102 | type=float, 103 | default=1e-3, 104 | metavar='LR', 105 | help='base learning rate: ' 106 | 'absolute_lr = base_lr * total_batch_size / 256') 107 | parser.add_argument('--layer_decay', 108 | type=float, 109 | default=[0.75], 110 | nargs="+", 111 | help='layer-wise lr decay from ELECTRA/BEiT.' 112 | 'For each layer, the function get_layer_id in utils.lr_decay ' 113 | 'returns (layer_group, layer_id). ' 114 | 'According to the layer_group, different parameters are grouped, ' 115 | 'and the layer_decay[layer_group] is used as the decay rate for different groups.') 116 | parser.add_argument('--layer_multiplier', 117 | type=float, 118 | default=[1.0], 119 | nargs="+", 120 | help='The learning rate multipliers for different layers. ' 121 | 'For each layer, the function get_layer_id in utils.lr_decay ' 122 | 'returns (layer_group, layer_id). ' 123 | 'According to the layer_group, different parameters are grouped, ' 124 | 'and the learning rate of each group is lr = lr * layer_multiplier[layer_group].') 125 | parser.add_argument('--min_lr', 126 | type=float, 127 | default=1e-6, 128 | metavar='LR', 129 | help='lower lr bound for cyclic schedulers that hit 0') 130 | parser.add_argument('--warmup_epochs', 131 | type=int, 132 | default=5, 133 | metavar='N', 134 | help='epochs to warmup LR') 135 | 136 | # Augmentation parameters 137 | parser.add_argument( 138 | '--color_jitter', 139 | type=float, 140 | default=None, 141 | metavar='PCT', 142 | help='Color jitter factor (enabled only when not using Auto/RandAug)') 143 | 144 | # * Finetuning params 145 | parser.add_argument('--finetune', 146 | default='', 147 | help='finetune from checkpoint') 148 | parser.add_argument('--pretrained_rfnext', 149 | default='', 150 | help='pretrained weights for RF-Next') 151 | 152 | # Dataset parameters 153 | parser.add_argument('--data_path', 154 | default='/datasets01/imagenet_full_size/061417/', 155 | type=str, 156 | help='dataset path') 157 | parser.add_argument('--iteration_one_epoch', 158 | default=-1, 159 | type=int, 160 | help='number of iterations in one epoch') 161 | parser.add_argument('--nb_classes', 162 | default=1000, 163 | type=int, 164 | help='number of the classification types') 165 | 166 | parser.add_argument('--output_dir', 167 | default=None, 168 | help='path where to save, empty for no saving') 169 | parser.add_argument('--device', 170 | default='cuda', 171 | help='device to use for training / testing') 172 | parser.add_argument('--seed', default=0, type=int) 173 | parser.add_argument('--resume', default='', help='resume from checkpoint') 174 | 175 | parser.add_argument('--start_epoch', 176 | default=0, 177 | type=int, 178 | metavar='N', 179 | help='start epoch') 180 | parser.add_argument('--eval', 181 | action='store_true', 182 | help='Perform evaluation only') 183 | parser.add_argument('--dist_eval', 184 | action='store_true', 185 | default=False, 186 | help='Enabling distributed evaluation ' 187 | '(recommended during training for faster monitor') 188 | parser.add_argument('--num_workers', default=10, type=int) 189 | # distributed training parameters 190 | parser.add_argument('--world_size', 191 | default=1, 192 | type=int, 193 | help='number of distributed processes') 194 | parser.add_argument('--local_rank', default=-1, type=int) 195 | parser.add_argument('--dist_on_itp', action='store_true') 196 | parser.add_argument('--dist_url', 197 | default='env://', 198 | help='url used to set up distributed training') 199 | 200 | return parser 201 | 202 | 203 | def main(args): 204 | misc.init_distributed_mode(args) 205 | 206 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 207 | print('{}'.format(args).replace(', ', ',\n')) 208 | 209 | device = torch.device(args.device) 210 | 211 | # fix the seed for reproducibility 212 | seed = args.seed + misc.get_rank() 213 | torch.manual_seed(seed) 214 | np.random.seed(seed) 215 | 216 | cudnn.benchmark = True 217 | 218 | dataset_train = build_dataset(is_train=True, args=args) 219 | dataset_val = build_dataset(is_train=False, args=args) 220 | 221 | if True: # args.distributed: 222 | num_tasks = misc.get_world_size() 223 | global_rank = misc.get_rank() 224 | sampler_train = torch.utils.data.DistributedSampler( 225 | dataset_train, 226 | num_replicas=num_tasks, 227 | rank=global_rank, 228 | shuffle=True) 229 | print('Sampler_train = %s' % str(sampler_train)) 230 | if args.dist_eval: 231 | if len(dataset_val) % num_tasks != 0: 232 | print( 233 | 'Warning: Enabling distributed evaluation ' 234 | 'with an eval dataset not divisible by process number. ' 235 | 'This will slightly alter validation ' 236 | 'results as extra duplicate entries are added to achieve ' 237 | 'equal num of samples per-process.') 238 | sampler_val = torch.utils.data.DistributedSampler( 239 | dataset_val, 240 | num_replicas=num_tasks, 241 | rank=global_rank, 242 | shuffle=True) # shuffle=True to reduce monitor bias 243 | else: 244 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 245 | else: 246 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 247 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 248 | 249 | data_loader_train = torch.utils.data.DataLoader( 250 | dataset_train, 251 | sampler=sampler_train, 252 | batch_size=args.batch_size, 253 | num_workers=args.num_workers, 254 | pin_memory=True, 255 | drop_last=False, 256 | ) 257 | 258 | data_loader_val = torch.utils.data.DataLoader(dataset_val, 259 | sampler=sampler_val, 260 | batch_size=1, 261 | num_workers=args.num_workers, 262 | pin_memory=True, 263 | drop_last=False) 264 | args.iteration_one_epoch = len(data_loader_train) 265 | model = models.__dict__[args.model](args) 266 | 267 | if args.finetune and not args.eval: 268 | checkpoint = torch.load(args.finetune, map_location='cpu') 269 | print('Load pre-trained checkpoint from: %s' % args.finetune) 270 | if 'model' in checkpoint: 271 | checkpoint = checkpoint['model'] 272 | elif 'state_dict' in checkpoint: 273 | checkpoint = checkpoint['state_dict'] 274 | checkpoint = { 275 | k.replace('module.', ''): v 276 | for k, v in checkpoint.items() 277 | } 278 | checkpoint = { 279 | k.replace('backbone.', ''): v 280 | for k, v in checkpoint.items() 281 | } 282 | 283 | for k in ['head.weight', 'head.bias']: 284 | if k in checkpoint.keys(): 285 | print(f'Removing key {k} from pretrained checkpoint') 286 | del checkpoint[k] 287 | 288 | if 'vit' in args.model: 289 | # interpolate position embedding 290 | interpolate_pos_embed(model, checkpoint) 291 | elif 'convnext' in args.model: 292 | checkpoint = checkpoint_filter_fn(checkpoint, model) 293 | 294 | # load pre-trained model 295 | msg = model.load_state_dict(checkpoint, strict=False) 296 | print('Missing: {}'.format(msg.missing_keys)) 297 | 298 | model.to(device) 299 | 300 | model_without_ddp = model 301 | n_parameters = sum(p.numel() for p in model.parameters() 302 | if p.requires_grad) 303 | 304 | print('Model = %s' % str(model_without_ddp)) 305 | print('number of params (M): %.2f' % (n_parameters / 1.e6)) 306 | 307 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 308 | 309 | if args.lr is None: # only base_lr is specified 310 | args.lr = args.blr * eff_batch_size / 256 311 | 312 | print('base lr: %.2e' % (args.lr * 256 / eff_batch_size)) 313 | print('actual lr: %.2e' % args.lr) 314 | 315 | print('accumulate grad iterations: %d' % args.accum_iter) 316 | print('effective batch size: %d' % eff_batch_size) 317 | 318 | if args.distributed: 319 | model = torch.nn.parallel.DistributedDataParallel( 320 | model, device_ids=[args.gpu], find_unused_parameters=True) 321 | model_without_ddp = model.module 322 | 323 | # build optimizer with layer-wise lr decay (lrd) 324 | param_groups = lrd.param_groups_lrd( 325 | model_without_ddp, 326 | args.weight_decay, 327 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 328 | layer_decay=args.layer_decay, 329 | layer_multiplier=args.layer_multiplier) 330 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 331 | loss_scaler = NativeScaler() 332 | criterion = torch.nn.CrossEntropyLoss( 333 | ignore_index=1000) # 1000 denotes the ignored region in ImageNet-S. 334 | print('criterion = %s' % str(criterion)) 335 | 336 | misc.load_model(args=args, 337 | model_without_ddp=model_without_ddp, 338 | optimizer=optimizer, 339 | loss_scaler=loss_scaler) 340 | 341 | if args.eval: 342 | test_stats = evaluate(data_loader_val, model, device, args.nb_classes) 343 | print(f'mIoU of the network on the {len(dataset_val)} ' 344 | f"test images: {test_stats['mIoU']:.1f}%") 345 | if len(dataset_val) % num_tasks != 0: 346 | print('Warning: Enabling distributed evaluation ' 347 | 'with an eval dataset not divisible by process number. ' 348 | 'This will slightly alter validation ' 349 | 'results as extra duplicate entries are added to achieve ' 350 | 'equal num of samples per-process.') 351 | exit(0) 352 | 353 | print(f'Start training for {args.epochs} epochs') 354 | start_time = time.time() 355 | max_accuracy = 0.0 356 | for epoch in range(args.start_epoch, args.epochs): 357 | if args.distributed: 358 | data_loader_train.sampler.set_epoch(epoch) 359 | train_stats = train_one_epoch(model, 360 | criterion, 361 | data_loader_train, 362 | optimizer, 363 | device, 364 | epoch, 365 | loss_scaler, 366 | args.clip_grad, 367 | args=args) 368 | if args.output_dir and (epoch + 1) % args.saveckp_freq == 0: 369 | misc.save_model(args=args, 370 | model=model, 371 | model_without_ddp=model_without_ddp, 372 | optimizer=optimizer, 373 | loss_scaler=loss_scaler, 374 | epoch=epoch) 375 | 376 | if (epoch + 1) % args.eval_freq == 0 or epoch == 0: 377 | test_stats = evaluate(data_loader_val, 378 | model, 379 | device, 380 | args.nb_classes, 381 | max_res=args.max_res) 382 | print(f'mIoU of the network on the {len(dataset_val)} ' 383 | f"test images: {test_stats['mIoU']:.3f}%") 384 | if len(dataset_val) % num_tasks != 0: 385 | print('Warning: Enabling distributed evaluation ' 386 | 'with an eval dataset not divisible by process number. ' 387 | 'This will slightly alter validation ' 388 | 'results as extra duplicate entries are added to achieve ' 389 | 'equal num of samples per-process.') 390 | max_accuracy = max(max_accuracy, test_stats['mIoU']) 391 | print(f'Max mIoU: {max_accuracy:.2f}%') 392 | 393 | log_stats = { 394 | **{f'train_{k}': v 395 | for k, v in train_stats.items()}, 396 | **{f'test_{k}': v 397 | for k, v in test_stats.items()}, 'epoch': epoch, 398 | 'n_parameters': n_parameters 399 | } 400 | 401 | if args.output_dir and misc.is_main_process(): 402 | with open(os.path.join(args.output_dir, 'log.txt'), 403 | mode='a', 404 | encoding='utf-8') as f: 405 | f.write(json.dumps(log_stats) + '\n') 406 | 407 | total_time = time.time() - start_time 408 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 409 | print('Training time {}'.format(total_time_str)) 410 | 411 | 412 | if __name__ == '__main__': 413 | args = get_args_parser() 414 | args = args.parse_args() 415 | if args.output_dir: 416 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 417 | main(args) 418 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models_resnet import resnet18, resnet50, resnet50_d16 2 | from .models_vit import vit_base_patch16, vit_small_patch16 3 | from .models_convnext import convnext_tiny 4 | from .models_rfconvnext import rfconvnext_tiny_rfmerge, rfconvnext_tiny_rfmultiple, rfconvnext_tiny_rfsearch, rfconvnext_tiny_rfsingle 5 | -------------------------------------------------------------------------------- /models/models_convnext.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import timm.models.convnext 4 | from collections import OrderedDict 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.layers import trunc_normal_ 8 | 9 | 10 | class ConvNeXt(timm.models.convnext.ConvNeXt): 11 | """Vision Transformer with support for semantic seg.""" 12 | def __init__(self, patch_size=4, **kwargs): 13 | norm_layer = kwargs.pop('norm_layer') 14 | super(ConvNeXt, self).__init__(**kwargs) 15 | assert self.num_classes > 0 16 | 17 | del self.head 18 | del self.norm_pre 19 | 20 | self.patch_size = patch_size 21 | self.depths = kwargs['depths'] 22 | self.num_layers = sum(self.depths) + len(self.depths) 23 | self.rf_change = [] 24 | 25 | self.seg_norm = norm_layer(self.num_features) 26 | self.seg_head = nn.Sequential(OrderedDict([ 27 | ('drop', nn.Dropout(self.drop_rate)), 28 | ('fc', nn.Conv2d(self.num_features, self.num_classes * (self.patch_size**2), 1)) 29 | ])) 30 | 31 | trunc_normal_(self.seg_head[1].weight, std=.02) 32 | torch.nn.init.zeros_(self.seg_head[1].bias) 33 | 34 | @torch.jit.ignore 35 | def no_weight_decay(self): 36 | return dict() 37 | 38 | def forward_features(self, x): 39 | x = self.stem(x) 40 | x = self.stages(x) 41 | b, c, h, w = x.shape 42 | x = x.view(b, c, -1).permute(0, 2, 1) 43 | x = self.seg_norm(x) 44 | x = x.permute(0, 2, 1).view(b, c, h, w) 45 | return x 46 | 47 | def forward_head(self, x): 48 | x = self.seg_head.drop(x) 49 | x = self.seg_head.fc(x) 50 | b, _, h, w = x.shape 51 | x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, self.patch_size, self.patch_size, self.num_classes) 52 | x = torch.einsum('nhwpqc->nchpwq', x) 53 | x = x.contiguous().view(b, self.num_classes, h * self.patch_size, w * self.patch_size) 54 | return x 55 | 56 | def get_layer_id(self, name): 57 | """ 58 | Assign a parameter with its layer id for layer-wise decay. 59 | 60 | For each layer, the get_layer_id returns (layer_group, layer_id). 61 | According to the layer_group, different parameters are grouped, 62 | and layers in different groups use different decay rates. 63 | 64 | If only the layer_id is returned, the layer_group are set to 0 by default. 65 | """ 66 | if name in ("cls_token", "mask_token", "pos_embed"): 67 | return (0, 0) 68 | elif name.startswith("stem"): 69 | return (0, 0) 70 | elif name.startswith("stages") and 'downsample' in name: 71 | stage_id = int(name.split('.')[1]) 72 | if stage_id == 0: 73 | layer_id = 0 74 | else: 75 | layer_id = sum(self.depths[:stage_id]) + stage_id 76 | return (0, layer_id) 77 | elif name.startswith("stages") and 'downsample' not in name: 78 | stage_id = int(name.split('.')[1]) 79 | block_id = int(name.split('.')[3]) 80 | if stage_id == 0: 81 | layer_id = block_id + 1 82 | else: 83 | layer_id = sum(self.depths[:stage_id]) + stage_id + block_id + 1 84 | return (0, layer_id) 85 | else: 86 | return (0, self.num_layers) 87 | 88 | 89 | def convnext_tiny(args): 90 | model = ConvNeXt( 91 | depths=(3, 3, 9, 3), 92 | dims=(96, 192, 384, 768), 93 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 94 | num_classes=getattr(args, 'nb_classes', 920), 95 | drop_path_rate=getattr(args, 'drop_path', 0), 96 | patch_size=getattr(args, 'patch_size', 4) 97 | ) 98 | return model 99 | -------------------------------------------------------------------------------- /models/models_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 7 | """3x3 convolution with padding""" 8 | return nn.Conv2d( 9 | in_planes, 10 | out_planes, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=dilation, 14 | groups=groups, 15 | bias=False, 16 | dilation=dilation, 17 | ) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | __constants__ = ["downsample"] 28 | 29 | def __init__( 30 | self, 31 | inplanes, 32 | planes, 33 | stride=1, 34 | downsample=None, 35 | groups=1, 36 | base_width=64, 37 | dilation=1, 38 | norm_layer=None, 39 | ): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError("BasicBlock only supports groups=1 and base_width=64") 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | __constants__ = ["downsample"] 78 | 79 | def __init__( 80 | self, 81 | inplanes, 82 | planes, 83 | stride=1, 84 | downsample=None, 85 | groups=1, 86 | base_width=64, 87 | dilation=1, 88 | norm_layer=None, 89 | ): 90 | super(Bottleneck, self).__init__() 91 | if norm_layer is None: 92 | norm_layer = nn.BatchNorm2d 93 | width = int(planes * (base_width / 64.0)) * groups 94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 95 | self.conv1 = conv1x1(inplanes, width) 96 | self.bn1 = norm_layer(width) 97 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 98 | self.bn2 = norm_layer(width) 99 | self.conv3 = conv1x1(width, planes * self.expansion) 100 | self.bn3 = norm_layer(planes * self.expansion) 101 | self.relu = nn.ReLU(inplace=True) 102 | self.downsample = downsample 103 | self.stride = stride 104 | 105 | def forward(self, x): 106 | identity = x 107 | 108 | out = self.conv1(x) 109 | out = self.bn1(out) 110 | out = self.relu(out) 111 | 112 | out = self.conv2(out) 113 | out = self.bn2(out) 114 | out = self.relu(out) 115 | 116 | out = self.conv3(out) 117 | out = self.bn3(out) 118 | 119 | if self.downsample is not None: 120 | identity = self.downsample(x) 121 | 122 | out += identity 123 | out = self.relu(out) 124 | 125 | return out 126 | 127 | 128 | class ResNet(nn.Module): 129 | def __init__( 130 | self, 131 | block, 132 | layers, 133 | zero_init_residual=False, 134 | groups=1, 135 | widen=1, 136 | width_per_group=64, 137 | replace_stride_with_dilation=None, 138 | norm_layer=None, 139 | eval_mode=False, 140 | num_classes=0, 141 | ): 142 | super(ResNet, self).__init__() 143 | if norm_layer is None: 144 | norm_layer = nn.BatchNorm2d 145 | self._norm_layer = norm_layer 146 | 147 | self.eval_mode = eval_mode 148 | self.padding = nn.ConstantPad2d(1, 0.0) 149 | 150 | self.inplanes = width_per_group * widen 151 | self.dilation = 1 152 | if replace_stride_with_dilation is None: 153 | # each element in the tuple indicates if we should replace 154 | # the 2x2 stride with a dilated convolution instead 155 | replace_stride_with_dilation = [False, False, False] 156 | if len(replace_stride_with_dilation) != 3: 157 | raise ValueError( 158 | "replace_stride_with_dilation should be None " 159 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation) 160 | ) 161 | self.groups = groups 162 | self.base_width = width_per_group 163 | self.layers = layers 164 | self.num_layers = sum(self.layers) + 1 165 | 166 | # change padding 3 -> 2 compared to original torchvision code because added a padding layer 167 | num_out_filters = width_per_group * widen 168 | self.conv1 = nn.Conv2d( 169 | 3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False 170 | ) 171 | self.bn1 = norm_layer(num_out_filters) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 174 | self.layer1 = self._make_layer(block, num_out_filters, layers[0]) 175 | num_out_filters *= 2 176 | self.layer2 = self._make_layer( 177 | block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0] 178 | ) 179 | num_out_filters *= 2 180 | self.layer3 = self._make_layer( 181 | block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1] 182 | ) 183 | num_out_filters *= 2 184 | self.layer4 = self._make_layer( 185 | block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2] 186 | ) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | 189 | mid_channels = 512 * block.expansion 190 | # segmentation head and loss function 191 | self.head = nn.Conv2d(mid_channels, num_classes, 1, 1) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 197 | nn.init.constant_(m.weight, 1) 198 | nn.init.constant_(m.bias, 0) 199 | 200 | # Zero-initialize the last BN in each residual branch, 201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 203 | if zero_init_residual: 204 | for m in self.modules(): 205 | if isinstance(m, Bottleneck): 206 | nn.init.constant_(m.bn3.weight, 0) 207 | elif isinstance(m, BasicBlock): 208 | nn.init.constant_(m.bn2.weight, 0) 209 | 210 | @torch.jit.ignore 211 | def no_weight_decay(self): 212 | return dict() 213 | 214 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 215 | norm_layer = self._norm_layer 216 | downsample = None 217 | previous_dilation = self.dilation 218 | if dilate: 219 | self.dilation *= stride 220 | stride = 1 221 | if stride != 1 or self.inplanes != planes * block.expansion: 222 | downsample = nn.Sequential( 223 | conv1x1(self.inplanes, planes * block.expansion, stride), 224 | norm_layer(planes * block.expansion), 225 | ) 226 | 227 | layers = [] 228 | layers.append( 229 | block( 230 | self.inplanes, 231 | planes, 232 | stride, 233 | downsample, 234 | self.groups, 235 | self.base_width, 236 | previous_dilation, 237 | norm_layer, 238 | ) 239 | ) 240 | self.inplanes = planes * block.expansion 241 | for _ in range(1, blocks): 242 | layers.append( 243 | block( 244 | self.inplanes, 245 | planes, 246 | groups=self.groups, 247 | base_width=self.base_width, 248 | dilation=self.dilation, 249 | norm_layer=norm_layer, 250 | ) 251 | ) 252 | 253 | return nn.Sequential(*layers) 254 | 255 | def forward_backbone(self, x, pool=True): 256 | x = self.padding(x) 257 | x = self.conv1(x) 258 | x = self.bn1(x) 259 | x = self.relu(x) 260 | x = self.maxpool(x) 261 | x = self.layer1(x) 262 | x = self.layer2(x) 263 | x = self.layer3(x) 264 | x = self.layer4(x) 265 | 266 | return x 267 | 268 | def forward(self, inputs): 269 | 270 | out = self.forward_backbone(inputs, pool=False) 271 | out = self.head(out) 272 | out = F.interpolate(out, scale_factor=2, align_corners=False, mode='bilinear') 273 | 274 | return out 275 | 276 | def get_layer_id(self, name): 277 | """ 278 | Assign a parameter with its layer id for layer-wise decay. 279 | 280 | For each layer, the get_layer_id returns (layer_group, layer_id). 281 | According to the layer_group, different parameters are grouped, 282 | and layers in different groups use different decay rates. 283 | 284 | If only the layer_id is returned, the layer_group are set to 0 by default. 285 | """ 286 | 287 | if name.startswith('conv1'): 288 | return (0, 0) 289 | elif name.startswith('bn1'): 290 | return (0, 0) 291 | elif name.startswith('layer'): 292 | return (0, sum(self.layers[:int(name[5]) - 1]) + int(name[7]) + 1) 293 | else: 294 | return (0, self.num_layers) 295 | 296 | 297 | def resnet18(args): 298 | kwargs=dict(num_classes=args.nb_classes) 299 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 300 | 301 | 302 | def resnet50(args): 303 | kwargs=dict(num_classes=args.nb_classes) 304 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 305 | 306 | def resnet50_d16(args): 307 | kwargs=dict(num_classes=args.nb_classes) 308 | return ResNet(Bottleneck, [3, 4, 6, 3], replace_stride_with_dilation=[False, False, True], **kwargs) 309 | -------------------------------------------------------------------------------- /models/models_rfconvnext.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import models.rfconvnext as rfconvnext 4 | from collections import OrderedDict 5 | import torch 6 | import torch.nn as nn 7 | from timm.models.layers import trunc_normal_ 8 | 9 | 10 | class RFConvNeXt(rfconvnext.RFConvNeXt): 11 | """Vision Transformer with support for semantic seg.""" 12 | def __init__(self, patch_size=4, **kwargs): 13 | norm_layer = kwargs.pop('norm_layer') 14 | super(RFConvNeXt, self).__init__(**kwargs) 15 | assert self.num_classes > 0 16 | 17 | del self.head 18 | del self.norm_pre 19 | 20 | self.patch_size = patch_size 21 | self.depths = kwargs['depths'] 22 | self.num_layers = sum(self.depths) + len(self.depths) 23 | # The layers whose dilation rates are changed in RF-Next. 24 | # These layers use different hyper-parameters in training. 25 | self.rf_change = [] 26 | self.rf_change_name = [] 27 | 28 | self.seg_norm = norm_layer(self.num_features) 29 | self.seg_head = nn.Sequential(OrderedDict([ 30 | ('drop', nn.Dropout(self.drop_rate)), 31 | ('fc', nn.Conv2d(self.num_features, self.num_classes * (self.patch_size**2), 1)) 32 | ])) 33 | 34 | trunc_normal_(self.seg_head[1].weight, std=.02) 35 | torch.nn.init.zeros_(self.seg_head[1].bias) 36 | 37 | self.get_kernel_size_changed() 38 | 39 | 40 | def get_kernel_size_changed(self): 41 | """ 42 | To get rfconvs whose dilate rates are changed. 43 | """ 44 | for i, stage in enumerate(self.stages): 45 | for j, block in enumerate(stage.blocks): 46 | if block.conv_dw.dilation[0] > 1 or block.conv_dw.kernel_size[0] > 13: 47 | self.rf_change_name.extend( 48 | [ 49 | 'stages.{}.blocks.{}.conv_dw.weight'.format(i, j), 50 | 'stages.{}.blocks.{}.conv_dw.bias'.format(i, j), 51 | 'stages.{}.blocks.{}.conv_dw.sample_weights'.format(i, j) 52 | ] 53 | ) 54 | self.rf_change.append(self.stages[i].blocks[j].conv_dw) 55 | 56 | def freeze(self): 57 | """ 58 | In the mode of rfmerge, 59 | we initilize the model with weights in rfmultiple and 60 | only finetune seg_norm, seg_head and rfconvs whose dilate rates are changed. 61 | The other parts of the network are freezed during funetuning. 62 | 63 | Note that this freezing operation may be not required for other tasks. 64 | """ 65 | if len(self.rf_change_name) == 0: 66 | self.get_kernel_size_changed() 67 | # finetune the rfconvs whose dilate rates are changed 68 | for n, p in self.named_parameters(): 69 | p.requires_grad = True if n in self.rf_change_name else False 70 | # finetune the seg_norm, seg_head 71 | for n, p in self.seg_head.named_parameters(): 72 | p.requires_grad = True 73 | for n, p in self.seg_norm.named_parameters(): 74 | p.requires_grad = True 75 | 76 | @torch.jit.ignore 77 | def no_weight_decay(self): 78 | return dict() 79 | 80 | def forward_features(self, x): 81 | x = self.stem(x) 82 | x = self.stages(x) 83 | b, c, h, w = x.shape 84 | x = x.view(b, c, -1).permute(0, 2, 1) 85 | x = self.seg_norm(x) 86 | x = x.permute(0, 2, 1).view(b, c, h, w) 87 | return x 88 | 89 | def forward_head(self, x): 90 | x = self.seg_head.drop(x) 91 | x = self.seg_head.fc(x) 92 | b, _, h, w = x.shape 93 | x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, self.patch_size, self.patch_size, self.num_classes) 94 | x = torch.einsum('nhwpqc->nchpwq', x) 95 | x = x.contiguous().view(b, self.num_classes, h * self.patch_size, w * self.patch_size) 96 | return x 97 | 98 | def get_layer_id(self, name): 99 | """ 100 | Assign a parameter with its layer id for layer-wise decay. 101 | 102 | For each layer, the get_layer_id returns (layer_group, layer_id). 103 | According to the layer_group, different parameters are grouped, 104 | and layers in different groups use different decay rates. 105 | 106 | If only the layer_id is returned, the layer_group are set to 0 by default. 107 | """ 108 | if name in ("cls_token", "mask_token", "pos_embed"): 109 | return (0, 0) 110 | elif name.startswith("stem"): 111 | return (0, 0) 112 | elif name.startswith("stages") and 'downsample' in name: 113 | stage_id = int(name.split('.')[1]) 114 | if stage_id == 0: 115 | layer_id = 0 116 | else: 117 | layer_id = sum(self.depths[:stage_id]) + stage_id 118 | 119 | if name.endswith('sample_weights') or name in self.rf_change_name: 120 | return (1, layer_id) 121 | return (0, layer_id) 122 | elif name.startswith("stages") and 'downsample' not in name: 123 | stage_id = int(name.split('.')[1]) 124 | block_id = int(name.split('.')[3]) 125 | if stage_id == 0: 126 | layer_id = block_id + 1 127 | else: 128 | layer_id = sum(self.depths[:stage_id]) + stage_id + block_id + 1 129 | 130 | if name.endswith('sample_weights') or name in self.rf_change_name: 131 | return (1, layer_id) 132 | return (0, layer_id) 133 | else: 134 | return (0, self.num_layers) 135 | 136 | 137 | def rfconvnext_tiny_rfsearch(args): 138 | search_cfgs = dict( 139 | num_branches=3, 140 | expand_rate=0.5, 141 | max_dilation=None, 142 | min_dilation=1, 143 | init_weight=0.01, 144 | search_interval=getattr(args, 'iteration_one_epoch', 1250) * 10, # step every 10 epochs 145 | max_search_step=3, # search for 3 steps 146 | ) 147 | model = RFConvNeXt( 148 | depths=(3, 3, 9, 3), 149 | dims=(96, 192, 384, 768), 150 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 151 | rf_mode='rfsearch', 152 | search_cfgs=search_cfgs, 153 | num_classes=getattr(args, 'nb_classes', 920), 154 | drop_path_rate=getattr(args, 'drop_path', 0), 155 | pretrained_weights=getattr(args, 'pretrained_rfnext', None), 156 | patch_size=getattr(args, 'patch_size', 4) 157 | ) 158 | return model 159 | 160 | 161 | def rfconvnext_tiny_rfmultiple(args): 162 | search_cfgs = dict( 163 | num_branches=3, 164 | expand_rate=0.5, 165 | max_dilation=None, 166 | min_dilation=1, 167 | init_weight=0.01, 168 | ) 169 | model = RFConvNeXt( 170 | depths=(3, 3, 9, 3), 171 | dims=(96, 192, 384, 768), 172 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 173 | rf_mode='rfmultiple', 174 | search_cfgs=search_cfgs, 175 | num_classes=getattr(args, 'nb_classes', 920), 176 | drop_path_rate=getattr(args, 'drop_path', 0), 177 | pretrained_weights=getattr(args, 'pretrained_rfnext', None), 178 | patch_size=getattr(args, 'patch_size', 4) 179 | ) 180 | return model 181 | 182 | def rfconvnext_tiny_rfsingle(args): 183 | search_cfgs = dict( 184 | num_branches=3, 185 | expand_rate=0.5, 186 | max_dilation=None, 187 | min_dilation=1, 188 | init_weight=0.01, 189 | ) 190 | model = RFConvNeXt( 191 | depths=(3, 3, 9, 3), 192 | dims=(96, 192, 384, 768), 193 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 194 | rf_mode='rfsingle', 195 | search_cfgs=search_cfgs, 196 | num_classes=getattr(args, 'nb_classes', 920), 197 | drop_path_rate=getattr(args, 'drop_path', 0), 198 | pretrained_weights=getattr(args, 'pretrained_rfnext', None), 199 | patch_size=getattr(args, 'patch_size', 4) 200 | ) 201 | return model 202 | 203 | def rfconvnext_tiny_rfmerge(args): 204 | search_cfgs = dict( 205 | num_branches=3, 206 | expand_rate=0.5, 207 | max_dilation=None, 208 | min_dilation=1, 209 | init_weight=0.01 210 | ) 211 | model = RFConvNeXt( 212 | depths=(3, 3, 9, 3), 213 | dims=(96, 192, 384, 768), 214 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 215 | rf_mode='rfmerge', 216 | search_cfgs=search_cfgs, 217 | num_classes=getattr(args, 'nb_classes', 920), 218 | drop_path_rate=getattr(args, 'drop_path', 0), 219 | pretrained_weights=getattr(args, 'pretrained_rfnext', None), 220 | patch_size=getattr(args, 'patch_size', 4) 221 | ) 222 | # freeze layers except for seg_norm, seg_head and the rfconvs whose dialtion rates are changed. 223 | model.freeze() 224 | return model 225 | -------------------------------------------------------------------------------- /models/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | from functools import partial 14 | 15 | import timm.models.vision_transformer 16 | import torch 17 | import torch.nn as nn 18 | from timm.models.layers import trunc_normal_ 19 | 20 | 21 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 22 | """Vision Transformer with support for semantic seg.""" 23 | def __init__(self, **kwargs): 24 | super(VisionTransformer, self).__init__(**kwargs) 25 | 26 | embed_dim = kwargs['embed_dim'] 27 | norm_layer = kwargs['norm_layer'] 28 | patch_size = kwargs['patch_size'] 29 | self.num_layers = len(self.blocks) + 1 30 | 31 | self.fc_norm = norm_layer(embed_dim) 32 | del self.norm 33 | 34 | self.patch_embed = PatchEmbed(img_size=3, 35 | patch_size=patch_size, 36 | in_chans=3, 37 | embed_dim=embed_dim) 38 | assert self.num_classes > 0 39 | self.head = nn.Conv2d(self.embed_dim, self.num_classes, 1) 40 | # manually initialize fc layer 41 | trunc_normal_(self.head.weight, std=2e-5) 42 | 43 | def forward_head(self, x): 44 | return self.head(x) 45 | 46 | def forward(self, x): 47 | x = self.forward_features(x) 48 | x = self.forward_head(x) 49 | return x 50 | 51 | def forward_features(self, x): 52 | B, _, w, h = x.shape 53 | x = self.patch_embed(x) 54 | 55 | cls_tokens = self.cls_token.expand(B, -1, -1) 56 | x = torch.cat((cls_tokens, x), dim=1) 57 | x = x + self.interpolate_pos_encoding(x, w, h) 58 | x = self.pos_drop(x) 59 | 60 | for blk in self.blocks: 61 | x = blk(x) 62 | 63 | x = x[:, 1:, :] 64 | x = self.fc_norm(x) 65 | b, _, c = x.shape 66 | ih, iw = w // self.patch_embed.patch_size, \ 67 | h // self.patch_embed.patch_size 68 | x = x.view(b, ih, iw, c).permute(0, 3, 1, 2).contiguous() 69 | 70 | return x 71 | 72 | def interpolate_pos_encoding(self, x, w, h): 73 | npatch = x.shape[1] - 1 74 | N = self.pos_embed.shape[1] - 1 75 | if npatch == N and w == h: 76 | return self.pos_embed 77 | class_pos_embed = self.pos_embed[:, 0] 78 | patch_pos_embed = self.pos_embed[:, 1:] 79 | dim = x.shape[-1] 80 | w0 = w // self.patch_embed.patch_size 81 | h0 = h // self.patch_embed.patch_size 82 | # we add a small number to avoid 83 | # floating point error in the interpolation 84 | # see discussion at https://github.com/facebookresearch/dino/issues/8 85 | w0, h0 = w0 + 0.1, h0 + 0.1 86 | patch_pos_embed = nn.functional.interpolate( 87 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), 88 | dim).permute(0, 3, 1, 2).contiguous(), 89 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), 90 | mode='bicubic', 91 | ) 92 | assert int(w0) == patch_pos_embed.shape[-2] and int( 93 | h0) == patch_pos_embed.shape[-1] 94 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).contiguous().view(1, -1, dim) 95 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), 96 | dim=1) 97 | 98 | def get_layer_id(self, name): 99 | """Assign a parameter with its layer id Following BEiT: https://github.com/ 100 | microsoft/unilm/blob/master/beit/optim_factory.py#L33. 101 | 102 | For each layer, the get_layer_id returns (layer_group, layer_id). 103 | According to the layer_group, different parameters are grouped, 104 | and layers in different groups use different decay rates. 105 | 106 | If only the layer_id is returned, the layer_group are set to 0 by default. 107 | """ 108 | if name in ['cls_token', 'pos_embed']: 109 | return (0, 0) 110 | elif name.startswith('patch_embed'): 111 | return (0, 0) 112 | elif name.startswith('blocks'): 113 | return (0, int(name.split('.')[1]) + 1) 114 | else: 115 | return (0, self.num_layers) 116 | 117 | 118 | class PatchEmbed(nn.Module): 119 | """Image to Patch Embedding.""" 120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 121 | super().__init__() 122 | num_patches = (img_size // patch_size) * (img_size // patch_size) 123 | self.img_size = img_size 124 | self.patch_size = patch_size 125 | self.num_patches = num_patches 126 | 127 | self.proj = nn.Conv2d(in_chans, 128 | embed_dim, 129 | kernel_size=patch_size, 130 | stride=patch_size) 131 | 132 | def forward(self, x): 133 | B, C, H, W = x.shape 134 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() 135 | return x 136 | 137 | 138 | def vit_small_patch16(args): 139 | kwargs = dict( 140 | num_classes=args.nb_classes, 141 | drop_path_rate=getattr(args, 'drop_path', 0) 142 | ) 143 | model = VisionTransformer(patch_size=16, 144 | embed_dim=384, 145 | depth=12, 146 | num_heads=6, 147 | mlp_ratio=4, 148 | qkv_bias=True, 149 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 150 | **kwargs) 151 | return model 152 | 153 | 154 | def vit_base_patch16(args): 155 | kwargs = dict( 156 | num_classes=args.nb_classes, 157 | drop_path_rate=getattr(args, 'drop_path', 0) 158 | ) 159 | model = VisionTransformer(patch_size=16, 160 | embed_dim=768, 161 | depth=12, 162 | num_heads=12, 163 | mlp_ratio=4, 164 | qkv_bias=True, 165 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 166 | **kwargs) 167 | return model 168 | -------------------------------------------------------------------------------- /models/rfconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import collections.abc as container_abcs 5 | from itertools import repeat 6 | from timm.models.layers import get_padding 7 | 8 | 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, container_abcs.Iterable): 12 | return x 13 | return tuple(repeat(x, n)) 14 | return parse 15 | 16 | 17 | _pair = _ntuple(2) 18 | 19 | 20 | def value_crop(dilation, min_dilation, max_dilation): 21 | if min_dilation is not None: 22 | if dilation < min_dilation: 23 | dilation = min_dilation 24 | if max_dilation is not None: 25 | if dilation > max_dilation: 26 | dilation = max_dilation 27 | return dilation 28 | 29 | 30 | def rf_expand(dilation, expand_rate, num_branches, min_dilation=1, max_dilation=None): 31 | rate_list = [] 32 | assert num_branches>=2, "number of branches must >=2" 33 | delta_dilation0 = expand_rate * dilation[0] 34 | delta_dilation1 = expand_rate * dilation[1] 35 | for i in range(num_branches): 36 | rate_list.append( 37 | tuple([value_crop( 38 | int(round(dilation[0] - delta_dilation0 + (i) * 2 * delta_dilation0/(num_branches-1))), min_dilation, max_dilation), 39 | value_crop( 40 | int(round(dilation[1] - delta_dilation1 + (i) * 2 * delta_dilation1/(num_branches-1))), min_dilation, max_dilation) 41 | ]) 42 | ) 43 | 44 | unique_rate_list = list(set(rate_list)) 45 | unique_rate_list.sort(key=rate_list.index) 46 | return unique_rate_list 47 | 48 | 49 | class RFConv2d(nn.Conv2d): 50 | 51 | def __init__(self, 52 | in_channels, 53 | out_channels, 54 | kernel_size=1, 55 | stride=1, 56 | padding=0, 57 | dilation=1, 58 | groups=1, 59 | bias=True, 60 | padding_mode='zeros', 61 | num_branches=3, 62 | expand_rate=0.5, 63 | min_dilation=1, 64 | max_dilation=None, 65 | init_weight=0.01, 66 | search_interval=1250, 67 | max_search_step=0, 68 | rf_mode='rfsearch', 69 | pretrained=None 70 | ): 71 | if pretrained is not None and rf_mode == 'rfmerge': 72 | rates = pretrained['rates'] 73 | num_rates = pretrained['num_rates'] 74 | sample_weights = pretrained['sample_weights'] 75 | sample_weights = self.normlize(sample_weights[:num_rates.item()]) 76 | max_dliation_rate = rates[num_rates.item() - 1] 77 | if isinstance(kernel_size, int): 78 | kernel_size = [kernel_size, kernel_size] 79 | if isinstance(stride, int): 80 | stride = [stride, stride] 81 | new_kernel_size = ( 82 | kernel_size[0] + (max_dliation_rate[0].item() - 83 | 1) * (kernel_size[0] // 2) * 2, 84 | kernel_size[1] + (max_dliation_rate[1].item() - 1) * (kernel_size[1] // 2) * 2) 85 | # assign dilation to (1, 1) after merge 86 | new_dilation = (1, 1) 87 | new_padding = ( 88 | get_padding(new_kernel_size[0], stride[0], new_dilation[0]), 89 | get_padding(new_kernel_size[1], stride[1], new_dilation[1])) 90 | 91 | # merge weight of each branch 92 | old_weight = pretrained['weight'] 93 | new_weight = torch.zeros( 94 | size=(old_weight.shape[0], old_weight.shape[1], 95 | new_kernel_size[0], new_kernel_size[1]), 96 | dtype=old_weight.dtype) 97 | for r, rate in enumerate(rates[:num_rates.item()]): 98 | rate = (rate[0].item(), rate[1].item()) 99 | for i in range(- (kernel_size[0] // 2), kernel_size[0] // 2 + 1): 100 | for j in range(- (kernel_size[1] // 2), kernel_size[1] // 2 + 1): 101 | new_weight[:, :, 102 | new_kernel_size[0] // 2 - i * rate[0], 103 | new_kernel_size[1] // 2 - j * rate[1]] += \ 104 | old_weight[:, :, kernel_size[0] // 2 - i, 105 | kernel_size[1] // 2 - j] * sample_weights[r] 106 | 107 | kernel_size = new_kernel_size 108 | padding = new_padding 109 | dilation = new_dilation 110 | pretrained['rates'][0] = torch.FloatTensor([1, 1]) 111 | pretrained['num_rates'] = torch.IntTensor([1]) 112 | pretrained['weight'] = new_weight 113 | # re-initilize the sample_weights 114 | pretrained['sample_weights'] = pretrained['sample_weights'] * \ 115 | 0.0 + init_weight 116 | 117 | super(RFConv2d, self).__init__( 118 | in_channels, 119 | out_channels, 120 | kernel_size, 121 | stride, 122 | padding, 123 | dilation, 124 | groups, 125 | bias, 126 | padding_mode 127 | ) 128 | self.rf_mode = rf_mode 129 | self.pretrained = pretrained 130 | self.num_branches = num_branches 131 | self.max_dilation = max_dilation 132 | self.min_dilation = min_dilation 133 | self.expand_rate = expand_rate 134 | self.init_weight = init_weight 135 | self.search_interval = search_interval 136 | self.max_search_step = max_search_step 137 | self.sample_weights = nn.Parameter(torch.Tensor(self.num_branches)) 138 | self.register_buffer('counter', torch.zeros(1)) 139 | self.register_buffer('current_search_step', torch.zeros(1)) 140 | self.register_buffer('rates', torch.ones( 141 | size=(self.num_branches, 2), dtype=torch.int32)) 142 | self.register_buffer('num_rates', torch.ones(1, dtype=torch.int32)) 143 | self.rates[0] = torch.FloatTensor([self.dilation[0], self.dilation[1]]) 144 | self.sample_weights.data.fill_(self.init_weight) 145 | 146 | # rf-next 147 | if pretrained is not None: 148 | # load pretrained weights 149 | msg = self.load_state_dict(pretrained, strict=False) 150 | assert all([key in ['sample_weights', 'counter', 'current_search_step', 'rates', 'num_rates'] for key in msg.missing_keys]), \ 151 | 'Missing keys: {}'.format(msg.missing_keys) 152 | if self.rf_mode == 'rfsearch': 153 | self.estimate() 154 | self.expand() 155 | elif self.rf_mode == 'rfsingle': 156 | self.estimate() 157 | self.max_search_step = 0 158 | self.sample_weights.requires_grad = False 159 | elif self.rf_mode == 'rfmultiple': 160 | self.estimate() 161 | self.expand() 162 | # re-initilize the sample_weights 163 | self.sample_weights.data.fill_(self.init_weight) 164 | self.max_search_step = 0 165 | elif self.rf_mode == 'rfmerge': 166 | self.max_search_step = 0 167 | self.sample_weights.requires_grad = False 168 | else: 169 | raise NotImplementedError() 170 | 171 | if self.rf_mode in ['rfsingle', 'rfmerge']: 172 | assert self.num_rates.item() == 1 173 | 174 | def _conv_forward_dilation(self, input, dilation_rate): 175 | if self.padding_mode != 'zeros': 176 | return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 177 | self.weight, self.bias, self.stride, 178 | _pair(0), dilation_rate, self.groups) 179 | else: 180 | padding = ( 181 | dilation_rate[0] * (self.kernel_size[0] - 1) // 2, dilation_rate[1] * (self.kernel_size[1] - 1) // 2) 182 | return F.conv2d(input, self.weight, self.bias, self.stride, 183 | padding, dilation_rate, self.groups) 184 | 185 | def normlize(self, w): 186 | abs_w = torch.abs(w) 187 | norm_w = abs_w / torch.sum(abs_w) 188 | return norm_w 189 | 190 | def forward(self, x): 191 | if self.num_rates.item() == 1: 192 | return super().forward(x) 193 | else: 194 | norm_w = self.normlize(self.sample_weights[:self.num_rates.item()]) 195 | xx = [ 196 | self._conv_forward_dilation( 197 | x, (self.rates[i][0].item(), self.rates[i][1].item())) 198 | * norm_w[i] for i in range(self.num_rates.item()) 199 | ] 200 | x = xx[0] 201 | for i in range(1, self.num_rates.item()): 202 | x += xx[i] 203 | if self.training: 204 | self.searcher() 205 | return x 206 | 207 | def searcher(self): 208 | self.counter += 1 209 | if self.counter % self.search_interval == 0 and self.current_search_step < self.max_search_step and self.max_search_step != 0: 210 | self.counter[0] = 0 211 | self.current_search_step += 1 212 | self.estimate() 213 | self.expand() 214 | 215 | def tensor_to_tuple(self, tensor): 216 | return tuple([(x[0].item(), x[1].item()) for x in tensor]) 217 | 218 | def estimate(self): 219 | norm_w = self.normlize(self.sample_weights[:self.num_rates.item()]) 220 | print('Estimate dilation {} with weight {}.'.format( 221 | self.tensor_to_tuple(self.rates[:self.num_rates.item()]), norm_w.detach().cpu().numpy().tolist())) 222 | 223 | sum0, sum1, w_sum = 0, 0, 0 224 | for i in range(self.num_rates.item()): 225 | sum0 += norm_w[i].item() * self.rates[i][0].item() 226 | sum1 += norm_w[i].item() * self.rates[i][1].item() 227 | w_sum += norm_w[i].item() 228 | estimated = [value_crop( 229 | int(round(sum0 / w_sum)), 230 | self.min_dilation, 231 | self.max_dilation), value_crop( 232 | int(round(sum1 / w_sum)), 233 | self.min_dilation, 234 | self.max_dilation)] 235 | self.dilation = tuple(estimated) 236 | self.padding = ( 237 | get_padding(self.kernel_size[0], self.stride[0], self.dilation[0]), 238 | get_padding(self.kernel_size[1], self.stride[1], self.dilation[1]) 239 | ) 240 | self.rates[0] = torch.FloatTensor([self.dilation[0], self.dilation[1]]) 241 | self.num_rates[0] = 1 242 | print('Estimate as {}'.format(self.dilation)) 243 | 244 | def expand(self): 245 | rates = rf_expand(self.dilation, self.expand_rate, 246 | self.num_branches, 247 | min_dilation=self.min_dilation, 248 | max_dilation=self.max_dilation) 249 | for i, rate in enumerate(rates): 250 | self.rates[i] = torch.FloatTensor([rate[0], rate[1]]) 251 | self.num_rates[0] = len(rates) 252 | self.sample_weights.data.fill_(self.init_weight) 253 | print('Expand as {}'.format(self.rates[:len(rates)].cpu().tolist())) 254 | -------------------------------------------------------------------------------- /models/rfconvnext.py: -------------------------------------------------------------------------------- 1 | """ RFConvNeXt 2 | Paper: RF-Next: Efficient Receptive Field Search for Convolutional Neural Networks 3 | https://arxiv.org/abs/2206.06637 4 | 5 | Modified from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/convnext.py 6 | """ 7 | from collections import OrderedDict 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq 15 | from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\ 16 | create_conv2d, make_divisible, get_padding 17 | from .rfconv import RFConv2d 18 | import os 19 | 20 | __all__ = ['RFConvNeXt'] 21 | 22 | 23 | def _cfg(url='', **kwargs): 24 | return { 25 | 'url': url, 26 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 27 | 'crop_pct': 0.875, 'interpolation': 'bicubic', 28 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 29 | 'first_conv': 'stem.0', 'classifier': 'head.fc', 30 | **kwargs 31 | } 32 | 33 | 34 | default_cfgs = dict( 35 | convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"), 36 | convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"), 37 | convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), 38 | convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), 39 | 40 | # timm specific variants 41 | convnext_atto=_cfg(url=''), 42 | convnext_atto_ols=_cfg(url=''), 43 | convnext_femto=_cfg( 44 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', 45 | test_input_size=(3, 288, 288), test_crop_pct=0.95), 46 | convnext_femto_ols=_cfg( 47 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', 48 | test_input_size=(3, 288, 288), test_crop_pct=0.95), 49 | convnext_pico=_cfg( 50 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', 51 | test_input_size=(3, 288, 288), test_crop_pct=0.95), 52 | convnext_pico_ols=_cfg( 53 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', 54 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), 55 | convnext_nano=_cfg( 56 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', 57 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), 58 | convnext_nano_ols=_cfg( 59 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', 60 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), 61 | convnext_tiny_hnf=_cfg( 62 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', 63 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), 64 | 65 | convnext_tiny_in22ft1k=_cfg( 66 | url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'), 67 | convnext_small_in22ft1k=_cfg( 68 | url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'), 69 | convnext_base_in22ft1k=_cfg( 70 | url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), 71 | convnext_large_in22ft1k=_cfg( 72 | url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'), 73 | convnext_xlarge_in22ft1k=_cfg( 74 | url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), 75 | 76 | convnext_tiny_384_in22ft1k=_cfg( 77 | url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', 78 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), 79 | convnext_small_384_in22ft1k=_cfg( 80 | url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', 81 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), 82 | convnext_base_384_in22ft1k=_cfg( 83 | url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', 84 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), 85 | convnext_large_384_in22ft1k=_cfg( 86 | url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', 87 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), 88 | convnext_xlarge_384_in22ft1k=_cfg( 89 | url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', 90 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), 91 | 92 | convnext_tiny_in22k=_cfg( 93 | url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), 94 | convnext_small_in22k=_cfg( 95 | url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), 96 | convnext_base_in22k=_cfg( 97 | url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), 98 | convnext_large_in22k=_cfg( 99 | url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), 100 | convnext_xlarge_in22k=_cfg( 101 | url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), 102 | ) 103 | 104 | 105 | default_search_cfg = dict( 106 | num_branches=3, 107 | expand_rate=0.5, 108 | max_dilation=None, 109 | min_dilation=1, 110 | init_weight=0.01, 111 | search_interval=1250, 112 | max_search_step=0, 113 | ) 114 | 115 | 116 | class RFConvNeXtBlock(nn.Module): 117 | """ ConvNeXt Block 118 | There are two equivalent implementations: 119 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 120 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 121 | 122 | Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate 123 | choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear 124 | is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. 125 | 126 | Args: 127 | dim (int): Number of input channels. 128 | drop_path (float): Stochastic depth rate. Default: 0.0 129 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim, 135 | dim_out=None, 136 | stride=1, 137 | dilation=1, 138 | mlp_ratio=4, 139 | conv_mlp=False, 140 | conv_bias=True, 141 | ls_init_value=1e-6, 142 | norm_layer=None, 143 | act_layer=nn.GELU, 144 | drop_path=0., 145 | search_cfgs=default_search_cfg 146 | ): 147 | super().__init__() 148 | dim_out = dim_out or dim 149 | if not norm_layer: 150 | norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) 151 | mlp_layer = ConvMlp if conv_mlp else Mlp 152 | self.use_conv_mlp = conv_mlp 153 | 154 | # replace dwconv with rfconv 155 | self.conv_dw = RFConv2d( 156 | in_channels=dim, 157 | out_channels=dim_out, 158 | kernel_size=7, 159 | stride=stride, 160 | padding=get_padding(kernel_size=7, stride=stride, dilation=dilation), 161 | dilation=dilation, 162 | groups=dim, 163 | bias=conv_bias, 164 | **search_cfgs) 165 | self.norm = norm_layer(dim_out) 166 | self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) 167 | self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None 168 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 169 | 170 | def forward(self, x): 171 | shortcut = x 172 | x = self.conv_dw(x) 173 | if self.use_conv_mlp: 174 | x = self.norm(x) 175 | x = self.mlp(x) 176 | else: 177 | x = x.permute(0, 2, 3, 1) 178 | x = self.norm(x) 179 | x = self.mlp(x) 180 | x = x.permute(0, 3, 1, 2) 181 | if self.gamma is not None: 182 | x = x.mul(self.gamma.reshape(1, -1, 1, 1)) 183 | 184 | x = self.drop_path(x) + shortcut 185 | return x 186 | 187 | 188 | class RFConvNeXtStage(nn.Module): 189 | 190 | def __init__( 191 | self, 192 | in_chs, 193 | out_chs, 194 | stride=2, 195 | depth=2, 196 | dilation=(1, 1), 197 | drop_path_rates=None, 198 | ls_init_value=1.0, 199 | conv_mlp=False, 200 | conv_bias=True, 201 | norm_layer=None, 202 | norm_layer_cl=None, 203 | search_cfgs=default_search_cfg 204 | ): 205 | super().__init__() 206 | self.grad_checkpointing = False 207 | 208 | if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: 209 | ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 210 | pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used 211 | self.downsample = nn.Sequential( 212 | norm_layer(in_chs), 213 | create_conv2d( 214 | in_chs, out_chs, kernel_size=ds_ks, stride=stride, 215 | dilation=dilation[0], padding=pad, bias=conv_bias), 216 | ) 217 | in_chs = out_chs 218 | else: 219 | self.downsample = nn.Identity() 220 | 221 | drop_path_rates = drop_path_rates or [0.] * depth 222 | stage_blocks = [] 223 | for i in range(depth): 224 | stage_blocks.append(RFConvNeXtBlock( 225 | dim=in_chs, 226 | dim_out=out_chs, 227 | dilation=dilation[1], 228 | drop_path=drop_path_rates[i], 229 | ls_init_value=ls_init_value, 230 | conv_mlp=conv_mlp, 231 | conv_bias=conv_bias, 232 | norm_layer=norm_layer if conv_mlp else norm_layer_cl, 233 | search_cfgs=search_cfgs 234 | )) 235 | in_chs = out_chs 236 | self.blocks = nn.Sequential(*stage_blocks) 237 | 238 | def forward(self, x): 239 | x = self.downsample(x) 240 | if self.grad_checkpointing and not torch.jit.is_scripting(): 241 | x = checkpoint_seq(self.blocks, x) 242 | else: 243 | x = self.blocks(x) 244 | return x 245 | 246 | 247 | class RFConvNeXt(nn.Module): 248 | r""" ConvNeXt 249 | A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf 250 | 251 | Args: 252 | in_chans (int): Number of input image channels. Default: 3 253 | num_classes (int): Number of classes for classification head. Default: 1000 254 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] 255 | dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] 256 | drop_rate (float): Head dropout rate 257 | drop_path_rate (float): Stochastic depth rate. Default: 0. 258 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6. 259 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. 260 | rf_mode (str): Training mode for RF-Next. Choose from ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge']. 261 | kernel_cfgs (Dict(str, int)): Kernel size for each RFConv. Example: {"stages.0.blocks.0.conv_dw": 7, "stages.0.blocks.1.conv_dw": 7, ...}. 262 | """ 263 | 264 | def __init__( 265 | self, 266 | in_chans=3, 267 | num_classes=1000, 268 | global_pool='avg', 269 | output_stride=32, 270 | depths=(3, 3, 9, 3), 271 | dims=(96, 192, 384, 768), 272 | ls_init_value=1e-6, 273 | stem_type='patch', 274 | patch_size=4, 275 | head_init_scale=1., 276 | head_norm_first=False, 277 | conv_mlp=False, 278 | conv_bias=True, 279 | norm_layer=None, 280 | drop_rate=0., 281 | drop_path_rate=0., 282 | pretrained_weights=None, 283 | rf_mode='rfsearch', 284 | kernel_cfgs=None, 285 | search_cfgs=default_search_cfg 286 | ): 287 | super().__init__() 288 | assert output_stride in (8, 16, 32) 289 | if norm_layer is None: 290 | norm_layer = partial(LayerNorm2d, eps=1e-6) 291 | norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) 292 | else: 293 | assert conv_mlp,\ 294 | 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' 295 | norm_layer_cl = norm_layer 296 | 297 | self.num_classes = num_classes 298 | self.drop_rate = drop_rate 299 | self.feature_info = [] 300 | 301 | assert stem_type in ('patch', 'overlap', 'overlap_tiered') 302 | if stem_type == 'patch': 303 | # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 304 | self.stem = nn.Sequential( 305 | nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), 306 | norm_layer(dims[0]) 307 | ) 308 | stem_stride = patch_size 309 | else: 310 | mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0] 311 | self.stem = nn.Sequential( 312 | nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias), 313 | nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias), 314 | norm_layer(dims[0]), 315 | ) 316 | stem_stride = 4 317 | 318 | self.stages = nn.Sequential() 319 | dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] 320 | stages = [] 321 | prev_chs = dims[0] 322 | curr_stride = stem_stride 323 | dilation = 1 324 | # 4 feature resolution stages, each consisting of multiple residual blocks 325 | for i in range(4): 326 | stride = 2 if curr_stride == 2 or i > 0 else 1 327 | if curr_stride >= output_stride and stride > 1: 328 | dilation *= stride 329 | stride = 1 330 | curr_stride *= stride 331 | first_dilation = 1 if dilation in (1, 2) else 2 332 | out_chs = dims[i] 333 | stages.append(RFConvNeXtStage( 334 | prev_chs, 335 | out_chs, 336 | stride=stride, 337 | dilation=(first_dilation, dilation), 338 | depth=depths[i], 339 | drop_path_rates=dp_rates[i], 340 | ls_init_value=ls_init_value, 341 | conv_mlp=conv_mlp, 342 | conv_bias=conv_bias, 343 | norm_layer=norm_layer, 344 | norm_layer_cl=norm_layer_cl, 345 | search_cfgs=search_cfgs 346 | )) 347 | prev_chs = out_chs 348 | # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 349 | self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] 350 | self.stages = nn.Sequential(*stages) 351 | self.num_features = prev_chs 352 | 353 | # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets 354 | # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) 355 | self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() 356 | self.head = nn.Sequential(OrderedDict([ 357 | ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), 358 | ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), 359 | ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), 360 | ('drop', nn.Dropout(self.drop_rate)), 361 | ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) 362 | 363 | named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) 364 | 365 | # RF-Next 366 | self.prepare_rfsearch(pretrained_weights, rf_mode, kernel_cfgs, search_cfgs) 367 | if self.rf_mode not in ['rfsearch', 'rfmultiple']: 368 | for n, p, in self.named_parameters(): 369 | if 'sample_weights' in n: 370 | p.requires_grad = False 371 | 372 | def prepare_rfsearch(self, pretrained_weights, rf_mode, kernel_cfgs, search_cfgs): 373 | self.rf_mode = rf_mode 374 | self.pretrained_weights = pretrained_weights 375 | assert self.rf_mode in ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge'], \ 376 | "rf_mode should be in ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge']." 377 | if pretrained_weights is None or not os.path.exists(pretrained_weights): 378 | checkpoint = None 379 | else: 380 | checkpoint = torch.load(pretrained_weights, map_location='cpu') 381 | checkpoint = checkpoint_filter_fn(checkpoint, self) 382 | # Remove the prefix in checkpint, e.g., 'backbone' and 'module', 383 | # to guarantee the matching between 'checkpoint' and 'model.state_dict'. 384 | checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()} 385 | checkpoint = {k.replace('backbone.', ''): v for k, v in checkpoint.items()} 386 | for name in list(checkpoint.keys()): 387 | if name.endswith('counter') or name.endswith('current_search_step'): 388 | # Do not load pretrained buffer of counter and current_step!!!!!!! 389 | print(f"RF-Next: Removing key {name} from pretrained checkpoint") 390 | del checkpoint[name] 391 | 392 | # Remove the parameters with mismatched shape from checkpoint 393 | for name, module in self.named_parameters(): 394 | if name in checkpoint and module.shape != checkpoint[name].shape: 395 | print(f"RF-Next: Removing key {name} from pretrained checkpoint") 396 | del checkpoint[name] 397 | # Load the pretrained weights for a rfconv. 398 | # The pretarined weights are obtained after rfseach. 399 | msg = self.load_state_dict(checkpoint, strict=False) 400 | missing_keys = list(msg.missing_keys) 401 | missing_keys = list(filter(lambda x: not x.endswith('.counter') and not x.endswith('.current_search_step'), missing_keys)) 402 | print('RF-Next: RF-Next init, missing keys: {}'.format(missing_keys)) 403 | 404 | print('RF-Next: convert rfconv.') 405 | # Convert conv to rfconv 406 | def convert_rfconv(module, prefix): 407 | module_output = module 408 | if isinstance(module, RFConv2d): 409 | if kernel_cfgs is not None: 410 | kernel = kernel_cfgs[prefix] 411 | else: 412 | kernel = module.kernel_size 413 | if checkpoint is not None: 414 | module_pretrained = dict() 415 | # Load the pretrained weights for a rfconv. 416 | # The pretarined weights are obtained after rfseach. 417 | for k in checkpoint.keys(): 418 | if k.startswith(prefix): 419 | module_pretrained[k.replace('{}.'.format(prefix), '')] = checkpoint[k] 420 | else: 421 | module_pretrained = None 422 | if isinstance(kernel, int): 423 | kernel = (kernel, kernel) 424 | module_output = RFConv2d( 425 | in_channels=module.in_channels, 426 | out_channels=module.out_channels, 427 | kernel_size=kernel, 428 | stride=module.stride, 429 | padding=( 430 | get_padding(kernel[0], module.stride[0], module.dilation[0]), 431 | get_padding(kernel[1], module.stride[1], module.dilation[1])), 432 | dilation=module.dilation, 433 | groups=module.groups, 434 | bias=hasattr(module, 'bias'), 435 | rf_mode=self.rf_mode, 436 | pretrained=module_pretrained, 437 | **search_cfgs 438 | ) 439 | 440 | for name, child in module.named_children(): 441 | fullname = name 442 | if prefix != '': 443 | fullname = prefix + '.' + name 444 | # Replace the conv with rfconv。 445 | module_output.add_module(name, convert_rfconv(child, fullname)) 446 | del module 447 | return module_output 448 | 449 | convert_rfconv(self, '') 450 | 451 | if self.rf_mode == 'rfmerge': 452 | # Show the kernel sizes after rfmerge。 453 | rfmerge = dict() 454 | for name, module in self.named_modules(): 455 | if isinstance(module, RFConv2d): 456 | rfmerge[name] = module.kernel_size 457 | 458 | print('Merged structure:') 459 | print(rfmerge) 460 | print('RF-Next: convert done.') 461 | 462 | @torch.jit.ignore 463 | def group_matcher(self, coarse=False): 464 | return dict( 465 | stem=r'^stem', 466 | blocks=r'^stages\.(\d+)' if coarse else [ 467 | (r'^stages\.(\d+)\.downsample', (0,)), # blocks 468 | (r'^stages\.(\d+)\.blocks\.(\d+)', None), 469 | (r'^norm_pre', (99999,)) 470 | ] 471 | ) 472 | 473 | @torch.jit.ignore 474 | def set_grad_checkpointing(self, enable=True): 475 | for s in self.stages: 476 | s.grad_checkpointing = enable 477 | 478 | @torch.jit.ignore 479 | def get_classifier(self): 480 | return self.head.fc 481 | 482 | def reset_classifier(self, num_classes=0, global_pool=None): 483 | if global_pool is not None: 484 | self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) 485 | self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() 486 | self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 487 | 488 | def forward_features(self, x): 489 | x = self.stem(x) 490 | x = self.stages(x) 491 | x = self.norm_pre(x) 492 | return x 493 | 494 | def forward_head(self, x, pre_logits: bool = False): 495 | # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( 496 | x = self.head.global_pool(x) 497 | x = self.head.norm(x) 498 | x = self.head.flatten(x) 499 | x = self.head.drop(x) 500 | return x if pre_logits else self.head.fc(x) 501 | 502 | def forward(self, x): 503 | x = self.forward_features(x) 504 | x = self.forward_head(x) 505 | return x 506 | 507 | 508 | def _init_weights(module, name=None, head_init_scale=1.0): 509 | if isinstance(module, nn.Conv2d): 510 | trunc_normal_(module.weight, std=.02) 511 | if module.bias is not None: 512 | nn.init.zeros_(module.bias) 513 | elif isinstance(module, nn.Linear): 514 | trunc_normal_(module.weight, std=.02) 515 | nn.init.zeros_(module.bias) 516 | if name and 'head.' in name: 517 | module.weight.data.mul_(head_init_scale) 518 | module.bias.data.mul_(head_init_scale) 519 | 520 | 521 | def checkpoint_filter_fn(state_dict, model): 522 | """ Remap FB checkpoints -> timm """ 523 | if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: 524 | return state_dict # non-FB checkpoint 525 | if 'model' in state_dict: 526 | state_dict = state_dict['model'] 527 | out_dict = {} 528 | import re 529 | for k, v in state_dict.items(): 530 | k = k.replace('downsample_layers.0.', 'stem.') 531 | k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) 532 | k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) 533 | k = k.replace('dwconv', 'conv_dw') 534 | k = k.replace('pwconv', 'mlp.fc') 535 | k = k.replace('head.', 'head.fc.') 536 | if k.startswith('norm.'): 537 | k = k.replace('norm', 'head.norm') 538 | if v.ndim == 2 and 'head' not in k: 539 | model_shape = model.state_dict()[k].shape 540 | v = v.reshape(model_shape) 541 | if ('current_search_step' in k) or ('counter' in k): 542 | continue 543 | out_dict[k] = v 544 | return out_dict 545 | 546 | 547 | def _create_rfconvnext(variant, pretrained=False, **kwargs): 548 | model = build_model_with_cfg( 549 | RFConvNeXt, variant, pretrained, 550 | pretrained_filter_fn=checkpoint_filter_fn, 551 | feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), 552 | **kwargs) 553 | return model 554 | 555 | 556 | def rfconvnext_tiny(pretrained=False, **kwargs): 557 | model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) 558 | model = _create_rfconvnext('convnext_tiny', pretrained=pretrained, **model_args) 559 | return model 560 | 561 | 562 | def rfconvnext_small(pretrained=False, **kwargs): 563 | model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) 564 | model = _create_rfconvnext('convnext_small', pretrained=pretrained, **model_args) 565 | return model 566 | 567 | 568 | def rfconvnext_base(pretrained=False, **kwargs): 569 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) 570 | model = _create_rfconvnext('convnext_base', pretrained=pretrained, **model_args) 571 | return model 572 | 573 | 574 | def rfconvnext_large(pretrained=False, **kwargs): 575 | model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 576 | model = _create_rfconvnext('convnext_large', pretrained=pretrained, **model_args) 577 | return model 578 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import random 13 | 14 | import numpy as np 15 | import torch 16 | from PIL import Image, ImageFilter 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | from torchvision import datasets, transforms 19 | 20 | import util.transforms as custom_transforms 21 | 22 | 23 | def build_dataset(is_train, args): 24 | transform = build_transform(is_train) 25 | data_root = os.path.join(args.data_path, 26 | 'train-semi' if is_train else 'validation') 27 | gt_root = os.path.join( 28 | args.data_path, 29 | 'train-semi-segmentation' if is_train else 'validation-segmentation') 30 | dataset = SegDataset(data_root, gt_root, transform, is_train) 31 | return dataset 32 | 33 | 34 | def build_transform(is_train): 35 | mean = IMAGENET_DEFAULT_MEAN 36 | std = IMAGENET_DEFAULT_STD 37 | # train transform 38 | if is_train: 39 | # this should always dispatch to transforms_imagenet_train 40 | color_transform = [get_color_distortion(), PILRandomGaussianBlur()] 41 | randomresizedcrop = custom_transforms.RandomResizedCrop( 42 | 224, 43 | scale=(0.14, 1), 44 | ) 45 | transform = custom_transforms.Compose([ 46 | randomresizedcrop, 47 | custom_transforms.RandomHorizontalFlip(p=0.5), 48 | transforms.Compose(color_transform), 49 | custom_transforms.ToTensor(), 50 | transforms.Normalize(mean=mean, std=std) 51 | ]) 52 | return transform 53 | 54 | # eval transform 55 | t = [] 56 | t.append(transforms.Resize(256)) 57 | t.append(transforms.ToTensor()) 58 | t.append(transforms.Normalize(mean, std)) 59 | return transforms.Compose(t) 60 | 61 | 62 | class SegDataset(datasets.ImageFolder): 63 | def __init__(self, data_root, gt_root=None, transform=None, is_train=True): 64 | super(SegDataset, self).__init__(data_root) 65 | assert gt_root is not None 66 | self.gt_root = gt_root 67 | self.transform = transform 68 | self.is_train = is_train 69 | 70 | def __getitem__(self, index): 71 | path, _ = self.samples[index] 72 | image = self.loader(path) 73 | segmentation = self.load_segmentation(path) 74 | 75 | if self.is_train: 76 | image, segmentation = self.transform(image, segmentation) 77 | else: 78 | image = self.transform(image) 79 | segmentation = torch.from_numpy(np.array(segmentation)) 80 | segmentation = segmentation.long() 81 | 82 | segmentation = segmentation[:, :, 1] * 256 + segmentation[:, :, 0] 83 | return image, segmentation 84 | 85 | def load_segmentation(self, path): 86 | cate, name = path.split('/')[-2:] 87 | name = name.replace('JPEG', 'png') 88 | path = os.path.join(self.gt_root, cate, name) 89 | segmentation = Image.open(path) 90 | return segmentation 91 | 92 | 93 | class PILRandomGaussianBlur(object): 94 | """Apply Gaussian Blur to the PIL image. Take the radius and probability of 95 | application as the parameter. 96 | 97 | This transform was used in SimCLR - https://arxiv.org/abs/2002.05709 98 | """ 99 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.): 100 | self.prob = p 101 | self.radius_min = radius_min 102 | self.radius_max = radius_max 103 | 104 | def __call__(self, img): 105 | do_it = np.random.rand() <= self.prob 106 | if not do_it: 107 | return img 108 | 109 | return img.filter( 110 | ImageFilter.GaussianBlur( 111 | radius=random.uniform(self.radius_min, self.radius_max))) 112 | 113 | 114 | def get_color_distortion(s=1.0): 115 | # s is the strength of color distortion. 116 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s) 117 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 118 | rnd_gray = transforms.RandomGrayscale(p=0.2) 119 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray]) 120 | return color_distort 121 | -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | 13 | def param_groups_lrd(model, 14 | weight_decay=0.05, 15 | no_weight_decay_list=[], 16 | layer_decay=[.75], 17 | layer_multiplier=[1.0]): 18 | """Parameter groups for layer-wise lr decay Following BEiT: https://github. 19 | 20 | com/microsoft/unilm/blob/master/beit/optim_factory.py#L58. 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = model.num_layers 26 | 27 | if isinstance(layer_decay, (float, int)): 28 | layer_decay = [layer_decay] 29 | 30 | layer_scales = [ 31 | list(decay**(num_layers - i) for i in range(num_layers + 1)) for decay in layer_decay] 32 | 33 | for n, p in model.named_parameters(): 34 | if not p.requires_grad: 35 | continue 36 | 37 | # no decay: all 1D parameters and model specific ones 38 | if p.ndim == 1 or n in no_weight_decay_list: 39 | g_decay = 'no_decay' 40 | this_decay = 0. 41 | else: 42 | g_decay = 'decay' 43 | this_decay = weight_decay 44 | 45 | """ 46 | For each layer, the get_layer_id returns (layer_group, layer_id). 47 | According to the layer_group, different parameters are grouped, 48 | and layers in different groups use different decay rates. 49 | 50 | If only the layer_id is returned, the layer_group are set to 0 by default. 51 | """ 52 | layer_group_id = model.get_layer_id(n) 53 | if isinstance(layer_group_id, (list, tuple)): 54 | layer_group, layer_id = layer_group_id 55 | elif isinstance(layer_group_id, int): 56 | layer_group, layer_id = 0, layer_group_id 57 | else: 58 | raise NotImplementedError() 59 | group_name = 'layer_%d_%d_%s' % (layer_group, layer_id, g_decay) 60 | 61 | if group_name not in param_group_names: 62 | this_scale = layer_scales[layer_group][layer_id] * layer_multiplier[layer_group] 63 | 64 | param_group_names[group_name] = { 65 | 'lr_scale': this_scale, 66 | 'weight_decay': this_decay, 67 | 'params': [], 68 | } 69 | param_groups[group_name] = { 70 | 'lr_scale': this_scale, 71 | 'weight_decay': this_decay, 72 | 'params': [], 73 | } 74 | 75 | param_group_names[group_name]['params'].append(n) 76 | param_groups[group_name]['params'].append(p) 77 | 78 | return list(param_groups.values()) 79 | -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup.""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 16 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / 17 | (args.epochs - args.warmup_epochs))) 18 | for param_group in optimizer.param_groups: 19 | if 'lr_scale' in param_group: 20 | param_group['lr'] = lr * param_group['lr_scale'] 21 | else: 22 | param_group['lr'] = lr 23 | return lr 24 | -------------------------------------------------------------------------------- /util/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def IoUGPU(output, target, K): 5 | # 'K' classes, output and target sizes are 6 | # N or N * L or N * H * W, each value in range 0 to K - 1. 7 | assert (output.dim() in [1, 2, 3]) 8 | assert output.shape == target.shape 9 | output = output.view(-1) 10 | target = target.view(-1) 11 | intersection = output[output == target] 12 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) 13 | area_output = torch.histc(output, bins=K, min=0, max=K - 1) 14 | area_target = torch.histc(target, bins=K, min=0, max=K - 1) 15 | return area_intersection, area_output, area_target 16 | 17 | 18 | def FMeasureGPU(output, target, eps=1e-20, beta=0.3): 19 | target = (target > 0) * 1.0 20 | output = (output > 0) * 1.0 21 | 22 | t = torch.sum(target) 23 | p = torch.sum(output) 24 | tp = torch.sum(target * output) 25 | recall = tp / (t + eps) 26 | precision = tp / (p + eps) 27 | f_score = (1 + beta) * precision * recall / (beta * precision + recall + 28 | eps) 29 | 30 | return f_score 31 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch._six import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average.""" 27 | def __init__(self, window_size=20, fmt=None): 28 | if fmt is None: 29 | fmt = '{median:.4f} ({global_avg:.4f})' 30 | self.deque = deque(maxlen=window_size) 31 | self.total = 0.0 32 | self.count = 0 33 | self.fmt = fmt 34 | 35 | def update(self, value, n=1): 36 | self.deque.append(value) 37 | self.count += n 38 | self.total += value * n 39 | 40 | def synchronize_between_processes(self): 41 | """ 42 | Warning: does not synchronize the deque! 43 | """ 44 | if not is_dist_avail_and_initialized(): 45 | return 46 | t = torch.tensor([self.count, self.total], 47 | dtype=torch.float64, 48 | device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format(median=self.median, 79 | avg=self.avg, 80 | global_avg=self.global_avg, 81 | max=self.max, 82 | value=self.value) 83 | 84 | 85 | class MetricLogger(object): 86 | def __init__(self, delimiter='\t'): 87 | self.meters = defaultdict(SmoothedValue) 88 | self.delimiter = delimiter 89 | 90 | def update(self, **kwargs): 91 | for k, v in kwargs.items(): 92 | if v is None: 93 | continue 94 | if isinstance(v, torch.Tensor): 95 | v = v.item() 96 | assert isinstance(v, (float, int)) 97 | self.meters[k].update(v) 98 | 99 | def __getattr__(self, attr): 100 | if attr in self.meters: 101 | return self.meters[attr] 102 | if attr in self.__dict__: 103 | return self.__dict__[attr] 104 | raise AttributeError("'{}' object has no attribute '{}'".format( 105 | type(self).__name__, attr)) 106 | 107 | def __str__(self): 108 | loss_str = [] 109 | for name, meter in self.meters.items(): 110 | loss_str.append('{}: {}'.format(name, str(meter))) 111 | return self.delimiter.join(loss_str) 112 | 113 | def synchronize_between_processes(self): 114 | for meter in self.meters.values(): 115 | meter.synchronize_between_processes() 116 | 117 | def add_meter(self, name, meter): 118 | self.meters[name] = meter 119 | 120 | def log_every(self, iterable, print_freq, header=None): 121 | i = 0 122 | if not header: 123 | header = '' 124 | start_time = time.time() 125 | end = time.time() 126 | iter_time = SmoothedValue(fmt='{avg:.4f}') 127 | data_time = SmoothedValue(fmt='{avg:.4f}') 128 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 129 | log_msg = [ 130 | header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 131 | 'time: {time}', 'data: {data}' 132 | ] 133 | if torch.cuda.is_available(): 134 | log_msg.append('max mem: {memory:.0f}') 135 | log_msg = self.delimiter.join(log_msg) 136 | MB = 1024.0 * 1024.0 137 | for obj in iterable: 138 | data_time.update(time.time() - end) 139 | yield obj 140 | iter_time.update(time.time() - end) 141 | if i % print_freq == 0 or i == len(iterable) - 1: 142 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 144 | if torch.cuda.is_available(): 145 | print( 146 | log_msg.format( 147 | i, 148 | len(iterable), 149 | eta=eta_string, 150 | meters=str(self), 151 | time=str(iter_time), 152 | data=str(data_time), 153 | memory=torch.cuda.max_memory_allocated() / MB)) 154 | else: 155 | print( 156 | log_msg.format(i, 157 | len(iterable), 158 | eta=eta_string, 159 | meters=str(self), 160 | time=str(iter_time), 161 | data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """This function disables printing when not in master process.""" 172 | builtin_print = builtins.print 173 | 174 | def print(*args, **kwargs): 175 | force = kwargs.pop('force', False) 176 | force = force or (get_world_size() > 8) 177 | if is_master or force: 178 | now = datetime.datetime.now().time() 179 | builtin_print('[{}] '.format(now), end='') # print with time stamp 180 | builtin_print(*args, **kwargs) 181 | 182 | builtins.print = print 183 | 184 | 185 | def is_dist_avail_and_initialized(): 186 | if not dist.is_available(): 187 | return False 188 | if not dist.is_initialized(): 189 | return False 190 | return True 191 | 192 | 193 | def get_world_size(): 194 | if not is_dist_avail_and_initialized(): 195 | return 1 196 | return dist.get_world_size() 197 | 198 | 199 | def get_rank(): 200 | if not is_dist_avail_and_initialized(): 201 | return 0 202 | return dist.get_rank() 203 | 204 | 205 | def is_main_process(): 206 | return get_rank() == 0 207 | 208 | 209 | def save_on_master(*args, **kwargs): 210 | if is_main_process(): 211 | torch.save(*args, **kwargs) 212 | 213 | 214 | def init_distributed_mode(args): 215 | if args.dist_on_itp: 216 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 217 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 218 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 219 | args.dist_url = 'tcp://%s:%s' % (os.environ['MASTER_ADDR'], 220 | os.environ['MASTER_PORT']) 221 | os.environ['LOCAL_RANK'] = str(args.gpu) 222 | os.environ['RANK'] = str(args.rank) 223 | os.environ['WORLD_SIZE'] = str(args.world_size) 224 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 225 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 226 | args.rank = int(os.environ['RANK']) 227 | args.world_size = int(os.environ['WORLD_SIZE']) 228 | args.gpu = int(os.environ['LOCAL_RANK']) 229 | elif 'SLURM_PROCID' in os.environ: 230 | args.rank = int(os.environ['SLURM_PROCID']) 231 | args.gpu = args.rank % torch.cuda.device_count() 232 | else: 233 | print('Not using distributed mode') 234 | setup_for_distributed(is_master=True) # hack 235 | args.distributed = False 236 | return 237 | 238 | args.distributed = True 239 | 240 | torch.cuda.set_device(args.gpu) 241 | args.dist_backend = 'nccl' 242 | print('| distributed init (rank {}): {}, gpu {}'.format( 243 | args.rank, args.dist_url, args.gpu), 244 | flush=True) 245 | torch.distributed.init_process_group(backend=args.dist_backend, 246 | init_method=args.dist_url, 247 | world_size=args.world_size, 248 | rank=args.rank) 249 | torch.distributed.barrier() 250 | setup_for_distributed(args.rank == 0) 251 | 252 | 253 | class NativeScalerWithGradNormCount: 254 | state_dict_key = 'amp_scaler' 255 | 256 | def __init__(self): 257 | self._scaler = torch.cuda.amp.GradScaler() 258 | 259 | def __call__(self, 260 | loss, 261 | optimizer, 262 | clip_grad=None, 263 | parameters=None, 264 | create_graph=False, 265 | update_grad=True): 266 | self._scaler.scale(loss).backward(create_graph=create_graph) 267 | if update_grad: 268 | if clip_grad is not None: 269 | assert parameters is not None 270 | self._scaler.unscale_(optimizer) 271 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 272 | else: 273 | self._scaler.unscale_(optimizer) 274 | norm = get_grad_norm_(parameters) 275 | self._scaler.step(optimizer) 276 | self._scaler.update() 277 | else: 278 | norm = None 279 | return norm 280 | 281 | def state_dict(self): 282 | return self._scaler.state_dict() 283 | 284 | def load_state_dict(self, state_dict): 285 | self._scaler.load_state_dict(state_dict) 286 | 287 | 288 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 289 | if isinstance(parameters, torch.Tensor): 290 | parameters = [parameters] 291 | parameters = [p for p in parameters if p.grad is not None] 292 | norm_type = float(norm_type) 293 | if len(parameters) == 0: 294 | return torch.tensor(0.) 295 | device = parameters[0].grad.device 296 | if norm_type == inf: 297 | total_norm = max(p.grad.detach().abs().max().to(device) 298 | for p in parameters) 299 | else: 300 | total_norm = torch.norm( 301 | torch.stack([ 302 | torch.norm(p.grad.detach(), norm_type).to(device) 303 | for p in parameters 304 | ]), norm_type) 305 | return total_norm 306 | 307 | 308 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 309 | output_dir = Path(args.output_dir) 310 | epoch_name = str(epoch) 311 | if loss_scaler is not None: 312 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 313 | for checkpoint_path in checkpoint_paths: 314 | to_save = { 315 | 'model': model_without_ddp.state_dict(), 316 | 'optimizer': optimizer.state_dict(), 317 | 'epoch': epoch, 318 | 'scaler': loss_scaler.state_dict(), 319 | 'args': args, 320 | } 321 | 322 | save_on_master(to_save, checkpoint_path) 323 | else: 324 | client_state = {'epoch': epoch} 325 | model.save_checkpoint(save_dir=args.output_dir, 326 | tag='checkpoint-%s' % epoch_name, 327 | client_state=client_state) 328 | 329 | 330 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 331 | if args.resume: 332 | if args.resume.startswith('https'): 333 | checkpoint = torch.hub.load_state_dict_from_url(args.resume, 334 | map_location='cpu', 335 | check_hash=True) 336 | else: 337 | checkpoint = torch.load(args.resume, map_location='cpu') 338 | model_without_ddp.load_state_dict(checkpoint['model']) 339 | print('Resume checkpoint %s' % args.resume) 340 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not ( 341 | hasattr(args, 'eval') and args.eval): 342 | optimizer.load_state_dict(checkpoint['optimizer']) 343 | args.start_epoch = checkpoint['epoch'] + 1 344 | if 'scaler' in checkpoint: 345 | loss_scaler.load_state_dict(checkpoint['scaler']) 346 | print('With optim & sched!') 347 | 348 | 349 | def all_reduce_mean(x): 350 | world_size = get_world_size() 351 | if world_size > 1: 352 | x_reduce = torch.tensor(x).cuda() 353 | dist.all_reduce(x_reduce) 354 | x_reduce /= world_size 355 | return x_reduce.item() 356 | else: 357 | return x 358 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: 18 | # https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 19 | # MoCo v3: 20 | # https://github.com/facebookresearch/moco-v3 21 | # -------------------------------------------------------- 22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 23 | """ 24 | grid_size: int of the grid height and width 25 | return: 26 | pos_embed: [grid_size*grid_size, embed_dim] 27 | or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 28 | """ 29 | grid_h = np.arange(grid_size, dtype=np.float32) 30 | grid_w = np.arange(grid_size, dtype=np.float32) 31 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 32 | grid = np.stack(grid, axis=0) 33 | 34 | grid = grid.reshape([2, 1, grid_size, grid_size]) 35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 36 | if cls_token: 37 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], 38 | axis=0) 39 | return pos_embed 40 | 41 | 42 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 43 | assert embed_dim % 2 == 0 44 | 45 | # use half of dimensions to encode grid_h 46 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, 47 | grid[0]) # (H*W, D/2) 48 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, 49 | grid[1]) # (H*W, D/2) 50 | 51 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 52 | return emb 53 | 54 | 55 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 56 | """ 57 | embed_dim: output dimension for each position 58 | pos: a list of positions to be encoded: size (M,) 59 | out: (M, D) 60 | """ 61 | assert embed_dim % 2 == 0 62 | omega = np.arange(embed_dim // 2, dtype=np.float) 63 | omega /= embed_dim / 2. 64 | omega = 1. / 10000**omega # (D/2,) 65 | 66 | pos = pos.reshape(-1) # (M,) 67 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 68 | 69 | emb_sin = np.sin(out) # (M, D/2) 70 | emb_cos = np.cos(out) # (M, D/2) 71 | 72 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 73 | return emb 74 | 75 | 76 | # -------------------------------------------------------- 77 | # Interpolate position embeddings for high-resolution 78 | # References: 79 | # DeiT: https://github.com/facebookresearch/deit 80 | # -------------------------------------------------------- 81 | def interpolate_pos_embed(model, checkpoint_model): 82 | if 'pos_embed' in checkpoint_model: 83 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 84 | embedding_size = pos_embed_checkpoint.shape[-1] 85 | num_patches = model.patch_embed.num_patches 86 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 87 | # height (== width) for the checkpoint position embedding 88 | orig_size = int( 89 | (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5) 90 | # height (== width) for the new position embedding 91 | new_size = int(num_patches**0.5) 92 | # class_token and dist_token are kept unchanged 93 | if orig_size != new_size: 94 | print('Position interpolate from %dx%d to %dx%d' % 95 | (orig_size, orig_size, new_size, new_size)) 96 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 97 | # only the position tokens are interpolated 98 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 99 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, 100 | embedding_size).permute( 101 | 0, 3, 1, 2).contiguous() 102 | pos_tokens = torch.nn.functional.interpolate(pos_tokens, 103 | size=(new_size, 104 | new_size), 105 | mode='bicubic', 106 | align_corners=False) 107 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2) 108 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 109 | checkpoint_model['pos_embed'] = new_pos_embed 110 | -------------------------------------------------------------------------------- /util/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import random 5 | import warnings 6 | from collections import Iterable 7 | 8 | import numpy as np 9 | import torch 10 | from torchvision.transforms import functional as F 11 | 12 | try: 13 | from torchvision.transforms import InterpolationMode 14 | 15 | NEAREST = InterpolationMode.NEAREST 16 | BILINEAR = InterpolationMode.BILINEAR 17 | BICUBIC = InterpolationMode.BICUBIC 18 | LANCZOS = InterpolationMode.LANCZOS 19 | HAMMING = InterpolationMode.HAMMING 20 | HAMMING = InterpolationMode.HAMMING 21 | 22 | _pil_interpolation_to_str = { 23 | InterpolationMode.NEAREST: 'InterpolationMode.NEAREST', 24 | InterpolationMode.BILINEAR: 'InterpolationMode.BILINEAR', 25 | InterpolationMode.BICUBIC: 'InterpolationMode.BICUBIC', 26 | InterpolationMode.LANCZOS: 'InterpolationMode.LANCZOS', 27 | InterpolationMode.HAMMING: 'InterpolationMode.HAMMING', 28 | InterpolationMode.BOX: 'InterpolationMode.BOX', 29 | } 30 | 31 | except: 32 | from PIL import Image 33 | 34 | NEAREST = Image.NEAREST 35 | BILINEAR = Image.BILINEAR 36 | BICUBIC = Image.BICUBIC 37 | LANCZOS = Image.LANCZOS 38 | HAMMING = Image.HAMMING 39 | HAMMING = Image.HAMMING 40 | 41 | _pil_interpolation_to_str = { 42 | Image.NEAREST: 'PIL.Image.NEAREST', 43 | Image.BILINEAR: 'PIL.Image.BILINEAR', 44 | Image.BICUBIC: 'PIL.Image.BICUBIC', 45 | Image.LANCZOS: 'PIL.Image.LANCZOS', 46 | Image.HAMMING: 'PIL.Image.HAMMING', 47 | Image.BOX: 'PIL.Image.BOX', 48 | } 49 | 50 | def _get_image_size(img): 51 | if F._is_pil_image(img): 52 | return img.size 53 | elif isinstance(img, torch.Tensor) and img.dim() > 2: 54 | return img.shape[-2:][::-1] 55 | else: 56 | raise TypeError('Unexpected type {}'.format(type(img))) 57 | 58 | 59 | class Compose(object): 60 | """Composes several transforms together. 61 | 62 | Args: 63 | transforms (list of ``Transform`` objects): 64 | list of transforms to compose. 65 | 66 | Example: 67 | >>> transforms.Compose([ 68 | >>> transforms.CenterCrop(10), 69 | >>> transforms.ToTensor(), 70 | >>> ]) 71 | """ 72 | def __init__(self, transforms): 73 | self.transforms = transforms 74 | 75 | def __call__(self, img, gt): 76 | for t in self.transforms: 77 | if 'RandomResizedCrop' in t.__class__.__name__: 78 | img, gt = t(img, gt) 79 | elif 'Flip' in t.__class__.__name__: 80 | img, gt = t(img, gt) 81 | elif 'ToTensor' in t.__class__.__name__: 82 | img, gt = t(img, gt) 83 | else: 84 | img = t(img) 85 | gt = gt.float() 86 | 87 | return img, gt 88 | 89 | def __repr__(self): 90 | format_string = self.__class__.__name__ + '(' 91 | for t in self.transforms: 92 | format_string += '\n' 93 | format_string += ' {0}'.format(t) 94 | format_string += '\n)' 95 | return format_string 96 | 97 | 98 | class RandomHorizontalFlip(object): 99 | """Horizontally flip the given PIL Image randomly with a given probability. 100 | 101 | Args: 102 | p (float): probability of the image being flipped. Default value is 0.5 103 | """ 104 | def __init__(self, p=0.5): 105 | self.p = p 106 | 107 | def __call__(self, img, gt): 108 | """ 109 | Args: 110 | img (PIL Image): Image to be flipped. 111 | 112 | Returns: 113 | PIL Image: Randomly flipped image. 114 | """ 115 | if random.random() < self.p: 116 | return F.hflip(img), F.hflip(gt) 117 | return img, gt 118 | 119 | def __repr__(self): 120 | return self.__class__.__name__ + '(p={})'.format(self.p) 121 | 122 | 123 | class RandomResizedCrop(object): 124 | """Crop the given PIL Image to random size and aspect ratio. 125 | 126 | A crop of random size (default: of 0.08 to 1.0) of the original size 127 | and a random aspect ratio (default: of 3/4 to 4/3) of the original 128 | aspect ratio is made. This crop is finally resized to given size. 129 | This is popularly used to train the Inception networks. 130 | 131 | Args: 132 | size: expected output size of each edge 133 | scale: range of size of the origin size cropped 134 | ratio: range of aspect ratio of the origin aspect ratio cropped 135 | interpolation: Default: PIL.Image.BILINEAR 136 | """ 137 | def __init__(self, 138 | size, 139 | scale=(0.08, 1.0), 140 | ratio=(3. / 4., 4. / 3.), 141 | interpolation=BILINEAR): 142 | if isinstance(size, (tuple, list)): 143 | self.size = size 144 | else: 145 | self.size = (size, size) 146 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 147 | warnings.warn('range should be of kind (min, max)') 148 | 149 | self.interpolation = interpolation 150 | self.scale = scale 151 | self.ratio = ratio 152 | 153 | @staticmethod 154 | def get_params(img, scale, ratio): 155 | """Get parameters for ``crop`` for a random sized crop. 156 | 157 | Args: 158 | img (PIL Image): Image to be cropped. 159 | scale (tuple): 160 | range of size of the origin size cropped 161 | ratio (tuple): 162 | range of aspect ratio of the origin aspect ratio cropped 163 | 164 | Returns: 165 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 166 | sized crop. 167 | """ 168 | width, height = _get_image_size(img) 169 | area = height * width 170 | 171 | for attempt in range(10): 172 | target_area = random.uniform(*scale) * area 173 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 174 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 175 | 176 | w = int(round(math.sqrt(target_area * aspect_ratio))) 177 | h = int(round(math.sqrt(target_area / aspect_ratio))) 178 | 179 | if 0 < w <= width and 0 < h <= height: 180 | i = random.randint(0, height - h) 181 | j = random.randint(0, width - w) 182 | return i, j, h, w 183 | 184 | # Fallback to central crop 185 | in_ratio = float(width) / float(height) 186 | if (in_ratio < min(ratio)): 187 | w = width 188 | h = int(round(w / min(ratio))) 189 | elif (in_ratio > max(ratio)): 190 | h = height 191 | w = int(round(h * max(ratio))) 192 | else: # whole image 193 | w = width 194 | h = height 195 | i = (height - h) // 2 196 | j = (width - w) // 2 197 | return i, j, h, w 198 | 199 | def __call__(self, img, gt): 200 | """ 201 | Args: 202 | img (PIL Image): Image to be cropped and resized. 203 | 204 | Returns: 205 | PIL Image: Randomly cropped and resized image. 206 | """ 207 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 208 | return F.resized_crop( 209 | img, i, j, h, w, self.size, self.interpolation), \ 210 | F.resized_crop( 211 | gt, i, j, h, w, self.size, NEAREST) 212 | 213 | def __repr__(self): 214 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 215 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 216 | format_string += ', scale={0}'.format( 217 | tuple(round(s, 4) for s in self.scale)) 218 | format_string += ', ratio={0}'.format( 219 | tuple(round(r, 4) for r in self.ratio)) 220 | format_string += ', interpolation={0})'.format(interpolate_str) 221 | return format_string 222 | 223 | class ToTensor(object): 224 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 225 | 226 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 227 | [0, 255] to a torch.FloatTensor of 228 | shape (C x H x W) in the range [0.0, 1.0] 229 | if the PIL Image belongs to one of the 230 | modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 231 | or if the numpy.ndarray has dtype = np.uint8 232 | 233 | In the other cases, tensors are returned without scaling. 234 | """ 235 | def __call__(self, pic, gt): 236 | """ 237 | Args: 238 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 239 | 240 | Returns: 241 | Tensor: Converted image. 242 | """ 243 | return F.to_tensor(pic), torch.from_numpy(np.array(gt)) 244 | 245 | def __repr__(self): 246 | return self.__class__.__name__ + '()' 247 | --------------------------------------------------------------------------------