├── .gitignore
├── LICENSE
├── MODEL_ZOO.md
├── README.md
├── engine_segfinetune.py
├── inference.py
├── main_segfinetune.py
├── models
├── __init__.py
├── models_convnext.py
├── models_resnet.py
├── models_rfconvnext.py
├── models_vit.py
├── rfconv.py
└── rfconvnext.py
└── util
├── datasets.py
├── lr_decay.py
├── lr_sched.py
├── metric.py
├── misc.py
├── pos_embed.py
└── transforms.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | dataset/
4 | output*/
5 | output_dir/
6 | output_dir*/
7 | ckpts/
8 | *.pth
9 | *.t7
10 | *.png
11 | *.jpg
12 | tmp*.py
13 | *.pdf
14 |
15 | # Byte-compiled / optimized / DLL files
16 | __pycache__/
17 | *.py[cod]
18 | *$py.class
19 |
20 | # C extensions
21 | *.so
22 |
23 | # Distribution / packaging
24 | .Python
25 | build/
26 | develop-eggs/
27 | dist/
28 | downloads/
29 | eggs/
30 | .eggs/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .coverage
55 | .coverage.*
56 | .cache
57 | nosetests.xml
58 | coverage.xml
59 | *.cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 |
72 | # Flask stuff:
73 | instance/
74 | .webassets-cache
75 |
76 | # Scrapy stuff:
77 | .scrapy
78 |
79 | # Sphinx documentation
80 | docs/_build/
81 |
82 | # PyBuilder
83 | target/
84 |
85 | # Jupyter Notebook
86 | .ipynb_checkpoints
87 |
88 | # pyenv
89 | .python-version
90 |
91 | # celery beat schedule file
92 | celerybeat-schedule
93 |
94 | # SageMath parsed files
95 | *.sage.py
96 |
97 | # Environments
98 | .env
99 | .venv
100 | env/
101 | venv/
102 | ENV/
103 | env.bak/
104 | venv.bak/
105 |
106 | # Spyder project settings
107 | .spyderproject
108 | .spyproject
109 |
110 | # Rope project settings
111 | .ropeproject
112 |
113 | # mkdocs documentation
114 | /site
115 |
116 | # mypy
117 | .mypy_cache/
118 | .DS_Store
119 |
120 | .vscode/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Attribution-NonCommercial 4.0 International
3 |
4 | =======================================================================
5 |
6 | Creative Commons Corporation ("Creative Commons") is not a law firm and
7 | does not provide legal services or legal advice. Distribution of
8 | Creative Commons public licenses does not create a lawyer-client or
9 | other relationship. Creative Commons makes its licenses and related
10 | information available on an "as-is" basis. Creative Commons gives no
11 | warranties regarding its licenses, any material licensed under their
12 | terms and conditions, or any related information. Creative Commons
13 | disclaims all liability for damages resulting from their use to the
14 | fullest extent possible.
15 |
16 | Using Creative Commons Public Licenses
17 |
18 | Creative Commons public licenses provide a standard set of terms and
19 | conditions that creators and other rights holders may use to share
20 | original works of authorship and other material subject to copyright
21 | and certain other rights specified in the public license below. The
22 | following considerations are for informational purposes only, are not
23 | exhaustive, and do not form part of our licenses.
24 |
25 | Considerations for licensors: Our public licenses are
26 | intended for use by those authorized to give the public
27 | permission to use material in ways otherwise restricted by
28 | copyright and certain other rights. Our licenses are
29 | irrevocable. Licensors should read and understand the terms
30 | and conditions of the license they choose before applying it.
31 | Licensors should also secure all rights necessary before
32 | applying our licenses so that the public can reuse the
33 | material as expected. Licensors should clearly mark any
34 | material not subject to the license. This includes other CC-
35 | licensed material, or material used under an exception or
36 | limitation to copyright. More considerations for licensors:
37 | wiki.creativecommons.org/Considerations_for_licensors
38 |
39 | Considerations for the public: By using one of our public
40 | licenses, a licensor grants the public permission to use the
41 | licensed material under specified terms and conditions. If
42 | the licensor's permission is not necessary for any reason--for
43 | example, because of any applicable exception or limitation to
44 | copyright--then that use is not regulated by the license. Our
45 | licenses grant only permissions under copyright and certain
46 | other rights that a licensor has authority to grant. Use of
47 | the licensed material may still be restricted for other
48 | reasons, including because others have copyright or other
49 | rights in the material. A licensor may make special requests,
50 | such as asking that all changes be marked or described.
51 | Although not required by our licenses, you are encouraged to
52 | respect those requests where reasonable. More_considerations
53 | for the public:
54 | wiki.creativecommons.org/Considerations_for_licensees
55 |
56 | =======================================================================
57 |
58 | Creative Commons Attribution-NonCommercial 4.0 International Public
59 | License
60 |
61 | By exercising the Licensed Rights (defined below), You accept and agree
62 | to be bound by the terms and conditions of this Creative Commons
63 | Attribution-NonCommercial 4.0 International Public License ("Public
64 | License"). To the extent this Public License may be interpreted as a
65 | contract, You are granted the Licensed Rights in consideration of Your
66 | acceptance of these terms and conditions, and the Licensor grants You
67 | such rights in consideration of benefits the Licensor receives from
68 | making the Licensed Material available under these terms and
69 | conditions.
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 | Section 2 -- Scope.
142 |
143 | a. License grant.
144 |
145 | 1. Subject to the terms and conditions of this Public License,
146 | the Licensor hereby grants You a worldwide, royalty-free,
147 | non-sublicensable, non-exclusive, irrevocable license to
148 | exercise the Licensed Rights in the Licensed Material to:
149 |
150 | a. reproduce and Share the Licensed Material, in whole or
151 | in part, for NonCommercial purposes only; and
152 |
153 | b. produce, reproduce, and Share Adapted Material for
154 | NonCommercial purposes only.
155 |
156 | 2. Exceptions and Limitations. For the avoidance of doubt, where
157 | Exceptions and Limitations apply to Your use, this Public
158 | License does not apply, and You do not need to comply with
159 | its terms and conditions.
160 |
161 | 3. Term. The term of this Public License is specified in Section
162 | 6(a).
163 |
164 | 4. Media and formats; technical modifications allowed. The
165 | Licensor authorizes You to exercise the Licensed Rights in
166 | all media and formats whether now known or hereafter created,
167 | and to make technical modifications necessary to do so. The
168 | Licensor waives and/or agrees not to assert any right or
169 | authority to forbid You from making technical modifications
170 | necessary to exercise the Licensed Rights, including
171 | technical modifications necessary to circumvent Effective
172 | Technological Measures. For purposes of this Public License,
173 | simply making modifications authorized by this Section 2(a)
174 | (4) never produces Adapted Material.
175 |
176 | 5. Downstream recipients.
177 |
178 | a. Offer from the Licensor -- Licensed Material. Every
179 | recipient of the Licensed Material automatically
180 | receives an offer from the Licensor to exercise the
181 | Licensed Rights under the terms and conditions of this
182 | Public License.
183 |
184 | b. No downstream restrictions. You may not offer or impose
185 | any additional or different terms or conditions on, or
186 | apply any Effective Technological Measures to, the
187 | Licensed Material if doing so restricts exercise of the
188 | Licensed Rights by any recipient of the Licensed
189 | Material.
190 |
191 | 6. No endorsement. Nothing in this Public License constitutes or
192 | may be construed as permission to assert or imply that You
193 | are, or that Your use of the Licensed Material is, connected
194 | with, or sponsored, endorsed, or granted official status by,
195 | the Licensor or others designated to receive attribution as
196 | provided in Section 3(a)(1)(A)(i).
197 |
198 | b. Other rights.
199 |
200 | 1. Moral rights, such as the right of integrity, are not
201 | licensed under this Public License, nor are publicity,
202 | privacy, and/or other similar personality rights; however, to
203 | the extent possible, the Licensor waives and/or agrees not to
204 | assert any such rights held by the Licensor to the limited
205 | extent necessary to allow You to exercise the Licensed
206 | Rights, but not otherwise.
207 |
208 | 2. Patent and trademark rights are not licensed under this
209 | Public License.
210 |
211 | 3. To the extent possible, the Licensor waives any right to
212 | collect royalties from You for the exercise of the Licensed
213 | Rights, whether directly or through a collecting society
214 | under any voluntary or waivable statutory or compulsory
215 | licensing scheme. In all other cases the Licensor expressly
216 | reserves any right to collect such royalties, including when
217 | the Licensed Material is used other than for NonCommercial
218 | purposes.
219 |
220 | Section 3 -- License Conditions.
221 |
222 | Your exercise of the Licensed Rights is expressly made subject to the
223 | following conditions.
224 |
225 | a. Attribution.
226 |
227 | 1. If You Share the Licensed Material (including in modified
228 | form), You must:
229 |
230 | a. retain the following if it is supplied by the Licensor
231 | with the Licensed Material:
232 |
233 | i. identification of the creator(s) of the Licensed
234 | Material and any others designated to receive
235 | attribution, in any reasonable manner requested by
236 | the Licensor (including by pseudonym if
237 | designated);
238 |
239 | ii. a copyright notice;
240 |
241 | iii. a notice that refers to this Public License;
242 |
243 | iv. a notice that refers to the disclaimer of
244 | warranties;
245 |
246 | v. a URI or hyperlink to the Licensed Material to the
247 | extent reasonably practicable;
248 |
249 | b. indicate if You modified the Licensed Material and
250 | retain an indication of any previous modifications; and
251 |
252 | c. indicate the Licensed Material is licensed under this
253 | Public License, and include the text of, or the URI or
254 | hyperlink to, this Public License.
255 |
256 | 2. You may satisfy the conditions in Section 3(a)(1) in any
257 | reasonable manner based on the medium, means, and context in
258 | which You Share the Licensed Material. For example, it may be
259 | reasonable to satisfy the conditions by providing a URI or
260 | hyperlink to a resource that includes the required
261 | information.
262 |
263 | 3. If requested by the Licensor, You must remove any of the
264 | information required by Section 3(a)(1)(A) to the extent
265 | reasonably practicable.
266 |
267 | 4. If You Share Adapted Material You produce, the Adapter's
268 | License You apply must not prevent recipients of the Adapted
269 | Material from complying with this Public License.
270 |
271 | Section 4 -- Sui Generis Database Rights.
272 |
273 | Where the Licensed Rights include Sui Generis Database Rights that
274 | apply to Your use of the Licensed Material:
275 |
276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
277 | to extract, reuse, reproduce, and Share all or a substantial
278 | portion of the contents of the database for NonCommercial purposes
279 | only;
280 |
281 | b. if You include all or a substantial portion of the database
282 | contents in a database in which You have Sui Generis Database
283 | Rights, then the database in which You have Sui Generis Database
284 | Rights (but not its individual contents) is Adapted Material; and
285 |
286 | c. You must comply with the conditions in Section 3(a) if You Share
287 | all or a substantial portion of the contents of the database.
288 |
289 | For the avoidance of doubt, this Section 4 supplements and does not
290 | replace Your obligations under this Public License where the Licensed
291 | Rights include other Copyright and Similar Rights.
292 |
293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
294 |
295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
305 |
306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
315 |
316 | c. The disclaimer of warranties and limitation of liability provided
317 | above shall be interpreted in a manner that, to the extent
318 | possible, most closely approximates an absolute disclaimer and
319 | waiver of all liability.
320 |
321 | Section 6 -- Term and Termination.
322 |
323 | a. This Public License applies for the term of the Copyright and
324 | Similar Rights licensed here. However, if You fail to comply with
325 | this Public License, then Your rights under this Public License
326 | terminate automatically.
327 |
328 | b. Where Your right to use the Licensed Material has terminated under
329 | Section 6(a), it reinstates:
330 |
331 | 1. automatically as of the date the violation is cured, provided
332 | it is cured within 30 days of Your discovery of the
333 | violation; or
334 |
335 | 2. upon express reinstatement by the Licensor.
336 |
337 | For the avoidance of doubt, this Section 6(b) does not affect any
338 | right the Licensor may have to seek remedies for Your violations
339 | of this Public License.
340 |
341 | c. For the avoidance of doubt, the Licensor may also offer the
342 | Licensed Material under separate terms or conditions or stop
343 | distributing the Licensed Material at any time; however, doing so
344 | will not terminate this Public License.
345 |
346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
347 | License.
348 |
349 | Section 7 -- Other Terms and Conditions.
350 |
351 | a. The Licensor shall not be bound by any additional or different
352 | terms or conditions communicated by You unless expressly agreed.
353 |
354 | b. Any arrangements, understandings, or agreements regarding the
355 | Licensed Material not stated herein are separate from and
356 | independent of the terms and conditions of this Public License.
357 |
358 | Section 8 -- Interpretation.
359 |
360 | a. For the avoidance of doubt, this Public License does not, and
361 | shall not be interpreted to, reduce, limit, restrict, or impose
362 | conditions on any use of the Licensed Material that could lawfully
363 | be made without permission under this Public License.
364 |
365 | b. To the extent possible, if any provision of this Public License is
366 | deemed unenforceable, it shall be automatically reformed to the
367 | minimum extent necessary to make it enforceable. If the provision
368 | cannot be reformed, it shall be severed from this Public License
369 | without affecting the enforceability of the remaining terms and
370 | conditions.
371 |
372 | c. No term or condition of this Public License will be waived and no
373 | failure to comply consented to unless expressly agreed to by the
374 | Licensor.
375 |
376 | d. Nothing in this Public License constitutes or may be interpreted
377 | as a limitation upon, or waiver of, any privileges and immunities
378 | that apply to the Licensor or You, including from the legal
379 | processes of any jurisdiction or authority.
380 |
381 | =======================================================================
382 |
383 | Creative Commons is not a party to its public
384 | licenses. Notwithstanding, Creative Commons may elect to apply one of
385 | its public licenses to material it publishes and in those instances
386 | will be considered the “Licensor.” The text of the Creative Commons
387 | public licenses is dedicated to the public domain under the CC0 Public
388 | Domain Dedication. Except for the limited purpose of indicating that
389 | material is shared under a Creative Commons public license or as
390 | otherwise permitted by the Creative Commons policies published at
391 | creativecommons.org/policies, Creative Commons does not authorize the
392 | use of the trademark "Creative Commons" or any other trademark or logo
393 | of Creative Commons without its prior written consent including,
394 | without limitation, in connection with any unauthorized modifications
395 | to any of its public licenses or any other arrangements,
396 | understandings, or agreements concerning use of licensed material. For
397 | the avoidance of doubt, this paragraph does not form part of the
398 | public licenses.
399 |
400 | Creative Commons may be contacted at creativecommons.org.
--------------------------------------------------------------------------------
/MODEL_ZOO.md:
--------------------------------------------------------------------------------
1 | # Model ZOO for Semi-Supervised Learning on ImageNet-S
2 |
3 | [Finetuning with ViT](#1)
4 |
5 | [Finetuning with ResNet](#2)
6 |
7 | [Finetuning with RF-ConvNext](#3)
8 |
9 |
10 |
11 |
12 | ## Finetuning with ViT
13 |
14 |
15 |
16 |
17 | Method |
18 | Arch |
19 | Pretraining epochs |
20 | Pretraining mode |
21 | val |
22 | test |
23 | Pretrained |
24 | Finetuned |
25 |
26 |
27 | MAE |
28 | ViT-B/16 |
29 | 1600 |
30 | SSL |
31 | 38.3 |
32 | 37.0 |
33 | model |
34 | model |
35 |
36 | MAE |
37 | ViT-B/16 |
38 | 1600 |
39 | SSL+Sup |
40 | 61.0 |
41 | 60.2 |
42 | model |
43 | model |
44 |
45 |
46 | SERE |
47 | ViT-S/16 |
48 | 100 |
49 | SSL |
50 | 41.0 |
51 | 40.2 |
52 | model |
53 | model |
54 |
55 | SERE |
56 | ViT-S/16 |
57 | 100 |
58 | SSL+Sup |
59 | 58.9 |
60 | 57.8 |
61 | model |
62 | model |
63 |
64 |
65 |
66 | ### Masked Autoencoders Are Scalable Vision Learners (MAE)
67 |
68 |
69 | Command for SSL+Sup
70 |
71 | ```shell
72 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
73 | --accum_iter 1 \
74 | --batch_size 32 \
75 | --model vit_base_patch16 \
76 | --finetune mae_finetuned_vit_base.pth \
77 | --epochs 100 \
78 | --nb_classes 920 \
79 | --blr 1e-4 --layer_decay 0.40 \
80 | --weight_decay 0.05 --drop_path 0.1 \
81 | --data_path ${IMAGENETS_DIR} \
82 | --output_dir ${OUTPATH} \
83 | --dist_eval
84 | ```
85 |
86 |
87 |
88 |
89 | Command for SSL
90 |
91 | ```shell
92 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
93 | --accum_iter 1 \
94 | --batch_size 32 \
95 | --model vit_base_patch16 \
96 | --finetune mae_pretrain_vit_base.pth \
97 | --epochs 100 \
98 | --nb_classes 920 \
99 | --blr 5e-4 --layer_decay 0.60 \
100 | --weight_decay 0.05 --drop_path 0.1 \
101 | --data_path ${IMAGENETS_DIR} \
102 | --output_dir ${OUTPATH} \
103 | --dist_eval
104 | ```
105 |
106 |
107 |
108 | ### SERE: Exploring Feature Self-relation for Self-supervised Transformer
109 |
110 |
111 | Command for SSL+Sup
112 |
113 | ```shell
114 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
115 | --accum_iter 1 \
116 | --batch_size 32 \
117 | --model vit_small_patch16 \
118 | --finetune sere_finetuned_vit_small_ep100.pth \
119 | --epochs 100 \
120 | --nb_classes 920 \
121 | --blr 5e-4 --layer_decay 0.50 \
122 | --weight_decay 0.05 --drop_path 0.1 \
123 | --data_path ${IMAGENETS_DIR} \
124 | --output_dir ${OUTPATH} \
125 | --dist_eval
126 | ```
127 |
128 |
129 |
130 |
131 | Command for SSL
132 |
133 | ```shell
134 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
135 | --accum_iter 1 \
136 | --batch_size 32 \
137 | --model vit_small_patch16 \
138 | --finetune sere_pretrained_vit_small_ep100.pth \
139 | --epochs 100 \
140 | --nb_classes 920 \
141 | --blr 5e-4 --layer_decay 0.50 \
142 | --weight_decay 0.05 --drop_path 0.1 \
143 | --data_path ${IMAGENETS_DIR} \
144 | --output_dir ${OUTPATH} \
145 | --dist_eval
146 | ```
147 |
148 |
149 |
150 |
151 |
152 | ## Finetuning with ResNet
153 |
154 |
155 |
156 | Method |
157 | Arch |
158 | Pretraining epochs |
159 | Pretraining mode |
160 | val |
161 | test |
162 | Pretrained |
163 | Finetuned |
164 |
165 |
166 | PASS |
167 | ResNet-50 D32 |
168 | 100 |
169 | SSL |
170 | 21.0 |
171 | 20.3 |
172 | model |
173 | model |
174 |
175 |
176 | PASS |
177 | ResNet-50 D16 |
178 | 100 |
179 | SSL |
180 | 21.6 |
181 | 20.8 |
182 | model |
183 | model |
184 |
185 |
186 |
187 | `D16` means the output stride is 16 with dilation=2 in the last stage. This result is better than the results reported in the paper thanks to the new training scripts.
188 |
189 | ### Large-scale Unsupervised Semantic Segmentation (PASS)
190 |
191 | Command for SSL (ResNet-50 D32)
192 |
193 | ```shell
194 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
195 | --accum_iter 1 \
196 | --batch_size 32 \
197 | --model resnet50 \
198 | --finetune pass919_pretrained.pth.tar \
199 | --epochs 100 \
200 | --nb_classes 920 \
201 | --blr 5e-4 --layer_decay 0.4 \
202 | --weight_decay 0.0005 \
203 | --data_path ${IMAGENETS_DIR} \
204 | --output_dir ${OUTPATH} \
205 | --dist_eval
206 | ```
207 |
208 |
209 |
210 | Command for SSL (ResNet-50 D16)
211 |
212 | ```shell
213 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
214 | --accum_iter 1 \
215 | --batch_size 32 \
216 | --model resnet50_d16 \
217 | --finetune pass919_pretrained.pth.tar \
218 | --epochs 100 \
219 | --nb_classes 920 \
220 | --blr 5e-4 --layer_decay 0.45 \
221 | --weight_decay 0.0005 \
222 | --data_path ${IMAGENETS_DIR} \
223 | --output_dir ${OUTPATH} \
224 | --dist_eval
225 | ```
226 |
227 |
228 |
229 |
230 |
231 | ## Finetuning with RF-ConvNeXt
232 |
233 |
234 |
235 |
236 | Arch |
237 | Pretraining epochs |
238 | RF-Next mode |
239 | val |
240 | test |
241 | Pretrained |
242 | Searched |
243 | Finetuned |
244 |
245 |
246 | ConvNeXt-T |
247 | 300 |
248 | - |
249 | 48.7 |
250 | 48.8 |
251 | model |
252 | - |
253 | model |
254 |
255 |
256 | RF-ConvNeXt-T |
257 | 300 |
258 | rfsingle |
259 | 50.7 |
260 | 50.5 |
261 | model |
262 | model |
263 | model |
264 |
265 |
266 | RF-ConvNeXt-T |
267 | 300 |
268 | rfmultiple |
269 | 50.8 |
270 | 50.5 |
271 | model |
272 | model |
273 | model |
274 |
275 |
276 | RF-ConvNeXt-T |
277 | 300 |
278 | rfmerge |
279 | 51.3 |
280 | 51.1 |
281 | model |
282 | model |
283 | model |
284 |
285 |
286 |
287 |
288 | Command for ConvNeXt-T
289 |
290 | ```shell
291 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
292 | --accum_iter 1 \
293 | --batch_size 32 \
294 | --model convnext_tiny \
295 | --patch_size 4 \
296 | --finetune convnext_tiny_1k_224_ema.pth \
297 | --epochs 100 \
298 | --nb_classes 920 \
299 | --blr 2.5e-4 --layer_decay 0.6 \
300 | --weight_decay 0.05 --drop_path 0.2 \
301 | --data_path ${IMAGENETS_DIR} \
302 | --output_dir ${OUTPATH} \
303 | --dist_eval
304 | ```
305 |
306 |
307 | Before training RF-ConvNext,
308 | please search dilation rates with the mode of rfsearch.
309 |
310 | For rfmultiple and rfsingle, please set `pretrained_rfnext`
311 | as the weights trained in rfsearch.
312 |
313 | For rfmerge, we initilize the model with weights in rfmultiple and only finetune `seg_norm`, `seg_head` and `rfconvs` whose dilate rates are changed.
314 | The othe parts of the network are freezed.
315 | Please set `pretrained_rfnext`
316 | as the weights trained in rfmutilple.
317 |
318 | **Note that this freezing operation in rfmerge may be not required for other tasks.**
319 |
320 |
321 | Command for RF-ConvNeXt-T (rfsearch)
322 |
323 | ```shell
324 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
325 | --accum_iter 1 \
326 | --batch_size 32 \
327 | --model rfconvnext_tiny_rfsearch \
328 | --patch_size 4 \
329 | --finetune convnext_tiny_1k_224_ema.pth \
330 | --epochs 100 \
331 | --nb_classes 920 \
332 | --blr 2.5e-4 --layer_decay 0.6 0.9 --layer_multiplier 1.0 10.0 \
333 | --weight_decay 0.05 --drop_path 0.2 \
334 | --data_path ${IMAGENETS_DIR} \
335 | --output_dir ${OUTPATH} \
336 | --dist_eval
337 | ```
338 |
339 |
340 |
341 | Command for RF-ConvNeXt-T (rfsingle)
342 |
343 | ```shell
344 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
345 | --accum_iter 1 \
346 | --batch_size 32 \
347 | --model rfconvnext_tiny_rfsingle \
348 | --patch_size 4 \
349 | --finetune convnext_tiny_1k_224_ema.pth \
350 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \
351 | --epochs 100 \
352 | --nb_classes 920 \
353 | --blr 2.5e-4 --layer_decay 0.6 0.9 --layer_multiplier 1.0 10.0 \
354 | --weight_decay 0.05 --drop_path 0.2 \
355 | --data_path ${IMAGENETS_DIR} \
356 | --output_dir ${OUTPATH} \
357 | --dist_eval
358 |
359 | python inference.py --model rfconvnext_tiny_rfsingle \
360 | --patch_size 4 \
361 | --nb_classes 920 \
362 | --output_dir ${OUTPATH}/predictions \
363 | --data_path ${IMAGENETS_DIR} \
364 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \
365 | --finetune ${OUTPATH}/checkpoint-99.pth \
366 | --mode validation
367 | ```
368 |
369 |
370 |
371 | Command for RF-ConvNeXt-T (rfmultiple)
372 |
373 | ```shell
374 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
375 | --accum_iter 1 \
376 | --batch_size 32 \
377 | --model rfconvnext_tiny_rfmultiple \
378 | --patch_size 4 \
379 | --finetune convnext_tiny_1k_224_ema.pth \
380 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \
381 | --epochs 100 \
382 | --nb_classes 920 \
383 | --blr 2.5e-4 --layer_decay 0.55 0.9 --layer_multiplier 1.0 10.0 \
384 | --weight_decay 0.05 --drop_path 0.1 \
385 | --data_path ${IMAGENETS_DIR} \
386 | --output_dir ${OUTPATH} \
387 | --dist_eval
388 |
389 | python inference.py --model rfconvnext_tiny_rfmultiple \
390 | --patch_size 4 \
391 | --nb_classes 920 \
392 | --output_dir ${OUTPATH}/predictions \
393 | --data_path ${IMAGENETS_DIR} \
394 | --pretrained_rfnext ${OUTPATH_OF_RFSEARCH}/checkpoint-99.pth \
395 | --finetune ${OUTPATH}/checkpoint-99.pth \
396 | --mode validation
397 | ```
398 |
399 |
400 |
401 |
402 | Command for RF-ConvNeXt-T (rfmerge)
403 |
404 | ```shell
405 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
406 | --accum_iter 1 \
407 | --batch_size 32 \
408 | --model rfconvnext_tiny_rfmerge \
409 | --patch_size 4 \
410 | --pretrained_rfnext ${OUTPATH_OF_RFMULTIPLE}/checkpoint-99.pth \
411 | --epochs 100 \
412 | --nb_classes 920 \
413 | --blr 2.5e-4 --layer_decay 0.55 1.0 --layer_multiplier 1.0 10.0 \
414 | --weight_decay 0.05 --drop_path 0.2 \
415 | --data_path ${IMAGENETS_DIR} \
416 | --output_dir ${OUTPATH} \
417 | --dist_eval
418 |
419 | python inference.py --model rfconvnext_tiny_rfmerge \
420 | --patch_size 4 \
421 | --nb_classes 920 \
422 | --output_dir ${OUTPATH}/predictions \
423 | --data_path ${IMAGENETS_DIR} \
424 | --pretrained_rfnext ${OUTPATH_OF_RFMULTIPLE}/checkpoint-99.pth \
425 | --finetune ${OUTPATH}/checkpoint-99.pth \
426 | --mode validation
427 | ```
428 |
429 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Semi-supervised Semantic Segmentation on the ImageNet-S dataset
2 |
3 | This repo provides the code of semi-supervised training of large-scale semantic segmentation on the ImageNet-S dataset.
4 |
5 | ## About ImageNet-S
6 | Based on the ImageNet dataset, the ImageNet-S dataset has 1.2 million training images and 50k high-quality semantic segmentation annotations to
7 | support unsupervised/semi-supervised semantic segmentation on the ImageNet dataset. ImageNet-S dataset is available on [ImageNet-S](https://github.com/LUSSeg/ImageNet-S). More details about the dataset please refer to the [project page](https://LUSSeg.github.io/) or [paper link](https://arxiv.org/abs/2106.03149).
8 |
9 |
10 |
11 | ## Usage
12 | - Semi-supervised finetuning with pre-trained checkpoints
13 | ```
14 | python -m torch.distributed.launch --nproc_per_node=8 main_segfinetune.py \
15 | --accum_iter 1 \
16 | --batch_size 32 \
17 | --model vit_small_patch16 \
18 | --finetune ${PRETRAIN_CHKPT} \
19 | --epochs 100 \
20 | --nb_classes 920 | 301 | 51 \
21 | --blr 5e-4 --layer_decay 0.50 \
22 | --weight_decay 0.05 --drop_path 0.1 \
23 | --data_path ${IMAGENETS_DIR} \
24 | --output_dir ${OUTPATH} \
25 | --dist_eval
26 | ```
27 | Note: To use one GPU for training, you can change `--nproc_per_node=8` to `--nproc_per_node=1` and change `--accum_iter 1` to `--accum_iter 8`.
28 | - Get the zip file for testing set. You can submit it to our [online server](https://lusseg.github.io/).
29 | ```
30 | python inference.py --model vit_small_patch16 \
31 | --nb_classes 920 | 301 | 51 \
32 | --output_dir ${OUTPATH}/predictions \
33 | --data_path ${IMAGENETS_DIR} \
34 | --finetune ${OUTPATH}/checkpoint-99.pth \
35 | --mode validation | test
36 | ```
37 |
38 | ## Model Zoo
39 | **[Model Zoo](MODEL_ZOO.md)**:
40 | We provide a model zoo to record the trend of semi-supervised semantic segmentation on the ImageNet-S dataset.
41 | For now, this repo supports ViT, and more backbones and pretrained models will be added.
42 | Please open a pull request if you want to update your new results.
43 |
44 | Supported networks: ViT, ResNet, ConvNext, RF-ConvNext
45 |
46 | Supported pretrain: MAE, SERE, PASS
47 |
48 | ## Citation
49 | ```
50 | @article{gao2021luss,
51 | title={Large-scale Unsupervised Semantic Segmentation},
52 | author={Gao, Shanghua and Li, Zhong-Yu and Yang, Ming-Hsuan and Cheng, Ming-Ming and Han, Junwei and Torr, Philip},
53 | journal={arXiv preprint arXiv:2106.03149},
54 | year={2021}
55 | }
56 | ```
57 |
58 | ## Acknowledgement
59 |
60 | This codebase is build based on the [MAE codebase](https://github.com/facebookresearch/mae).
61 |
--------------------------------------------------------------------------------
/engine_segfinetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import math
13 | import sys
14 | import warnings
15 | from typing import Iterable
16 |
17 | import torch
18 | import torch.distributed as dist
19 | import torch.nn.functional as F
20 | from torch.distributed import ReduceOp
21 |
22 | import util.lr_sched as lr_sched
23 | import util.misc as misc
24 | from util.metric import FMeasureGPU, IoUGPU
25 |
26 |
27 | def train_one_epoch(model: torch.nn.Module,
28 | criterion: torch.nn.Module,
29 | data_loader: Iterable,
30 | optimizer: torch.optim.Optimizer,
31 | device: torch.device,
32 | epoch: int,
33 | loss_scaler,
34 | max_norm: float = 0,
35 | args=None):
36 | model.train(True)
37 | metric_logger = misc.MetricLogger(delimiter=' ')
38 | metric_logger.add_meter(
39 | 'lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
40 | header = 'Epoch: [{}]'.format(epoch)
41 | print_freq = 20
42 |
43 | accum_iter = args.accum_iter
44 |
45 | optimizer.zero_grad()
46 |
47 | for data_iter_step, (samples, targets) in enumerate(
48 | metric_logger.log_every(data_loader, print_freq, header)):
49 |
50 | # we use a per iteration (instead of per epoch) lr scheduler
51 | if data_iter_step % accum_iter == 0:
52 | lr_sched.adjust_learning_rate(
53 | optimizer, data_iter_step / len(data_loader) + epoch, args)
54 |
55 | samples = samples.to(device, non_blocking=True)
56 | targets = targets.to(device, non_blocking=True)
57 |
58 | with torch.cuda.amp.autocast():
59 | outputs = model(samples)
60 |
61 | outputs = torch.nn.functional.interpolate(outputs,
62 | scale_factor=2,
63 | align_corners=False,
64 | mode='bilinear')
65 | targets = torch.nn.functional.interpolate(
66 | targets.unsqueeze(1),
67 | size=(outputs.shape[2], outputs.shape[3]),
68 | mode='nearest').squeeze(1)
69 | loss = criterion(outputs, targets.long())
70 |
71 | loss_value = loss.item()
72 |
73 | if not math.isfinite(loss_value):
74 | print('Loss is {}, stopping training'.format(loss_value))
75 | sys.exit(1)
76 |
77 | loss /= accum_iter
78 | loss_scaler(loss,
79 | optimizer,
80 | clip_grad=max_norm,
81 | parameters=model.parameters(),
82 | create_graph=False,
83 | update_grad=(data_iter_step + 1) % accum_iter == 0)
84 | if (data_iter_step + 1) % accum_iter == 0:
85 | optimizer.zero_grad()
86 |
87 | torch.cuda.synchronize()
88 |
89 | metric_logger.update(loss=loss_value)
90 | max_lr = 0.
91 | for group in optimizer.param_groups:
92 | max_lr = max(max_lr, group['lr'])
93 |
94 | metric_logger.update(lr=max_lr)
95 |
96 | # gather the stats from all processes
97 | metric_logger.synchronize_between_processes()
98 | print('Averaged stats:', metric_logger)
99 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
100 |
101 |
102 | @torch.no_grad()
103 | def evaluate(data_loader, model, device, num_classes, max_res=1000):
104 | metric_logger = misc.MetricLogger(delimiter=' ')
105 | header = 'Test:'
106 |
107 | T = torch.zeros(size=(num_classes, )).cuda()
108 | P = torch.zeros(size=(num_classes, )).cuda()
109 | TP = torch.zeros(size=(num_classes, )).cuda()
110 | IoU = torch.zeros(size=(num_classes, )).cuda()
111 | FMeasure = 0.
112 |
113 | # switch to evaluation mode
114 | model.eval()
115 |
116 | for batch in metric_logger.log_every(data_loader, 100, header):
117 | images = batch[0]
118 | target = batch[-1]
119 | images = images.to(device, non_blocking=True)
120 | target = target.to(device, non_blocking=True)
121 |
122 | # compute output
123 | with torch.no_grad():
124 | output = model(images)
125 |
126 | # process an image with a large resolution
127 | H, W = target.shape[1], target.shape[2]
128 | if (H > W and H * W > max_res * max_res
129 | and max_res > 0):
130 | output = F.interpolate(output, (max_res, int(max_res * W / H)),
131 | mode='bilinear',
132 | align_corners=False)
133 | output = torch.argmax(output, dim=1, keepdim=True)
134 | output = F.interpolate(output.float(), (H, W),
135 | mode='nearest').long()
136 | elif (H <= W and H * W > max_res * max_res
137 | and max_res > 0):
138 | output = F.interpolate(output, (int(max_res * H / W), max_res),
139 | mode='bilinear', align_corners=False)
140 | output = torch.argmax(output, dim=1, keepdim=True)
141 | output = F.interpolate(output.float(), (H, W),
142 | mode='nearest').long()
143 | else:
144 | output = F.interpolate(output, (H, W),
145 | mode='bilinear',
146 | align_corners=False)
147 | output = torch.argmax(output, dim=1, keepdim=True)
148 |
149 | target = target.view(-1)
150 | output = output.view(-1)
151 | mask = target != 1000
152 | target = target[mask]
153 | output = output[mask]
154 |
155 | area_intersection, area_output, area_target = IoUGPU(
156 | output, target, num_classes)
157 | f_score = FMeasureGPU(output, target)
158 |
159 | T += area_output
160 | P += area_target
161 | TP += area_intersection
162 | FMeasure += f_score
163 |
164 | metric_logger.synchronize_between_processes()
165 |
166 | # gather the stats from all processes
167 | dist.barrier()
168 | dist.all_reduce(T, op=ReduceOp.SUM)
169 | dist.all_reduce(P, op=ReduceOp.SUM)
170 | dist.all_reduce(TP, op=ReduceOp.SUM)
171 | dist.all_reduce(FMeasure, op=ReduceOp.SUM)
172 |
173 | IoU = TP / (T + P - TP + 1e-10) * 100
174 | FMeasure = FMeasure / len(data_loader.dataset)
175 |
176 | mIoU = torch.mean(IoU).item()
177 | FMeasure = FMeasure.item() * 100
178 |
179 | log = {}
180 | log['mIoU'] = mIoU
181 | log['IoUs'] = IoU.tolist()
182 | log['FMeasure'] = FMeasure
183 |
184 | print('* mIoU {mIoU:.3f} FMeasure {FMeasure:.3f}'.format(
185 | mIoU=mIoU, FMeasure=FMeasure))
186 |
187 | return log
188 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from PIL import Image
9 | from torchvision import datasets, transforms
10 | from tqdm import tqdm
11 |
12 | import models
13 |
14 |
15 | class SegmentationFolder(datasets.ImageFolder):
16 | def __getitem__(self, index):
17 | path = self.imgs[index][0]
18 | sample = self.loader(path)
19 | height, width = sample.size[1], sample.size[0]
20 |
21 | if self.transform is not None:
22 | sample = self.transform(sample)
23 | return sample, path, height, width
24 |
25 |
26 | def parse_args():
27 | parser = argparse.ArgumentParser(description='Inference')
28 | parser.add_argument('--nb_classes', type=int, default=50)
29 | parser.add_argument('--mode',
30 | type=str,
31 | required=True,
32 | help='validation or test',
33 | choices=['validation', 'test'])
34 | parser.add_argument('--output_dir',
35 | type=str,
36 | default=None,
37 | help='the path to save segmentation masks')
38 | parser.add_argument('--data_path',
39 | type=str,
40 | default=None,
41 | help='path to imagenetS dataset')
42 | parser.add_argument('--finetune',
43 | type=str,
44 | default=None,
45 | help='the model checkpoint file')
46 | parser.add_argument('--pretrained_rfnext',
47 | default='',
48 | help='pretrained weights for RF-Next')
49 | parser.add_argument('--model',
50 | default='vit_small_patch16',
51 | help='model architecture')
52 | parser.add_argument('--patch_size',
53 | type=int,
54 | default=4,
55 | help='For convnext/rfconvnext, the numnber of output channels is '
56 | 'nb_classes * patch_size * patch_size.'
57 | 'https://arxiv.org/pdf/2111.06377.pdf')
58 | parser.add_argument(
59 | '--max_res',
60 | default=1000,
61 | type=int,
62 | help='Maximum resolution for evaluation. 0 for disable.')
63 | parser.add_argument('--method',
64 | default='example submission',
65 | help='Method name in method description file(.txt).')
66 | parser.add_argument('--train_data',
67 | default='null',
68 | help='Training data in method description file(.txt).')
69 | parser.add_argument(
70 | '--train_scheme',
71 | default='null',
72 | help='Training scheme in method description file(.txt), \
73 | e.g., SSL, Sup, SSL+Sup.')
74 | parser.add_argument(
75 | '--link',
76 | default='null',
77 | help='Paper/project link in method description file(.txt).')
78 | parser.add_argument(
79 | '--description',
80 | default='null',
81 | help='Method description in method description file(.txt).')
82 | args = parser.parse_args()
83 | return args
84 |
85 |
86 | def main():
87 | args = parse_args()
88 |
89 | # build model
90 | model = models.__dict__[args.model](args)
91 | model = model.cuda()
92 | model.eval()
93 |
94 | # load checkpoints
95 | checkpoint = torch.load(args.finetune)['model']
96 | model.load_state_dict(checkpoint, strict=True)
97 | # build the dataloader
98 | dataset_path = os.path.join(args.data_path, args.mode)
99 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
100 | std=[0.229, 0.224, 0.225])
101 | dataset = SegmentationFolder(root=dataset_path,
102 | transform=transforms.Compose([
103 | transforms.Resize(256),
104 | transforms.ToTensor(),
105 | normalize,
106 | ]))
107 | dataloader = torch.utils.data.DataLoader(dataset,
108 | batch_size=1,
109 | num_workers=16,
110 | pin_memory=True)
111 |
112 | output_dir = os.path.join(args.output_dir, args.mode)
113 |
114 | for images, path, height, width in tqdm(dataloader):
115 | path = path[0]
116 | cate = path.split('/')[-2]
117 | name = path.split('/')[-1].split('.')[0]
118 | if not os.path.exists(os.path.join(output_dir, cate)):
119 | os.makedirs(os.path.join(output_dir, cate))
120 |
121 | with torch.no_grad():
122 | H = height.item()
123 | W = width.item()
124 |
125 | output = model.forward(images.cuda())
126 |
127 | if (H > W and H * W > args.max_res * args.max_res
128 | and args.max_res > 0):
129 | output = F.interpolate(
130 | output, (args.max_res, int(args.max_res * W / H)),
131 | mode='bilinear',
132 | align_corners=False)
133 | output = torch.argmax(output, dim=1, keepdim=True)
134 | output = F.interpolate(output.float(), (H, W),
135 | mode='nearest').long()
136 | elif (H <= W and H * W > args.max_res * args.max_res
137 | and args.max_res > 0):
138 | output = F.interpolate(
139 | output, (int(args.max_res * H / W), args.max_res),
140 | mode='bilinear',
141 | align_corners=False)
142 | output = torch.argmax(output, dim=1, keepdim=True)
143 | output = F.interpolate(output.float(), (H, W),
144 | mode='nearest').long()
145 | else:
146 | output = F.interpolate(output, (H, W),
147 | mode='bilinear',
148 | align_corners=False)
149 | output = torch.argmax(output, dim=1, keepdim=True)
150 | output = output.squeeze()
151 |
152 | res = torch.zeros(size=(output.shape[0], output.shape[1], 3))
153 | res[:, :, 0] = output % 256
154 | res[:, :, 1] = output // 256
155 | res = res.cpu().numpy()
156 |
157 | res = Image.fromarray(res.astype(np.uint8))
158 | res.save(os.path.join(output_dir, cate, name + '.png'))
159 |
160 | if args.mode == 'test':
161 | method = 'Method name: {}\n'.format(
162 | args.method) + \
163 | 'Training data: {}\nTraining scheme: {}\n'.format(
164 | args.train_data, args.train_scheme) + \
165 | 'Networks: {}\nPaper/Project link: {}\n'.format(
166 | args.model, args.link) + \
167 | 'Method description: {}'.format(
168 | args.description)
169 | with open(os.path.join(output_dir, 'method.txt'), 'w') as f:
170 | f.write(method)
171 |
172 | # zip for submission
173 | shutil.make_archive(os.path.join(args.output_dir, args.mode),
174 | 'zip',
175 | root_dir=output_dir)
176 |
177 |
178 | if __name__ == '__main__':
179 | main()
180 |
--------------------------------------------------------------------------------
/main_segfinetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import argparse
13 | import datetime
14 | import json
15 | import os
16 | import time
17 | from pathlib import Path
18 |
19 | import numpy as np
20 | import timm
21 | import torch
22 | import torch.backends.cudnn as cudnn
23 | from timm.models.layers import trunc_normal_
24 |
25 | import models
26 | import util.lr_decay as lrd
27 | import util.misc as misc
28 | from engine_segfinetune import evaluate, train_one_epoch
29 | from util.datasets import build_dataset
30 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
31 | from util.pos_embed import interpolate_pos_embed
32 | from timm.models.convnext import checkpoint_filter_fn
33 |
34 |
35 | def get_args_parser():
36 | parser = argparse.ArgumentParser(
37 | 'Semi-supervised fine-tuning for '
38 | 'the semantic segmentation on the ImageNet-S dataset',
39 | add_help=False)
40 | parser.add_argument(
41 | '--batch_size',
42 | default=64,
43 | type=int,
44 | help='Batch size per GPU '
45 | '(effective batch size is batch_size * accum_iter * # gpus')
46 | parser.add_argument('--epochs', default=50, type=int)
47 | parser.add_argument(
48 | '--accum_iter',
49 | default=1,
50 | type=int,
51 | help='Accumulate gradient iterations '
52 | '(for increasing the effective batch size under memory constraints)')
53 | parser.add_argument('--saveckp_freq',
54 | default=20,
55 | type=int,
56 | help='Save checkpoint every x epochs.')
57 | parser.add_argument('--eval_freq',
58 | default=20,
59 | type=int,
60 | help='Evaluate the model every x epochs.')
61 | parser.add_argument(
62 | '--max_res',
63 | default=1000,
64 | type=int,
65 | help='Maximum resolution for evaluation. 0 for disable.')
66 |
67 | # Model parameters
68 | parser.add_argument('--model',
69 | default='vit_small_patch16',
70 | type=str,
71 | metavar='MODEL',
72 | help='Name of model to train')
73 | parser.add_argument('--drop_path',
74 | type=float,
75 | default=0.1,
76 | metavar='PCT',
77 | help='Drop path rate (default: 0.1)')
78 | parser.add_argument('--patch_size',
79 | type=int,
80 | default=4,
81 | help='For convnext/rfconvnext, the numnber of output channels is '
82 | 'nb_classes * patch_size * patch_size.'
83 | 'https://arxiv.org/pdf/2111.06377.pdf')
84 |
85 | # Optimizer parameters
86 | parser.add_argument('--clip_grad',
87 | type=float,
88 | default=None,
89 | metavar='NORM',
90 | help='Clip gradient norm (default: None, no clipping)')
91 | parser.add_argument('--weight_decay',
92 | type=float,
93 | default=0.05,
94 | help='weight decay (default: 0.05)')
95 |
96 | parser.add_argument('--lr',
97 | type=float,
98 | default=None,
99 | metavar='LR',
100 | help='learning rate (absolute lr)')
101 | parser.add_argument('--blr',
102 | type=float,
103 | default=1e-3,
104 | metavar='LR',
105 | help='base learning rate: '
106 | 'absolute_lr = base_lr * total_batch_size / 256')
107 | parser.add_argument('--layer_decay',
108 | type=float,
109 | default=[0.75],
110 | nargs="+",
111 | help='layer-wise lr decay from ELECTRA/BEiT.'
112 | 'For each layer, the function get_layer_id in utils.lr_decay '
113 | 'returns (layer_group, layer_id). '
114 | 'According to the layer_group, different parameters are grouped, '
115 | 'and the layer_decay[layer_group] is used as the decay rate for different groups.')
116 | parser.add_argument('--layer_multiplier',
117 | type=float,
118 | default=[1.0],
119 | nargs="+",
120 | help='The learning rate multipliers for different layers. '
121 | 'For each layer, the function get_layer_id in utils.lr_decay '
122 | 'returns (layer_group, layer_id). '
123 | 'According to the layer_group, different parameters are grouped, '
124 | 'and the learning rate of each group is lr = lr * layer_multiplier[layer_group].')
125 | parser.add_argument('--min_lr',
126 | type=float,
127 | default=1e-6,
128 | metavar='LR',
129 | help='lower lr bound for cyclic schedulers that hit 0')
130 | parser.add_argument('--warmup_epochs',
131 | type=int,
132 | default=5,
133 | metavar='N',
134 | help='epochs to warmup LR')
135 |
136 | # Augmentation parameters
137 | parser.add_argument(
138 | '--color_jitter',
139 | type=float,
140 | default=None,
141 | metavar='PCT',
142 | help='Color jitter factor (enabled only when not using Auto/RandAug)')
143 |
144 | # * Finetuning params
145 | parser.add_argument('--finetune',
146 | default='',
147 | help='finetune from checkpoint')
148 | parser.add_argument('--pretrained_rfnext',
149 | default='',
150 | help='pretrained weights for RF-Next')
151 |
152 | # Dataset parameters
153 | parser.add_argument('--data_path',
154 | default='/datasets01/imagenet_full_size/061417/',
155 | type=str,
156 | help='dataset path')
157 | parser.add_argument('--iteration_one_epoch',
158 | default=-1,
159 | type=int,
160 | help='number of iterations in one epoch')
161 | parser.add_argument('--nb_classes',
162 | default=1000,
163 | type=int,
164 | help='number of the classification types')
165 |
166 | parser.add_argument('--output_dir',
167 | default=None,
168 | help='path where to save, empty for no saving')
169 | parser.add_argument('--device',
170 | default='cuda',
171 | help='device to use for training / testing')
172 | parser.add_argument('--seed', default=0, type=int)
173 | parser.add_argument('--resume', default='', help='resume from checkpoint')
174 |
175 | parser.add_argument('--start_epoch',
176 | default=0,
177 | type=int,
178 | metavar='N',
179 | help='start epoch')
180 | parser.add_argument('--eval',
181 | action='store_true',
182 | help='Perform evaluation only')
183 | parser.add_argument('--dist_eval',
184 | action='store_true',
185 | default=False,
186 | help='Enabling distributed evaluation '
187 | '(recommended during training for faster monitor')
188 | parser.add_argument('--num_workers', default=10, type=int)
189 | # distributed training parameters
190 | parser.add_argument('--world_size',
191 | default=1,
192 | type=int,
193 | help='number of distributed processes')
194 | parser.add_argument('--local_rank', default=-1, type=int)
195 | parser.add_argument('--dist_on_itp', action='store_true')
196 | parser.add_argument('--dist_url',
197 | default='env://',
198 | help='url used to set up distributed training')
199 |
200 | return parser
201 |
202 |
203 | def main(args):
204 | misc.init_distributed_mode(args)
205 |
206 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
207 | print('{}'.format(args).replace(', ', ',\n'))
208 |
209 | device = torch.device(args.device)
210 |
211 | # fix the seed for reproducibility
212 | seed = args.seed + misc.get_rank()
213 | torch.manual_seed(seed)
214 | np.random.seed(seed)
215 |
216 | cudnn.benchmark = True
217 |
218 | dataset_train = build_dataset(is_train=True, args=args)
219 | dataset_val = build_dataset(is_train=False, args=args)
220 |
221 | if True: # args.distributed:
222 | num_tasks = misc.get_world_size()
223 | global_rank = misc.get_rank()
224 | sampler_train = torch.utils.data.DistributedSampler(
225 | dataset_train,
226 | num_replicas=num_tasks,
227 | rank=global_rank,
228 | shuffle=True)
229 | print('Sampler_train = %s' % str(sampler_train))
230 | if args.dist_eval:
231 | if len(dataset_val) % num_tasks != 0:
232 | print(
233 | 'Warning: Enabling distributed evaluation '
234 | 'with an eval dataset not divisible by process number. '
235 | 'This will slightly alter validation '
236 | 'results as extra duplicate entries are added to achieve '
237 | 'equal num of samples per-process.')
238 | sampler_val = torch.utils.data.DistributedSampler(
239 | dataset_val,
240 | num_replicas=num_tasks,
241 | rank=global_rank,
242 | shuffle=True) # shuffle=True to reduce monitor bias
243 | else:
244 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
245 | else:
246 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
247 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
248 |
249 | data_loader_train = torch.utils.data.DataLoader(
250 | dataset_train,
251 | sampler=sampler_train,
252 | batch_size=args.batch_size,
253 | num_workers=args.num_workers,
254 | pin_memory=True,
255 | drop_last=False,
256 | )
257 |
258 | data_loader_val = torch.utils.data.DataLoader(dataset_val,
259 | sampler=sampler_val,
260 | batch_size=1,
261 | num_workers=args.num_workers,
262 | pin_memory=True,
263 | drop_last=False)
264 | args.iteration_one_epoch = len(data_loader_train)
265 | model = models.__dict__[args.model](args)
266 |
267 | if args.finetune and not args.eval:
268 | checkpoint = torch.load(args.finetune, map_location='cpu')
269 | print('Load pre-trained checkpoint from: %s' % args.finetune)
270 | if 'model' in checkpoint:
271 | checkpoint = checkpoint['model']
272 | elif 'state_dict' in checkpoint:
273 | checkpoint = checkpoint['state_dict']
274 | checkpoint = {
275 | k.replace('module.', ''): v
276 | for k, v in checkpoint.items()
277 | }
278 | checkpoint = {
279 | k.replace('backbone.', ''): v
280 | for k, v in checkpoint.items()
281 | }
282 |
283 | for k in ['head.weight', 'head.bias']:
284 | if k in checkpoint.keys():
285 | print(f'Removing key {k} from pretrained checkpoint')
286 | del checkpoint[k]
287 |
288 | if 'vit' in args.model:
289 | # interpolate position embedding
290 | interpolate_pos_embed(model, checkpoint)
291 | elif 'convnext' in args.model:
292 | checkpoint = checkpoint_filter_fn(checkpoint, model)
293 |
294 | # load pre-trained model
295 | msg = model.load_state_dict(checkpoint, strict=False)
296 | print('Missing: {}'.format(msg.missing_keys))
297 |
298 | model.to(device)
299 |
300 | model_without_ddp = model
301 | n_parameters = sum(p.numel() for p in model.parameters()
302 | if p.requires_grad)
303 |
304 | print('Model = %s' % str(model_without_ddp))
305 | print('number of params (M): %.2f' % (n_parameters / 1.e6))
306 |
307 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
308 |
309 | if args.lr is None: # only base_lr is specified
310 | args.lr = args.blr * eff_batch_size / 256
311 |
312 | print('base lr: %.2e' % (args.lr * 256 / eff_batch_size))
313 | print('actual lr: %.2e' % args.lr)
314 |
315 | print('accumulate grad iterations: %d' % args.accum_iter)
316 | print('effective batch size: %d' % eff_batch_size)
317 |
318 | if args.distributed:
319 | model = torch.nn.parallel.DistributedDataParallel(
320 | model, device_ids=[args.gpu], find_unused_parameters=True)
321 | model_without_ddp = model.module
322 |
323 | # build optimizer with layer-wise lr decay (lrd)
324 | param_groups = lrd.param_groups_lrd(
325 | model_without_ddp,
326 | args.weight_decay,
327 | no_weight_decay_list=model_without_ddp.no_weight_decay(),
328 | layer_decay=args.layer_decay,
329 | layer_multiplier=args.layer_multiplier)
330 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr)
331 | loss_scaler = NativeScaler()
332 | criterion = torch.nn.CrossEntropyLoss(
333 | ignore_index=1000) # 1000 denotes the ignored region in ImageNet-S.
334 | print('criterion = %s' % str(criterion))
335 |
336 | misc.load_model(args=args,
337 | model_without_ddp=model_without_ddp,
338 | optimizer=optimizer,
339 | loss_scaler=loss_scaler)
340 |
341 | if args.eval:
342 | test_stats = evaluate(data_loader_val, model, device, args.nb_classes)
343 | print(f'mIoU of the network on the {len(dataset_val)} '
344 | f"test images: {test_stats['mIoU']:.1f}%")
345 | if len(dataset_val) % num_tasks != 0:
346 | print('Warning: Enabling distributed evaluation '
347 | 'with an eval dataset not divisible by process number. '
348 | 'This will slightly alter validation '
349 | 'results as extra duplicate entries are added to achieve '
350 | 'equal num of samples per-process.')
351 | exit(0)
352 |
353 | print(f'Start training for {args.epochs} epochs')
354 | start_time = time.time()
355 | max_accuracy = 0.0
356 | for epoch in range(args.start_epoch, args.epochs):
357 | if args.distributed:
358 | data_loader_train.sampler.set_epoch(epoch)
359 | train_stats = train_one_epoch(model,
360 | criterion,
361 | data_loader_train,
362 | optimizer,
363 | device,
364 | epoch,
365 | loss_scaler,
366 | args.clip_grad,
367 | args=args)
368 | if args.output_dir and (epoch + 1) % args.saveckp_freq == 0:
369 | misc.save_model(args=args,
370 | model=model,
371 | model_without_ddp=model_without_ddp,
372 | optimizer=optimizer,
373 | loss_scaler=loss_scaler,
374 | epoch=epoch)
375 |
376 | if (epoch + 1) % args.eval_freq == 0 or epoch == 0:
377 | test_stats = evaluate(data_loader_val,
378 | model,
379 | device,
380 | args.nb_classes,
381 | max_res=args.max_res)
382 | print(f'mIoU of the network on the {len(dataset_val)} '
383 | f"test images: {test_stats['mIoU']:.3f}%")
384 | if len(dataset_val) % num_tasks != 0:
385 | print('Warning: Enabling distributed evaluation '
386 | 'with an eval dataset not divisible by process number. '
387 | 'This will slightly alter validation '
388 | 'results as extra duplicate entries are added to achieve '
389 | 'equal num of samples per-process.')
390 | max_accuracy = max(max_accuracy, test_stats['mIoU'])
391 | print(f'Max mIoU: {max_accuracy:.2f}%')
392 |
393 | log_stats = {
394 | **{f'train_{k}': v
395 | for k, v in train_stats.items()},
396 | **{f'test_{k}': v
397 | for k, v in test_stats.items()}, 'epoch': epoch,
398 | 'n_parameters': n_parameters
399 | }
400 |
401 | if args.output_dir and misc.is_main_process():
402 | with open(os.path.join(args.output_dir, 'log.txt'),
403 | mode='a',
404 | encoding='utf-8') as f:
405 | f.write(json.dumps(log_stats) + '\n')
406 |
407 | total_time = time.time() - start_time
408 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
409 | print('Training time {}'.format(total_time_str))
410 |
411 |
412 | if __name__ == '__main__':
413 | args = get_args_parser()
414 | args = args.parse_args()
415 | if args.output_dir:
416 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
417 | main(args)
418 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models_resnet import resnet18, resnet50, resnet50_d16
2 | from .models_vit import vit_base_patch16, vit_small_patch16
3 | from .models_convnext import convnext_tiny
4 | from .models_rfconvnext import rfconvnext_tiny_rfmerge, rfconvnext_tiny_rfmultiple, rfconvnext_tiny_rfsearch, rfconvnext_tiny_rfsingle
5 |
--------------------------------------------------------------------------------
/models/models_convnext.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import timm.models.convnext
4 | from collections import OrderedDict
5 | import torch
6 | import torch.nn as nn
7 | from timm.models.layers import trunc_normal_
8 |
9 |
10 | class ConvNeXt(timm.models.convnext.ConvNeXt):
11 | """Vision Transformer with support for semantic seg."""
12 | def __init__(self, patch_size=4, **kwargs):
13 | norm_layer = kwargs.pop('norm_layer')
14 | super(ConvNeXt, self).__init__(**kwargs)
15 | assert self.num_classes > 0
16 |
17 | del self.head
18 | del self.norm_pre
19 |
20 | self.patch_size = patch_size
21 | self.depths = kwargs['depths']
22 | self.num_layers = sum(self.depths) + len(self.depths)
23 | self.rf_change = []
24 |
25 | self.seg_norm = norm_layer(self.num_features)
26 | self.seg_head = nn.Sequential(OrderedDict([
27 | ('drop', nn.Dropout(self.drop_rate)),
28 | ('fc', nn.Conv2d(self.num_features, self.num_classes * (self.patch_size**2), 1))
29 | ]))
30 |
31 | trunc_normal_(self.seg_head[1].weight, std=.02)
32 | torch.nn.init.zeros_(self.seg_head[1].bias)
33 |
34 | @torch.jit.ignore
35 | def no_weight_decay(self):
36 | return dict()
37 |
38 | def forward_features(self, x):
39 | x = self.stem(x)
40 | x = self.stages(x)
41 | b, c, h, w = x.shape
42 | x = x.view(b, c, -1).permute(0, 2, 1)
43 | x = self.seg_norm(x)
44 | x = x.permute(0, 2, 1).view(b, c, h, w)
45 | return x
46 |
47 | def forward_head(self, x):
48 | x = self.seg_head.drop(x)
49 | x = self.seg_head.fc(x)
50 | b, _, h, w = x.shape
51 | x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, self.patch_size, self.patch_size, self.num_classes)
52 | x = torch.einsum('nhwpqc->nchpwq', x)
53 | x = x.contiguous().view(b, self.num_classes, h * self.patch_size, w * self.patch_size)
54 | return x
55 |
56 | def get_layer_id(self, name):
57 | """
58 | Assign a parameter with its layer id for layer-wise decay.
59 |
60 | For each layer, the get_layer_id returns (layer_group, layer_id).
61 | According to the layer_group, different parameters are grouped,
62 | and layers in different groups use different decay rates.
63 |
64 | If only the layer_id is returned, the layer_group are set to 0 by default.
65 | """
66 | if name in ("cls_token", "mask_token", "pos_embed"):
67 | return (0, 0)
68 | elif name.startswith("stem"):
69 | return (0, 0)
70 | elif name.startswith("stages") and 'downsample' in name:
71 | stage_id = int(name.split('.')[1])
72 | if stage_id == 0:
73 | layer_id = 0
74 | else:
75 | layer_id = sum(self.depths[:stage_id]) + stage_id
76 | return (0, layer_id)
77 | elif name.startswith("stages") and 'downsample' not in name:
78 | stage_id = int(name.split('.')[1])
79 | block_id = int(name.split('.')[3])
80 | if stage_id == 0:
81 | layer_id = block_id + 1
82 | else:
83 | layer_id = sum(self.depths[:stage_id]) + stage_id + block_id + 1
84 | return (0, layer_id)
85 | else:
86 | return (0, self.num_layers)
87 |
88 |
89 | def convnext_tiny(args):
90 | model = ConvNeXt(
91 | depths=(3, 3, 9, 3),
92 | dims=(96, 192, 384, 768),
93 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
94 | num_classes=getattr(args, 'nb_classes', 920),
95 | drop_path_rate=getattr(args, 'drop_path', 0),
96 | patch_size=getattr(args, 'patch_size', 4)
97 | )
98 | return model
99 |
--------------------------------------------------------------------------------
/models/models_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
7 | """3x3 convolution with padding"""
8 | return nn.Conv2d(
9 | in_planes,
10 | out_planes,
11 | kernel_size=3,
12 | stride=stride,
13 | padding=dilation,
14 | groups=groups,
15 | bias=False,
16 | dilation=dilation,
17 | )
18 |
19 |
20 | def conv1x1(in_planes, out_planes, stride=1):
21 | """1x1 convolution"""
22 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 | __constants__ = ["downsample"]
28 |
29 | def __init__(
30 | self,
31 | inplanes,
32 | planes,
33 | stride=1,
34 | downsample=None,
35 | groups=1,
36 | base_width=64,
37 | dilation=1,
38 | norm_layer=None,
39 | ):
40 | super(BasicBlock, self).__init__()
41 | if norm_layer is None:
42 | norm_layer = nn.BatchNorm2d
43 | if groups != 1 or base_width != 64:
44 | raise ValueError("BasicBlock only supports groups=1 and base_width=64")
45 | if dilation > 1:
46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = norm_layer(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes)
52 | self.bn2 = norm_layer(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | identity = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | identity = self.downsample(x)
68 |
69 | out += identity
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 | __constants__ = ["downsample"]
78 |
79 | def __init__(
80 | self,
81 | inplanes,
82 | planes,
83 | stride=1,
84 | downsample=None,
85 | groups=1,
86 | base_width=64,
87 | dilation=1,
88 | norm_layer=None,
89 | ):
90 | super(Bottleneck, self).__init__()
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | width = int(planes * (base_width / 64.0)) * groups
94 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
95 | self.conv1 = conv1x1(inplanes, width)
96 | self.bn1 = norm_layer(width)
97 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
98 | self.bn2 = norm_layer(width)
99 | self.conv3 = conv1x1(width, planes * self.expansion)
100 | self.bn3 = norm_layer(planes * self.expansion)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.downsample = downsample
103 | self.stride = stride
104 |
105 | def forward(self, x):
106 | identity = x
107 |
108 | out = self.conv1(x)
109 | out = self.bn1(out)
110 | out = self.relu(out)
111 |
112 | out = self.conv2(out)
113 | out = self.bn2(out)
114 | out = self.relu(out)
115 |
116 | out = self.conv3(out)
117 | out = self.bn3(out)
118 |
119 | if self.downsample is not None:
120 | identity = self.downsample(x)
121 |
122 | out += identity
123 | out = self.relu(out)
124 |
125 | return out
126 |
127 |
128 | class ResNet(nn.Module):
129 | def __init__(
130 | self,
131 | block,
132 | layers,
133 | zero_init_residual=False,
134 | groups=1,
135 | widen=1,
136 | width_per_group=64,
137 | replace_stride_with_dilation=None,
138 | norm_layer=None,
139 | eval_mode=False,
140 | num_classes=0,
141 | ):
142 | super(ResNet, self).__init__()
143 | if norm_layer is None:
144 | norm_layer = nn.BatchNorm2d
145 | self._norm_layer = norm_layer
146 |
147 | self.eval_mode = eval_mode
148 | self.padding = nn.ConstantPad2d(1, 0.0)
149 |
150 | self.inplanes = width_per_group * widen
151 | self.dilation = 1
152 | if replace_stride_with_dilation is None:
153 | # each element in the tuple indicates if we should replace
154 | # the 2x2 stride with a dilated convolution instead
155 | replace_stride_with_dilation = [False, False, False]
156 | if len(replace_stride_with_dilation) != 3:
157 | raise ValueError(
158 | "replace_stride_with_dilation should be None "
159 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
160 | )
161 | self.groups = groups
162 | self.base_width = width_per_group
163 | self.layers = layers
164 | self.num_layers = sum(self.layers) + 1
165 |
166 | # change padding 3 -> 2 compared to original torchvision code because added a padding layer
167 | num_out_filters = width_per_group * widen
168 | self.conv1 = nn.Conv2d(
169 | 3, num_out_filters, kernel_size=7, stride=2, padding=2, bias=False
170 | )
171 | self.bn1 = norm_layer(num_out_filters)
172 | self.relu = nn.ReLU(inplace=True)
173 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
174 | self.layer1 = self._make_layer(block, num_out_filters, layers[0])
175 | num_out_filters *= 2
176 | self.layer2 = self._make_layer(
177 | block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
178 | )
179 | num_out_filters *= 2
180 | self.layer3 = self._make_layer(
181 | block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
182 | )
183 | num_out_filters *= 2
184 | self.layer4 = self._make_layer(
185 | block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
186 | )
187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
188 |
189 | mid_channels = 512 * block.expansion
190 | # segmentation head and loss function
191 | self.head = nn.Conv2d(mid_channels, num_classes, 1, 1)
192 |
193 | for m in self.modules():
194 | if isinstance(m, nn.Conv2d):
195 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
196 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
197 | nn.init.constant_(m.weight, 1)
198 | nn.init.constant_(m.bias, 0)
199 |
200 | # Zero-initialize the last BN in each residual branch,
201 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
202 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
203 | if zero_init_residual:
204 | for m in self.modules():
205 | if isinstance(m, Bottleneck):
206 | nn.init.constant_(m.bn3.weight, 0)
207 | elif isinstance(m, BasicBlock):
208 | nn.init.constant_(m.bn2.weight, 0)
209 |
210 | @torch.jit.ignore
211 | def no_weight_decay(self):
212 | return dict()
213 |
214 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
215 | norm_layer = self._norm_layer
216 | downsample = None
217 | previous_dilation = self.dilation
218 | if dilate:
219 | self.dilation *= stride
220 | stride = 1
221 | if stride != 1 or self.inplanes != planes * block.expansion:
222 | downsample = nn.Sequential(
223 | conv1x1(self.inplanes, planes * block.expansion, stride),
224 | norm_layer(planes * block.expansion),
225 | )
226 |
227 | layers = []
228 | layers.append(
229 | block(
230 | self.inplanes,
231 | planes,
232 | stride,
233 | downsample,
234 | self.groups,
235 | self.base_width,
236 | previous_dilation,
237 | norm_layer,
238 | )
239 | )
240 | self.inplanes = planes * block.expansion
241 | for _ in range(1, blocks):
242 | layers.append(
243 | block(
244 | self.inplanes,
245 | planes,
246 | groups=self.groups,
247 | base_width=self.base_width,
248 | dilation=self.dilation,
249 | norm_layer=norm_layer,
250 | )
251 | )
252 |
253 | return nn.Sequential(*layers)
254 |
255 | def forward_backbone(self, x, pool=True):
256 | x = self.padding(x)
257 | x = self.conv1(x)
258 | x = self.bn1(x)
259 | x = self.relu(x)
260 | x = self.maxpool(x)
261 | x = self.layer1(x)
262 | x = self.layer2(x)
263 | x = self.layer3(x)
264 | x = self.layer4(x)
265 |
266 | return x
267 |
268 | def forward(self, inputs):
269 |
270 | out = self.forward_backbone(inputs, pool=False)
271 | out = self.head(out)
272 | out = F.interpolate(out, scale_factor=2, align_corners=False, mode='bilinear')
273 |
274 | return out
275 |
276 | def get_layer_id(self, name):
277 | """
278 | Assign a parameter with its layer id for layer-wise decay.
279 |
280 | For each layer, the get_layer_id returns (layer_group, layer_id).
281 | According to the layer_group, different parameters are grouped,
282 | and layers in different groups use different decay rates.
283 |
284 | If only the layer_id is returned, the layer_group are set to 0 by default.
285 | """
286 |
287 | if name.startswith('conv1'):
288 | return (0, 0)
289 | elif name.startswith('bn1'):
290 | return (0, 0)
291 | elif name.startswith('layer'):
292 | return (0, sum(self.layers[:int(name[5]) - 1]) + int(name[7]) + 1)
293 | else:
294 | return (0, self.num_layers)
295 |
296 |
297 | def resnet18(args):
298 | kwargs=dict(num_classes=args.nb_classes)
299 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
300 |
301 |
302 | def resnet50(args):
303 | kwargs=dict(num_classes=args.nb_classes)
304 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
305 |
306 | def resnet50_d16(args):
307 | kwargs=dict(num_classes=args.nb_classes)
308 | return ResNet(Bottleneck, [3, 4, 6, 3], replace_stride_with_dilation=[False, False, True], **kwargs)
309 |
--------------------------------------------------------------------------------
/models/models_rfconvnext.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import models.rfconvnext as rfconvnext
4 | from collections import OrderedDict
5 | import torch
6 | import torch.nn as nn
7 | from timm.models.layers import trunc_normal_
8 |
9 |
10 | class RFConvNeXt(rfconvnext.RFConvNeXt):
11 | """Vision Transformer with support for semantic seg."""
12 | def __init__(self, patch_size=4, **kwargs):
13 | norm_layer = kwargs.pop('norm_layer')
14 | super(RFConvNeXt, self).__init__(**kwargs)
15 | assert self.num_classes > 0
16 |
17 | del self.head
18 | del self.norm_pre
19 |
20 | self.patch_size = patch_size
21 | self.depths = kwargs['depths']
22 | self.num_layers = sum(self.depths) + len(self.depths)
23 | # The layers whose dilation rates are changed in RF-Next.
24 | # These layers use different hyper-parameters in training.
25 | self.rf_change = []
26 | self.rf_change_name = []
27 |
28 | self.seg_norm = norm_layer(self.num_features)
29 | self.seg_head = nn.Sequential(OrderedDict([
30 | ('drop', nn.Dropout(self.drop_rate)),
31 | ('fc', nn.Conv2d(self.num_features, self.num_classes * (self.patch_size**2), 1))
32 | ]))
33 |
34 | trunc_normal_(self.seg_head[1].weight, std=.02)
35 | torch.nn.init.zeros_(self.seg_head[1].bias)
36 |
37 | self.get_kernel_size_changed()
38 |
39 |
40 | def get_kernel_size_changed(self):
41 | """
42 | To get rfconvs whose dilate rates are changed.
43 | """
44 | for i, stage in enumerate(self.stages):
45 | for j, block in enumerate(stage.blocks):
46 | if block.conv_dw.dilation[0] > 1 or block.conv_dw.kernel_size[0] > 13:
47 | self.rf_change_name.extend(
48 | [
49 | 'stages.{}.blocks.{}.conv_dw.weight'.format(i, j),
50 | 'stages.{}.blocks.{}.conv_dw.bias'.format(i, j),
51 | 'stages.{}.blocks.{}.conv_dw.sample_weights'.format(i, j)
52 | ]
53 | )
54 | self.rf_change.append(self.stages[i].blocks[j].conv_dw)
55 |
56 | def freeze(self):
57 | """
58 | In the mode of rfmerge,
59 | we initilize the model with weights in rfmultiple and
60 | only finetune seg_norm, seg_head and rfconvs whose dilate rates are changed.
61 | The other parts of the network are freezed during funetuning.
62 |
63 | Note that this freezing operation may be not required for other tasks.
64 | """
65 | if len(self.rf_change_name) == 0:
66 | self.get_kernel_size_changed()
67 | # finetune the rfconvs whose dilate rates are changed
68 | for n, p in self.named_parameters():
69 | p.requires_grad = True if n in self.rf_change_name else False
70 | # finetune the seg_norm, seg_head
71 | for n, p in self.seg_head.named_parameters():
72 | p.requires_grad = True
73 | for n, p in self.seg_norm.named_parameters():
74 | p.requires_grad = True
75 |
76 | @torch.jit.ignore
77 | def no_weight_decay(self):
78 | return dict()
79 |
80 | def forward_features(self, x):
81 | x = self.stem(x)
82 | x = self.stages(x)
83 | b, c, h, w = x.shape
84 | x = x.view(b, c, -1).permute(0, 2, 1)
85 | x = self.seg_norm(x)
86 | x = x.permute(0, 2, 1).view(b, c, h, w)
87 | return x
88 |
89 | def forward_head(self, x):
90 | x = self.seg_head.drop(x)
91 | x = self.seg_head.fc(x)
92 | b, _, h, w = x.shape
93 | x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, self.patch_size, self.patch_size, self.num_classes)
94 | x = torch.einsum('nhwpqc->nchpwq', x)
95 | x = x.contiguous().view(b, self.num_classes, h * self.patch_size, w * self.patch_size)
96 | return x
97 |
98 | def get_layer_id(self, name):
99 | """
100 | Assign a parameter with its layer id for layer-wise decay.
101 |
102 | For each layer, the get_layer_id returns (layer_group, layer_id).
103 | According to the layer_group, different parameters are grouped,
104 | and layers in different groups use different decay rates.
105 |
106 | If only the layer_id is returned, the layer_group are set to 0 by default.
107 | """
108 | if name in ("cls_token", "mask_token", "pos_embed"):
109 | return (0, 0)
110 | elif name.startswith("stem"):
111 | return (0, 0)
112 | elif name.startswith("stages") and 'downsample' in name:
113 | stage_id = int(name.split('.')[1])
114 | if stage_id == 0:
115 | layer_id = 0
116 | else:
117 | layer_id = sum(self.depths[:stage_id]) + stage_id
118 |
119 | if name.endswith('sample_weights') or name in self.rf_change_name:
120 | return (1, layer_id)
121 | return (0, layer_id)
122 | elif name.startswith("stages") and 'downsample' not in name:
123 | stage_id = int(name.split('.')[1])
124 | block_id = int(name.split('.')[3])
125 | if stage_id == 0:
126 | layer_id = block_id + 1
127 | else:
128 | layer_id = sum(self.depths[:stage_id]) + stage_id + block_id + 1
129 |
130 | if name.endswith('sample_weights') or name in self.rf_change_name:
131 | return (1, layer_id)
132 | return (0, layer_id)
133 | else:
134 | return (0, self.num_layers)
135 |
136 |
137 | def rfconvnext_tiny_rfsearch(args):
138 | search_cfgs = dict(
139 | num_branches=3,
140 | expand_rate=0.5,
141 | max_dilation=None,
142 | min_dilation=1,
143 | init_weight=0.01,
144 | search_interval=getattr(args, 'iteration_one_epoch', 1250) * 10, # step every 10 epochs
145 | max_search_step=3, # search for 3 steps
146 | )
147 | model = RFConvNeXt(
148 | depths=(3, 3, 9, 3),
149 | dims=(96, 192, 384, 768),
150 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
151 | rf_mode='rfsearch',
152 | search_cfgs=search_cfgs,
153 | num_classes=getattr(args, 'nb_classes', 920),
154 | drop_path_rate=getattr(args, 'drop_path', 0),
155 | pretrained_weights=getattr(args, 'pretrained_rfnext', None),
156 | patch_size=getattr(args, 'patch_size', 4)
157 | )
158 | return model
159 |
160 |
161 | def rfconvnext_tiny_rfmultiple(args):
162 | search_cfgs = dict(
163 | num_branches=3,
164 | expand_rate=0.5,
165 | max_dilation=None,
166 | min_dilation=1,
167 | init_weight=0.01,
168 | )
169 | model = RFConvNeXt(
170 | depths=(3, 3, 9, 3),
171 | dims=(96, 192, 384, 768),
172 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
173 | rf_mode='rfmultiple',
174 | search_cfgs=search_cfgs,
175 | num_classes=getattr(args, 'nb_classes', 920),
176 | drop_path_rate=getattr(args, 'drop_path', 0),
177 | pretrained_weights=getattr(args, 'pretrained_rfnext', None),
178 | patch_size=getattr(args, 'patch_size', 4)
179 | )
180 | return model
181 |
182 | def rfconvnext_tiny_rfsingle(args):
183 | search_cfgs = dict(
184 | num_branches=3,
185 | expand_rate=0.5,
186 | max_dilation=None,
187 | min_dilation=1,
188 | init_weight=0.01,
189 | )
190 | model = RFConvNeXt(
191 | depths=(3, 3, 9, 3),
192 | dims=(96, 192, 384, 768),
193 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
194 | rf_mode='rfsingle',
195 | search_cfgs=search_cfgs,
196 | num_classes=getattr(args, 'nb_classes', 920),
197 | drop_path_rate=getattr(args, 'drop_path', 0),
198 | pretrained_weights=getattr(args, 'pretrained_rfnext', None),
199 | patch_size=getattr(args, 'patch_size', 4)
200 | )
201 | return model
202 |
203 | def rfconvnext_tiny_rfmerge(args):
204 | search_cfgs = dict(
205 | num_branches=3,
206 | expand_rate=0.5,
207 | max_dilation=None,
208 | min_dilation=1,
209 | init_weight=0.01
210 | )
211 | model = RFConvNeXt(
212 | depths=(3, 3, 9, 3),
213 | dims=(96, 192, 384, 768),
214 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
215 | rf_mode='rfmerge',
216 | search_cfgs=search_cfgs,
217 | num_classes=getattr(args, 'nb_classes', 920),
218 | drop_path_rate=getattr(args, 'drop_path', 0),
219 | pretrained_weights=getattr(args, 'pretrained_rfnext', None),
220 | patch_size=getattr(args, 'patch_size', 4)
221 | )
222 | # freeze layers except for seg_norm, seg_head and the rfconvs whose dialtion rates are changed.
223 | model.freeze()
224 | return model
225 |
--------------------------------------------------------------------------------
/models/models_vit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9 | # DeiT: https://github.com/facebookresearch/deit
10 | # --------------------------------------------------------
11 |
12 | import math
13 | from functools import partial
14 |
15 | import timm.models.vision_transformer
16 | import torch
17 | import torch.nn as nn
18 | from timm.models.layers import trunc_normal_
19 |
20 |
21 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer):
22 | """Vision Transformer with support for semantic seg."""
23 | def __init__(self, **kwargs):
24 | super(VisionTransformer, self).__init__(**kwargs)
25 |
26 | embed_dim = kwargs['embed_dim']
27 | norm_layer = kwargs['norm_layer']
28 | patch_size = kwargs['patch_size']
29 | self.num_layers = len(self.blocks) + 1
30 |
31 | self.fc_norm = norm_layer(embed_dim)
32 | del self.norm
33 |
34 | self.patch_embed = PatchEmbed(img_size=3,
35 | patch_size=patch_size,
36 | in_chans=3,
37 | embed_dim=embed_dim)
38 | assert self.num_classes > 0
39 | self.head = nn.Conv2d(self.embed_dim, self.num_classes, 1)
40 | # manually initialize fc layer
41 | trunc_normal_(self.head.weight, std=2e-5)
42 |
43 | def forward_head(self, x):
44 | return self.head(x)
45 |
46 | def forward(self, x):
47 | x = self.forward_features(x)
48 | x = self.forward_head(x)
49 | return x
50 |
51 | def forward_features(self, x):
52 | B, _, w, h = x.shape
53 | x = self.patch_embed(x)
54 |
55 | cls_tokens = self.cls_token.expand(B, -1, -1)
56 | x = torch.cat((cls_tokens, x), dim=1)
57 | x = x + self.interpolate_pos_encoding(x, w, h)
58 | x = self.pos_drop(x)
59 |
60 | for blk in self.blocks:
61 | x = blk(x)
62 |
63 | x = x[:, 1:, :]
64 | x = self.fc_norm(x)
65 | b, _, c = x.shape
66 | ih, iw = w // self.patch_embed.patch_size, \
67 | h // self.patch_embed.patch_size
68 | x = x.view(b, ih, iw, c).permute(0, 3, 1, 2).contiguous()
69 |
70 | return x
71 |
72 | def interpolate_pos_encoding(self, x, w, h):
73 | npatch = x.shape[1] - 1
74 | N = self.pos_embed.shape[1] - 1
75 | if npatch == N and w == h:
76 | return self.pos_embed
77 | class_pos_embed = self.pos_embed[:, 0]
78 | patch_pos_embed = self.pos_embed[:, 1:]
79 | dim = x.shape[-1]
80 | w0 = w // self.patch_embed.patch_size
81 | h0 = h // self.patch_embed.patch_size
82 | # we add a small number to avoid
83 | # floating point error in the interpolation
84 | # see discussion at https://github.com/facebookresearch/dino/issues/8
85 | w0, h0 = w0 + 0.1, h0 + 0.1
86 | patch_pos_embed = nn.functional.interpolate(
87 | patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
88 | dim).permute(0, 3, 1, 2).contiguous(),
89 | scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
90 | mode='bicubic',
91 | )
92 | assert int(w0) == patch_pos_embed.shape[-2] and int(
93 | h0) == patch_pos_embed.shape[-1]
94 | patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).contiguous().view(1, -1, dim)
95 | return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed),
96 | dim=1)
97 |
98 | def get_layer_id(self, name):
99 | """Assign a parameter with its layer id Following BEiT: https://github.com/
100 | microsoft/unilm/blob/master/beit/optim_factory.py#L33.
101 |
102 | For each layer, the get_layer_id returns (layer_group, layer_id).
103 | According to the layer_group, different parameters are grouped,
104 | and layers in different groups use different decay rates.
105 |
106 | If only the layer_id is returned, the layer_group are set to 0 by default.
107 | """
108 | if name in ['cls_token', 'pos_embed']:
109 | return (0, 0)
110 | elif name.startswith('patch_embed'):
111 | return (0, 0)
112 | elif name.startswith('blocks'):
113 | return (0, int(name.split('.')[1]) + 1)
114 | else:
115 | return (0, self.num_layers)
116 |
117 |
118 | class PatchEmbed(nn.Module):
119 | """Image to Patch Embedding."""
120 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
121 | super().__init__()
122 | num_patches = (img_size // patch_size) * (img_size // patch_size)
123 | self.img_size = img_size
124 | self.patch_size = patch_size
125 | self.num_patches = num_patches
126 |
127 | self.proj = nn.Conv2d(in_chans,
128 | embed_dim,
129 | kernel_size=patch_size,
130 | stride=patch_size)
131 |
132 | def forward(self, x):
133 | B, C, H, W = x.shape
134 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous()
135 | return x
136 |
137 |
138 | def vit_small_patch16(args):
139 | kwargs = dict(
140 | num_classes=args.nb_classes,
141 | drop_path_rate=getattr(args, 'drop_path', 0)
142 | )
143 | model = VisionTransformer(patch_size=16,
144 | embed_dim=384,
145 | depth=12,
146 | num_heads=6,
147 | mlp_ratio=4,
148 | qkv_bias=True,
149 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
150 | **kwargs)
151 | return model
152 |
153 |
154 | def vit_base_patch16(args):
155 | kwargs = dict(
156 | num_classes=args.nb_classes,
157 | drop_path_rate=getattr(args, 'drop_path', 0)
158 | )
159 | model = VisionTransformer(patch_size=16,
160 | embed_dim=768,
161 | depth=12,
162 | num_heads=12,
163 | mlp_ratio=4,
164 | qkv_bias=True,
165 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
166 | **kwargs)
167 | return model
168 |
--------------------------------------------------------------------------------
/models/rfconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import collections.abc as container_abcs
5 | from itertools import repeat
6 | from timm.models.layers import get_padding
7 |
8 |
9 | def _ntuple(n):
10 | def parse(x):
11 | if isinstance(x, container_abcs.Iterable):
12 | return x
13 | return tuple(repeat(x, n))
14 | return parse
15 |
16 |
17 | _pair = _ntuple(2)
18 |
19 |
20 | def value_crop(dilation, min_dilation, max_dilation):
21 | if min_dilation is not None:
22 | if dilation < min_dilation:
23 | dilation = min_dilation
24 | if max_dilation is not None:
25 | if dilation > max_dilation:
26 | dilation = max_dilation
27 | return dilation
28 |
29 |
30 | def rf_expand(dilation, expand_rate, num_branches, min_dilation=1, max_dilation=None):
31 | rate_list = []
32 | assert num_branches>=2, "number of branches must >=2"
33 | delta_dilation0 = expand_rate * dilation[0]
34 | delta_dilation1 = expand_rate * dilation[1]
35 | for i in range(num_branches):
36 | rate_list.append(
37 | tuple([value_crop(
38 | int(round(dilation[0] - delta_dilation0 + (i) * 2 * delta_dilation0/(num_branches-1))), min_dilation, max_dilation),
39 | value_crop(
40 | int(round(dilation[1] - delta_dilation1 + (i) * 2 * delta_dilation1/(num_branches-1))), min_dilation, max_dilation)
41 | ])
42 | )
43 |
44 | unique_rate_list = list(set(rate_list))
45 | unique_rate_list.sort(key=rate_list.index)
46 | return unique_rate_list
47 |
48 |
49 | class RFConv2d(nn.Conv2d):
50 |
51 | def __init__(self,
52 | in_channels,
53 | out_channels,
54 | kernel_size=1,
55 | stride=1,
56 | padding=0,
57 | dilation=1,
58 | groups=1,
59 | bias=True,
60 | padding_mode='zeros',
61 | num_branches=3,
62 | expand_rate=0.5,
63 | min_dilation=1,
64 | max_dilation=None,
65 | init_weight=0.01,
66 | search_interval=1250,
67 | max_search_step=0,
68 | rf_mode='rfsearch',
69 | pretrained=None
70 | ):
71 | if pretrained is not None and rf_mode == 'rfmerge':
72 | rates = pretrained['rates']
73 | num_rates = pretrained['num_rates']
74 | sample_weights = pretrained['sample_weights']
75 | sample_weights = self.normlize(sample_weights[:num_rates.item()])
76 | max_dliation_rate = rates[num_rates.item() - 1]
77 | if isinstance(kernel_size, int):
78 | kernel_size = [kernel_size, kernel_size]
79 | if isinstance(stride, int):
80 | stride = [stride, stride]
81 | new_kernel_size = (
82 | kernel_size[0] + (max_dliation_rate[0].item() -
83 | 1) * (kernel_size[0] // 2) * 2,
84 | kernel_size[1] + (max_dliation_rate[1].item() - 1) * (kernel_size[1] // 2) * 2)
85 | # assign dilation to (1, 1) after merge
86 | new_dilation = (1, 1)
87 | new_padding = (
88 | get_padding(new_kernel_size[0], stride[0], new_dilation[0]),
89 | get_padding(new_kernel_size[1], stride[1], new_dilation[1]))
90 |
91 | # merge weight of each branch
92 | old_weight = pretrained['weight']
93 | new_weight = torch.zeros(
94 | size=(old_weight.shape[0], old_weight.shape[1],
95 | new_kernel_size[0], new_kernel_size[1]),
96 | dtype=old_weight.dtype)
97 | for r, rate in enumerate(rates[:num_rates.item()]):
98 | rate = (rate[0].item(), rate[1].item())
99 | for i in range(- (kernel_size[0] // 2), kernel_size[0] // 2 + 1):
100 | for j in range(- (kernel_size[1] // 2), kernel_size[1] // 2 + 1):
101 | new_weight[:, :,
102 | new_kernel_size[0] // 2 - i * rate[0],
103 | new_kernel_size[1] // 2 - j * rate[1]] += \
104 | old_weight[:, :, kernel_size[0] // 2 - i,
105 | kernel_size[1] // 2 - j] * sample_weights[r]
106 |
107 | kernel_size = new_kernel_size
108 | padding = new_padding
109 | dilation = new_dilation
110 | pretrained['rates'][0] = torch.FloatTensor([1, 1])
111 | pretrained['num_rates'] = torch.IntTensor([1])
112 | pretrained['weight'] = new_weight
113 | # re-initilize the sample_weights
114 | pretrained['sample_weights'] = pretrained['sample_weights'] * \
115 | 0.0 + init_weight
116 |
117 | super(RFConv2d, self).__init__(
118 | in_channels,
119 | out_channels,
120 | kernel_size,
121 | stride,
122 | padding,
123 | dilation,
124 | groups,
125 | bias,
126 | padding_mode
127 | )
128 | self.rf_mode = rf_mode
129 | self.pretrained = pretrained
130 | self.num_branches = num_branches
131 | self.max_dilation = max_dilation
132 | self.min_dilation = min_dilation
133 | self.expand_rate = expand_rate
134 | self.init_weight = init_weight
135 | self.search_interval = search_interval
136 | self.max_search_step = max_search_step
137 | self.sample_weights = nn.Parameter(torch.Tensor(self.num_branches))
138 | self.register_buffer('counter', torch.zeros(1))
139 | self.register_buffer('current_search_step', torch.zeros(1))
140 | self.register_buffer('rates', torch.ones(
141 | size=(self.num_branches, 2), dtype=torch.int32))
142 | self.register_buffer('num_rates', torch.ones(1, dtype=torch.int32))
143 | self.rates[0] = torch.FloatTensor([self.dilation[0], self.dilation[1]])
144 | self.sample_weights.data.fill_(self.init_weight)
145 |
146 | # rf-next
147 | if pretrained is not None:
148 | # load pretrained weights
149 | msg = self.load_state_dict(pretrained, strict=False)
150 | assert all([key in ['sample_weights', 'counter', 'current_search_step', 'rates', 'num_rates'] for key in msg.missing_keys]), \
151 | 'Missing keys: {}'.format(msg.missing_keys)
152 | if self.rf_mode == 'rfsearch':
153 | self.estimate()
154 | self.expand()
155 | elif self.rf_mode == 'rfsingle':
156 | self.estimate()
157 | self.max_search_step = 0
158 | self.sample_weights.requires_grad = False
159 | elif self.rf_mode == 'rfmultiple':
160 | self.estimate()
161 | self.expand()
162 | # re-initilize the sample_weights
163 | self.sample_weights.data.fill_(self.init_weight)
164 | self.max_search_step = 0
165 | elif self.rf_mode == 'rfmerge':
166 | self.max_search_step = 0
167 | self.sample_weights.requires_grad = False
168 | else:
169 | raise NotImplementedError()
170 |
171 | if self.rf_mode in ['rfsingle', 'rfmerge']:
172 | assert self.num_rates.item() == 1
173 |
174 | def _conv_forward_dilation(self, input, dilation_rate):
175 | if self.padding_mode != 'zeros':
176 | return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
177 | self.weight, self.bias, self.stride,
178 | _pair(0), dilation_rate, self.groups)
179 | else:
180 | padding = (
181 | dilation_rate[0] * (self.kernel_size[0] - 1) // 2, dilation_rate[1] * (self.kernel_size[1] - 1) // 2)
182 | return F.conv2d(input, self.weight, self.bias, self.stride,
183 | padding, dilation_rate, self.groups)
184 |
185 | def normlize(self, w):
186 | abs_w = torch.abs(w)
187 | norm_w = abs_w / torch.sum(abs_w)
188 | return norm_w
189 |
190 | def forward(self, x):
191 | if self.num_rates.item() == 1:
192 | return super().forward(x)
193 | else:
194 | norm_w = self.normlize(self.sample_weights[:self.num_rates.item()])
195 | xx = [
196 | self._conv_forward_dilation(
197 | x, (self.rates[i][0].item(), self.rates[i][1].item()))
198 | * norm_w[i] for i in range(self.num_rates.item())
199 | ]
200 | x = xx[0]
201 | for i in range(1, self.num_rates.item()):
202 | x += xx[i]
203 | if self.training:
204 | self.searcher()
205 | return x
206 |
207 | def searcher(self):
208 | self.counter += 1
209 | if self.counter % self.search_interval == 0 and self.current_search_step < self.max_search_step and self.max_search_step != 0:
210 | self.counter[0] = 0
211 | self.current_search_step += 1
212 | self.estimate()
213 | self.expand()
214 |
215 | def tensor_to_tuple(self, tensor):
216 | return tuple([(x[0].item(), x[1].item()) for x in tensor])
217 |
218 | def estimate(self):
219 | norm_w = self.normlize(self.sample_weights[:self.num_rates.item()])
220 | print('Estimate dilation {} with weight {}.'.format(
221 | self.tensor_to_tuple(self.rates[:self.num_rates.item()]), norm_w.detach().cpu().numpy().tolist()))
222 |
223 | sum0, sum1, w_sum = 0, 0, 0
224 | for i in range(self.num_rates.item()):
225 | sum0 += norm_w[i].item() * self.rates[i][0].item()
226 | sum1 += norm_w[i].item() * self.rates[i][1].item()
227 | w_sum += norm_w[i].item()
228 | estimated = [value_crop(
229 | int(round(sum0 / w_sum)),
230 | self.min_dilation,
231 | self.max_dilation), value_crop(
232 | int(round(sum1 / w_sum)),
233 | self.min_dilation,
234 | self.max_dilation)]
235 | self.dilation = tuple(estimated)
236 | self.padding = (
237 | get_padding(self.kernel_size[0], self.stride[0], self.dilation[0]),
238 | get_padding(self.kernel_size[1], self.stride[1], self.dilation[1])
239 | )
240 | self.rates[0] = torch.FloatTensor([self.dilation[0], self.dilation[1]])
241 | self.num_rates[0] = 1
242 | print('Estimate as {}'.format(self.dilation))
243 |
244 | def expand(self):
245 | rates = rf_expand(self.dilation, self.expand_rate,
246 | self.num_branches,
247 | min_dilation=self.min_dilation,
248 | max_dilation=self.max_dilation)
249 | for i, rate in enumerate(rates):
250 | self.rates[i] = torch.FloatTensor([rate[0], rate[1]])
251 | self.num_rates[0] = len(rates)
252 | self.sample_weights.data.fill_(self.init_weight)
253 | print('Expand as {}'.format(self.rates[:len(rates)].cpu().tolist()))
254 |
--------------------------------------------------------------------------------
/models/rfconvnext.py:
--------------------------------------------------------------------------------
1 | """ RFConvNeXt
2 | Paper: RF-Next: Efficient Receptive Field Search for Convolutional Neural Networks
3 | https://arxiv.org/abs/2206.06637
4 |
5 | Modified from https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/convnext.py
6 | """
7 | from collections import OrderedDict
8 | from functools import partial
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14 | from timm.models.helpers import named_apply, build_model_with_cfg, checkpoint_seq
15 | from timm.models.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d,\
16 | create_conv2d, make_divisible, get_padding
17 | from .rfconv import RFConv2d
18 | import os
19 |
20 | __all__ = ['RFConvNeXt']
21 |
22 |
23 | def _cfg(url='', **kwargs):
24 | return {
25 | 'url': url,
26 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
27 | 'crop_pct': 0.875, 'interpolation': 'bicubic',
28 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
29 | 'first_conv': 'stem.0', 'classifier': 'head.fc',
30 | **kwargs
31 | }
32 |
33 |
34 | default_cfgs = dict(
35 | convnext_tiny=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth"),
36 | convnext_small=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth"),
37 | convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
38 | convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
39 |
40 | # timm specific variants
41 | convnext_atto=_cfg(url=''),
42 | convnext_atto_ols=_cfg(url=''),
43 | convnext_femto=_cfg(
44 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
45 | test_input_size=(3, 288, 288), test_crop_pct=0.95),
46 | convnext_femto_ols=_cfg(
47 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
48 | test_input_size=(3, 288, 288), test_crop_pct=0.95),
49 | convnext_pico=_cfg(
50 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
51 | test_input_size=(3, 288, 288), test_crop_pct=0.95),
52 | convnext_pico_ols=_cfg(
53 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
54 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
55 | convnext_nano=_cfg(
56 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
57 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
58 | convnext_nano_ols=_cfg(
59 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
60 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
61 | convnext_tiny_hnf=_cfg(
62 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
63 | crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
64 |
65 | convnext_tiny_in22ft1k=_cfg(
66 | url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
67 | convnext_small_in22ft1k=_cfg(
68 | url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
69 | convnext_base_in22ft1k=_cfg(
70 | url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
71 | convnext_large_in22ft1k=_cfg(
72 | url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth'),
73 | convnext_xlarge_in22ft1k=_cfg(
74 | url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
75 |
76 | convnext_tiny_384_in22ft1k=_cfg(
77 | url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
78 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
79 | convnext_small_384_in22ft1k=_cfg(
80 | url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
81 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
82 | convnext_base_384_in22ft1k=_cfg(
83 | url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
84 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
85 | convnext_large_384_in22ft1k=_cfg(
86 | url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
87 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
88 | convnext_xlarge_384_in22ft1k=_cfg(
89 | url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
90 | input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
91 |
92 | convnext_tiny_in22k=_cfg(
93 | url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
94 | convnext_small_in22k=_cfg(
95 | url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
96 | convnext_base_in22k=_cfg(
97 | url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
98 | convnext_large_in22k=_cfg(
99 | url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
100 | convnext_xlarge_in22k=_cfg(
101 | url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
102 | )
103 |
104 |
105 | default_search_cfg = dict(
106 | num_branches=3,
107 | expand_rate=0.5,
108 | max_dilation=None,
109 | min_dilation=1,
110 | init_weight=0.01,
111 | search_interval=1250,
112 | max_search_step=0,
113 | )
114 |
115 |
116 | class RFConvNeXtBlock(nn.Module):
117 | """ ConvNeXt Block
118 | There are two equivalent implementations:
119 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
120 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
121 |
122 | Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
123 | choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
124 | is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
125 |
126 | Args:
127 | dim (int): Number of input channels.
128 | drop_path (float): Stochastic depth rate. Default: 0.0
129 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
130 | """
131 |
132 | def __init__(
133 | self,
134 | dim,
135 | dim_out=None,
136 | stride=1,
137 | dilation=1,
138 | mlp_ratio=4,
139 | conv_mlp=False,
140 | conv_bias=True,
141 | ls_init_value=1e-6,
142 | norm_layer=None,
143 | act_layer=nn.GELU,
144 | drop_path=0.,
145 | search_cfgs=default_search_cfg
146 | ):
147 | super().__init__()
148 | dim_out = dim_out or dim
149 | if not norm_layer:
150 | norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
151 | mlp_layer = ConvMlp if conv_mlp else Mlp
152 | self.use_conv_mlp = conv_mlp
153 |
154 | # replace dwconv with rfconv
155 | self.conv_dw = RFConv2d(
156 | in_channels=dim,
157 | out_channels=dim_out,
158 | kernel_size=7,
159 | stride=stride,
160 | padding=get_padding(kernel_size=7, stride=stride, dilation=dilation),
161 | dilation=dilation,
162 | groups=dim,
163 | bias=conv_bias,
164 | **search_cfgs)
165 | self.norm = norm_layer(dim_out)
166 | self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
167 | self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
168 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
169 |
170 | def forward(self, x):
171 | shortcut = x
172 | x = self.conv_dw(x)
173 | if self.use_conv_mlp:
174 | x = self.norm(x)
175 | x = self.mlp(x)
176 | else:
177 | x = x.permute(0, 2, 3, 1)
178 | x = self.norm(x)
179 | x = self.mlp(x)
180 | x = x.permute(0, 3, 1, 2)
181 | if self.gamma is not None:
182 | x = x.mul(self.gamma.reshape(1, -1, 1, 1))
183 |
184 | x = self.drop_path(x) + shortcut
185 | return x
186 |
187 |
188 | class RFConvNeXtStage(nn.Module):
189 |
190 | def __init__(
191 | self,
192 | in_chs,
193 | out_chs,
194 | stride=2,
195 | depth=2,
196 | dilation=(1, 1),
197 | drop_path_rates=None,
198 | ls_init_value=1.0,
199 | conv_mlp=False,
200 | conv_bias=True,
201 | norm_layer=None,
202 | norm_layer_cl=None,
203 | search_cfgs=default_search_cfg
204 | ):
205 | super().__init__()
206 | self.grad_checkpointing = False
207 |
208 | if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
209 | ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
210 | pad = 'same' if dilation[1] > 1 else 0 # same padding needed if dilation used
211 | self.downsample = nn.Sequential(
212 | norm_layer(in_chs),
213 | create_conv2d(
214 | in_chs, out_chs, kernel_size=ds_ks, stride=stride,
215 | dilation=dilation[0], padding=pad, bias=conv_bias),
216 | )
217 | in_chs = out_chs
218 | else:
219 | self.downsample = nn.Identity()
220 |
221 | drop_path_rates = drop_path_rates or [0.] * depth
222 | stage_blocks = []
223 | for i in range(depth):
224 | stage_blocks.append(RFConvNeXtBlock(
225 | dim=in_chs,
226 | dim_out=out_chs,
227 | dilation=dilation[1],
228 | drop_path=drop_path_rates[i],
229 | ls_init_value=ls_init_value,
230 | conv_mlp=conv_mlp,
231 | conv_bias=conv_bias,
232 | norm_layer=norm_layer if conv_mlp else norm_layer_cl,
233 | search_cfgs=search_cfgs
234 | ))
235 | in_chs = out_chs
236 | self.blocks = nn.Sequential(*stage_blocks)
237 |
238 | def forward(self, x):
239 | x = self.downsample(x)
240 | if self.grad_checkpointing and not torch.jit.is_scripting():
241 | x = checkpoint_seq(self.blocks, x)
242 | else:
243 | x = self.blocks(x)
244 | return x
245 |
246 |
247 | class RFConvNeXt(nn.Module):
248 | r""" ConvNeXt
249 | A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
250 |
251 | Args:
252 | in_chans (int): Number of input image channels. Default: 3
253 | num_classes (int): Number of classes for classification head. Default: 1000
254 | depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
255 | dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
256 | drop_rate (float): Head dropout rate
257 | drop_path_rate (float): Stochastic depth rate. Default: 0.
258 | ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
259 | head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
260 | rf_mode (str): Training mode for RF-Next. Choose from ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge'].
261 | kernel_cfgs (Dict(str, int)): Kernel size for each RFConv. Example: {"stages.0.blocks.0.conv_dw": 7, "stages.0.blocks.1.conv_dw": 7, ...}.
262 | """
263 |
264 | def __init__(
265 | self,
266 | in_chans=3,
267 | num_classes=1000,
268 | global_pool='avg',
269 | output_stride=32,
270 | depths=(3, 3, 9, 3),
271 | dims=(96, 192, 384, 768),
272 | ls_init_value=1e-6,
273 | stem_type='patch',
274 | patch_size=4,
275 | head_init_scale=1.,
276 | head_norm_first=False,
277 | conv_mlp=False,
278 | conv_bias=True,
279 | norm_layer=None,
280 | drop_rate=0.,
281 | drop_path_rate=0.,
282 | pretrained_weights=None,
283 | rf_mode='rfsearch',
284 | kernel_cfgs=None,
285 | search_cfgs=default_search_cfg
286 | ):
287 | super().__init__()
288 | assert output_stride in (8, 16, 32)
289 | if norm_layer is None:
290 | norm_layer = partial(LayerNorm2d, eps=1e-6)
291 | norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
292 | else:
293 | assert conv_mlp,\
294 | 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
295 | norm_layer_cl = norm_layer
296 |
297 | self.num_classes = num_classes
298 | self.drop_rate = drop_rate
299 | self.feature_info = []
300 |
301 | assert stem_type in ('patch', 'overlap', 'overlap_tiered')
302 | if stem_type == 'patch':
303 | # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
304 | self.stem = nn.Sequential(
305 | nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
306 | norm_layer(dims[0])
307 | )
308 | stem_stride = patch_size
309 | else:
310 | mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
311 | self.stem = nn.Sequential(
312 | nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
313 | nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
314 | norm_layer(dims[0]),
315 | )
316 | stem_stride = 4
317 |
318 | self.stages = nn.Sequential()
319 | dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
320 | stages = []
321 | prev_chs = dims[0]
322 | curr_stride = stem_stride
323 | dilation = 1
324 | # 4 feature resolution stages, each consisting of multiple residual blocks
325 | for i in range(4):
326 | stride = 2 if curr_stride == 2 or i > 0 else 1
327 | if curr_stride >= output_stride and stride > 1:
328 | dilation *= stride
329 | stride = 1
330 | curr_stride *= stride
331 | first_dilation = 1 if dilation in (1, 2) else 2
332 | out_chs = dims[i]
333 | stages.append(RFConvNeXtStage(
334 | prev_chs,
335 | out_chs,
336 | stride=stride,
337 | dilation=(first_dilation, dilation),
338 | depth=depths[i],
339 | drop_path_rates=dp_rates[i],
340 | ls_init_value=ls_init_value,
341 | conv_mlp=conv_mlp,
342 | conv_bias=conv_bias,
343 | norm_layer=norm_layer,
344 | norm_layer_cl=norm_layer_cl,
345 | search_cfgs=search_cfgs
346 | ))
347 | prev_chs = out_chs
348 | # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
349 | self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
350 | self.stages = nn.Sequential(*stages)
351 | self.num_features = prev_chs
352 |
353 | # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
354 | # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
355 | self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
356 | self.head = nn.Sequential(OrderedDict([
357 | ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
358 | ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
359 | ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
360 | ('drop', nn.Dropout(self.drop_rate)),
361 | ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
362 |
363 | named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
364 |
365 | # RF-Next
366 | self.prepare_rfsearch(pretrained_weights, rf_mode, kernel_cfgs, search_cfgs)
367 | if self.rf_mode not in ['rfsearch', 'rfmultiple']:
368 | for n, p, in self.named_parameters():
369 | if 'sample_weights' in n:
370 | p.requires_grad = False
371 |
372 | def prepare_rfsearch(self, pretrained_weights, rf_mode, kernel_cfgs, search_cfgs):
373 | self.rf_mode = rf_mode
374 | self.pretrained_weights = pretrained_weights
375 | assert self.rf_mode in ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge'], \
376 | "rf_mode should be in ['rfsearch', 'rfsingle', 'rfmultiple', 'rfmerge']."
377 | if pretrained_weights is None or not os.path.exists(pretrained_weights):
378 | checkpoint = None
379 | else:
380 | checkpoint = torch.load(pretrained_weights, map_location='cpu')
381 | checkpoint = checkpoint_filter_fn(checkpoint, self)
382 | # Remove the prefix in checkpint, e.g., 'backbone' and 'module',
383 | # to guarantee the matching between 'checkpoint' and 'model.state_dict'.
384 | checkpoint = {k.replace('module.', ''): v for k, v in checkpoint.items()}
385 | checkpoint = {k.replace('backbone.', ''): v for k, v in checkpoint.items()}
386 | for name in list(checkpoint.keys()):
387 | if name.endswith('counter') or name.endswith('current_search_step'):
388 | # Do not load pretrained buffer of counter and current_step!!!!!!!
389 | print(f"RF-Next: Removing key {name} from pretrained checkpoint")
390 | del checkpoint[name]
391 |
392 | # Remove the parameters with mismatched shape from checkpoint
393 | for name, module in self.named_parameters():
394 | if name in checkpoint and module.shape != checkpoint[name].shape:
395 | print(f"RF-Next: Removing key {name} from pretrained checkpoint")
396 | del checkpoint[name]
397 | # Load the pretrained weights for a rfconv.
398 | # The pretarined weights are obtained after rfseach.
399 | msg = self.load_state_dict(checkpoint, strict=False)
400 | missing_keys = list(msg.missing_keys)
401 | missing_keys = list(filter(lambda x: not x.endswith('.counter') and not x.endswith('.current_search_step'), missing_keys))
402 | print('RF-Next: RF-Next init, missing keys: {}'.format(missing_keys))
403 |
404 | print('RF-Next: convert rfconv.')
405 | # Convert conv to rfconv
406 | def convert_rfconv(module, prefix):
407 | module_output = module
408 | if isinstance(module, RFConv2d):
409 | if kernel_cfgs is not None:
410 | kernel = kernel_cfgs[prefix]
411 | else:
412 | kernel = module.kernel_size
413 | if checkpoint is not None:
414 | module_pretrained = dict()
415 | # Load the pretrained weights for a rfconv.
416 | # The pretarined weights are obtained after rfseach.
417 | for k in checkpoint.keys():
418 | if k.startswith(prefix):
419 | module_pretrained[k.replace('{}.'.format(prefix), '')] = checkpoint[k]
420 | else:
421 | module_pretrained = None
422 | if isinstance(kernel, int):
423 | kernel = (kernel, kernel)
424 | module_output = RFConv2d(
425 | in_channels=module.in_channels,
426 | out_channels=module.out_channels,
427 | kernel_size=kernel,
428 | stride=module.stride,
429 | padding=(
430 | get_padding(kernel[0], module.stride[0], module.dilation[0]),
431 | get_padding(kernel[1], module.stride[1], module.dilation[1])),
432 | dilation=module.dilation,
433 | groups=module.groups,
434 | bias=hasattr(module, 'bias'),
435 | rf_mode=self.rf_mode,
436 | pretrained=module_pretrained,
437 | **search_cfgs
438 | )
439 |
440 | for name, child in module.named_children():
441 | fullname = name
442 | if prefix != '':
443 | fullname = prefix + '.' + name
444 | # Replace the conv with rfconv。
445 | module_output.add_module(name, convert_rfconv(child, fullname))
446 | del module
447 | return module_output
448 |
449 | convert_rfconv(self, '')
450 |
451 | if self.rf_mode == 'rfmerge':
452 | # Show the kernel sizes after rfmerge。
453 | rfmerge = dict()
454 | for name, module in self.named_modules():
455 | if isinstance(module, RFConv2d):
456 | rfmerge[name] = module.kernel_size
457 |
458 | print('Merged structure:')
459 | print(rfmerge)
460 | print('RF-Next: convert done.')
461 |
462 | @torch.jit.ignore
463 | def group_matcher(self, coarse=False):
464 | return dict(
465 | stem=r'^stem',
466 | blocks=r'^stages\.(\d+)' if coarse else [
467 | (r'^stages\.(\d+)\.downsample', (0,)), # blocks
468 | (r'^stages\.(\d+)\.blocks\.(\d+)', None),
469 | (r'^norm_pre', (99999,))
470 | ]
471 | )
472 |
473 | @torch.jit.ignore
474 | def set_grad_checkpointing(self, enable=True):
475 | for s in self.stages:
476 | s.grad_checkpointing = enable
477 |
478 | @torch.jit.ignore
479 | def get_classifier(self):
480 | return self.head.fc
481 |
482 | def reset_classifier(self, num_classes=0, global_pool=None):
483 | if global_pool is not None:
484 | self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
485 | self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
486 | self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
487 |
488 | def forward_features(self, x):
489 | x = self.stem(x)
490 | x = self.stages(x)
491 | x = self.norm_pre(x)
492 | return x
493 |
494 | def forward_head(self, x, pre_logits: bool = False):
495 | # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
496 | x = self.head.global_pool(x)
497 | x = self.head.norm(x)
498 | x = self.head.flatten(x)
499 | x = self.head.drop(x)
500 | return x if pre_logits else self.head.fc(x)
501 |
502 | def forward(self, x):
503 | x = self.forward_features(x)
504 | x = self.forward_head(x)
505 | return x
506 |
507 |
508 | def _init_weights(module, name=None, head_init_scale=1.0):
509 | if isinstance(module, nn.Conv2d):
510 | trunc_normal_(module.weight, std=.02)
511 | if module.bias is not None:
512 | nn.init.zeros_(module.bias)
513 | elif isinstance(module, nn.Linear):
514 | trunc_normal_(module.weight, std=.02)
515 | nn.init.zeros_(module.bias)
516 | if name and 'head.' in name:
517 | module.weight.data.mul_(head_init_scale)
518 | module.bias.data.mul_(head_init_scale)
519 |
520 |
521 | def checkpoint_filter_fn(state_dict, model):
522 | """ Remap FB checkpoints -> timm """
523 | if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
524 | return state_dict # non-FB checkpoint
525 | if 'model' in state_dict:
526 | state_dict = state_dict['model']
527 | out_dict = {}
528 | import re
529 | for k, v in state_dict.items():
530 | k = k.replace('downsample_layers.0.', 'stem.')
531 | k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
532 | k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
533 | k = k.replace('dwconv', 'conv_dw')
534 | k = k.replace('pwconv', 'mlp.fc')
535 | k = k.replace('head.', 'head.fc.')
536 | if k.startswith('norm.'):
537 | k = k.replace('norm', 'head.norm')
538 | if v.ndim == 2 and 'head' not in k:
539 | model_shape = model.state_dict()[k].shape
540 | v = v.reshape(model_shape)
541 | if ('current_search_step' in k) or ('counter' in k):
542 | continue
543 | out_dict[k] = v
544 | return out_dict
545 |
546 |
547 | def _create_rfconvnext(variant, pretrained=False, **kwargs):
548 | model = build_model_with_cfg(
549 | RFConvNeXt, variant, pretrained,
550 | pretrained_filter_fn=checkpoint_filter_fn,
551 | feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
552 | **kwargs)
553 | return model
554 |
555 |
556 | def rfconvnext_tiny(pretrained=False, **kwargs):
557 | model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
558 | model = _create_rfconvnext('convnext_tiny', pretrained=pretrained, **model_args)
559 | return model
560 |
561 |
562 | def rfconvnext_small(pretrained=False, **kwargs):
563 | model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
564 | model = _create_rfconvnext('convnext_small', pretrained=pretrained, **model_args)
565 | return model
566 |
567 |
568 | def rfconvnext_base(pretrained=False, **kwargs):
569 | model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
570 | model = _create_rfconvnext('convnext_base', pretrained=pretrained, **model_args)
571 | return model
572 |
573 |
574 | def rfconvnext_large(pretrained=False, **kwargs):
575 | model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs)
576 | model = _create_rfconvnext('convnext_large', pretrained=pretrained, **model_args)
577 | return model
578 |
--------------------------------------------------------------------------------
/util/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # --------------------------------------------------------
10 |
11 | import os
12 | import random
13 |
14 | import numpy as np
15 | import torch
16 | from PIL import Image, ImageFilter
17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18 | from torchvision import datasets, transforms
19 |
20 | import util.transforms as custom_transforms
21 |
22 |
23 | def build_dataset(is_train, args):
24 | transform = build_transform(is_train)
25 | data_root = os.path.join(args.data_path,
26 | 'train-semi' if is_train else 'validation')
27 | gt_root = os.path.join(
28 | args.data_path,
29 | 'train-semi-segmentation' if is_train else 'validation-segmentation')
30 | dataset = SegDataset(data_root, gt_root, transform, is_train)
31 | return dataset
32 |
33 |
34 | def build_transform(is_train):
35 | mean = IMAGENET_DEFAULT_MEAN
36 | std = IMAGENET_DEFAULT_STD
37 | # train transform
38 | if is_train:
39 | # this should always dispatch to transforms_imagenet_train
40 | color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
41 | randomresizedcrop = custom_transforms.RandomResizedCrop(
42 | 224,
43 | scale=(0.14, 1),
44 | )
45 | transform = custom_transforms.Compose([
46 | randomresizedcrop,
47 | custom_transforms.RandomHorizontalFlip(p=0.5),
48 | transforms.Compose(color_transform),
49 | custom_transforms.ToTensor(),
50 | transforms.Normalize(mean=mean, std=std)
51 | ])
52 | return transform
53 |
54 | # eval transform
55 | t = []
56 | t.append(transforms.Resize(256))
57 | t.append(transforms.ToTensor())
58 | t.append(transforms.Normalize(mean, std))
59 | return transforms.Compose(t)
60 |
61 |
62 | class SegDataset(datasets.ImageFolder):
63 | def __init__(self, data_root, gt_root=None, transform=None, is_train=True):
64 | super(SegDataset, self).__init__(data_root)
65 | assert gt_root is not None
66 | self.gt_root = gt_root
67 | self.transform = transform
68 | self.is_train = is_train
69 |
70 | def __getitem__(self, index):
71 | path, _ = self.samples[index]
72 | image = self.loader(path)
73 | segmentation = self.load_segmentation(path)
74 |
75 | if self.is_train:
76 | image, segmentation = self.transform(image, segmentation)
77 | else:
78 | image = self.transform(image)
79 | segmentation = torch.from_numpy(np.array(segmentation))
80 | segmentation = segmentation.long()
81 |
82 | segmentation = segmentation[:, :, 1] * 256 + segmentation[:, :, 0]
83 | return image, segmentation
84 |
85 | def load_segmentation(self, path):
86 | cate, name = path.split('/')[-2:]
87 | name = name.replace('JPEG', 'png')
88 | path = os.path.join(self.gt_root, cate, name)
89 | segmentation = Image.open(path)
90 | return segmentation
91 |
92 |
93 | class PILRandomGaussianBlur(object):
94 | """Apply Gaussian Blur to the PIL image. Take the radius and probability of
95 | application as the parameter.
96 |
97 | This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
98 | """
99 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
100 | self.prob = p
101 | self.radius_min = radius_min
102 | self.radius_max = radius_max
103 |
104 | def __call__(self, img):
105 | do_it = np.random.rand() <= self.prob
106 | if not do_it:
107 | return img
108 |
109 | return img.filter(
110 | ImageFilter.GaussianBlur(
111 | radius=random.uniform(self.radius_min, self.radius_max)))
112 |
113 |
114 | def get_color_distortion(s=1.0):
115 | # s is the strength of color distortion.
116 | color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
117 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
118 | rnd_gray = transforms.RandomGrayscale(p=0.2)
119 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
120 | return color_distort
121 |
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 |
13 | def param_groups_lrd(model,
14 | weight_decay=0.05,
15 | no_weight_decay_list=[],
16 | layer_decay=[.75],
17 | layer_multiplier=[1.0]):
18 | """Parameter groups for layer-wise lr decay Following BEiT: https://github.
19 |
20 | com/microsoft/unilm/blob/master/beit/optim_factory.py#L58.
21 | """
22 | param_group_names = {}
23 | param_groups = {}
24 |
25 | num_layers = model.num_layers
26 |
27 | if isinstance(layer_decay, (float, int)):
28 | layer_decay = [layer_decay]
29 |
30 | layer_scales = [
31 | list(decay**(num_layers - i) for i in range(num_layers + 1)) for decay in layer_decay]
32 |
33 | for n, p in model.named_parameters():
34 | if not p.requires_grad:
35 | continue
36 |
37 | # no decay: all 1D parameters and model specific ones
38 | if p.ndim == 1 or n in no_weight_decay_list:
39 | g_decay = 'no_decay'
40 | this_decay = 0.
41 | else:
42 | g_decay = 'decay'
43 | this_decay = weight_decay
44 |
45 | """
46 | For each layer, the get_layer_id returns (layer_group, layer_id).
47 | According to the layer_group, different parameters are grouped,
48 | and layers in different groups use different decay rates.
49 |
50 | If only the layer_id is returned, the layer_group are set to 0 by default.
51 | """
52 | layer_group_id = model.get_layer_id(n)
53 | if isinstance(layer_group_id, (list, tuple)):
54 | layer_group, layer_id = layer_group_id
55 | elif isinstance(layer_group_id, int):
56 | layer_group, layer_id = 0, layer_group_id
57 | else:
58 | raise NotImplementedError()
59 | group_name = 'layer_%d_%d_%s' % (layer_group, layer_id, g_decay)
60 |
61 | if group_name not in param_group_names:
62 | this_scale = layer_scales[layer_group][layer_id] * layer_multiplier[layer_group]
63 |
64 | param_group_names[group_name] = {
65 | 'lr_scale': this_scale,
66 | 'weight_decay': this_decay,
67 | 'params': [],
68 | }
69 | param_groups[group_name] = {
70 | 'lr_scale': this_scale,
71 | 'weight_decay': this_decay,
72 | 'params': [],
73 | }
74 |
75 | param_group_names[group_name]['params'].append(n)
76 | param_groups[group_name]['params'].append(p)
77 |
78 | return list(param_groups.values())
79 |
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 |
10 | def adjust_learning_rate(optimizer, epoch, args):
11 | """Decay the learning rate with half-cycle cosine after warmup."""
12 | if epoch < args.warmup_epochs:
13 | lr = args.lr * epoch / args.warmup_epochs
14 | else:
15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
16 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) /
17 | (args.epochs - args.warmup_epochs)))
18 | for param_group in optimizer.param_groups:
19 | if 'lr_scale' in param_group:
20 | param_group['lr'] = lr * param_group['lr_scale']
21 | else:
22 | param_group['lr'] = lr
23 | return lr
24 |
--------------------------------------------------------------------------------
/util/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def IoUGPU(output, target, K):
5 | # 'K' classes, output and target sizes are
6 | # N or N * L or N * H * W, each value in range 0 to K - 1.
7 | assert (output.dim() in [1, 2, 3])
8 | assert output.shape == target.shape
9 | output = output.view(-1)
10 | target = target.view(-1)
11 | intersection = output[output == target]
12 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
13 | area_output = torch.histc(output, bins=K, min=0, max=K - 1)
14 | area_target = torch.histc(target, bins=K, min=0, max=K - 1)
15 | return area_intersection, area_output, area_target
16 |
17 |
18 | def FMeasureGPU(output, target, eps=1e-20, beta=0.3):
19 | target = (target > 0) * 1.0
20 | output = (output > 0) * 1.0
21 |
22 | t = torch.sum(target)
23 | p = torch.sum(output)
24 | tp = torch.sum(target * output)
25 | recall = tp / (t + eps)
26 | precision = tp / (p + eps)
27 | f_score = (1 + beta) * precision * recall / (beta * precision + recall +
28 | eps)
29 |
30 | return f_score
31 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import os
15 | import time
16 | from collections import defaultdict, deque
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch._six import inf
22 |
23 |
24 | class SmoothedValue(object):
25 | """Track a series of values and provide access to smoothed values over a
26 | window or the global series average."""
27 | def __init__(self, window_size=20, fmt=None):
28 | if fmt is None:
29 | fmt = '{median:.4f} ({global_avg:.4f})'
30 | self.deque = deque(maxlen=window_size)
31 | self.total = 0.0
32 | self.count = 0
33 | self.fmt = fmt
34 |
35 | def update(self, value, n=1):
36 | self.deque.append(value)
37 | self.count += n
38 | self.total += value * n
39 |
40 | def synchronize_between_processes(self):
41 | """
42 | Warning: does not synchronize the deque!
43 | """
44 | if not is_dist_avail_and_initialized():
45 | return
46 | t = torch.tensor([self.count, self.total],
47 | dtype=torch.float64,
48 | device='cuda')
49 | dist.barrier()
50 | dist.all_reduce(t)
51 | t = t.tolist()
52 | self.count = int(t[0])
53 | self.total = t[1]
54 |
55 | @property
56 | def median(self):
57 | d = torch.tensor(list(self.deque))
58 | return d.median().item()
59 |
60 | @property
61 | def avg(self):
62 | d = torch.tensor(list(self.deque), dtype=torch.float32)
63 | return d.mean().item()
64 |
65 | @property
66 | def global_avg(self):
67 | return self.total / self.count
68 |
69 | @property
70 | def max(self):
71 | return max(self.deque)
72 |
73 | @property
74 | def value(self):
75 | return self.deque[-1]
76 |
77 | def __str__(self):
78 | return self.fmt.format(median=self.median,
79 | avg=self.avg,
80 | global_avg=self.global_avg,
81 | max=self.max,
82 | value=self.value)
83 |
84 |
85 | class MetricLogger(object):
86 | def __init__(self, delimiter='\t'):
87 | self.meters = defaultdict(SmoothedValue)
88 | self.delimiter = delimiter
89 |
90 | def update(self, **kwargs):
91 | for k, v in kwargs.items():
92 | if v is None:
93 | continue
94 | if isinstance(v, torch.Tensor):
95 | v = v.item()
96 | assert isinstance(v, (float, int))
97 | self.meters[k].update(v)
98 |
99 | def __getattr__(self, attr):
100 | if attr in self.meters:
101 | return self.meters[attr]
102 | if attr in self.__dict__:
103 | return self.__dict__[attr]
104 | raise AttributeError("'{}' object has no attribute '{}'".format(
105 | type(self).__name__, attr))
106 |
107 | def __str__(self):
108 | loss_str = []
109 | for name, meter in self.meters.items():
110 | loss_str.append('{}: {}'.format(name, str(meter)))
111 | return self.delimiter.join(loss_str)
112 |
113 | def synchronize_between_processes(self):
114 | for meter in self.meters.values():
115 | meter.synchronize_between_processes()
116 |
117 | def add_meter(self, name, meter):
118 | self.meters[name] = meter
119 |
120 | def log_every(self, iterable, print_freq, header=None):
121 | i = 0
122 | if not header:
123 | header = ''
124 | start_time = time.time()
125 | end = time.time()
126 | iter_time = SmoothedValue(fmt='{avg:.4f}')
127 | data_time = SmoothedValue(fmt='{avg:.4f}')
128 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
129 | log_msg = [
130 | header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}',
131 | 'time: {time}', 'data: {data}'
132 | ]
133 | if torch.cuda.is_available():
134 | log_msg.append('max mem: {memory:.0f}')
135 | log_msg = self.delimiter.join(log_msg)
136 | MB = 1024.0 * 1024.0
137 | for obj in iterable:
138 | data_time.update(time.time() - end)
139 | yield obj
140 | iter_time.update(time.time() - end)
141 | if i % print_freq == 0 or i == len(iterable) - 1:
142 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
143 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
144 | if torch.cuda.is_available():
145 | print(
146 | log_msg.format(
147 | i,
148 | len(iterable),
149 | eta=eta_string,
150 | meters=str(self),
151 | time=str(iter_time),
152 | data=str(data_time),
153 | memory=torch.cuda.max_memory_allocated() / MB))
154 | else:
155 | print(
156 | log_msg.format(i,
157 | len(iterable),
158 | eta=eta_string,
159 | meters=str(self),
160 | time=str(iter_time),
161 | data=str(data_time)))
162 | i += 1
163 | end = time.time()
164 | total_time = time.time() - start_time
165 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
166 | print('{} Total time: {} ({:.4f} s / it)'.format(
167 | header, total_time_str, total_time / len(iterable)))
168 |
169 |
170 | def setup_for_distributed(is_master):
171 | """This function disables printing when not in master process."""
172 | builtin_print = builtins.print
173 |
174 | def print(*args, **kwargs):
175 | force = kwargs.pop('force', False)
176 | force = force or (get_world_size() > 8)
177 | if is_master or force:
178 | now = datetime.datetime.now().time()
179 | builtin_print('[{}] '.format(now), end='') # print with time stamp
180 | builtin_print(*args, **kwargs)
181 |
182 | builtins.print = print
183 |
184 |
185 | def is_dist_avail_and_initialized():
186 | if not dist.is_available():
187 | return False
188 | if not dist.is_initialized():
189 | return False
190 | return True
191 |
192 |
193 | def get_world_size():
194 | if not is_dist_avail_and_initialized():
195 | return 1
196 | return dist.get_world_size()
197 |
198 |
199 | def get_rank():
200 | if not is_dist_avail_and_initialized():
201 | return 0
202 | return dist.get_rank()
203 |
204 |
205 | def is_main_process():
206 | return get_rank() == 0
207 |
208 |
209 | def save_on_master(*args, **kwargs):
210 | if is_main_process():
211 | torch.save(*args, **kwargs)
212 |
213 |
214 | def init_distributed_mode(args):
215 | if args.dist_on_itp:
216 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
217 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
218 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
219 | args.dist_url = 'tcp://%s:%s' % (os.environ['MASTER_ADDR'],
220 | os.environ['MASTER_PORT'])
221 | os.environ['LOCAL_RANK'] = str(args.gpu)
222 | os.environ['RANK'] = str(args.rank)
223 | os.environ['WORLD_SIZE'] = str(args.world_size)
224 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
225 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
226 | args.rank = int(os.environ['RANK'])
227 | args.world_size = int(os.environ['WORLD_SIZE'])
228 | args.gpu = int(os.environ['LOCAL_RANK'])
229 | elif 'SLURM_PROCID' in os.environ:
230 | args.rank = int(os.environ['SLURM_PROCID'])
231 | args.gpu = args.rank % torch.cuda.device_count()
232 | else:
233 | print('Not using distributed mode')
234 | setup_for_distributed(is_master=True) # hack
235 | args.distributed = False
236 | return
237 |
238 | args.distributed = True
239 |
240 | torch.cuda.set_device(args.gpu)
241 | args.dist_backend = 'nccl'
242 | print('| distributed init (rank {}): {}, gpu {}'.format(
243 | args.rank, args.dist_url, args.gpu),
244 | flush=True)
245 | torch.distributed.init_process_group(backend=args.dist_backend,
246 | init_method=args.dist_url,
247 | world_size=args.world_size,
248 | rank=args.rank)
249 | torch.distributed.barrier()
250 | setup_for_distributed(args.rank == 0)
251 |
252 |
253 | class NativeScalerWithGradNormCount:
254 | state_dict_key = 'amp_scaler'
255 |
256 | def __init__(self):
257 | self._scaler = torch.cuda.amp.GradScaler()
258 |
259 | def __call__(self,
260 | loss,
261 | optimizer,
262 | clip_grad=None,
263 | parameters=None,
264 | create_graph=False,
265 | update_grad=True):
266 | self._scaler.scale(loss).backward(create_graph=create_graph)
267 | if update_grad:
268 | if clip_grad is not None:
269 | assert parameters is not None
270 | self._scaler.unscale_(optimizer)
271 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
272 | else:
273 | self._scaler.unscale_(optimizer)
274 | norm = get_grad_norm_(parameters)
275 | self._scaler.step(optimizer)
276 | self._scaler.update()
277 | else:
278 | norm = None
279 | return norm
280 |
281 | def state_dict(self):
282 | return self._scaler.state_dict()
283 |
284 | def load_state_dict(self, state_dict):
285 | self._scaler.load_state_dict(state_dict)
286 |
287 |
288 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
289 | if isinstance(parameters, torch.Tensor):
290 | parameters = [parameters]
291 | parameters = [p for p in parameters if p.grad is not None]
292 | norm_type = float(norm_type)
293 | if len(parameters) == 0:
294 | return torch.tensor(0.)
295 | device = parameters[0].grad.device
296 | if norm_type == inf:
297 | total_norm = max(p.grad.detach().abs().max().to(device)
298 | for p in parameters)
299 | else:
300 | total_norm = torch.norm(
301 | torch.stack([
302 | torch.norm(p.grad.detach(), norm_type).to(device)
303 | for p in parameters
304 | ]), norm_type)
305 | return total_norm
306 |
307 |
308 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
309 | output_dir = Path(args.output_dir)
310 | epoch_name = str(epoch)
311 | if loss_scaler is not None:
312 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
313 | for checkpoint_path in checkpoint_paths:
314 | to_save = {
315 | 'model': model_without_ddp.state_dict(),
316 | 'optimizer': optimizer.state_dict(),
317 | 'epoch': epoch,
318 | 'scaler': loss_scaler.state_dict(),
319 | 'args': args,
320 | }
321 |
322 | save_on_master(to_save, checkpoint_path)
323 | else:
324 | client_state = {'epoch': epoch}
325 | model.save_checkpoint(save_dir=args.output_dir,
326 | tag='checkpoint-%s' % epoch_name,
327 | client_state=client_state)
328 |
329 |
330 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
331 | if args.resume:
332 | if args.resume.startswith('https'):
333 | checkpoint = torch.hub.load_state_dict_from_url(args.resume,
334 | map_location='cpu',
335 | check_hash=True)
336 | else:
337 | checkpoint = torch.load(args.resume, map_location='cpu')
338 | model_without_ddp.load_state_dict(checkpoint['model'])
339 | print('Resume checkpoint %s' % args.resume)
340 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (
341 | hasattr(args, 'eval') and args.eval):
342 | optimizer.load_state_dict(checkpoint['optimizer'])
343 | args.start_epoch = checkpoint['epoch'] + 1
344 | if 'scaler' in checkpoint:
345 | loss_scaler.load_state_dict(checkpoint['scaler'])
346 | print('With optim & sched!')
347 |
348 |
349 | def all_reduce_mean(x):
350 | world_size = get_world_size()
351 | if world_size > 1:
352 | x_reduce = torch.tensor(x).cuda()
353 | dist.all_reduce(x_reduce)
354 | x_reduce /= world_size
355 | return x_reduce.item()
356 | else:
357 | return x
358 |
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 | import torch
12 |
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer:
18 | # https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
19 | # MoCo v3:
20 | # https://github.com/facebookresearch/moco-v3
21 | # --------------------------------------------------------
22 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
23 | """
24 | grid_size: int of the grid height and width
25 | return:
26 | pos_embed: [grid_size*grid_size, embed_dim]
27 | or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
28 | """
29 | grid_h = np.arange(grid_size, dtype=np.float32)
30 | grid_w = np.arange(grid_size, dtype=np.float32)
31 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
32 | grid = np.stack(grid, axis=0)
33 |
34 | grid = grid.reshape([2, 1, grid_size, grid_size])
35 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
36 | if cls_token:
37 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed],
38 | axis=0)
39 | return pos_embed
40 |
41 |
42 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
43 | assert embed_dim % 2 == 0
44 |
45 | # use half of dimensions to encode grid_h
46 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
47 | grid[0]) # (H*W, D/2)
48 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2,
49 | grid[1]) # (H*W, D/2)
50 |
51 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52 | return emb
53 |
54 |
55 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
56 | """
57 | embed_dim: output dimension for each position
58 | pos: a list of positions to be encoded: size (M,)
59 | out: (M, D)
60 | """
61 | assert embed_dim % 2 == 0
62 | omega = np.arange(embed_dim // 2, dtype=np.float)
63 | omega /= embed_dim / 2.
64 | omega = 1. / 10000**omega # (D/2,)
65 |
66 | pos = pos.reshape(-1) # (M,)
67 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
68 |
69 | emb_sin = np.sin(out) # (M, D/2)
70 | emb_cos = np.cos(out) # (M, D/2)
71 |
72 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
73 | return emb
74 |
75 |
76 | # --------------------------------------------------------
77 | # Interpolate position embeddings for high-resolution
78 | # References:
79 | # DeiT: https://github.com/facebookresearch/deit
80 | # --------------------------------------------------------
81 | def interpolate_pos_embed(model, checkpoint_model):
82 | if 'pos_embed' in checkpoint_model:
83 | pos_embed_checkpoint = checkpoint_model['pos_embed']
84 | embedding_size = pos_embed_checkpoint.shape[-1]
85 | num_patches = model.patch_embed.num_patches
86 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
87 | # height (== width) for the checkpoint position embedding
88 | orig_size = int(
89 | (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
90 | # height (== width) for the new position embedding
91 | new_size = int(num_patches**0.5)
92 | # class_token and dist_token are kept unchanged
93 | if orig_size != new_size:
94 | print('Position interpolate from %dx%d to %dx%d' %
95 | (orig_size, orig_size, new_size, new_size))
96 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
97 | # only the position tokens are interpolated
98 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
99 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
100 | embedding_size).permute(
101 | 0, 3, 1, 2).contiguous()
102 | pos_tokens = torch.nn.functional.interpolate(pos_tokens,
103 | size=(new_size,
104 | new_size),
105 | mode='bicubic',
106 | align_corners=False)
107 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
108 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
109 | checkpoint_model['pos_embed'] = new_pos_embed
110 |
--------------------------------------------------------------------------------
/util/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import math
4 | import random
5 | import warnings
6 | from collections import Iterable
7 |
8 | import numpy as np
9 | import torch
10 | from torchvision.transforms import functional as F
11 |
12 | try:
13 | from torchvision.transforms import InterpolationMode
14 |
15 | NEAREST = InterpolationMode.NEAREST
16 | BILINEAR = InterpolationMode.BILINEAR
17 | BICUBIC = InterpolationMode.BICUBIC
18 | LANCZOS = InterpolationMode.LANCZOS
19 | HAMMING = InterpolationMode.HAMMING
20 | HAMMING = InterpolationMode.HAMMING
21 |
22 | _pil_interpolation_to_str = {
23 | InterpolationMode.NEAREST: 'InterpolationMode.NEAREST',
24 | InterpolationMode.BILINEAR: 'InterpolationMode.BILINEAR',
25 | InterpolationMode.BICUBIC: 'InterpolationMode.BICUBIC',
26 | InterpolationMode.LANCZOS: 'InterpolationMode.LANCZOS',
27 | InterpolationMode.HAMMING: 'InterpolationMode.HAMMING',
28 | InterpolationMode.BOX: 'InterpolationMode.BOX',
29 | }
30 |
31 | except:
32 | from PIL import Image
33 |
34 | NEAREST = Image.NEAREST
35 | BILINEAR = Image.BILINEAR
36 | BICUBIC = Image.BICUBIC
37 | LANCZOS = Image.LANCZOS
38 | HAMMING = Image.HAMMING
39 | HAMMING = Image.HAMMING
40 |
41 | _pil_interpolation_to_str = {
42 | Image.NEAREST: 'PIL.Image.NEAREST',
43 | Image.BILINEAR: 'PIL.Image.BILINEAR',
44 | Image.BICUBIC: 'PIL.Image.BICUBIC',
45 | Image.LANCZOS: 'PIL.Image.LANCZOS',
46 | Image.HAMMING: 'PIL.Image.HAMMING',
47 | Image.BOX: 'PIL.Image.BOX',
48 | }
49 |
50 | def _get_image_size(img):
51 | if F._is_pil_image(img):
52 | return img.size
53 | elif isinstance(img, torch.Tensor) and img.dim() > 2:
54 | return img.shape[-2:][::-1]
55 | else:
56 | raise TypeError('Unexpected type {}'.format(type(img)))
57 |
58 |
59 | class Compose(object):
60 | """Composes several transforms together.
61 |
62 | Args:
63 | transforms (list of ``Transform`` objects):
64 | list of transforms to compose.
65 |
66 | Example:
67 | >>> transforms.Compose([
68 | >>> transforms.CenterCrop(10),
69 | >>> transforms.ToTensor(),
70 | >>> ])
71 | """
72 | def __init__(self, transforms):
73 | self.transforms = transforms
74 |
75 | def __call__(self, img, gt):
76 | for t in self.transforms:
77 | if 'RandomResizedCrop' in t.__class__.__name__:
78 | img, gt = t(img, gt)
79 | elif 'Flip' in t.__class__.__name__:
80 | img, gt = t(img, gt)
81 | elif 'ToTensor' in t.__class__.__name__:
82 | img, gt = t(img, gt)
83 | else:
84 | img = t(img)
85 | gt = gt.float()
86 |
87 | return img, gt
88 |
89 | def __repr__(self):
90 | format_string = self.__class__.__name__ + '('
91 | for t in self.transforms:
92 | format_string += '\n'
93 | format_string += ' {0}'.format(t)
94 | format_string += '\n)'
95 | return format_string
96 |
97 |
98 | class RandomHorizontalFlip(object):
99 | """Horizontally flip the given PIL Image randomly with a given probability.
100 |
101 | Args:
102 | p (float): probability of the image being flipped. Default value is 0.5
103 | """
104 | def __init__(self, p=0.5):
105 | self.p = p
106 |
107 | def __call__(self, img, gt):
108 | """
109 | Args:
110 | img (PIL Image): Image to be flipped.
111 |
112 | Returns:
113 | PIL Image: Randomly flipped image.
114 | """
115 | if random.random() < self.p:
116 | return F.hflip(img), F.hflip(gt)
117 | return img, gt
118 |
119 | def __repr__(self):
120 | return self.__class__.__name__ + '(p={})'.format(self.p)
121 |
122 |
123 | class RandomResizedCrop(object):
124 | """Crop the given PIL Image to random size and aspect ratio.
125 |
126 | A crop of random size (default: of 0.08 to 1.0) of the original size
127 | and a random aspect ratio (default: of 3/4 to 4/3) of the original
128 | aspect ratio is made. This crop is finally resized to given size.
129 | This is popularly used to train the Inception networks.
130 |
131 | Args:
132 | size: expected output size of each edge
133 | scale: range of size of the origin size cropped
134 | ratio: range of aspect ratio of the origin aspect ratio cropped
135 | interpolation: Default: PIL.Image.BILINEAR
136 | """
137 | def __init__(self,
138 | size,
139 | scale=(0.08, 1.0),
140 | ratio=(3. / 4., 4. / 3.),
141 | interpolation=BILINEAR):
142 | if isinstance(size, (tuple, list)):
143 | self.size = size
144 | else:
145 | self.size = (size, size)
146 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
147 | warnings.warn('range should be of kind (min, max)')
148 |
149 | self.interpolation = interpolation
150 | self.scale = scale
151 | self.ratio = ratio
152 |
153 | @staticmethod
154 | def get_params(img, scale, ratio):
155 | """Get parameters for ``crop`` for a random sized crop.
156 |
157 | Args:
158 | img (PIL Image): Image to be cropped.
159 | scale (tuple):
160 | range of size of the origin size cropped
161 | ratio (tuple):
162 | range of aspect ratio of the origin aspect ratio cropped
163 |
164 | Returns:
165 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random
166 | sized crop.
167 | """
168 | width, height = _get_image_size(img)
169 | area = height * width
170 |
171 | for attempt in range(10):
172 | target_area = random.uniform(*scale) * area
173 | log_ratio = (math.log(ratio[0]), math.log(ratio[1]))
174 | aspect_ratio = math.exp(random.uniform(*log_ratio))
175 |
176 | w = int(round(math.sqrt(target_area * aspect_ratio)))
177 | h = int(round(math.sqrt(target_area / aspect_ratio)))
178 |
179 | if 0 < w <= width and 0 < h <= height:
180 | i = random.randint(0, height - h)
181 | j = random.randint(0, width - w)
182 | return i, j, h, w
183 |
184 | # Fallback to central crop
185 | in_ratio = float(width) / float(height)
186 | if (in_ratio < min(ratio)):
187 | w = width
188 | h = int(round(w / min(ratio)))
189 | elif (in_ratio > max(ratio)):
190 | h = height
191 | w = int(round(h * max(ratio)))
192 | else: # whole image
193 | w = width
194 | h = height
195 | i = (height - h) // 2
196 | j = (width - w) // 2
197 | return i, j, h, w
198 |
199 | def __call__(self, img, gt):
200 | """
201 | Args:
202 | img (PIL Image): Image to be cropped and resized.
203 |
204 | Returns:
205 | PIL Image: Randomly cropped and resized image.
206 | """
207 | i, j, h, w = self.get_params(img, self.scale, self.ratio)
208 | return F.resized_crop(
209 | img, i, j, h, w, self.size, self.interpolation), \
210 | F.resized_crop(
211 | gt, i, j, h, w, self.size, NEAREST)
212 |
213 | def __repr__(self):
214 | interpolate_str = _pil_interpolation_to_str[self.interpolation]
215 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
216 | format_string += ', scale={0}'.format(
217 | tuple(round(s, 4) for s in self.scale))
218 | format_string += ', ratio={0}'.format(
219 | tuple(round(r, 4) for r in self.ratio))
220 | format_string += ', interpolation={0})'.format(interpolate_str)
221 | return format_string
222 |
223 | class ToTensor(object):
224 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
225 |
226 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range
227 | [0, 255] to a torch.FloatTensor of
228 | shape (C x H x W) in the range [0.0, 1.0]
229 | if the PIL Image belongs to one of the
230 | modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
231 | or if the numpy.ndarray has dtype = np.uint8
232 |
233 | In the other cases, tensors are returned without scaling.
234 | """
235 | def __call__(self, pic, gt):
236 | """
237 | Args:
238 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
239 |
240 | Returns:
241 | Tensor: Converted image.
242 | """
243 | return F.to_tensor(pic), torch.from_numpy(np.array(gt))
244 |
245 | def __repr__(self):
246 | return self.__class__.__name__ + '()'
247 |
--------------------------------------------------------------------------------