├── LICENSE
├── README.md
├── config.py
├── res
├── doc
│ ├── figures
│ │ └── teaser.png
│ └── notes.txt
└── eval
│ ├── info.txt
│ ├── kodak_results.txt
│ ├── kodim01_hat.png
│ ├── kodim02_hat.png
│ ├── kodim03_hat.png
│ ├── kodim04_hat.png
│ ├── kodim05_hat.png
│ ├── kodim06_hat.png
│ ├── kodim07_hat.png
│ ├── kodim08_hat.png
│ ├── kodim09_hat.png
│ ├── kodim10_hat.png
│ ├── kodim11_hat.png
│ ├── kodim12_hat.png
│ ├── kodim13_hat.png
│ ├── kodim14_hat.png
│ ├── kodim15.png
│ ├── kodim15_hat.png
│ ├── kodim16_hat.png
│ ├── kodim17_hat.png
│ ├── kodim18_hat.png
│ ├── kodim19_hat.png
│ ├── kodim20_hat.png
│ ├── kodim21_hat.png
│ ├── kodim22.png
│ ├── kodim22_hat.png
│ ├── kodim23.png
│ ├── kodim23_hat.png
│ └── kodim24_hat.png
├── swin-transformers-tf
├── LICENSE
├── README.md
├── changelog.txt
├── convert.py
├── convert_all_models.py
├── hub_utilities
│ ├── README.md
│ ├── export_for_hub.py
│ └── generate_doc.py
├── in1k-eval
│ ├── README.md
│ ├── df.ipynb
│ ├── eval-swins.ipynb
│ ├── swin_224_in1k.csv
│ └── swin_384_in1k.csv
├── notebooks
│ ├── classification.ipynb
│ ├── finetune.ipynb
│ ├── ilsvrc2012_wordnet_lemmas.txt
│ └── weight-porting.ipynb
├── requirements.txt
├── swins
│ ├── __init__.py
│ ├── blocks
│ │ ├── __init__.py
│ │ ├── mlp.py
│ │ ├── stage_block.py
│ │ ├── swin_transformer_block.py
│ │ └── utils.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── patch_merging.py
│ │ ├── patch_splitting.py
│ │ ├── sd.py
│ │ └── window_attn.py
│ ├── model_configs.py
│ └── models.py
├── test.py
└── utils
│ └── helpers.py
└── zyc2022.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SwinT-ChARM (TensorFlow 2)
2 |
3 | [
](https://colab.research.google.com/drive/1w42soP9daxok4f6jvh3nen1DX0z4KSRb?usp=sharing)
4 |
5 | This repository provides a TensorFlow implementation of SwinT-ChARM based on:
6 |
7 | - [Transformer-Based Transform Coding (ICLR 2022)](https://openreview.net/pdf?id=IDwN6xjHnK8),
8 | - [Channel-wise Autoregressive Entropy Models For Learned Image Compression (ICIP 2020)](https://arxiv.org/pdf/2007.08739.pdf).
9 |
10 | 
11 |
12 | [Source](https://openreview.net/pdf?id=IDwN6xjHnK8)
13 |
14 |
15 | ## Updates
16 |
17 | ***10/06/2023***
18 | 1. LIC-TCM (TensorFlow 2) is now available: https://github.com/Nikolai10/LIC-TCM (Liu et al. CVPR 2023 Highlight).
19 |
20 | ***09/06/2023***
21 | 1. The high quality of this reimplementation has been confirmed in [EGIC, Section A.8.](https://arxiv.org/pdf/2309.03244v1.pdf).
22 |
23 | ***10/09/2022***
24 |
25 | 1. The number of model parameters now corresponds exactly to the reported number (32.6 million). We thank the authors for providing us with the official DeepSpeed log files.
26 | 2. SwinT-ChARM now supports compression at different input resolutions (multiples of 256).
27 | 3. We release a pre-trained model as proof of functional correctness.
28 |
29 | ***08/17/2022***
30 |
31 | 1. Initial release of this project (see branch *release_08/17/2022*)
32 |
33 | ## Acknowledgment
34 | This project is based on:
35 |
36 | - [TensorFlow Compression (TFC)](https://github.com/tensorflow/compression), a TF library dedicated to data compression.
37 | - [swin-transformers-tf](https://github.com/sayakpaul/swin-transformers-tf), an unofficial implementation of [Swin-Transformer](https://github.com/microsoft/Swin-Transformer). Functional correctness has been [proven](https://github.com/microsoft/Swin-Transformer/pull/206).
38 |
39 | Note that this repository builds upon the official TF implementation of [Minnen et al.](https://github.com/tensorflow/compression/blob/master/models/ms2020.py), while Zhu et al. base their work on an
40 | unknown (possibly not publicly available) PyTorch reimplementation.
41 |
42 | ## Examples
43 |
44 | The samples below are taken from the [Kodak dataset](http://r0k.us/graphics/kodak/), external to the training set:
45 |
46 | Original | SwinT-ChARM (β = 0.0003)
47 | :-------------------------:|:-------------------------:
48 |  | 
49 |
50 | ```python
51 | Mean squared error: 13.7772
52 | PSNR (dB): 36.74
53 | Multiscale SSIM: 0.9871
54 | Multiscale SSIM (dB): 18.88
55 | Bits per pixel: 0.9890
56 | ```
57 |
58 | Original | SwinT-ChARM (β = 0.0003)
59 | :-------------------------:|:-------------------------:
60 |  | 
61 |
62 | ```python
63 | Mean squared error: 7.1963
64 | PSNR (dB): 39.56
65 | Multiscale SSIM: 0.9903
66 | Multiscale SSIM (dB): 20.13
67 | Bits per pixel: 0.3953
68 | ```
69 |
70 | Original | SwinT-ChARM (β = 0.0003)
71 | :-------------------------:|:-------------------------:
72 |  | 
73 |
74 | ```python
75 | Mean squared error: 10.1494
76 | PSNR (dB): 38.07
77 | Multiscale SSIM: 0.9888
78 | Multiscale SSIM (dB): 19.49
79 | Bits per pixel: 0.6525
80 | ```
81 | More examples can be found [here](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval).
82 |
83 | ## Pretrained Models/ Performance (TFC 2.8)
84 |
85 | Our pre-trained model (β = 0.0003) achieves a [PSNR of 37.59 (db) using an average of 0.93 bpp](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodak_results.txt) on the [Kodak dataset](http://r0k.us/graphics/kodak/), which is very close to the reported numbers (see [paper](https://openreview.net/pdf?id=IDwN6xjHnK8), Figure 3). Worth mentioning: we achieve this result despite training our model from scratch and using less than one-third of the computational resources (1M optimization steps).
86 |
87 |
88 |
89 | | Lagrangian multiplier (β) | SavedModel | Training Instructions |
90 | | ----------- | -------------------------------- | ---------------------- |
91 | | 0.0003 | [download](https://drive.google.com/drive/folders/1bUWowLEgU8ukYejvVPbhU38_7japhp7C?usp=sharing) |
`!python SwinT-ChARM/zyc2022.py -V --model_path <...> train --max_support_slices 10 --lambda 0.0003 --epochs 1000 --batchsize 16 --train_path <...>`
|
92 |
93 | ## File Structure
94 |
95 | res
96 | ├── doc/ # addtional resources
97 | ├── eval/ # sample images + reconstructions
98 | ├── train_zyc2022/ # model checkpoints + tf.summaries
99 | ├── zyc2022/ # saved model
100 | swin-transformers-tf/ # extended swin-transformers-tf implementation
101 | ├── changelog.txt # summary of changes made to the original work
102 | ├── ...
103 | config.py # model-dependent configurations
104 | zyc2022.py # core of this repo
105 |
106 | ## License
107 | [Apache License 2.0](LICENSE)
108 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """https://openreview.net/pdf?id=IDwN6xjHnK8, Appendix A2
17 |
18 | "SwinT-Hyperprior, SwinT-ChARM For both SwinT-Hyperprior and SwinT-ChARM, we use the
19 | same configurations: (wg, wh) = (8, 4), (C1, C2, C3, C4, C5, C6) = (128, 192, 256, 320, 192, 192),
20 | (d1, d2, d3, d4, d5, d6) = (2, 2, 6, 2, 5, 1) where C, d, and w are defined in Figure 13 and Figure 2.
21 | The head dim is 32 for all attention layers in SwinT-based models."
22 |
23 | """
24 |
25 | class ConfigGa:
26 | embed_dim = [128, 192, 256, 320]
27 | embed_out_dim = [192, 256, 320, None]
28 | depths = [2, 2, 6, 2]
29 | head_dim = [32, 32, 32, 32]
30 | window_size = [8, 8, 8, 8]
31 | num_layers = len(depths)
32 |
33 | class ConfigHa:
34 | embed_dim = [192, 192]
35 | embed_out_dim = [192, None]
36 | depths = [5, 1]
37 | head_dim = [32, 32]
38 | window_size = [4, 4]
39 | num_layers = len(depths)
40 |
41 | class ConfigHs:
42 | embed_dim = [192, 192]
43 | embed_out_dim = [192, int(2*320)]
44 | depths = [1, 5]
45 | head_dim = [32, 32]
46 | window_size = [4, 4]
47 | num_layers = len(depths)
48 |
49 | class ConfigGs:
50 | embed_dim = [320, 256, 192, 128]
51 | embed_out_dim = [256, 192, 128, 3]
52 | depths = [2, 6, 2, 2]
53 | head_dim = [32, 32, 32, 32]
54 | window_size = [8, 8, 8, 8]
55 | num_layers = len(depths)
56 |
57 | """https://arxiv.org/pdf/2007.08739.pdf, Appendix A
58 |
59 | "To account for the different input depths, each CC [...] transform is
60 | programmatically defined to linearly interpolate between the input and the
61 | output depth.
62 |
63 | Note: In SwinT-ChARM a slightly different logic is used, which is why the depth
64 | is explicitly hardcoded here.
65 |
66 | For example, the tenth slice should have depths: 224, 128 and 32.
67 |
68 | import pandas as pd
69 | import numpy as np
70 | a=pd.Series([320, np.nan, np.nan, 32])
71 |
72 | a.interpolate(method='linear')
73 |
74 | vs. 234, 117, 32 (taken from the official deepspeed logfile,
75 | which was provided by the authors).
76 |
77 | """
78 |
79 | class ConfigChARM:
80 | depths_conv0 = [64, 64, 85, 106, 128, 149, 170, 192, 213, 234]
81 | depths_conv1 = [32, 32, 42, 53, 64, 74, 85, 96, 106, 117]
--------------------------------------------------------------------------------
/res/doc/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/doc/figures/teaser.png
--------------------------------------------------------------------------------
/res/doc/notes.txt:
--------------------------------------------------------------------------------
1 | ## Technical Differences
2 |
3 | IMPORTANT: only affects branch release_08/17/2022 (the differences have been resolved!)
4 |
5 | In comparison to Zhu et al.:
6 |
7 | - we keep LPR (lrp_transforms) -> no motivation given for removal.
8 | - for CC, LPR we keep the approach described in Minnen et al., Appendix A.; i.e.
9 | the input to slice_1 is 320, input to slice_2 is 352 etc. rather than 32, 64,
10 | ... as described in Zhu et al. Appendix A, Figure 12.
11 |
12 | Further technical deviations are possible, as we base our work on the official
13 | TF implementation of Minnen et al., while Zhu et al. base their work on an
14 | unknown (possibly not publicly available) PyTorch reimplementation.
--------------------------------------------------------------------------------
/res/eval/info.txt:
--------------------------------------------------------------------------------
1 | Kodak
2 | The Kodak data set is a collection of 24 images with resolution 768x512 (or 512x768).
3 | The images are available as PNG files here: http://r0k.us/graphics/kodak
4 |
5 | @misc{kodak,
6 | title = "Kodak Lossless True Color Image Suite ({PhotoCD PCD0992})",
7 | author = "Eastman Kodak",
8 | url = "http://r0k.us/graphics/kodak",
9 | }
--------------------------------------------------------------------------------
/res/eval/kodak_results.txt:
--------------------------------------------------------------------------------
1 | The following values were generated by the official evaluation scripts provided by TFC (Minnen):
2 | https://github.com/tensorflow/compression/tree/master/results/image_compression
3 |
4 | SwinT-ChARM (beta=0.0003)
5 |
6 | ------------------------------
7 | kodim1
8 | ------------------------------
9 | Mean squared error: 14.7817
10 | PSNR (dB): 36.43
11 | Multiscale SSIM: 0.9941
12 | Multiscale SSIM (dB): 22.31
13 | Bits per pixel: 1.5605
14 |
15 | ------------------------------
16 | kodim2
17 | ------------------------------
18 | Mean squared error: 11.9409
19 | PSNR (dB): 37.36
20 | Multiscale SSIM: 0.9828
21 | Multiscale SSIM (dB): 17.63
22 | Bits per pixel: 0.6922
23 |
24 | ------------------------------
25 | kodim3
26 | ------------------------------
27 | Mean squared error: 6.6958
28 | PSNR (dB): 39.87
29 | Multiscale SSIM: 0.9911
30 | Multiscale SSIM (dB): 20.49
31 | Bits per pixel: 0.4716
32 |
33 | ------------------------------
34 | kodim4
35 | ------------------------------
36 | Mean squared error: 10.5122
37 | PSNR (dB): 37.91
38 | Multiscale SSIM: 0.9878
39 | Multiscale SSIM (dB): 19.15
40 | Bits per pixel: 0.7019
41 |
42 | ------------------------------
43 | kodim5
44 | ------------------------------
45 | Mean squared error: 14.5078
46 | PSNR (dB): 36.51
47 | Multiscale SSIM: 0.9950
48 | Multiscale SSIM (dB): 22.98
49 | Bits per pixel: 1.4607
50 |
51 | ------------------------------
52 | kodim6
53 | ------------------------------
54 | Mean squared error: 11.9955
55 | PSNR (dB): 37.34
56 | Multiscale SSIM: 0.9911
57 | Multiscale SSIM (dB): 20.48
58 | Bits per pixel: 1.1177
59 |
60 | ------------------------------
61 | kodim7
62 | ------------------------------
63 | Mean squared error: 7.3828
64 | PSNR (dB): 39.45
65 | Multiscale SSIM: 0.9943
66 | Multiscale SSIM (dB): 22.42
67 | Bits per pixel: 0.5966
68 |
69 | ------------------------------
70 | kodim8
71 | ------------------------------
72 | Mean squared error: 17.2197
73 | PSNR (dB): 35.77
74 | Multiscale SSIM: 0.9944
75 | Multiscale SSIM (dB): 22.54
76 | Bits per pixel: 1.5817
77 |
78 | ------------------------------
79 | kodim9
80 | ------------------------------
81 | Mean squared error: 8.4156
82 | PSNR (dB): 38.88
83 | Multiscale SSIM: 0.9896
84 | Multiscale SSIM (dB): 19.83
85 | Bits per pixel: 0.4888
86 |
87 | ------------------------------
88 | kodim10
89 | ------------------------------
90 | Mean squared error: 8.8622
91 | PSNR (dB): 38.66
92 | Multiscale SSIM: 0.9895
93 | Multiscale SSIM (dB): 19.78
94 | Bits per pixel: 0.5364
95 |
96 | ------------------------------
97 | kodim11
98 | ------------------------------
99 | Mean squared error: 12.1560
100 | PSNR (dB): 37.28
101 | Multiscale SSIM: 0.9900
102 | Multiscale SSIM (dB): 20.00
103 | Bits per pixel: 0.9873
104 |
105 | ------------------------------
106 | kodim12
107 | ------------------------------
108 | Mean squared error: 8.8393
109 | PSNR (dB): 38.67
110 | Multiscale SSIM: 0.9868
111 | Multiscale SSIM (dB): 18.79
112 | Bits per pixel: 0.5531
113 |
114 | ------------------------------
115 | kodim13
116 | ------------------------------
117 | Mean squared error: 21.1706
118 | PSNR (dB): 34.87
119 | Multiscale SSIM: 0.9932
120 | Multiscale SSIM (dB): 21.68
121 | Bits per pixel: 2.0560
122 |
123 | ------------------------------
124 | kodim14
125 | ------------------------------
126 | Mean squared error: 14.9357
127 | PSNR (dB): 36.39
128 | Multiscale SSIM: 0.9919
129 | Multiscale SSIM (dB): 20.90
130 | Bits per pixel: 1.2558
131 |
132 | ------------------------------
133 | kodim15
134 | ------------------------------
135 | Mean squared error: 10.1494
136 | PSNR (dB): 38.07
137 | Multiscale SSIM: 0.9888
138 | Multiscale SSIM (dB): 19.49
139 | Bits per pixel: 0.6525
140 |
141 | ------------------------------
142 | kodim16
143 | ------------------------------
144 | Mean squared error: 9.3178
145 | PSNR (dB): 38.44
146 | Multiscale SSIM: 0.9902
147 | Multiscale SSIM (dB): 20.09
148 | Bits per pixel: 0.7519
149 |
150 | ------------------------------
151 | kodim17
152 | ------------------------------
153 | Mean squared error: 9.6518
154 | PSNR (dB): 38.28
155 | Multiscale SSIM: 0.9910
156 | Multiscale SSIM (dB): 20.44
157 | Bits per pixel: 0.6467
158 |
159 | ------------------------------
160 | kodim18
161 | ------------------------------
162 | Mean squared error: 17.0509
163 | PSNR (dB): 35.81
164 | Multiscale SSIM: 0.9891
165 | Multiscale SSIM (dB): 19.61
166 | Bits per pixel: 1.2358
167 |
168 | ------------------------------
169 | kodim19
170 | ------------------------------
171 | Mean squared error: 11.6530
172 | PSNR (dB): 37.47
173 | Multiscale SSIM: 0.9891
174 | Multiscale SSIM (dB): 19.63
175 | Bits per pixel: 0.8815
176 |
177 | ------------------------------
178 | kodim20
179 | ------------------------------
180 | Mean squared error: 8.7219
181 | PSNR (dB): 38.72
182 | Multiscale SSIM: 0.9906
183 | Multiscale SSIM (dB): 20.28
184 | Bits per pixel: 0.5912
185 |
186 | ------------------------------
187 | kodim21
188 | ------------------------------
189 | Mean squared error: 11.8035
190 | PSNR (dB): 37.41
191 | Multiscale SSIM: 0.9899
192 | Multiscale SSIM (dB): 19.94
193 | Bits per pixel: 0.9675
194 |
195 | ------------------------------
196 | kodim22
197 | ------------------------------
198 | Mean squared error: 13.7772
199 | PSNR (dB): 36.74
200 | Multiscale SSIM: 0.9871
201 | Multiscale SSIM (dB): 18.88
202 | Bits per pixel: 0.9890
203 |
204 | ------------------------------
205 | kodim23
206 | ------------------------------
207 | Mean squared error: 7.1963
208 | PSNR (dB): 39.56
209 | Multiscale SSIM: 0.9903
210 | Multiscale SSIM (dB): 20.13
211 | Bits per pixel: 0.3953
212 |
213 | ------------------------------
214 | kodim24
215 | ------------------------------
216 | Mean squared error: 15.0891
217 | PSNR (dB): 36.34
218 | Multiscale SSIM: 0.9928
219 | Multiscale SSIM (dB): 21.42
220 | Bits per pixel: 1.1902
221 |
222 |
223 | STATS (mean across kodak)
224 |
225 | bpps = [1.5605, 0.6922, 0.4716, 0.7019, 1.4607, 1.1177, 0.5966, 1.5817, 0.4888, 0.5364, 0.9873, 0.5531, 2.0560, 1.2558, 0.6525, 0.7519, 0.6467, 1.2358, 0.8815, 0.5912, 0.9675, 0.9890, 0.3953, 1.1902]
226 | np.mean(bpps) -> 0.9317458333333334
227 |
228 | psnrs = [36.43, 37.36, 39.87, 37.91, 36.51, 37.34, 39.45, 35.77, 38.88, 38.66, 37.28, 38.67, 34.87, 36.39, 38.07, 38.44, 38.28, 35.81, 37.47, 38.72, 37.41, 36.74, 39.56, 36.34]
229 | np.mean(psnrs) -> 37.59291666666667
--------------------------------------------------------------------------------
/res/eval/kodim01_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim01_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim02_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim02_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim03_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim03_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim04_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim04_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim05_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim05_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim06_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim06_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim07_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim07_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim08_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim08_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim09_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim09_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim10_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim10_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim11_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim11_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim12_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim12_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim13_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim13_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim14_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim14_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim15.png
--------------------------------------------------------------------------------
/res/eval/kodim15_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim15_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim16_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim16_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim17_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim17_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim18_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim18_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim19_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim19_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim20_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim20_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim21_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim21_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim22.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim22.png
--------------------------------------------------------------------------------
/res/eval/kodim22_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim22_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim23.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim23.png
--------------------------------------------------------------------------------
/res/eval/kodim23_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim23_hat.png
--------------------------------------------------------------------------------
/res/eval/kodim24_hat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim24_hat.png
--------------------------------------------------------------------------------
/swin-transformers-tf/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/swin-transformers-tf/README.md:
--------------------------------------------------------------------------------
1 | # Swin for win!
2 |
3 | [](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0)
4 | [](https://tfhub.dev/sayakpaul/collections/swin/1)
5 |
6 | This repository provides TensorFlow / Keras implementations of different Swin Transformer
7 | [1, 2] variants by Liu et al. and Chen et al. It also provides the TensorFlow / Keras models
8 | that have been populated with the original Swin pre-trained params available from [3, 4]. These
9 | models are not blackbox SavedModels i.e., they can be fully expanded into `tf.keras.Model`
10 | objects and one can call all the utility functions on them (example: `.summary()`).
11 |
12 | Refer to the ["Using the models"](https://github.com/sayakpaul/swin-transformers-tf#using-the-models)
13 | section to get started.
14 |
15 | I find Swin Transformers interesting because they induce a sense of hierarchies by using ***s***hifted ***win***dows. Multi-scale
16 | representations like that are crucial to get good performance in tasks like object detection and segmentation.
17 | 
18 | Source
19 |
20 | "Swin for win!" however doesn't portray my architecture bias -- I found it cool and hence kept it.
21 |
22 | ## Table of contents
23 |
24 | * [Conversion](https://github.com/sayakpaul/swin-transformers-tf#conversion)
25 | * [Collection of pre-trained models (converted from PyTorch to TensorFlow)](https://github.com/sayakpaul/swin-transformers-tf#models)
26 | * [Results of the converted models](https://github.com/sayakpaul/swin-transformers-tf#results)
27 | * [How to use the models?](https://github.com/sayakpaul/swin-transformers-tf#using-the-models)
28 | * [References](https://github.com/sayakpaul/swin-transformers-tf#references)
29 | * [Acknowledgements](https://github.com/sayakpaul/swin-transformers-tf#acknowledgements)
30 |
31 | ## Conversion
32 |
33 | TensorFlow / Keras implementations are available in `swins/models.py`. All model configurations
34 | are in `swins/model_configs.py`. Conversion utilities are in `convert.py`. To run the conversion
35 | utilities, first install all the dependencies listed in `requirements.txt`. Additionally,
36 | nnstall `timm` from source:
37 |
38 | ```sh
39 | pip install -q git+https://github.com/rwightman/pytorch-image-models
40 | ```
41 |
42 | ## Models
43 |
44 | Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/swin/1. You can fully inspect the
45 | architecture of the TF-Hub models like so:
46 |
47 | ```py
48 | import tensorflow as tf
49 |
50 | model_gcs_path = "gs://tfhub-modules/sayakpaul/swin_tiny_patch4_window7_224/1/uncompressed"
51 | model = tf.keras.models.load_model(model_gcs_path)
52 |
53 | dummy_inputs = tf.ones((2, 224, 224, 3))
54 | _ = model(dummy_inputs)
55 | print(model.summary(expand_nested=True))
56 | ```
57 |
58 | ## Results
59 |
60 | The table below provides a performance summary (ImageNet-1k validation set):
61 |
62 | | model_name | top1_acc(%) | top5_acc(%) | orig_top1_acc(%) |
63 | |:------------------------------:|:-------------:|:-------------:|:------------------:|
64 | | swin_base_patch4_window7_224 | 85.134 | 97.48 | 85.2 |
65 | | swin_large_patch4_window7_224 | 86.252 | 97.878 | 86.3 |
66 | | swin_s3_base_224 | 83.958 | 96.532 | 84 |
67 | | swin_s3_small_224 | 83.648 | 96.358 | 83.7 |
68 | | swin_s3_tiny_224 | 82.034 | 95.864 | 82.1 |
69 | | swin_small_patch4_window7_224 | 83.178 | 96.24 | 83.2 |
70 | | swin_tiny_patch4_window7_224 | 81.184 | 95.512 | 81.2 |
71 | | swin_base_patch4_window12_384 | 86.428 | 98.042 | 86.4 |
72 | | swin_large_patch4_window12_384 | 87.272 | 98.242 | 87.3 |
73 |
74 |
75 | The `base` and `large` models were first pre-trained on the ImageNet-22k dataset and then fine-tuned
76 | on the ImageNet-1k dataset.
77 |
78 | [`in1k-eval` directory](https://github.com/sayakpaul/swin-transformers-tf/tree/main/in1k-eval) provides details
79 | on how these numbers were generated. Original scores for all the models except for the `s3` ones were
80 | gathered from [here](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md). Scores
81 | for the `s3` model were gathered from [here](https://github.com/microsoft/Cream/tree/main/AutoFormerV2#model-zoo).
82 |
83 | ## Using the models
84 |
85 | **Pre-trained models**:
86 |
87 | * Off-the-shelf classification: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/classification.ipynb)
88 | * Fine-tuning: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/finetune.ipynb)
89 |
90 | When doing transfer learning try using the models that were pre-trained on the ImageNet-22k dataset. All the
91 | `base` and `large` models listed here were pre-trained on the ImageNet-22k dataset. Refer to the
92 | [model collection page on TF-Hub](https://tfhub.dev/sayakpaul/collections/swin/1) to know more.
93 |
94 | These models also output attention weights from each of the Transformer blocks.
95 | Refer to [this notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/classification.ipynb)
96 | for more details. Additionally, the notebook shows how to obtain the attention maps for a given image.
97 |
98 |
99 | **Randomly initialized models**:
100 |
101 | ```py
102 | import tensorflow as tf
103 |
104 | from swins import SwinTransformer
105 |
106 | cfg = dict(
107 | patch_size=4,
108 | window_size=7,
109 | embed_dim=128,
110 | depths=(2, 2, 18, 2),
111 | num_heads=(4, 8, 16, 32),
112 | )
113 |
114 | swin_base_patch4_window7_224 = SwinTransformer(
115 | name="swin_base_patch4_window7_224", **cfg
116 | )
117 | print("Model instantiated, attempting predictions...")
118 | random_tensor = tf.random.normal((2, 224, 224, 3))
119 | outputs = swin_base_patch4_window7_224(random_tensor, training=False)
120 |
121 | print(outputs.shape)
122 |
123 | print(swin_base_patch4_window7_224.count_params() / 1e6)
124 | ```
125 |
126 | To initialize a network with say, 5 classes do:
127 |
128 | ```py
129 | cfg = dict(
130 | patch_size=4,
131 | window_size=7,
132 | embed_dim=128,
133 | depths=(2, 2, 18, 2),
134 | num_heads=(4, 8, 16, 32),
135 | num_classes=5,
136 | )
137 |
138 | swin_base_patch4_window7_224 = SwinTransformer(
139 | name="swin_base_patch4_window7_224", **cfg
140 | )
141 | ```
142 |
143 | To view different model configurations, refer to `swins/model_configs.py`.
144 |
145 | ## References
146 |
147 | [1] [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.](https://arxiv.org/abs/2103.14030)
148 |
149 | [2] [Searching the Search Space of Vision Transformer by Chen et al.](https://arxiv.org/abs/2111.14725)
150 |
151 | [3] [Swin Transformers GitHub](https://github.com/microsoft/Swin-Transformer)
152 |
153 | [4] [AutoFormerV2 GitHub](https://github.com/silent-chen/AutoFormerV2-model-zoo)
154 |
155 | ## Acknowledgements
156 |
157 | * [`timm` library source code](https://github.com/rwightman/pytorch-image-models)
158 | for the awesome codebase. I've copy-pasted and modified a huge chunk of code from there.
159 | I've also mentioned it from the respective scripts.
160 | * [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) for helping with a non-trivial model serialization hack.
161 | * [ML-GDE program](https://developers.google.com/programs/experts/) for
162 | providing GCP credits that supported my experiments.
163 |
--------------------------------------------------------------------------------
/swin-transformers-tf/changelog.txt:
--------------------------------------------------------------------------------
1 | Changelog to https://github.com/sayakpaul/swin-transformers-tf
2 |
3 | - swin-transformers-tf/swins/layers: patch_splitting.py added
4 | - swin-transformers-tf/swins/layers: patch_merging.py modified
5 | - swin_transformers-tf/swins/blocks: stage_block.py modified
6 | - swin_transformers-tf/swins/blocks: swin_transformer_block.py modified
7 | - swin_transformers-tf/swins/layers: __init__.py modified
8 | - swin_transformers-tf/swins/layers: window_attn.py modified
9 |
10 | General comments:
11 | - enabled support of arbitrary input shapes (multiples of 256).
12 | - replaced numpy-based attention mask calculation with a native TensorFlow equivalent.
--------------------------------------------------------------------------------
/swin-transformers-tf/convert.py:
--------------------------------------------------------------------------------
1 | """
2 | Some code is copied from here:
3 | https://github.com/sayakpaul/cait-tf/blob/main/convert.py
4 | """
5 |
6 | import argparse
7 | import os
8 | import sys
9 | from typing import Dict, List
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | import timm
14 |
15 | sys.path.append("..")
16 |
17 |
18 | from swins import SwinTransformer, model_configs
19 | from swins.blocks import *
20 | from swins.layers import *
21 | from utils import helpers
22 |
23 | TF_MODEL_ROOT = "gs://swin-tf"
24 |
25 | NUM_CLASSES = {"in1k": 1000, "in21k": 21841}
26 |
27 |
28 | def parse_args():
29 | parser = argparse.ArgumentParser(
30 | description="Conversion of the PyTorch pre-trained Swin weights to TensorFlow."
31 | )
32 | parser.add_argument(
33 | "-m",
34 | "--model-name",
35 | default="swin_tiny_patch4_window7_224",
36 | type=str,
37 | choices=model_configs.MODEL_MAP.keys(),
38 | help="Name of the Swin model variant.",
39 | )
40 | parser.add_argument(
41 | "-d",
42 | "--dataset",
43 | default="in1k",
44 | choices=["in1k", "in21k"],
45 | type=str,
46 | )
47 | parser.add_argument(
48 | "-pl",
49 | "--pre-logits",
50 | action="store_true",
51 | help="If we don't need the classification outputs.",
52 | )
53 | return parser.parse_args()
54 |
55 |
56 | def modify_swin_blocks(
57 | np_state_dict: Dict[str, np.ndarray],
58 | pt_weights_prefix: str,
59 | tf_block: List[tf.keras.layers.Layer],
60 | ) -> List[tf.keras.layers.Layer]:
61 | """Main utility to convert params of a swin block."""
62 | # Patch merging.
63 | for layer in tf_block:
64 | if isinstance(layer, PatchMerging):
65 | patch_merging_idx = f"{pt_weights_prefix}.downsample"
66 |
67 | layer.reduction = helpers.modify_tf_block(
68 | layer.reduction,
69 | np_state_dict[f"{patch_merging_idx}.reduction.weight"],
70 | )
71 | layer.norm = helpers.modify_tf_block(
72 | layer.norm,
73 | np_state_dict[f"{patch_merging_idx}.norm.weight"],
74 | np_state_dict[f"{patch_merging_idx}.norm.bias"],
75 | )
76 |
77 | # Swin layers.
78 | common_prefix = f"{pt_weights_prefix}.blocks"
79 | block_idx = 0
80 |
81 | for outer_layer in tf_block:
82 |
83 | layernorm_idx = 1
84 | mlp_layer_idx = 1
85 |
86 | if isinstance(outer_layer, SwinTransformerBlock):
87 | for inner_layer in outer_layer.layers:
88 |
89 | # Layer norm.
90 | if isinstance(inner_layer, tf.keras.layers.LayerNormalization):
91 | layer_norm_prefix = (
92 | f"{common_prefix}.{block_idx}.norm{layernorm_idx}"
93 | )
94 | inner_layer.gamma.assign(
95 | tf.Variable(
96 | np_state_dict[f"{layer_norm_prefix}.weight"]
97 | )
98 | )
99 | inner_layer.beta.assign(
100 | tf.Variable(np_state_dict[f"{layer_norm_prefix}.bias"])
101 | )
102 | layernorm_idx += 1
103 |
104 | # Windown attention.
105 | elif isinstance(inner_layer, WindowAttention):
106 | attn_prefix = f"{common_prefix}.{block_idx}.attn"
107 |
108 | # Relative position.
109 | inner_layer.relative_position_bias_table = (
110 | helpers.modify_tf_block(
111 | inner_layer.relative_position_bias_table,
112 | np_state_dict[
113 | f"{attn_prefix}.relative_position_bias_table"
114 | ],
115 | )
116 | )
117 | inner_layer.relative_position_index = (
118 | helpers.modify_tf_block(
119 | inner_layer.relative_position_index,
120 | np_state_dict[
121 | f"{attn_prefix}.relative_position_index"
122 | ],
123 | )
124 | )
125 |
126 | # QKV.
127 | inner_layer.qkv = helpers.modify_tf_block(
128 | inner_layer.qkv,
129 | np_state_dict[f"{attn_prefix}.qkv.weight"],
130 | np_state_dict[f"{attn_prefix}.qkv.bias"],
131 | )
132 |
133 | # Projection.
134 | inner_layer.proj = helpers.modify_tf_block(
135 | inner_layer.proj,
136 | np_state_dict[f"{attn_prefix}.proj.weight"],
137 | np_state_dict[f"{attn_prefix}.proj.bias"],
138 | )
139 |
140 | # MLP.
141 | elif isinstance(inner_layer, tf.keras.Model):
142 | mlp_prefix = f"{common_prefix}.{block_idx}.mlp"
143 | for mlp_layer in inner_layer.layers:
144 | if isinstance(mlp_layer, tf.keras.layers.Dense):
145 | mlp_layer = helpers.modify_tf_block(
146 | mlp_layer,
147 | np_state_dict[
148 | f"{mlp_prefix}.fc{mlp_layer_idx}.weight"
149 | ],
150 | np_state_dict[
151 | f"{mlp_prefix}.fc{mlp_layer_idx}.bias"
152 | ],
153 | )
154 | mlp_layer_idx += 1
155 |
156 | block_idx += 1
157 | return tf_block
158 |
159 |
160 | def main(args):
161 | if args.pre_logits:
162 | print(f"Converting {args.model_name} for feature extraction...")
163 | else:
164 | print(f"Converting {args.model_name}...")
165 |
166 | print("Instantiating PyTorch model...")
167 | pt_model = timm.create_model(model_name=args.model_name, pretrained=True)
168 | pt_model.eval()
169 |
170 | print("Instantiating TF model...")
171 | cfg_method = model_configs.MODEL_MAP[args.model_name]
172 | cfg = cfg_method()
173 | tf_model = SwinTransformer(**cfg, pre_logits=args.pre_logits)
174 |
175 | image_size = cfg.get("img_size", 224)
176 | dummy_inputs = tf.ones((2, image_size, image_size, 3))
177 | _ = tf_model(dummy_inputs)
178 |
179 | if not args.pre_logits:
180 | assert tf_model.count_params() == sum(
181 | p.numel() for p in pt_model.parameters()
182 | )
183 |
184 | # Load the PT params.
185 | np_state_dict = pt_model.state_dict()
186 | np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict}
187 |
188 | print("Beginning parameter porting process...")
189 |
190 | # Projection.
191 | tf_model.projection.layers[0] = helpers.modify_tf_block(
192 | tf_model.projection.layers[0],
193 | np_state_dict["patch_embed.proj.weight"],
194 | np_state_dict["patch_embed.proj.bias"],
195 | )
196 | tf_model.projection.layers[2] = helpers.modify_tf_block(
197 | tf_model.projection.layers[2],
198 | np_state_dict["patch_embed.norm.weight"],
199 | np_state_dict["patch_embed.norm.bias"],
200 | )
201 |
202 | # Layer norm layers.
203 | ln_idx = -2
204 | tf_model.layers[ln_idx] = helpers.modify_tf_block(
205 | tf_model.layers[ln_idx],
206 | np_state_dict["norm.weight"],
207 | np_state_dict["norm.bias"],
208 | )
209 |
210 | # Head layers.
211 | if not args.pre_logits:
212 | head_layer = tf_model.get_layer("classification_head")
213 | tf_model.layers[-1] = helpers.modify_tf_block(
214 | head_layer,
215 | np_state_dict["head.weight"],
216 | np_state_dict["head.bias"],
217 | )
218 |
219 | # Swin layers.
220 | for i in range(len(cfg["depths"])):
221 | _ = modify_swin_blocks(
222 | np_state_dict,
223 | f"layers.{i}",
224 | tf_model.layers[i + 2].layers,
225 | )
226 |
227 | print("Porting successful, serializing TensorFlow model...")
228 | save_path = os.path.join(TF_MODEL_ROOT, args.model_name)
229 | save_path = f"{save_path}_fe" if args.pre_logits else save_path
230 | tf_model.save(save_path)
231 | print(f"TensorFlow model serialized to: {save_path}...")
232 |
233 |
234 | if __name__ == "__main__":
235 | args = parse_args()
236 | main(args)
237 |
--------------------------------------------------------------------------------
/swin-transformers-tf/convert_all_models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append("..")
5 |
6 | from swins import model_configs
7 |
8 |
9 | def main():
10 | for model_name in model_configs.MODEL_MAP:
11 | if "in22k" in model_name:
12 | dataset = "in21k"
13 | else:
14 | dataset = "in1k"
15 |
16 | for i in range(2):
17 | command = f"python convert.py -m {model_name} -d {dataset}"
18 | if i == 1:
19 | command += " -pl"
20 | os.system(command)
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/swin-transformers-tf/hub_utilities/README.md:
--------------------------------------------------------------------------------
1 | The scripts contained in this directory are somewhat closely related to TF-Hub. Following utilities are supported:
2 |
3 | * `export_for_hub.py`: Exports a bulk of SavedModels as `tar.gz` archives needed by TF-Hub.
4 | * `generate_doc.py`: Generates documentation for a bulk of models.
5 |
--------------------------------------------------------------------------------
/swin-transformers-tf/hub_utilities/export_for_hub.py:
--------------------------------------------------------------------------------
1 | """Generates .tar.gz archives from SavedModels and serializes them."""
2 |
3 |
4 | import os
5 | from typing import List
6 |
7 | import tensorflow as tf
8 |
9 | TF_MODEL_ROOT = "gs://swin-tf"
10 | TAR_ARCHIVES = os.path.join(TF_MODEL_ROOT, "tars/")
11 |
12 |
13 | def prepare_archive(model_name: str) -> None:
14 | """Prepares a tar archive."""
15 | archive_name = f"{model_name}.tar.gz"
16 | print(f"Archiving to {archive_name}.")
17 | archive_command = f"cd {model_name} && tar -czvf ../{archive_name} *"
18 | os.system(archive_command)
19 | os.system(f"rm -rf {model_name}")
20 |
21 |
22 | def save_to_gcs(model_paths: List[str]) -> None:
23 | """Prepares tar archives and saves them inside a GCS bucket."""
24 | for path in model_paths:
25 | print(f"Preparing model: {path}.")
26 | model_name = path.strip("/")
27 | abs_model_path = os.path.join(TF_MODEL_ROOT, model_name)
28 |
29 | print(f"Copying from {abs_model_path}.")
30 | os.system(f"gsutil cp -r {abs_model_path} .")
31 | prepare_archive(model_name)
32 |
33 | os.system(f"gsutil -m cp -r *.tar.gz {TAR_ARCHIVES}")
34 | os.system("rm -rf *.tar.gz")
35 |
36 |
37 | model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)
38 | print(f"Total models: {len(model_paths)}.")
39 |
40 | print("Preparing archives for the classification and feature extractor models.")
41 | save_to_gcs(model_paths)
42 | tar_paths = tf.io.gfile.listdir(TAR_ARCHIVES)
43 | print(f"Total tars: {len(tar_paths)}.")
44 |
--------------------------------------------------------------------------------
/swin-transformers-tf/hub_utilities/generate_doc.py:
--------------------------------------------------------------------------------
1 | """Generates model documentation for Swin-TF models.
2 |
3 | Credits: Willi Gierke
4 | """
5 |
6 | import os
7 | from string import Template
8 |
9 | import attr
10 |
11 | template = Template(
12 | """# Module $HANDLE
13 |
14 | Fine-tunable Swin Transformer model pre-trained on the $DATASET_DESCRIPTION.
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## Overview
25 |
26 | This model is a Swin Transformer [1] pre-trained on the $DATASET_DESCRIPTION. You can find the complete
27 | collection of Swin models on TF-Hub on [this page](https://tfhub.dev/sayakpaul/collections/swin/1).
28 |
29 | You can use this model for feature extraction and fine-tuning. Please refer to
30 | the Colab Notebook linked on this page for more details.
31 |
32 | ## Notes
33 |
34 | * The original model weights are provided from [2]. There were ported to Keras models
35 | (`tf.keras.Model`) and then serialized as TensorFlow SavedModels. The porting
36 | steps are available in [3].
37 | * If the model handle contains `s3` then please refer to [4] for more details on the architecture. It's
38 | original weights are available in [5].
39 | * The model can be unrolled into a standard Keras model and you can inspect its topology.
40 | To do so, first download the model from TF-Hub and then load it using `tf.keras.models.load_model`
41 | providing the path to the downloaded model folder.
42 |
43 | ## References
44 |
45 | [1] [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.](https://arxiv.org/abs/2103.14030)
46 |
47 | [2] [Swin Transformers GitHub](https://github.com/microsoft/Swin-Transformer)
48 |
49 | [3] [Swin-TF GitHub](https://github.com/sayakpaul/swin-transformers-tf)
50 |
51 | [4] [Searching the Search Space of Vision Transformer by Chen et al.](https://arxiv.org/abs/2111.14725)
52 |
53 | [5] [AutoFormerV2 GitHub](https://github.com/silent-chen/AutoFormerV2-model-zoo)
54 |
55 | ## Acknowledgements
56 |
57 | * [Willi](https://ch.linkedin.com/in/willi-gierke)
58 | * [ML-GDE program](https://developers.google.com/programs/experts/)
59 |
60 | """
61 | )
62 |
63 |
64 | @attr.s
65 | class Config:
66 | size = attr.ib(type=str)
67 | patch_size = attr.ib(type=int)
68 | window_size = attr.ib(type=int)
69 | single_resolution = attr.ib(type=int)
70 | dataset = attr.ib(type=str)
71 | type = attr.ib(type=str, default="swin")
72 |
73 | def two_d_resolution(self):
74 | return f"{self.single_resolution}x{self.single_resolution}"
75 |
76 | def gcs_folder_name(self):
77 | if self.dataset == "in22k":
78 | return f"swin_{self.size}_patch{self.patch_size}_window{self.window_size}_{self.single_resolution}_{self.dataset}_fe"
79 | elif self.type == "autoformer":
80 | return f"swin_s3_{self.size}_{self.single_resolution}_fe"
81 | else:
82 | return f"swin_{self.size}_patch{self.patch_size}_window{self.window_size}_{self.single_resolution}_fe"
83 |
84 | def handle(self):
85 | return f"sayakpaul/{self.gcs_folder_name()}/1"
86 |
87 | def rel_doc_file_path(self):
88 | """Relative to the tfhub.dev directory."""
89 | return f"assets/docs/{self.handle()}.md"
90 |
91 |
92 | # swin_base_patch4_window12_384, swin_base_patch4_window12_384_in22k
93 | for c in [
94 | Config("tiny", 4, 7, 224, "in1k"),
95 | Config("small", 4, 7, 224, "in1k"),
96 | Config("base", 4, 7, 224, "in1k"),
97 | Config("base", 4, 12, 384, "in1k"),
98 | Config("large", 4, 7, 224, "in1k"),
99 | Config("large", 4, 12, 384, "in1k"),
100 | Config("base", 4, 7, 224, "in22k"),
101 | Config("base", 4, 12, 384, "in22k"),
102 | Config("large", 4, 7, 224, "in22k"),
103 | Config("large", 4, 12, 384, "in22k"),
104 | Config("tiny", 0, 0, 224, "in1k", "autoformer"),
105 | Config("small", 0, 0, 224, "in1k", "autoformer"),
106 | Config("base", 0, 0, 224, "in1k", "autoformer"),
107 | ]:
108 | if c.dataset == "in1k" and not ("large" in c.size or "base" in c.size):
109 | dataset_text = "ImageNet-1k dataset"
110 | elif c.dataset == "in22k":
111 | dataset_text = "ImageNet-22k dataset"
112 | elif c.dataset == "in1k" and ("large" in c.size or "base" in c.size):
113 | dataset_text = (
114 | "ImageNet-22k"
115 | " dataset and"
116 | " was then "
117 | "fine-tuned "
118 | "on the "
119 | "ImageNet-1k "
120 | "dataset"
121 | )
122 |
123 | save_path = os.path.join(
124 | "/Users/sayakpaul/Downloads/", "tfhub.dev", c.rel_doc_file_path()
125 | )
126 | model_folder = save_path.split("/")[-2]
127 | model_abs_path = "/".join(save_path.split("/")[:-1])
128 |
129 | if not os.path.exists(model_abs_path):
130 | os.makedirs(model_abs_path, exist_ok=True)
131 |
132 | with open(save_path, "w") as f:
133 | f.write(
134 | template.substitute(
135 | HANDLE=c.handle(),
136 | DATASET_DESCRIPTION=dataset_text,
137 | INPUT_RESOLUTION=c.two_d_resolution(),
138 | ARCHIVE_NAME=c.gcs_folder_name(),
139 | )
140 | )
141 |
--------------------------------------------------------------------------------
/swin-transformers-tf/in1k-eval/README.md:
--------------------------------------------------------------------------------
1 | This directory provides a notebook to run
2 | evaluation on the ImageNet-1k `val` split using the TF/Keras converted Swin
3 | models. The notebook assumes the following files are present in your working
4 | directory and the dependencies specified in `../requirements.txt` are installed:
5 |
6 | * The `val` split directory of ImageNet-1k.
--------------------------------------------------------------------------------
/swin-transformers-tf/in1k-eval/df.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import pandas as pd"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "text/html": [
20 | "\n",
21 | "\n",
34 | "
\n",
35 | " \n",
36 | " \n",
37 | " | \n",
38 | " model_name | \n",
39 | " top1_acc(%) | \n",
40 | " top5_acc(%) | \n",
41 | "
\n",
42 | " \n",
43 | " \n",
44 | " \n",
45 | " 0 | \n",
46 | " swin_base_patch4_window7_224 | \n",
47 | " 85.134 | \n",
48 | " 97.480 | \n",
49 | "
\n",
50 | " \n",
51 | " 1 | \n",
52 | " swin_large_patch4_window7_224 | \n",
53 | " 86.252 | \n",
54 | " 97.878 | \n",
55 | "
\n",
56 | " \n",
57 | " 2 | \n",
58 | " swin_s3_base_224 | \n",
59 | " 83.958 | \n",
60 | " 96.532 | \n",
61 | "
\n",
62 | " \n",
63 | " 3 | \n",
64 | " swin_s3_small_224 | \n",
65 | " 83.648 | \n",
66 | " 96.358 | \n",
67 | "
\n",
68 | " \n",
69 | " 4 | \n",
70 | " swin_s3_tiny_224 | \n",
71 | " 82.034 | \n",
72 | " 95.864 | \n",
73 | "
\n",
74 | " \n",
75 | "
\n",
76 | "
"
77 | ],
78 | "text/plain": [
79 | " model_name top1_acc(%) top5_acc(%)\n",
80 | "0 swin_base_patch4_window7_224 85.134 97.480\n",
81 | "1 swin_large_patch4_window7_224 86.252 97.878\n",
82 | "2 swin_s3_base_224 83.958 96.532\n",
83 | "3 swin_s3_small_224 83.648 96.358\n",
84 | "4 swin_s3_tiny_224 82.034 95.864"
85 | ]
86 | },
87 | "execution_count": 2,
88 | "metadata": {},
89 | "output_type": "execute_result"
90 | }
91 | ],
92 | "source": [
93 | "df_224 = pd.read_csv(\"swin_224_in1k.csv\")\n",
94 | "df_334 = pd.read_csv(\"swin_384_in1k.csv\")\n",
95 | "\n",
96 | "df = pd.concat([df_224, df_334])\n",
97 | "df[\"model_name\"] = df[\"model_name\"].apply(lambda x: x.strip(\"/\"))\n",
98 | "df.head()"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": 3,
104 | "metadata": {},
105 | "outputs": [
106 | {
107 | "data": {
108 | "text/html": [
109 | "\n",
110 | "\n",
123 | "
\n",
124 | " \n",
125 | " \n",
126 | " | \n",
127 | " model_name | \n",
128 | " top1_acc(%) | \n",
129 | " top5_acc(%) | \n",
130 | "
\n",
131 | " \n",
132 | " \n",
133 | " \n",
134 | " 0 | \n",
135 | " swin_base_patch4_window7_224 | \n",
136 | " 85.134 | \n",
137 | " 97.480 | \n",
138 | "
\n",
139 | " \n",
140 | " 1 | \n",
141 | " swin_large_patch4_window7_224 | \n",
142 | " 86.252 | \n",
143 | " 97.878 | \n",
144 | "
\n",
145 | " \n",
146 | " 2 | \n",
147 | " swin_s3_base_224 | \n",
148 | " 83.958 | \n",
149 | " 96.532 | \n",
150 | "
\n",
151 | " \n",
152 | " 3 | \n",
153 | " swin_s3_small_224 | \n",
154 | " 83.648 | \n",
155 | " 96.358 | \n",
156 | "
\n",
157 | " \n",
158 | " 4 | \n",
159 | " swin_s3_tiny_224 | \n",
160 | " 82.034 | \n",
161 | " 95.864 | \n",
162 | "
\n",
163 | " \n",
164 | " 5 | \n",
165 | " swin_small_patch4_window7_224 | \n",
166 | " 83.178 | \n",
167 | " 96.240 | \n",
168 | "
\n",
169 | " \n",
170 | " 6 | \n",
171 | " swin_tiny_patch4_window7_224 | \n",
172 | " 81.184 | \n",
173 | " 95.512 | \n",
174 | "
\n",
175 | " \n",
176 | " 0 | \n",
177 | " swin_base_patch4_window12_384 | \n",
178 | " 86.428 | \n",
179 | " 98.042 | \n",
180 | "
\n",
181 | " \n",
182 | " 1 | \n",
183 | " swin_large_patch4_window12_384 | \n",
184 | " 87.272 | \n",
185 | " 98.242 | \n",
186 | "
\n",
187 | " \n",
188 | "
\n",
189 | "
"
190 | ],
191 | "text/plain": [
192 | " model_name top1_acc(%) top5_acc(%)\n",
193 | "0 swin_base_patch4_window7_224 85.134 97.480\n",
194 | "1 swin_large_patch4_window7_224 86.252 97.878\n",
195 | "2 swin_s3_base_224 83.958 96.532\n",
196 | "3 swin_s3_small_224 83.648 96.358\n",
197 | "4 swin_s3_tiny_224 82.034 95.864\n",
198 | "5 swin_small_patch4_window7_224 83.178 96.240\n",
199 | "6 swin_tiny_patch4_window7_224 81.184 95.512\n",
200 | "0 swin_base_patch4_window12_384 86.428 98.042\n",
201 | "1 swin_large_patch4_window12_384 87.272 98.242"
202 | ]
203 | },
204 | "execution_count": 3,
205 | "metadata": {},
206 | "output_type": "execute_result"
207 | }
208 | ],
209 | "source": [
210 | "df"
211 | ]
212 | },
213 | {
214 | "cell_type": "code",
215 | "execution_count": 4,
216 | "metadata": {},
217 | "outputs": [
218 | {
219 | "data": {
220 | "text/html": [
221 | "\n",
222 | "\n",
235 | "
\n",
236 | " \n",
237 | " \n",
238 | " | \n",
239 | " model_name | \n",
240 | " top1_acc(%) | \n",
241 | " top5_acc(%) | \n",
242 | " orig_top1_acc(%) | \n",
243 | "
\n",
244 | " \n",
245 | " \n",
246 | " \n",
247 | " 0 | \n",
248 | " swin_base_patch4_window7_224 | \n",
249 | " 85.134 | \n",
250 | " 97.480 | \n",
251 | " 85.2 | \n",
252 | "
\n",
253 | " \n",
254 | " 1 | \n",
255 | " swin_large_patch4_window7_224 | \n",
256 | " 86.252 | \n",
257 | " 97.878 | \n",
258 | " 86.3 | \n",
259 | "
\n",
260 | " \n",
261 | " 2 | \n",
262 | " swin_s3_base_224 | \n",
263 | " 83.958 | \n",
264 | " 96.532 | \n",
265 | " 84.0 | \n",
266 | "
\n",
267 | " \n",
268 | " 3 | \n",
269 | " swin_s3_small_224 | \n",
270 | " 83.648 | \n",
271 | " 96.358 | \n",
272 | " 83.7 | \n",
273 | "
\n",
274 | " \n",
275 | " 4 | \n",
276 | " swin_s3_tiny_224 | \n",
277 | " 82.034 | \n",
278 | " 95.864 | \n",
279 | " 82.1 | \n",
280 | "
\n",
281 | " \n",
282 | " 5 | \n",
283 | " swin_small_patch4_window7_224 | \n",
284 | " 83.178 | \n",
285 | " 96.240 | \n",
286 | " 83.2 | \n",
287 | "
\n",
288 | " \n",
289 | " 6 | \n",
290 | " swin_tiny_patch4_window7_224 | \n",
291 | " 81.184 | \n",
292 | " 95.512 | \n",
293 | " 81.2 | \n",
294 | "
\n",
295 | " \n",
296 | " 7 | \n",
297 | " swin_base_patch4_window12_384 | \n",
298 | " 86.428 | \n",
299 | " 98.042 | \n",
300 | " 86.4 | \n",
301 | "
\n",
302 | " \n",
303 | " 8 | \n",
304 | " swin_large_patch4_window12_384 | \n",
305 | " 87.272 | \n",
306 | " 98.242 | \n",
307 | " 87.3 | \n",
308 | "
\n",
309 | " \n",
310 | "
\n",
311 | "
"
312 | ],
313 | "text/plain": [
314 | " model_name top1_acc(%) top5_acc(%) orig_top1_acc(%)\n",
315 | "0 swin_base_patch4_window7_224 85.134 97.480 85.2\n",
316 | "1 swin_large_patch4_window7_224 86.252 97.878 86.3\n",
317 | "2 swin_s3_base_224 83.958 96.532 84.0\n",
318 | "3 swin_s3_small_224 83.648 96.358 83.7\n",
319 | "4 swin_s3_tiny_224 82.034 95.864 82.1\n",
320 | "5 swin_small_patch4_window7_224 83.178 96.240 83.2\n",
321 | "6 swin_tiny_patch4_window7_224 81.184 95.512 81.2\n",
322 | "7 swin_base_patch4_window12_384 86.428 98.042 86.4\n",
323 | "8 swin_large_patch4_window12_384 87.272 98.242 87.3"
324 | ]
325 | },
326 | "execution_count": 4,
327 | "metadata": {},
328 | "output_type": "execute_result"
329 | }
330 | ],
331 | "source": [
332 | "# All models except for s3s: https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md\n",
333 | "# s3s: https://github.com/microsoft/Cream/tree/main/AutoFormerV2#model-zoo\n",
334 | "\n",
335 | "orig_1 = [\"85.2\", \"86.3\", \"84.0\", \"83.7\", \"82.1\", \"83.2\", \"81.2\", \"86.4\", \"87.3\"]\n",
336 | "df[\"orig_top1_acc(%)\"] = orig_1\n",
337 | "df = df.reset_index(drop=True)\n",
338 | "df"
339 | ]
340 | },
341 | {
342 | "cell_type": "code",
343 | "execution_count": 5,
344 | "metadata": {},
345 | "outputs": [
346 | {
347 | "name": "stdout",
348 | "output_type": "stream",
349 | "text": [
350 | "| | model_name | top1_acc(%) | top5_acc(%) | orig_top1_acc(%) |\n",
351 | "|---:|:-------------------------------|--------------:|--------------:|-------------------:|\n",
352 | "| 0 | swin_base_patch4_window7_224 | 85.134 | 97.48 | 85.2 |\n",
353 | "| 1 | swin_large_patch4_window7_224 | 86.252 | 97.878 | 86.3 |\n",
354 | "| 2 | swin_s3_base_224 | 83.958 | 96.532 | 84 |\n",
355 | "| 3 | swin_s3_small_224 | 83.648 | 96.358 | 83.7 |\n",
356 | "| 4 | swin_s3_tiny_224 | 82.034 | 95.864 | 82.1 |\n",
357 | "| 5 | swin_small_patch4_window7_224 | 83.178 | 96.24 | 83.2 |\n",
358 | "| 6 | swin_tiny_patch4_window7_224 | 81.184 | 95.512 | 81.2 |\n",
359 | "| 7 | swin_base_patch4_window12_384 | 86.428 | 98.042 | 86.4 |\n",
360 | "| 8 | swin_large_patch4_window12_384 | 87.272 | 98.242 | 87.3 |\n"
361 | ]
362 | }
363 | ],
364 | "source": [
365 | "print(df.to_markdown())"
366 | ]
367 | }
368 | ],
369 | "metadata": {
370 | "interpreter": {
371 | "hash": "2eeebd8186f946aafebcc49bceba063d6e659c945a9dcc5253ac24fa5b4e04cc"
372 | },
373 | "kernelspec": {
374 | "display_name": "Python 3.8.2 ('pytorch')",
375 | "language": "python",
376 | "name": "python3"
377 | },
378 | "language_info": {
379 | "codemirror_mode": {
380 | "name": "ipython",
381 | "version": 3
382 | },
383 | "file_extension": ".py",
384 | "mimetype": "text/x-python",
385 | "name": "python",
386 | "nbconvert_exporter": "python",
387 | "pygments_lexer": "ipython3",
388 | "version": "3.8.2"
389 | },
390 | "orig_nbformat": 4
391 | },
392 | "nbformat": 4,
393 | "nbformat_minor": 2
394 | }
395 |
--------------------------------------------------------------------------------
/swin-transformers-tf/in1k-eval/eval-swins.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "926bae46",
6 | "metadata": {},
7 | "source": [
8 | "## Imports"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "e038ddf1",
14 | "metadata": {},
15 | "source": [
16 | "Suppress TensorFlow warnings."
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "9ec05d94-4361-49d2-b680-ce41d0376299",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "# Copied from:\n",
27 | "# https://weepingfish.github.io/2020/07/22/0722-suppress-tensorflow-warnings/\n",
28 | "\n",
29 | "# Filter tensorflow version warnings\n",
30 | "import os\n",
31 | "\n",
32 | "# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709\n",
33 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # or any {'0', '1', '2'}\n",
34 | "import warnings\n",
35 | "\n",
36 | "# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning\n",
37 | "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
38 | "warnings.simplefilter(action=\"ignore\", category=Warning)\n",
39 | "import tensorflow as tf\n",
40 | "\n",
41 | "tf.get_logger().setLevel(\"INFO\")\n",
42 | "tf.autograph.set_verbosity(0)\n",
43 | "import logging\n",
44 | "\n",
45 | "tf.get_logger().setLevel(logging.ERROR)"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "id": "bae636ae-24d1-4523-9997-696731318a81",
52 | "metadata": {},
53 | "outputs": [],
54 | "source": [
55 | "from tensorflow import keras\n",
56 | "\n",
57 | "from torchvision.datasets import ImageFolder\n",
58 | "from torchvision import transforms\n",
59 | "from torch.utils.data import DataLoader\n",
60 | "\n",
61 | "from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD"
62 | ]
63 | },
64 | {
65 | "cell_type": "markdown",
66 | "id": "96c4f0a2",
67 | "metadata": {},
68 | "source": [
69 | "## Constants"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": null,
75 | "id": "f8238055-08bf-44e1-8f3b-98e7768f1603",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "# Change batch size accordingly in case of OOM.\n",
80 | "# Change the image size to 384 wheb evaluation's done on 224.\n",
81 | "BATCH_SIZE = 256\n",
82 | "IMAGE_SIZE = 224\n",
83 | "TF_MODEL_ROOT = \"gs://swin-tf\""
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "id": "ea42076a",
89 | "metadata": {},
90 | "source": [
91 | "## Swin models "
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": null,
97 | "id": "13a7e46e-31b2-48b9-9a57-2873fe27397a",
98 | "metadata": {},
99 | "outputs": [],
100 | "source": [
101 | "model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)\n",
102 | "model_paths = [p for p in model_paths if str(IMAGE_SIZE) in p and \"fe\" not in p and \"22k\" not in p]\n",
103 | "print(model_paths)"
104 | ]
105 | },
106 | {
107 | "cell_type": "markdown",
108 | "id": "60c39720",
109 | "metadata": {},
110 | "source": [
111 | "## Image loader"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "b066a571",
117 | "metadata": {},
118 | "source": [
119 | "To have an apples-to-apples comparison with the original PyTorch models for evaluation, it's important to ensure we use the same transformations."
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": null,
125 | "id": "5cf8eb2e-de82-48af-9292-c4917c237fa8",
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "# Transformations from:\n",
130 | "# (1) https://github.com/microsoft/Swin-Transformer\n",
131 | "# (2) https://github.com/microsoft/Swin-Transformer/tree/main/data\n",
132 | "\n",
133 | "if IMAGE_SIZE == 224:\n",
134 | " size = int((256 / 224) * IMAGE_SIZE)\n",
135 | " transform_chain = transforms.Compose(\n",
136 | " [\n",
137 | " transforms.Resize(size, interpolation=3),\n",
138 | " transforms.CenterCrop(IMAGE_SIZE),\n",
139 | " transforms.ToTensor(),\n",
140 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n",
141 | " ]\n",
142 | " )\n",
143 | "else:\n",
144 | " transform_chain = transforms.Compose(\n",
145 | " [\n",
146 | " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=3),\n",
147 | " transforms.ToTensor(),\n",
148 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n",
149 | " ]\n",
150 | " )"
151 | ]
152 | },
153 | {
154 | "cell_type": "code",
155 | "execution_count": null,
156 | "id": "51ae139c-786b-47b3-9840-655c624f86b7",
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "dataset = ImageFolder(\"val\", transform=transform_chain)\n",
161 | "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=6)\n",
162 | "\n",
163 | "batch = next(iter(dataloader))\n",
164 | "print(batch[0].shape)"
165 | ]
166 | },
167 | {
168 | "cell_type": "markdown",
169 | "id": "2b7e8e68",
170 | "metadata": {},
171 | "source": [
172 | "## Run evaluation"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": null,
178 | "id": "63a3da22-a60f-48b8-a0b0-02e54b2d012f",
179 | "metadata": {},
180 | "outputs": [],
181 | "source": [
182 | "def get_model(model_url):\n",
183 | " model = keras.models.load_model(model_url)\n",
184 | " return model"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "id": "cbf8eabf-df6d-4ee5-900c-48b8761329ed",
191 | "metadata": {},
192 | "outputs": [],
193 | "source": [
194 | "# Copied and modified from:\n",
195 | "# https://github.com/sebastian-sz/resnet-rs-keras/blob/main/imagenet_evaluation/main.py\n",
196 | "\n",
197 | "log_file = f\"swin_{IMAGE_SIZE}_in1k.csv\"\n",
198 | "\n",
199 | "if not os.path.exists(log_file):\n",
200 | " with open(log_file, 'w') as f:\n",
201 | " f.write(\n",
202 | " 'model_name,top1_acc(%),top5_acc(%)\\n'\n",
203 | " )\n",
204 | "\n",
205 | "for path in model_paths:\n",
206 | " print(f\"Evaluating {path}.\")\n",
207 | " model = get_model(f\"{TF_MODEL_ROOT}/{path.strip('/')}\")\n",
208 | "\n",
209 | " top1 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name=\"top1\")\n",
210 | " top5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name=\"top5\")\n",
211 | " progbar = tf.keras.utils.Progbar(target=len(dataset) // BATCH_SIZE)\n",
212 | "\n",
213 | " for idx, (images, y_true) in enumerate(dataloader):\n",
214 | " images = images.numpy().transpose(0, 2, 3, 1)\n",
215 | " y_true = y_true.numpy()\n",
216 | " y_pred = model.predict(images)\n",
217 | "\n",
218 | " top1.update_state(y_true=y_true, y_pred=y_pred)\n",
219 | " top5.update_state(y_true=y_true, y_pred=y_pred)\n",
220 | "\n",
221 | " progbar.update(\n",
222 | " idx, [(\"top1\", top1.result().numpy()), (\"top5\", top5.result().numpy())]\n",
223 | " )\n",
224 | "\n",
225 | " print()\n",
226 | " print(f\"TOP1: {top1.result().numpy()}. TOP5: {top5.result().numpy()}\")\n",
227 | " \n",
228 | " top_1 = top1.result().numpy() * 100.\n",
229 | " top_5 = top5.result().numpy() * 100.\n",
230 | " with open(log_file, 'a') as f:\n",
231 | " f.write(\"%s,%0.3f,%0.3f\\n\" % (path, top_1, top_5))"
232 | ]
233 | },
234 | {
235 | "cell_type": "code",
236 | "execution_count": null,
237 | "id": "c845cdcb-1036-43e2-93df-4d47cce020ec",
238 | "metadata": {},
239 | "outputs": [],
240 | "source": [
241 | "!sudo shutdown now"
242 | ]
243 | }
244 | ],
245 | "metadata": {
246 | "environment": {
247 | "kernel": "python3",
248 | "name": "tf2-gpu.2-7.m87",
249 | "type": "gcloud",
250 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87"
251 | },
252 | "kernelspec": {
253 | "display_name": "Python 3",
254 | "language": "python",
255 | "name": "python3"
256 | },
257 | "language_info": {
258 | "codemirror_mode": {
259 | "name": "ipython",
260 | "version": 3
261 | },
262 | "file_extension": ".py",
263 | "mimetype": "text/x-python",
264 | "name": "python",
265 | "nbconvert_exporter": "python",
266 | "pygments_lexer": "ipython3",
267 | "version": "3.7.12"
268 | }
269 | },
270 | "nbformat": 4,
271 | "nbformat_minor": 5
272 | }
273 |
--------------------------------------------------------------------------------
/swin-transformers-tf/in1k-eval/swin_224_in1k.csv:
--------------------------------------------------------------------------------
1 | model_name,top1_acc(%),top5_acc(%)
2 | swin_base_patch4_window7_224/,85.134,97.480
3 | swin_large_patch4_window7_224/,86.252,97.878
4 | swin_s3_base_224/,83.958,96.532
5 | swin_s3_small_224/,83.648,96.358
6 | swin_s3_tiny_224/,82.034,95.864
7 | swin_small_patch4_window7_224/,83.178,96.240
8 | swin_tiny_patch4_window7_224/,81.184,95.512
9 |
--------------------------------------------------------------------------------
/swin-transformers-tf/in1k-eval/swin_384_in1k.csv:
--------------------------------------------------------------------------------
1 | model_name,top1_acc(%),top5_acc(%)
2 | swin_base_patch4_window12_384/,86.428,98.042
3 | swin_large_patch4_window12_384/,87.272,98.242
4 |
--------------------------------------------------------------------------------
/swin-transformers-tf/notebooks/classification.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "zpiEiO2BUeO5",
6 | "metadata": {
7 | "id": "zpiEiO2BUeO5"
8 | },
9 | "source": [
10 | "# Off-the-shelf image classification with Swin Transformers on TF-Hub\n",
11 | "\n",
12 | ""
23 | ]
24 | },
25 | {
26 | "cell_type": "markdown",
27 | "id": "661e6538",
28 | "metadata": {
29 | "id": "661e6538"
30 | },
31 | "source": [
32 | "## Setup"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": null,
38 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5",
39 | "metadata": {
40 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5"
41 | },
42 | "outputs": [],
43 | "source": [
44 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9",
51 | "metadata": {
52 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9"
53 | },
54 | "outputs": [],
55 | "source": [
56 | "import tensorflow as tf\n",
57 | "import tensorflow_hub as hub\n",
58 | "from tensorflow import keras\n",
59 | "\n",
60 | "\n",
61 | "from PIL import Image\n",
62 | "from io import BytesIO\n",
63 | "\n",
64 | "import matplotlib.pyplot as plt\n",
65 | "import numpy as np\n",
66 | "import requests\n",
67 | "import cv2"
68 | ]
69 | },
70 | {
71 | "cell_type": "markdown",
72 | "id": "z5l1cRpiSavW",
73 | "metadata": {
74 | "id": "z5l1cRpiSavW"
75 | },
76 | "source": [
77 | "## Select a [Swin](https://arxiv.org/abs/2103.14030) ImageNet-1k model\n",
78 | "\n",
79 | "Find the entire collection [here](https://tfhub.dev/sayakpaul/collections/swin/1). For inferring with the ImageNet-22k models, please refer [here](https://tfhub.dev/google/bit/m-r50x1/imagenet21k_classification/1#usage)."
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "id": "a0wM8idaSaOq",
86 | "metadata": {
87 | "id": "a0wM8idaSaOq"
88 | },
89 | "outputs": [],
90 | "source": [
91 | "model_name = \"swin_tiny_patch4_window7_224\" #@param [\"swin_tiny_patch4_window7_224\", \"swin_small_patch4_window7_224\", \"swin_base_patch4_window7_224\", \"swin_base_patch4_window12_384\", \"swin_large_patch4_window7_224\", \"swin_large_patch4_window7_384\", \"swin_s3_tiny_224\", \"swin_s3_small_224\", \"swin_s3_base_224\"]\n",
92 | "\n",
93 | "model_handle_map ={\n",
94 | " \"swin_tiny_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224/1\",\n",
95 | " \"swin_small_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_small_patch4_window7_224/1\",\n",
96 | " \"swin_base_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_base_patch4_window7_224/1\",\n",
97 | " \"swin_base_patch4_window12_384\": \"https://tfhub.dev/sayakpaul/swin_base_patch4_window12_384/1\",\n",
98 | " \"swin_large_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_large_patch4_window7_224/1\",\n",
99 | " \"swin_large_patch4_window7_384\": \"https://tfhub.dev/sayakpaul/swin_large_patch4_window7_384/1\",\n",
100 | " \"swin_s3_tiny_224\": \"https://tfhub.dev/sayakpaul/swin_s3_tiny_224/1\",\n",
101 | " \"swin_s3_small_224\": \"https://tfhub.dev/sayakpaul/swin_s3_small_224/1\",\n",
102 | " \"swin_s3_base_224\": \"https://tfhub.dev/sayakpaul/swin_s3_base_224/1\",\n",
103 | "}\n",
104 | "\n",
105 | "input_resolution = int(model_name.split(\"_\")[-1])\n",
106 | "model_handle = model_handle_map[model_name]\n",
107 | "print(f\"Input resolution: {input_resolution} x {input_resolution} x 3.\")\n",
108 | "print(f\"TF-Hub handle: {model_handle}.\")"
109 | ]
110 | },
111 | {
112 | "cell_type": "markdown",
113 | "id": "441b5361",
114 | "metadata": {
115 | "id": "441b5361"
116 | },
117 | "source": [
118 | "## Image preprocessing utilities "
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": null,
124 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b",
125 | "metadata": {
126 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b"
127 | },
128 | "outputs": [],
129 | "source": [
130 | "crop_layer = keras.layers.CenterCrop(input_resolution, input_resolution)\n",
131 | "norm_layer = keras.layers.Normalization(\n",
132 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n",
133 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n",
134 | ")\n",
135 | "\n",
136 | "\n",
137 | "def preprocess_image(image, size=input_resolution):\n",
138 | " image = np.array(image)\n",
139 | " image_resized = tf.expand_dims(image, 0)\n",
140 | " \n",
141 | " if size == 224:\n",
142 | " resize_size = int((256 / 224) * size)\n",
143 | " image_resized = tf.image.resize(image_resized, (resize_size, resize_size), method=\"bicubic\")\n",
144 | " image_resized = crop_layer(image_resized)\n",
145 | " else:\n",
146 | " image_resized = tf.image.resize(image_resized, (size, size), method=\"bicubic\")\n",
147 | " \n",
148 | " return norm_layer(image_resized).numpy()\n",
149 | " \n",
150 | "\n",
151 | "def load_image_from_url(url):\n",
152 | " # Credit: Willi Gierke\n",
153 | " response = requests.get(url)\n",
154 | " image = Image.open(BytesIO(response.content))\n",
155 | " preprocessed_image = preprocess_image(image)\n",
156 | " return image, preprocessed_image"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "id": "8b961e14",
162 | "metadata": {
163 | "id": "8b961e14"
164 | },
165 | "source": [
166 | "## Load ImageNet-1k labels and a demo image"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": null,
172 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b",
173 | "metadata": {
174 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b"
175 | },
176 | "outputs": [],
177 | "source": [
178 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
179 | " lines = f.readlines()\n",
180 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
181 | "\n",
182 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
183 | "image, preprocessed_image = load_image_from_url(img_url)\n",
184 | "\n",
185 | "plt.imshow(image)\n",
186 | "plt.show()"
187 | ]
188 | },
189 | {
190 | "cell_type": "markdown",
191 | "id": "9006a643",
192 | "metadata": {
193 | "id": "9006a643"
194 | },
195 | "source": [
196 | "## Run inference"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": null,
202 | "id": "bHnCyJtAf9el",
203 | "metadata": {
204 | "id": "bHnCyJtAf9el"
205 | },
206 | "outputs": [],
207 | "source": [
208 | "def get_model(model_url: str) -> tf.keras.Model:\n",
209 | " inputs = tf.keras.Input((input_resolution, input_resolution, 3))\n",
210 | " hub_module = hub.KerasLayer(model_url)\n",
211 | "\n",
212 | " outputs = hub_module(inputs)\n",
213 | "\n",
214 | " return tf.keras.Model(inputs, outputs)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": null,
220 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908",
221 | "metadata": {
222 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908"
223 | },
224 | "outputs": [],
225 | "source": [
226 | "classification_model = get_model(model_handle)\n",
227 | "predictions = classification_model.predict(preprocessed_image)\n",
228 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n",
229 | "print(predicted_label)"
230 | ]
231 | },
232 | {
233 | "cell_type": "markdown",
234 | "source": [
235 | "## Obtain attention scores"
236 | ],
237 | "metadata": {
238 | "id": "wPisHE9lMmaN"
239 | },
240 | "id": "wPisHE9lMmaN"
241 | },
242 | {
243 | "cell_type": "code",
244 | "source": [
245 | "swin_tiny_patch4_window7_224_tf = tf.keras.models.load_model(\n",
246 | " f\"gs://tfhub-modules/sayakpaul/{model_name}/1/uncompressed\"\n",
247 | ")\n",
248 | "all_attn_scores = swin_tiny_patch4_window7_224_tf.get_attention_scores(\n",
249 | " preprocessed_image\n",
250 | ")\n",
251 | "print(all_attn_scores.keys())"
252 | ],
253 | "metadata": {
254 | "id": "cRO5v-yPMoEO"
255 | },
256 | "id": "cRO5v-yPMoEO",
257 | "execution_count": null,
258 | "outputs": []
259 | },
260 | {
261 | "cell_type": "code",
262 | "source": [
263 | "# attn score dimensions:\n",
264 | "# (batch_size, nb_attention_heads, seq_length, seq_length)\n",
265 | "print(all_attn_scores[\"swin_stage_3\"].keys()), print(all_attn_scores[\"swin_stage_3\"][\"swin_block_0\"].shape)"
266 | ],
267 | "metadata": {
268 | "id": "GwRtcNR2NLrC"
269 | },
270 | "id": "GwRtcNR2NLrC",
271 | "execution_count": null,
272 | "outputs": []
273 | }
274 | ],
275 | "metadata": {
276 | "accelerator": "GPU",
277 | "colab": {
278 | "machine_shape": "hm",
279 | "name": "classification.ipynb",
280 | "provenance": []
281 | },
282 | "environment": {
283 | "kernel": "python3",
284 | "name": "tf2-gpu.2-7.m87",
285 | "type": "gcloud",
286 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87"
287 | },
288 | "kernelspec": {
289 | "display_name": "Python 3",
290 | "language": "python",
291 | "name": "python3"
292 | },
293 | "language_info": {
294 | "codemirror_mode": {
295 | "name": "ipython",
296 | "version": 3
297 | },
298 | "file_extension": ".py",
299 | "mimetype": "text/x-python",
300 | "name": "python",
301 | "nbconvert_exporter": "python",
302 | "pygments_lexer": "ipython3",
303 | "version": "3.7.12"
304 | }
305 | },
306 | "nbformat": 4,
307 | "nbformat_minor": 5
308 | }
--------------------------------------------------------------------------------
/swin-transformers-tf/notebooks/finetune.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github",
7 | "colab_type": "text"
8 | },
9 | "source": [
10 | "
"
11 | ]
12 | },
13 | {
14 | "cell_type": "markdown",
15 | "metadata": {
16 | "id": "89B27-TGiDNB"
17 | },
18 | "source": [
19 | "## Imports"
20 | ]
21 | },
22 | {
23 | "cell_type": "code",
24 | "execution_count": null,
25 | "metadata": {
26 | "id": "9u3d4Z7uQsmp"
27 | },
28 | "outputs": [],
29 | "source": [
30 | "from tensorflow import keras\n",
31 | "import tensorflow as tf\n",
32 | "import tensorflow_hub as hub\n",
33 | "import tensorflow_datasets as tfds\n",
34 | "\n",
35 | "tfds.disable_progress_bar()\n",
36 | "\n",
37 | "import os\n",
38 | "import sys\n",
39 | "import math\n",
40 | "import numpy as np\n",
41 | "import pandas as pd\n",
42 | "import matplotlib.pyplot as plt"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": null,
48 | "metadata": {
49 | "id": "mdUDP5qIiG05"
50 | },
51 | "outputs": [],
52 | "source": [
53 | "# Copied from:\n",
54 | "# https://weepingfish.github.io/2020/07/22/0722-suppress-tensorflow-warnings/\n",
55 | "\n",
56 | "# Filter tensorflow version warnings\n",
57 | "\n",
58 | "# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709\n",
59 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # or any {'0', '1', '2'}\n",
60 | "import warnings\n",
61 | "\n",
62 | "# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning\n",
63 | "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n",
64 | "warnings.simplefilter(action=\"ignore\", category=Warning)\n",
65 | "\n",
66 | "tf.get_logger().setLevel(\"INFO\")\n",
67 | "tf.autograph.set_verbosity(0)\n",
68 | "import logging\n",
69 | "\n",
70 | "tf.get_logger().setLevel(logging.ERROR)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {
76 | "id": "mPo10cahZXXQ"
77 | },
78 | "source": [
79 | "## GPUs"
80 | ]
81 | },
82 | {
83 | "cell_type": "code",
84 | "execution_count": null,
85 | "metadata": {
86 | "id": "FpvUOuC3j27n"
87 | },
88 | "outputs": [],
89 | "source": [
90 | "try: # detect TPUs\n",
91 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection\n",
92 | " strategy = tf.distribute.TPUStrategy(tpu)\n",
93 | "except ValueError: # detect GPUs\n",
94 | " tpu = False\n",
95 | " strategy = (\n",
96 | " tf.distribute.get_strategy()\n",
97 | " ) # default strategy that works on CPU and single GPU\n",
98 | "print(\"Number of Accelerators: \", strategy.num_replicas_in_sync)"
99 | ]
100 | },
101 | {
102 | "cell_type": "markdown",
103 | "metadata": {
104 | "id": "w9S3uKC_iXY5"
105 | },
106 | "source": [
107 | "## Configuration\n",
108 | "\n",
109 | "Find the list of all fine-tunable models [here](https://tfhub.dev/sayakpaul/collections/swin/1)."
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {
116 | "id": "kCc6tdUGnD4C"
117 | },
118 | "outputs": [],
119 | "source": [
120 | "# Model\n",
121 | "IMAGE_SIZE = [224, 224] # Change this accordingly. \n",
122 | "MODEL_PATH = \"https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224_fe\" \n",
123 | "\n",
124 | "# TPU\n",
125 | "if tpu:\n",
126 | " BATCH_SIZE = (\n",
127 | " 16 * strategy.num_replicas_in_sync\n",
128 | " ) # a TPU has 8 cores so this will be 128\n",
129 | "else:\n",
130 | " BATCH_SIZE = 128 # on Colab/GPU, a higher batch size may throw OOM\n",
131 | "\n",
132 | "# Dataset\n",
133 | "CLASSES = [\n",
134 | " \"dandelion\",\n",
135 | " \"daisy\",\n",
136 | " \"tulips\",\n",
137 | " \"sunflowers\",\n",
138 | " \"roses\",\n",
139 | "] # don't change the order\n",
140 | "\n",
141 | "# Other constants\n",
142 | "MEAN = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255]) # imagenet mean\n",
143 | "STD = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255]) # imagenet std\n",
144 | "AUTO = tf.data.AUTOTUNE"
145 | ]
146 | },
147 | {
148 | "cell_type": "markdown",
149 | "metadata": {
150 | "id": "9iTImGI5qMQT"
151 | },
152 | "source": [
153 | "# Data Pipeline\n",
154 | "\n",
155 | "[DeiT authors](https://arxiv.org/abs/2012.12877) use a separate preprocessing pipeline for fine-tuning. But for keeping this walkthrough short and simple, we can just perform the basic ones."
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {
162 | "id": "h29TLx7gqN_7"
163 | },
164 | "outputs": [],
165 | "source": [
166 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):\n",
167 | " def preprocess(image, label):\n",
168 | " # for training, do augmentation\n",
169 | " if train:\n",
170 | " if tf.random.uniform(shape=[]) > 0.5:\n",
171 | " image = tf.image.flip_left_right(image)\n",
172 | " image = tf.image.resize(image, size=image_size, method=\"bicubic\")\n",
173 | " image = (image - MEAN) / STD # normalization\n",
174 | " return image, label\n",
175 | "\n",
176 | " if train:\n",
177 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n",
178 | "\n",
179 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)"
180 | ]
181 | },
182 | {
183 | "cell_type": "markdown",
184 | "metadata": {
185 | "id": "AMQ3Qs9_pddU"
186 | },
187 | "source": [
188 | "# Flower Dataset"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {
195 | "id": "M3G-2aUBQJ-H"
196 | },
197 | "outputs": [],
198 | "source": [
199 | "train_dataset, val_dataset = tfds.load(\n",
200 | " \"tf_flowers\",\n",
201 | " split=[\"train[:90%]\", \"train[90%:]\"],\n",
202 | " as_supervised=True,\n",
203 | " try_gcs=False, # gcs_path is necessary for tpu,\n",
204 | ")\n",
205 | "\n",
206 | "num_train = tf.data.experimental.cardinality(train_dataset)\n",
207 | "num_val = tf.data.experimental.cardinality(val_dataset)\n",
208 | "print(f\"Number of training examples: {num_train}\")\n",
209 | "print(f\"Number of validation examples: {num_val}\")"
210 | ]
211 | },
212 | {
213 | "cell_type": "markdown",
214 | "metadata": {
215 | "id": "l2X7sE3oRLXN"
216 | },
217 | "source": [
218 | "## Prepare dataset"
219 | ]
220 | },
221 | {
222 | "cell_type": "code",
223 | "execution_count": null,
224 | "metadata": {
225 | "id": "oftrfYw1qXei"
226 | },
227 | "outputs": [],
228 | "source": [
229 | "train_dataset = make_dataset(train_dataset, True)\n",
230 | "val_dataset = make_dataset(val_dataset, False)"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "metadata": {
236 | "id": "kNyCCM6PRM8I"
237 | },
238 | "source": [
239 | "## Visualize"
240 | ]
241 | },
242 | {
243 | "cell_type": "code",
244 | "execution_count": null,
245 | "metadata": {
246 | "id": "IaGzFUUVqjaC"
247 | },
248 | "outputs": [],
249 | "source": [
250 | "sample_images, sample_labels = next(iter(train_dataset))\n",
251 | "\n",
252 | "plt.figure(figsize=(5 * 3, 3 * 3))\n",
253 | "for n in range(15):\n",
254 | " ax = plt.subplot(3, 5, n + 1)\n",
255 | " image = (sample_images[n] * STD + MEAN).numpy()\n",
256 | " image = (image - image.min()) / (\n",
257 | " image.max() - image.min()\n",
258 | " ) # convert to [0, 1] for avoiding matplotlib warning\n",
259 | " plt.imshow(image)\n",
260 | " plt.title(CLASSES[sample_labels[n]])\n",
261 | " plt.axis(\"off\")\n",
262 | "plt.tight_layout()\n",
263 | "plt.show()"
264 | ]
265 | },
266 | {
267 | "cell_type": "markdown",
268 | "metadata": {
269 | "id": "Qf6u_7tt8BYy"
270 | },
271 | "source": [
272 | "# LR Scheduler Utility"
273 | ]
274 | },
275 | {
276 | "cell_type": "code",
277 | "execution_count": null,
278 | "metadata": {
279 | "id": "oVTbnkJL79T_"
280 | },
281 | "outputs": [],
282 | "source": [
283 | "# Reference:\n",
284 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n",
285 | "\n",
286 | "\n",
287 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n",
288 | " def __init__(\n",
289 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n",
290 | " ):\n",
291 | " super(WarmUpCosine, self).__init__()\n",
292 | "\n",
293 | " self.learning_rate_base = learning_rate_base\n",
294 | " self.total_steps = total_steps\n",
295 | " self.warmup_learning_rate = warmup_learning_rate\n",
296 | " self.warmup_steps = warmup_steps\n",
297 | " self.pi = tf.constant(np.pi)\n",
298 | "\n",
299 | " def __call__(self, step):\n",
300 | " if self.total_steps < self.warmup_steps:\n",
301 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
302 | " learning_rate = (\n",
303 | " 0.5\n",
304 | " * self.learning_rate_base\n",
305 | " * (\n",
306 | " 1\n",
307 | " + tf.cos(\n",
308 | " self.pi\n",
309 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
310 | " / float(self.total_steps - self.warmup_steps)\n",
311 | " )\n",
312 | " )\n",
313 | " )\n",
314 | "\n",
315 | " if self.warmup_steps > 0:\n",
316 | " if self.learning_rate_base < self.warmup_learning_rate:\n",
317 | " raise ValueError(\n",
318 | " \"Learning_rate_base must be larger or equal to \"\n",
319 | " \"warmup_learning_rate.\"\n",
320 | " )\n",
321 | " slope = (\n",
322 | " self.learning_rate_base - self.warmup_learning_rate\n",
323 | " ) / self.warmup_steps\n",
324 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
325 | " learning_rate = tf.where(\n",
326 | " step < self.warmup_steps, warmup_rate, learning_rate\n",
327 | " )\n",
328 | " return tf.where(\n",
329 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
330 | " )"
331 | ]
332 | },
333 | {
334 | "cell_type": "markdown",
335 | "metadata": {
336 | "id": "ALtRUlxhw8Vt"
337 | },
338 | "source": [
339 | "# Model Utility"
340 | ]
341 | },
342 | {
343 | "cell_type": "code",
344 | "execution_count": null,
345 | "metadata": {
346 | "id": "JD9SI_Q9JdAB"
347 | },
348 | "outputs": [],
349 | "source": [
350 | "def get_model(model_url: str, res: int = IMAGE_SIZE[0], num_classes: int = 5) -> tf.keras.Model:\n",
351 | " inputs = tf.keras.Input((res, res, 3))\n",
352 | " hub_module = hub.KerasLayer(model_url, trainable=True)\n",
353 | "\n",
354 | " x = hub_module(inputs, training=False) \n",
355 | " outputs = keras.layers.Dense(num_classes, activation=\"softmax\")(x)\n",
356 | "\n",
357 | " return tf.keras.Model(inputs, outputs)"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": null,
363 | "metadata": {
364 | "id": "wpZApp9u9_Y-"
365 | },
366 | "outputs": [],
367 | "source": [
368 | "get_model(MODEL_PATH).summary()"
369 | ]
370 | },
371 | {
372 | "cell_type": "markdown",
373 | "metadata": {
374 | "id": "dMfenMQcxAAb"
375 | },
376 | "source": [
377 | "# Training Hyperparameters"
378 | ]
379 | },
380 | {
381 | "cell_type": "code",
382 | "execution_count": null,
383 | "metadata": {
384 | "id": "1D7Iu7oD8WzX"
385 | },
386 | "outputs": [],
387 | "source": [
388 | "EPOCHS = 10\n",
389 | "WARMUP_STEPS = 10\n",
390 | "INIT_LR = 0.03\n",
391 | "WAMRUP_LR = 0.006\n",
392 | "\n",
393 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)\n",
394 | "\n",
395 | "scheduled_lrs = WarmUpCosine(\n",
396 | " learning_rate_base=INIT_LR,\n",
397 | " total_steps=TOTAL_STEPS,\n",
398 | " warmup_learning_rate=WAMRUP_LR,\n",
399 | " warmup_steps=WARMUP_STEPS,\n",
400 | ")"
401 | ]
402 | },
403 | {
404 | "cell_type": "code",
405 | "execution_count": null,
406 | "metadata": {
407 | "id": "M-ID7vP5mIKs"
408 | },
409 | "outputs": [],
410 | "source": [
411 | "optimizer = keras.optimizers.SGD(scheduled_lrs)\n",
412 | "loss = keras.losses.SparseCategoricalCrossentropy()"
413 | ]
414 | },
415 | {
416 | "cell_type": "markdown",
417 | "metadata": {
418 | "id": "E9p4ymNh9y7d"
419 | },
420 | "source": [
421 | "# Training & Validation"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": null,
427 | "metadata": {
428 | "id": "VnZTSd8K90Mq"
429 | },
430 | "outputs": [],
431 | "source": [
432 | "with strategy.scope(): # this line is all that is needed to run on TPU (or multi-GPU, ...)\n",
433 | " model = get_model(MODEL_PATH)\n",
434 | " model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n",
435 | "\n",
436 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": null,
442 | "metadata": {
443 | "id": "jc7LMVz5Cbx6"
444 | },
445 | "outputs": [],
446 | "source": [
447 | "result = pd.DataFrame(history.history)\n",
448 | "fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n",
449 | "result[[\"accuracy\", \"val_accuracy\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[0])\n",
450 | "result[[\"loss\", \"val_loss\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[1])"
451 | ]
452 | },
453 | {
454 | "cell_type": "markdown",
455 | "metadata": {
456 | "id": "MKFMWzh0Yxsq"
457 | },
458 | "source": [
459 | "# Predictions"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": null,
465 | "metadata": {
466 | "id": "yMEsR851VDZb"
467 | },
468 | "outputs": [],
469 | "source": [
470 | "sample_images, sample_labels = next(iter(val_dataset))\n",
471 | "\n",
472 | "predictions = model.predict(sample_images, batch_size=16).argmax(axis=-1)\n",
473 | "evaluations = model.evaluate(sample_images, sample_labels, batch_size=16)\n",
474 | "\n",
475 | "print(\"[val_loss, val_acc]\", evaluations)"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "metadata": {
482 | "id": "qzCCDL1CZFx6"
483 | },
484 | "outputs": [],
485 | "source": [
486 | "plt.figure(figsize=(5 * 3, 3 * 3))\n",
487 | "for n in range(15):\n",
488 | " ax = plt.subplot(3, 5, n + 1)\n",
489 | " image = (sample_images[n] * STD + MEAN).numpy()\n",
490 | " image = (image - image.min()) / (\n",
491 | " image.max() - image.min()\n",
492 | " ) # convert to [0, 1] for avoiding matplotlib warning\n",
493 | " plt.imshow(image)\n",
494 | " target = CLASSES[sample_labels[n]]\n",
495 | " pred = CLASSES[predictions[n]]\n",
496 | " plt.title(\"{} ({})\".format(target, pred))\n",
497 | " plt.axis(\"off\")\n",
498 | "\n",
499 | "plt.tight_layout()\n",
500 | "plt.show()"
501 | ]
502 | },
503 | {
504 | "cell_type": "markdown",
505 | "metadata": {
506 | "id": "2e5oy9zmNNID"
507 | },
508 | "source": [
509 | "# Reference\n",
510 | "* [Keras Flowers on TPU (solution)](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_solution.ipynb)\n",
511 | "\n",
512 | "\n",
513 | "This notebook is copied and modified from [here](https://github.com/sayakpaul/ConvNeXt-TF/blob/main/notebooks/finetune.ipynb). I'm thankful to [awsaf49](https://github.com/awsaf49) who originally worked on that notebook. "
514 | ]
515 | }
516 | ],
517 | "metadata": {
518 | "accelerator": "GPU",
519 | "colab": {
520 | "collapsed_sections": [],
521 | "machine_shape": "hm",
522 | "name": "finetune",
523 | "provenance": [],
524 | "toc_visible": true,
525 | "include_colab_link": true
526 | },
527 | "environment": {
528 | "kernel": "python3",
529 | "name": "tf2-gpu.2-7.m87",
530 | "type": "gcloud",
531 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87"
532 | },
533 | "kernelspec": {
534 | "display_name": "Python 3",
535 | "language": "python",
536 | "name": "python3"
537 | },
538 | "language_info": {
539 | "codemirror_mode": {
540 | "name": "ipython",
541 | "version": 3
542 | },
543 | "file_extension": ".py",
544 | "mimetype": "text/x-python",
545 | "name": "python",
546 | "nbconvert_exporter": "python",
547 | "pygments_lexer": "ipython3",
548 | "version": "3.7.12"
549 | }
550 | },
551 | "nbformat": 4,
552 | "nbformat_minor": 0
553 | }
--------------------------------------------------------------------------------
/swin-transformers-tf/notebooks/ilsvrc2012_wordnet_lemmas.txt:
--------------------------------------------------------------------------------
1 | tench, Tinca_tinca
2 | goldfish, Carassius_auratus
3 | great_white_shark, white_shark, man-eater, man-eating_shark, Carcharodon_carcharias
4 | tiger_shark, Galeocerdo_cuvieri
5 | hammerhead, hammerhead_shark
6 | electric_ray, crampfish, numbfish, torpedo
7 | stingray
8 | cock
9 | hen
10 | ostrich, Struthio_camelus
11 | brambling, Fringilla_montifringilla
12 | goldfinch, Carduelis_carduelis
13 | house_finch, linnet, Carpodacus_mexicanus
14 | junco, snowbird
15 | indigo_bunting, indigo_finch, indigo_bird, Passerina_cyanea
16 | robin, American_robin, Turdus_migratorius
17 | bulbul
18 | jay
19 | magpie
20 | chickadee
21 | water_ouzel, dipper
22 | kite
23 | bald_eagle, American_eagle, Haliaeetus_leucocephalus
24 | vulture
25 | great_grey_owl, great_gray_owl, Strix_nebulosa
26 | European_fire_salamander, Salamandra_salamandra
27 | common_newt, Triturus_vulgaris
28 | eft
29 | spotted_salamander, Ambystoma_maculatum
30 | axolotl, mud_puppy, Ambystoma_mexicanum
31 | bullfrog, Rana_catesbeiana
32 | tree_frog, tree-frog
33 | tailed_frog, bell_toad, ribbed_toad, tailed_toad, Ascaphus_trui
34 | loggerhead, loggerhead_turtle, Caretta_caretta
35 | leatherback_turtle, leatherback, leathery_turtle, Dermochelys_coriacea
36 | mud_turtle
37 | terrapin
38 | box_turtle, box_tortoise
39 | banded_gecko
40 | common_iguana, iguana, Iguana_iguana
41 | American_chameleon, anole, Anolis_carolinensis
42 | whiptail, whiptail_lizard
43 | agama
44 | frilled_lizard, Chlamydosaurus_kingi
45 | alligator_lizard
46 | Gila_monster, Heloderma_suspectum
47 | green_lizard, Lacerta_viridis
48 | African_chameleon, Chamaeleo_chamaeleon
49 | Komodo_dragon, Komodo_lizard, dragon_lizard, giant_lizard, Varanus_komodoensis
50 | African_crocodile, Nile_crocodile, Crocodylus_niloticus
51 | American_alligator, Alligator_mississipiensis
52 | triceratops
53 | thunder_snake, worm_snake, Carphophis_amoenus
54 | ringneck_snake, ring-necked_snake, ring_snake
55 | hognose_snake, puff_adder, sand_viper
56 | green_snake, grass_snake
57 | king_snake, kingsnake
58 | garter_snake, grass_snake
59 | water_snake
60 | vine_snake
61 | night_snake, Hypsiglena_torquata
62 | boa_constrictor, Constrictor_constrictor
63 | rock_python, rock_snake, Python_sebae
64 | Indian_cobra, Naja_naja
65 | green_mamba
66 | sea_snake
67 | horned_viper, cerastes, sand_viper, horned_asp, Cerastes_cornutus
68 | diamondback, diamondback_rattlesnake, Crotalus_adamanteus
69 | sidewinder, horned_rattlesnake, Crotalus_cerastes
70 | trilobite
71 | harvestman, daddy_longlegs, Phalangium_opilio
72 | scorpion
73 | black_and_gold_garden_spider, Argiope_aurantia
74 | barn_spider, Araneus_cavaticus
75 | garden_spider, Aranea_diademata
76 | black_widow, Latrodectus_mactans
77 | tarantula
78 | wolf_spider, hunting_spider
79 | tick
80 | centipede
81 | black_grouse
82 | ptarmigan
83 | ruffed_grouse, partridge, Bonasa_umbellus
84 | prairie_chicken, prairie_grouse, prairie_fowl
85 | peacock
86 | quail
87 | partridge
88 | African_grey, African_gray, Psittacus_erithacus
89 | macaw
90 | sulphur-crested_cockatoo, Kakatoe_galerita, Cacatua_galerita
91 | lorikeet
92 | coucal
93 | bee_eater
94 | hornbill
95 | hummingbird
96 | jacamar
97 | toucan
98 | drake
99 | red-breasted_merganser, Mergus_serrator
100 | goose
101 | black_swan, Cygnus_atratus
102 | tusker
103 | echidna, spiny_anteater, anteater
104 | platypus, duckbill, duckbilled_platypus, duck-billed_platypus, Ornithorhynchus_anatinus
105 | wallaby, brush_kangaroo
106 | koala, koala_bear, kangaroo_bear, native_bear, Phascolarctos_cinereus
107 | wombat
108 | jellyfish
109 | sea_anemone, anemone
110 | brain_coral
111 | flatworm, platyhelminth
112 | nematode, nematode_worm, roundworm
113 | conch
114 | snail
115 | slug
116 | sea_slug, nudibranch
117 | chiton, coat-of-mail_shell, sea_cradle, polyplacophore
118 | chambered_nautilus, pearly_nautilus, nautilus
119 | Dungeness_crab, Cancer_magister
120 | rock_crab, Cancer_irroratus
121 | fiddler_crab
122 | king_crab, Alaska_crab, Alaskan_king_crab, Alaska_king_crab, Paralithodes_camtschatica
123 | American_lobster, Northern_lobster, Maine_lobster, Homarus_americanus
124 | spiny_lobster, langouste, rock_lobster, crawfish, crayfish, sea_crawfish
125 | crayfish, crawfish, crawdad, crawdaddy
126 | hermit_crab
127 | isopod
128 | white_stork, Ciconia_ciconia
129 | black_stork, Ciconia_nigra
130 | spoonbill
131 | flamingo
132 | little_blue_heron, Egretta_caerulea
133 | American_egret, great_white_heron, Egretta_albus
134 | bittern
135 | crane
136 | limpkin, Aramus_pictus
137 | European_gallinule, Porphyrio_porphyrio
138 | American_coot, marsh_hen, mud_hen, water_hen, Fulica_americana
139 | bustard
140 | ruddy_turnstone, Arenaria_interpres
141 | red-backed_sandpiper, dunlin, Erolia_alpina
142 | redshank, Tringa_totanus
143 | dowitcher
144 | oystercatcher, oyster_catcher
145 | pelican
146 | king_penguin, Aptenodytes_patagonica
147 | albatross, mollymawk
148 | grey_whale, gray_whale, devilfish, Eschrichtius_gibbosus, Eschrichtius_robustus
149 | killer_whale, killer, orca, grampus, sea_wolf, Orcinus_orca
150 | dugong, Dugong_dugon
151 | sea_lion
152 | Chihuahua
153 | Japanese_spaniel
154 | Maltese_dog, Maltese_terrier, Maltese
155 | Pekinese, Pekingese, Peke
156 | Shih-Tzu
157 | Blenheim_spaniel
158 | papillon
159 | toy_terrier
160 | Rhodesian_ridgeback
161 | Afghan_hound, Afghan
162 | basset, basset_hound
163 | beagle
164 | bloodhound, sleuthhound
165 | bluetick
166 | black-and-tan_coonhound
167 | Walker_hound, Walker_foxhound
168 | English_foxhound
169 | redbone
170 | borzoi, Russian_wolfhound
171 | Irish_wolfhound
172 | Italian_greyhound
173 | whippet
174 | Ibizan_hound, Ibizan_Podenco
175 | Norwegian_elkhound, elkhound
176 | otterhound, otter_hound
177 | Saluki, gazelle_hound
178 | Scottish_deerhound, deerhound
179 | Weimaraner
180 | Staffordshire_bullterrier, Staffordshire_bull_terrier
181 | American_Staffordshire_terrier, Staffordshire_terrier, American_pit_bull_terrier, pit_bull_terrier
182 | Bedlington_terrier
183 | Border_terrier
184 | Kerry_blue_terrier
185 | Irish_terrier
186 | Norfolk_terrier
187 | Norwich_terrier
188 | Yorkshire_terrier
189 | wire-haired_fox_terrier
190 | Lakeland_terrier
191 | Sealyham_terrier, Sealyham
192 | Airedale, Airedale_terrier
193 | cairn, cairn_terrier
194 | Australian_terrier
195 | Dandie_Dinmont, Dandie_Dinmont_terrier
196 | Boston_bull, Boston_terrier
197 | miniature_schnauzer
198 | giant_schnauzer
199 | standard_schnauzer
200 | Scotch_terrier, Scottish_terrier, Scottie
201 | Tibetan_terrier, chrysanthemum_dog
202 | silky_terrier, Sydney_silky
203 | soft-coated_wheaten_terrier
204 | West_Highland_white_terrier
205 | Lhasa, Lhasa_apso
206 | flat-coated_retriever
207 | curly-coated_retriever
208 | golden_retriever
209 | Labrador_retriever
210 | Chesapeake_Bay_retriever
211 | German_short-haired_pointer
212 | vizsla, Hungarian_pointer
213 | English_setter
214 | Irish_setter, red_setter
215 | Gordon_setter
216 | Brittany_spaniel
217 | clumber, clumber_spaniel
218 | English_springer, English_springer_spaniel
219 | Welsh_springer_spaniel
220 | cocker_spaniel, English_cocker_spaniel, cocker
221 | Sussex_spaniel
222 | Irish_water_spaniel
223 | kuvasz
224 | schipperke
225 | groenendael
226 | malinois
227 | briard
228 | kelpie
229 | komondor
230 | Old_English_sheepdog, bobtail
231 | Shetland_sheepdog, Shetland_sheep_dog, Shetland
232 | collie
233 | Border_collie
234 | Bouvier_des_Flandres, Bouviers_des_Flandres
235 | Rottweiler
236 | German_shepherd, German_shepherd_dog, German_police_dog, alsatian
237 | Doberman, Doberman_pinscher
238 | miniature_pinscher
239 | Greater_Swiss_Mountain_dog
240 | Bernese_mountain_dog
241 | Appenzeller
242 | EntleBucher
243 | boxer
244 | bull_mastiff
245 | Tibetan_mastiff
246 | French_bulldog
247 | Great_Dane
248 | Saint_Bernard, St_Bernard
249 | Eskimo_dog, husky
250 | malamute, malemute, Alaskan_malamute
251 | Siberian_husky
252 | dalmatian, coach_dog, carriage_dog
253 | affenpinscher, monkey_pinscher, monkey_dog
254 | basenji
255 | pug, pug-dog
256 | Leonberg
257 | Newfoundland, Newfoundland_dog
258 | Great_Pyrenees
259 | Samoyed, Samoyede
260 | Pomeranian
261 | chow, chow_chow
262 | keeshond
263 | Brabancon_griffon
264 | Pembroke, Pembroke_Welsh_corgi
265 | Cardigan, Cardigan_Welsh_corgi
266 | toy_poodle
267 | miniature_poodle
268 | standard_poodle
269 | Mexican_hairless
270 | timber_wolf, grey_wolf, gray_wolf, Canis_lupus
271 | white_wolf, Arctic_wolf, Canis_lupus_tundrarum
272 | red_wolf, maned_wolf, Canis_rufus, Canis_niger
273 | coyote, prairie_wolf, brush_wolf, Canis_latrans
274 | dingo, warrigal, warragal, Canis_dingo
275 | dhole, Cuon_alpinus
276 | African_hunting_dog, hyena_dog, Cape_hunting_dog, Lycaon_pictus
277 | hyena, hyaena
278 | red_fox, Vulpes_vulpes
279 | kit_fox, Vulpes_macrotis
280 | Arctic_fox, white_fox, Alopex_lagopus
281 | grey_fox, gray_fox, Urocyon_cinereoargenteus
282 | tabby, tabby_cat
283 | tiger_cat
284 | Persian_cat
285 | Siamese_cat, Siamese
286 | Egyptian_cat
287 | cougar, puma, catamount, mountain_lion, painter, panther, Felis_concolor
288 | lynx, catamount
289 | leopard, Panthera_pardus
290 | snow_leopard, ounce, Panthera_uncia
291 | jaguar, panther, Panthera_onca, Felis_onca
292 | lion, king_of_beasts, Panthera_leo
293 | tiger, Panthera_tigris
294 | cheetah, chetah, Acinonyx_jubatus
295 | brown_bear, bruin, Ursus_arctos
296 | American_black_bear, black_bear, Ursus_americanus, Euarctos_americanus
297 | ice_bear, polar_bear, Ursus_Maritimus, Thalarctos_maritimus
298 | sloth_bear, Melursus_ursinus, Ursus_ursinus
299 | mongoose
300 | meerkat, mierkat
301 | tiger_beetle
302 | ladybug, ladybeetle, lady_beetle, ladybird, ladybird_beetle
303 | ground_beetle, carabid_beetle
304 | long-horned_beetle, longicorn, longicorn_beetle
305 | leaf_beetle, chrysomelid
306 | dung_beetle
307 | rhinoceros_beetle
308 | weevil
309 | fly
310 | bee
311 | ant, emmet, pismire
312 | grasshopper, hopper
313 | cricket
314 | walking_stick, walkingstick, stick_insect
315 | cockroach, roach
316 | mantis, mantid
317 | cicada, cicala
318 | leafhopper
319 | lacewing, lacewing_fly
320 | dragonfly, darning_needle, devil's_darning_needle, sewing_needle, snake_feeder, snake_doctor, mosquito_hawk, skeeter_hawk
321 | damselfly
322 | admiral
323 | ringlet, ringlet_butterfly
324 | monarch, monarch_butterfly, milkweed_butterfly, Danaus_plexippus
325 | cabbage_butterfly
326 | sulphur_butterfly, sulfur_butterfly
327 | lycaenid, lycaenid_butterfly
328 | starfish, sea_star
329 | sea_urchin
330 | sea_cucumber, holothurian
331 | wood_rabbit, cottontail, cottontail_rabbit
332 | hare
333 | Angora, Angora_rabbit
334 | hamster
335 | porcupine, hedgehog
336 | fox_squirrel, eastern_fox_squirrel, Sciurus_niger
337 | marmot
338 | beaver
339 | guinea_pig, Cavia_cobaya
340 | sorrel
341 | zebra
342 | hog, pig, grunter, squealer, Sus_scrofa
343 | wild_boar, boar, Sus_scrofa
344 | warthog
345 | hippopotamus, hippo, river_horse, Hippopotamus_amphibius
346 | ox
347 | water_buffalo, water_ox, Asiatic_buffalo, Bubalus_bubalis
348 | bison
349 | ram, tup
350 | bighorn, bighorn_sheep, cimarron, Rocky_Mountain_bighorn, Rocky_Mountain_sheep, Ovis_canadensis
351 | ibex, Capra_ibex
352 | hartebeest
353 | impala, Aepyceros_melampus
354 | gazelle
355 | Arabian_camel, dromedary, Camelus_dromedarius
356 | llama
357 | weasel
358 | mink
359 | polecat, fitch, foulmart, foumart, Mustela_putorius
360 | black-footed_ferret, ferret, Mustela_nigripes
361 | otter
362 | skunk, polecat, wood_pussy
363 | badger
364 | armadillo
365 | three-toed_sloth, ai, Bradypus_tridactylus
366 | orangutan, orang, orangutang, Pongo_pygmaeus
367 | gorilla, Gorilla_gorilla
368 | chimpanzee, chimp, Pan_troglodytes
369 | gibbon, Hylobates_lar
370 | siamang, Hylobates_syndactylus, Symphalangus_syndactylus
371 | guenon, guenon_monkey
372 | patas, hussar_monkey, Erythrocebus_patas
373 | baboon
374 | macaque
375 | langur
376 | colobus, colobus_monkey
377 | proboscis_monkey, Nasalis_larvatus
378 | marmoset
379 | capuchin, ringtail, Cebus_capucinus
380 | howler_monkey, howler
381 | titi, titi_monkey
382 | spider_monkey, Ateles_geoffroyi
383 | squirrel_monkey, Saimiri_sciureus
384 | Madagascar_cat, ring-tailed_lemur, Lemur_catta
385 | indri, indris, Indri_indri, Indri_brevicaudatus
386 | Indian_elephant, Elephas_maximus
387 | African_elephant, Loxodonta_africana
388 | lesser_panda, red_panda, panda, bear_cat, cat_bear, Ailurus_fulgens
389 | giant_panda, panda, panda_bear, coon_bear, Ailuropoda_melanoleuca
390 | barracouta, snoek
391 | eel
392 | coho, cohoe, coho_salmon, blue_jack, silver_salmon, Oncorhynchus_kisutch
393 | rock_beauty, Holocanthus_tricolor
394 | anemone_fish
395 | sturgeon
396 | gar, garfish, garpike, billfish, Lepisosteus_osseus
397 | lionfish
398 | puffer, pufferfish, blowfish, globefish
399 | abacus
400 | abaya
401 | academic_gown, academic_robe, judge's_robe
402 | accordion, piano_accordion, squeeze_box
403 | acoustic_guitar
404 | aircraft_carrier, carrier, flattop, attack_aircraft_carrier
405 | airliner
406 | airship, dirigible
407 | altar
408 | ambulance
409 | amphibian, amphibious_vehicle
410 | analog_clock
411 | apiary, bee_house
412 | apron
413 | ashcan, trash_can, garbage_can, wastebin, ash_bin, ash-bin, ashbin, dustbin, trash_barrel, trash_bin
414 | assault_rifle, assault_gun
415 | backpack, back_pack, knapsack, packsack, rucksack, haversack
416 | bakery, bakeshop, bakehouse
417 | balance_beam, beam
418 | balloon
419 | ballpoint, ballpoint_pen, ballpen, Biro
420 | Band_Aid
421 | banjo
422 | bannister, banister, balustrade, balusters, handrail
423 | barbell
424 | barber_chair
425 | barbershop
426 | barn
427 | barometer
428 | barrel, cask
429 | barrow, garden_cart, lawn_cart, wheelbarrow
430 | baseball
431 | basketball
432 | bassinet
433 | bassoon
434 | bathing_cap, swimming_cap
435 | bath_towel
436 | bathtub, bathing_tub, bath, tub
437 | beach_wagon, station_wagon, wagon, estate_car, beach_waggon, station_waggon, waggon
438 | beacon, lighthouse, beacon_light, pharos
439 | beaker
440 | bearskin, busby, shako
441 | beer_bottle
442 | beer_glass
443 | bell_cote, bell_cot
444 | bib
445 | bicycle-built-for-two, tandem_bicycle, tandem
446 | bikini, two-piece
447 | binder, ring-binder
448 | binoculars, field_glasses, opera_glasses
449 | birdhouse
450 | boathouse
451 | bobsled, bobsleigh, bob
452 | bolo_tie, bolo, bola_tie, bola
453 | bonnet, poke_bonnet
454 | bookcase
455 | bookshop, bookstore, bookstall
456 | bottlecap
457 | bow
458 | bow_tie, bow-tie, bowtie
459 | brass, memorial_tablet, plaque
460 | brassiere, bra, bandeau
461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty
462 | breastplate, aegis, egis
463 | broom
464 | bucket, pail
465 | buckle
466 | bulletproof_vest
467 | bullet_train, bullet
468 | butcher_shop, meat_market
469 | cab, hack, taxi, taxicab
470 | caldron, cauldron
471 | candle, taper, wax_light
472 | cannon
473 | canoe
474 | can_opener, tin_opener
475 | cardigan
476 | car_mirror
477 | carousel, carrousel, merry-go-round, roundabout, whirligig
478 | carpenter's_kit, tool_kit
479 | carton
480 | car_wheel
481 | cash_machine, cash_dispenser, automated_teller_machine, automatic_teller_machine, automated_teller, automatic_teller, ATM
482 | cassette
483 | cassette_player
484 | castle
485 | catamaran
486 | CD_player
487 | cello, violoncello
488 | cellular_telephone, cellular_phone, cellphone, cell, mobile_phone
489 | chain
490 | chainlink_fence
491 | chain_mail, ring_mail, mail, chain_armor, chain_armour, ring_armor, ring_armour
492 | chain_saw, chainsaw
493 | chest
494 | chiffonier, commode
495 | chime, bell, gong
496 | china_cabinet, china_closet
497 | Christmas_stocking
498 | church, church_building
499 | cinema, movie_theater, movie_theatre, movie_house, picture_palace
500 | cleaver, meat_cleaver, chopper
501 | cliff_dwelling
502 | cloak
503 | clog, geta, patten, sabot
504 | cocktail_shaker
505 | coffee_mug
506 | coffeepot
507 | coil, spiral, volute, whorl, helix
508 | combination_lock
509 | computer_keyboard, keypad
510 | confectionery, confectionary, candy_store
511 | container_ship, containership, container_vessel
512 | convertible
513 | corkscrew, bottle_screw
514 | cornet, horn, trumpet, trump
515 | cowboy_boot
516 | cowboy_hat, ten-gallon_hat
517 | cradle
518 | crane
519 | crash_helmet
520 | crate
521 | crib, cot
522 | Crock_Pot
523 | croquet_ball
524 | crutch
525 | cuirass
526 | dam, dike, dyke
527 | desk
528 | desktop_computer
529 | dial_telephone, dial_phone
530 | diaper, nappy, napkin
531 | digital_clock
532 | digital_watch
533 | dining_table, board
534 | dishrag, dishcloth
535 | dishwasher, dish_washer, dishwashing_machine
536 | disk_brake, disc_brake
537 | dock, dockage, docking_facility
538 | dogsled, dog_sled, dog_sleigh
539 | dome
540 | doormat, welcome_mat
541 | drilling_platform, offshore_rig
542 | drum, membranophone, tympan
543 | drumstick
544 | dumbbell
545 | Dutch_oven
546 | electric_fan, blower
547 | electric_guitar
548 | electric_locomotive
549 | entertainment_center
550 | envelope
551 | espresso_maker
552 | face_powder
553 | feather_boa, boa
554 | file, file_cabinet, filing_cabinet
555 | fireboat
556 | fire_engine, fire_truck
557 | fire_screen, fireguard
558 | flagpole, flagstaff
559 | flute, transverse_flute
560 | folding_chair
561 | football_helmet
562 | forklift
563 | fountain
564 | fountain_pen
565 | four-poster
566 | freight_car
567 | French_horn, horn
568 | frying_pan, frypan, skillet
569 | fur_coat
570 | garbage_truck, dustcart
571 | gasmask, respirator, gas_helmet
572 | gas_pump, gasoline_pump, petrol_pump, island_dispenser
573 | goblet
574 | go-kart
575 | golf_ball
576 | golfcart, golf_cart
577 | gondola
578 | gong, tam-tam
579 | gown
580 | grand_piano, grand
581 | greenhouse, nursery, glasshouse
582 | grille, radiator_grille
583 | grocery_store, grocery, food_market, market
584 | guillotine
585 | hair_slide
586 | hair_spray
587 | half_track
588 | hammer
589 | hamper
590 | hand_blower, blow_dryer, blow_drier, hair_dryer, hair_drier
591 | hand-held_computer, hand-held_microcomputer
592 | handkerchief, hankie, hanky, hankey
593 | hard_disc, hard_disk, fixed_disk
594 | harmonica, mouth_organ, harp, mouth_harp
595 | harp
596 | harvester, reaper
597 | hatchet
598 | holster
599 | home_theater, home_theatre
600 | honeycomb
601 | hook, claw
602 | hoopskirt, crinoline
603 | horizontal_bar, high_bar
604 | horse_cart, horse-cart
605 | hourglass
606 | iPod
607 | iron, smoothing_iron
608 | jack-o'-lantern
609 | jean, blue_jean, denim
610 | jeep, landrover
611 | jersey, T-shirt, tee_shirt
612 | jigsaw_puzzle
613 | jinrikisha, ricksha, rickshaw
614 | joystick
615 | kimono
616 | knee_pad
617 | knot
618 | lab_coat, laboratory_coat
619 | ladle
620 | lampshade, lamp_shade
621 | laptop, laptop_computer
622 | lawn_mower, mower
623 | lens_cap, lens_cover
624 | letter_opener, paper_knife, paperknife
625 | library
626 | lifeboat
627 | lighter, light, igniter, ignitor
628 | limousine, limo
629 | liner, ocean_liner
630 | lipstick, lip_rouge
631 | Loafer
632 | lotion
633 | loudspeaker, speaker, speaker_unit, loudspeaker_system, speaker_system
634 | loupe, jeweler's_loupe
635 | lumbermill, sawmill
636 | magnetic_compass
637 | mailbag, postbag
638 | mailbox, letter_box
639 | maillot
640 | maillot, tank_suit
641 | manhole_cover
642 | maraca
643 | marimba, xylophone
644 | mask
645 | matchstick
646 | maypole
647 | maze, labyrinth
648 | measuring_cup
649 | medicine_chest, medicine_cabinet
650 | megalith, megalithic_structure
651 | microphone, mike
652 | microwave, microwave_oven
653 | military_uniform
654 | milk_can
655 | minibus
656 | miniskirt, mini
657 | minivan
658 | missile
659 | mitten
660 | mixing_bowl
661 | mobile_home, manufactured_home
662 | Model_T
663 | modem
664 | monastery
665 | monitor
666 | moped
667 | mortar
668 | mortarboard
669 | mosque
670 | mosquito_net
671 | motor_scooter, scooter
672 | mountain_bike, all-terrain_bike, off-roader
673 | mountain_tent
674 | mouse, computer_mouse
675 | mousetrap
676 | moving_van
677 | muzzle
678 | nail
679 | neck_brace
680 | necklace
681 | nipple
682 | notebook, notebook_computer
683 | obelisk
684 | oboe, hautboy, hautbois
685 | ocarina, sweet_potato
686 | odometer, hodometer, mileometer, milometer
687 | oil_filter
688 | organ, pipe_organ
689 | oscilloscope, scope, cathode-ray_oscilloscope, CRO
690 | overskirt
691 | oxcart
692 | oxygen_mask
693 | packet
694 | paddle, boat_paddle
695 | paddlewheel, paddle_wheel
696 | padlock
697 | paintbrush
698 | pajama, pyjama, pj's, jammies
699 | palace
700 | panpipe, pandean_pipe, syrinx
701 | paper_towel
702 | parachute, chute
703 | parallel_bars, bars
704 | park_bench
705 | parking_meter
706 | passenger_car, coach, carriage
707 | patio, terrace
708 | pay-phone, pay-station
709 | pedestal, plinth, footstall
710 | pencil_box, pencil_case
711 | pencil_sharpener
712 | perfume, essence
713 | Petri_dish
714 | photocopier
715 | pick, plectrum, plectron
716 | pickelhaube
717 | picket_fence, paling
718 | pickup, pickup_truck
719 | pier
720 | piggy_bank, penny_bank
721 | pill_bottle
722 | pillow
723 | ping-pong_ball
724 | pinwheel
725 | pirate, pirate_ship
726 | pitcher, ewer
727 | plane, carpenter's_plane, woodworking_plane
728 | planetarium
729 | plastic_bag
730 | plate_rack
731 | plow, plough
732 | plunger, plumber's_helper
733 | Polaroid_camera, Polaroid_Land_camera
734 | pole
735 | police_van, police_wagon, paddy_wagon, patrol_wagon, wagon, black_Maria
736 | poncho
737 | pool_table, billiard_table, snooker_table
738 | pop_bottle, soda_bottle
739 | pot, flowerpot
740 | potter's_wheel
741 | power_drill
742 | prayer_rug, prayer_mat
743 | printer
744 | prison, prison_house
745 | projectile, missile
746 | projector
747 | puck, hockey_puck
748 | punching_bag, punch_bag, punching_ball, punchball
749 | purse
750 | quill, quill_pen
751 | quilt, comforter, comfort, puff
752 | racer, race_car, racing_car
753 | racket, racquet
754 | radiator
755 | radio, wireless
756 | radio_telescope, radio_reflector
757 | rain_barrel
758 | recreational_vehicle, RV, R.V.
759 | reel
760 | reflex_camera
761 | refrigerator, icebox
762 | remote_control, remote
763 | restaurant, eating_house, eating_place, eatery
764 | revolver, six-gun, six-shooter
765 | rifle
766 | rocking_chair, rocker
767 | rotisserie
768 | rubber_eraser, rubber, pencil_eraser
769 | rugby_ball
770 | rule, ruler
771 | running_shoe
772 | safe
773 | safety_pin
774 | saltshaker, salt_shaker
775 | sandal
776 | sarong
777 | sax, saxophone
778 | scabbard
779 | scale, weighing_machine
780 | school_bus
781 | schooner
782 | scoreboard
783 | screen, CRT_screen
784 | screw
785 | screwdriver
786 | seat_belt, seatbelt
787 | sewing_machine
788 | shield, buckler
789 | shoe_shop, shoe-shop, shoe_store
790 | shoji
791 | shopping_basket
792 | shopping_cart
793 | shovel
794 | shower_cap
795 | shower_curtain
796 | ski
797 | ski_mask
798 | sleeping_bag
799 | slide_rule, slipstick
800 | sliding_door
801 | slot, one-armed_bandit
802 | snorkel
803 | snowmobile
804 | snowplow, snowplough
805 | soap_dispenser
806 | soccer_ball
807 | sock
808 | solar_dish, solar_collector, solar_furnace
809 | sombrero
810 | soup_bowl
811 | space_bar
812 | space_heater
813 | space_shuttle
814 | spatula
815 | speedboat
816 | spider_web, spider's_web
817 | spindle
818 | sports_car, sport_car
819 | spotlight, spot
820 | stage
821 | steam_locomotive
822 | steel_arch_bridge
823 | steel_drum
824 | stethoscope
825 | stole
826 | stone_wall
827 | stopwatch, stop_watch
828 | stove
829 | strainer
830 | streetcar, tram, tramcar, trolley, trolley_car
831 | stretcher
832 | studio_couch, day_bed
833 | stupa, tope
834 | submarine, pigboat, sub, U-boat
835 | suit, suit_of_clothes
836 | sundial
837 | sunglass
838 | sunglasses, dark_glasses, shades
839 | sunscreen, sunblock, sun_blocker
840 | suspension_bridge
841 | swab, swob, mop
842 | sweatshirt
843 | swimming_trunks, bathing_trunks
844 | swing
845 | switch, electric_switch, electrical_switch
846 | syringe
847 | table_lamp
848 | tank, army_tank, armored_combat_vehicle, armoured_combat_vehicle
849 | tape_player
850 | teapot
851 | teddy, teddy_bear
852 | television, television_system
853 | tennis_ball
854 | thatch, thatched_roof
855 | theater_curtain, theatre_curtain
856 | thimble
857 | thresher, thrasher, threshing_machine
858 | throne
859 | tile_roof
860 | toaster
861 | tobacco_shop, tobacconist_shop, tobacconist
862 | toilet_seat
863 | torch
864 | totem_pole
865 | tow_truck, tow_car, wrecker
866 | toyshop
867 | tractor
868 | trailer_truck, tractor_trailer, trucking_rig, rig, articulated_lorry, semi
869 | tray
870 | trench_coat
871 | tricycle, trike, velocipede
872 | trimaran
873 | tripod
874 | triumphal_arch
875 | trolleybus, trolley_coach, trackless_trolley
876 | trombone
877 | tub, vat
878 | turnstile
879 | typewriter_keyboard
880 | umbrella
881 | unicycle, monocycle
882 | upright, upright_piano
883 | vacuum, vacuum_cleaner
884 | vase
885 | vault
886 | velvet
887 | vending_machine
888 | vestment
889 | viaduct
890 | violin, fiddle
891 | volleyball
892 | waffle_iron
893 | wall_clock
894 | wallet, billfold, notecase, pocketbook
895 | wardrobe, closet, press
896 | warplane, military_plane
897 | washbasin, handbasin, washbowl, lavabo, wash-hand_basin
898 | washer, automatic_washer, washing_machine
899 | water_bottle
900 | water_jug
901 | water_tower
902 | whiskey_jug
903 | whistle
904 | wig
905 | window_screen
906 | window_shade
907 | Windsor_tie
908 | wine_bottle
909 | wing
910 | wok
911 | wooden_spoon
912 | wool, woolen, woollen
913 | worm_fence, snake_fence, snake-rail_fence, Virginia_fence
914 | wreck
915 | yawl
916 | yurt
917 | web_site, website, internet_site, site
918 | comic_book
919 | crossword_puzzle, crossword
920 | street_sign
921 | traffic_light, traffic_signal, stoplight
922 | book_jacket, dust_cover, dust_jacket, dust_wrapper
923 | menu
924 | plate
925 | guacamole
926 | consomme
927 | hot_pot, hotpot
928 | trifle
929 | ice_cream, icecream
930 | ice_lolly, lolly, lollipop, popsicle
931 | French_loaf
932 | bagel, beigel
933 | pretzel
934 | cheeseburger
935 | hotdog, hot_dog, red_hot
936 | mashed_potato
937 | head_cabbage
938 | broccoli
939 | cauliflower
940 | zucchini, courgette
941 | spaghetti_squash
942 | acorn_squash
943 | butternut_squash
944 | cucumber, cuke
945 | artichoke, globe_artichoke
946 | bell_pepper
947 | cardoon
948 | mushroom
949 | Granny_Smith
950 | strawberry
951 | orange
952 | lemon
953 | fig
954 | pineapple, ananas
955 | banana
956 | jackfruit, jak, jack
957 | custard_apple
958 | pomegranate
959 | hay
960 | carbonara
961 | chocolate_sauce, chocolate_syrup
962 | dough
963 | meat_loaf, meatloaf
964 | pizza, pizza_pie
965 | potpie
966 | burrito
967 | red_wine
968 | espresso
969 | cup
970 | eggnog
971 | alp
972 | bubble
973 | cliff, drop, drop-off
974 | coral_reef
975 | geyser
976 | lakeside, lakeshore
977 | promontory, headland, head, foreland
978 | sandbar, sand_bar
979 | seashore, coast, seacoast, sea-coast
980 | valley, vale
981 | volcano
982 | ballplayer, baseball_player
983 | groom, bridegroom
984 | scuba_diver
985 | rapeseed
986 | daisy
987 | yellow_lady's_slipper, yellow_lady-slipper, Cypripedium_calceolus, Cypripedium_parviflorum
988 | corn
989 | acorn
990 | hip, rose_hip, rosehip
991 | buckeye, horse_chestnut, conker
992 | coral_fungus
993 | agaric
994 | gyromitra
995 | stinkhorn, carrion_fungus
996 | earthstar
997 | hen-of-the-woods, hen_of_the_woods, Polyporus_frondosus, Grifola_frondosa
998 | bolete
999 | ear, spike, capitulum
1000 | toilet_tissue, toilet_paper, bathroom_tissue
1001 |
--------------------------------------------------------------------------------
/swin-transformers-tf/notebooks/weight-porting.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "40180afb",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import timm\n",
11 | "\n",
12 | "import tensorflow as tf\n",
13 | "import numpy as np"
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": 2,
19 | "id": "50ef790b",
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import sys\n",
24 | "\n",
25 | "sys.path.append(\"..\")\n",
26 | "\n",
27 | "from swins import SwinTransformer\n",
28 | "from swins.layers import *\n",
29 | "from swins.blocks import *\n",
30 | "from utils import helpers"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 3,
36 | "id": "3d2ec87f",
37 | "metadata": {},
38 | "outputs": [],
39 | "source": [
40 | "cfg = dict(\n",
41 | " patch_size=4,\n",
42 | " window_size=7,\n",
43 | " embed_dim=96,\n",
44 | " depths=(2, 2, 6, 2),\n",
45 | " num_heads=(3, 6, 12, 24),\n",
46 | ")"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 4,
52 | "id": "59542c5f",
53 | "metadata": {},
54 | "outputs": [
55 | {
56 | "name": "stderr",
57 | "output_type": "stream",
58 | "text": [
59 | "2022-05-08 18:23:21.505079: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
60 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
61 | ]
62 | },
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "Swin TF model created.\n"
68 | ]
69 | }
70 | ],
71 | "source": [
72 | "swin_tiny_patch4_window7_224_tf = SwinTransformer(\n",
73 | " name=\"swin_tiny_patch4_window7_224\", **cfg\n",
74 | ")\n",
75 | "random_tensor = tf.random.normal((2, 224, 224, 3))\n",
76 | "outputs = swin_tiny_patch4_window7_224_tf(random_tensor, training=False)\n",
77 | "print(\"Swin TF model created.\")"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": 5,
83 | "id": "d29d6661",
84 | "metadata": {},
85 | "outputs": [
86 | {
87 | "name": "stderr",
88 | "output_type": "stream",
89 | "text": [
90 | "/Users/sayakpaul/.local/bin/.virtualenvs/pytorch/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.)\n",
91 | " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n"
92 | ]
93 | },
94 | {
95 | "name": "stdout",
96 | "output_type": "stream",
97 | "text": [
98 | "Swin PT model created.\n",
99 | "Number of parameters:\n",
100 | "28.288354\n"
101 | ]
102 | }
103 | ],
104 | "source": [
105 | "swin_tiny_patch4_window7_224_pt = timm.create_model(\n",
106 | " model_name=\"swin_tiny_patch4_window7_224\", pretrained=True\n",
107 | ")\n",
108 | "print(\"Swin PT model created.\")\n",
109 | "print(\"Number of parameters:\")\n",
110 | "num_params = sum(p.numel() for p in swin_tiny_patch4_window7_224_pt.parameters())\n",
111 | "print(num_params / 1e6)\n",
112 | "\n",
113 | "assert swin_tiny_patch4_window7_224_tf.count_params() == num_params"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 6,
119 | "id": "50ee5556",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "state_dict = swin_tiny_patch4_window7_224_pt.state_dict()\n",
124 | "np_state_dict = {k: state_dict[k].numpy() for k in state_dict}"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": 7,
130 | "id": "cbbc3080",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "# Projection.\n",
135 | "swin_tiny_patch4_window7_224_tf.projection.layers[0] = helpers.modify_tf_block(\n",
136 | " swin_tiny_patch4_window7_224_tf.projection.layers[0],\n",
137 | " np_state_dict[\"patch_embed.proj.weight\"],\n",
138 | " np_state_dict[\"patch_embed.proj.bias\"],\n",
139 | ")\n",
140 | "swin_tiny_patch4_window7_224_tf.projection.layers[2] = helpers.modify_tf_block(\n",
141 | " swin_tiny_patch4_window7_224_tf.projection.layers[2],\n",
142 | " np_state_dict[\"patch_embed.norm.weight\"],\n",
143 | " np_state_dict[\"patch_embed.norm.bias\"],\n",
144 | ")"
145 | ]
146 | },
147 | {
148 | "cell_type": "code",
149 | "execution_count": 8,
150 | "id": "e80ad36b",
151 | "metadata": {},
152 | "outputs": [],
153 | "source": [
154 | "# Layer norm layers.\n",
155 | "ln_idx = -2\n",
156 | "swin_tiny_patch4_window7_224_tf.layers[ln_idx] = helpers.modify_tf_block(\n",
157 | " swin_tiny_patch4_window7_224_tf.layers[ln_idx],\n",
158 | " np_state_dict[\"norm.weight\"],\n",
159 | " np_state_dict[\"norm.bias\"],\n",
160 | ")\n",
161 | "\n",
162 | "# Head layers.\n",
163 | "head_layer = swin_tiny_patch4_window7_224_tf.get_layer(\"classification_head\")\n",
164 | "swin_tiny_patch4_window7_224_tf.layers[-1] = helpers.modify_tf_block(\n",
165 | " head_layer,\n",
166 | " np_state_dict[\"head.weight\"],\n",
167 | " np_state_dict[\"head.bias\"],\n",
168 | ")"
169 | ]
170 | },
171 | {
172 | "cell_type": "code",
173 | "execution_count": 9,
174 | "id": "03ba0496",
175 | "metadata": {},
176 | "outputs": [
177 | {
178 | "data": {
179 | "text/plain": [
180 | "['layers.0.blocks.0.norm1.weight',\n",
181 | " 'layers.0.blocks.0.norm1.bias',\n",
182 | " 'layers.0.blocks.0.attn.relative_position_bias_table',\n",
183 | " 'layers.0.blocks.0.attn.relative_position_index',\n",
184 | " 'layers.0.blocks.0.attn.qkv.weight',\n",
185 | " 'layers.0.blocks.0.attn.qkv.bias',\n",
186 | " 'layers.0.blocks.0.attn.proj.weight',\n",
187 | " 'layers.0.blocks.0.attn.proj.bias',\n",
188 | " 'layers.0.blocks.0.norm2.weight',\n",
189 | " 'layers.0.blocks.0.norm2.bias',\n",
190 | " 'layers.0.blocks.0.mlp.fc1.weight',\n",
191 | " 'layers.0.blocks.0.mlp.fc1.bias',\n",
192 | " 'layers.0.blocks.0.mlp.fc2.weight',\n",
193 | " 'layers.0.blocks.0.mlp.fc2.bias',\n",
194 | " 'layers.0.blocks.1.attn_mask',\n",
195 | " 'layers.0.blocks.1.norm1.weight',\n",
196 | " 'layers.0.blocks.1.norm1.bias',\n",
197 | " 'layers.0.blocks.1.attn.relative_position_bias_table',\n",
198 | " 'layers.0.blocks.1.attn.relative_position_index',\n",
199 | " 'layers.0.blocks.1.attn.qkv.weight',\n",
200 | " 'layers.0.blocks.1.attn.qkv.bias',\n",
201 | " 'layers.0.blocks.1.attn.proj.weight',\n",
202 | " 'layers.0.blocks.1.attn.proj.bias',\n",
203 | " 'layers.0.blocks.1.norm2.weight',\n",
204 | " 'layers.0.blocks.1.norm2.bias',\n",
205 | " 'layers.0.blocks.1.mlp.fc1.weight',\n",
206 | " 'layers.0.blocks.1.mlp.fc1.bias',\n",
207 | " 'layers.0.blocks.1.mlp.fc2.weight',\n",
208 | " 'layers.0.blocks.1.mlp.fc2.bias',\n",
209 | " 'layers.0.downsample.reduction.weight',\n",
210 | " 'layers.0.downsample.norm.weight',\n",
211 | " 'layers.0.downsample.norm.bias']"
212 | ]
213 | },
214 | "execution_count": 9,
215 | "metadata": {},
216 | "output_type": "execute_result"
217 | }
218 | ],
219 | "source": [
220 | "list(filter(lambda x: \"layers.0\" in x, np_state_dict.keys()))"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 10,
226 | "id": "5c7af809",
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "def modify_swin_blocks(pt_weights_prefix, tf_block):\n",
231 | " # Patch merging.\n",
232 | " for layer in tf_block:\n",
233 | " if isinstance(layer, PatchMerging):\n",
234 | " patch_merging_idx = f\"{pt_weights_prefix}.downsample\"\n",
235 | "\n",
236 | " layer.reduction = helpers.modify_tf_block(\n",
237 | " layer.reduction,\n",
238 | " np_state_dict[f\"{patch_merging_idx}.reduction.weight\"],\n",
239 | " )\n",
240 | " layer.norm = helpers.modify_tf_block(\n",
241 | " layer.norm,\n",
242 | " np_state_dict[f\"{patch_merging_idx}.norm.weight\"],\n",
243 | " np_state_dict[f\"{patch_merging_idx}.norm.bias\"],\n",
244 | " )\n",
245 | "\n",
246 | " # Swin layers.\n",
247 | " common_prefix = f\"{pt_weights_prefix}.blocks\"\n",
248 | " block_idx = 0\n",
249 | "\n",
250 | " for outer_layer in tf_block:\n",
251 | "\n",
252 | " layernorm_idx = 1\n",
253 | " mlp_layer_idx = 1\n",
254 | "\n",
255 | " if isinstance(outer_layer, SwinTransformerBlock):\n",
256 | " for inner_layer in outer_layer.layers:\n",
257 | "\n",
258 | " # Layer norm.\n",
259 | " if isinstance(inner_layer, tf.keras.layers.LayerNormalization):\n",
260 | " layer_norm_prefix = (\n",
261 | " f\"{common_prefix}.{block_idx}.norm{layernorm_idx}\"\n",
262 | " )\n",
263 | " inner_layer.gamma.assign(\n",
264 | " tf.Variable(np_state_dict[f\"{layer_norm_prefix}.weight\"])\n",
265 | " )\n",
266 | " inner_layer.beta.assign(\n",
267 | " tf.Variable(np_state_dict[f\"{layer_norm_prefix}.bias\"])\n",
268 | " )\n",
269 | " layernorm_idx += 1\n",
270 | "\n",
271 | " # Windown attention.\n",
272 | " elif isinstance(inner_layer, WindowAttention):\n",
273 | " attn_prefix = f\"{common_prefix}.{block_idx}.attn\"\n",
274 | "\n",
275 | " # Relative position.\n",
276 | " inner_layer.relative_position_bias_table = helpers.modify_tf_block(\n",
277 | " inner_layer.relative_position_bias_table,\n",
278 | " np_state_dict[f\"{attn_prefix}.relative_position_bias_table\"],\n",
279 | " )\n",
280 | " inner_layer.relative_position_index = helpers.modify_tf_block(\n",
281 | " inner_layer.relative_position_index,\n",
282 | " np_state_dict[f\"{attn_prefix}.relative_position_index\"],\n",
283 | " )\n",
284 | "\n",
285 | " # QKV.\n",
286 | " inner_layer.qkv = helpers.modify_tf_block(\n",
287 | " inner_layer.qkv,\n",
288 | " np_state_dict[f\"{attn_prefix}.qkv.weight\"],\n",
289 | " np_state_dict[f\"{attn_prefix}.qkv.bias\"],\n",
290 | " )\n",
291 | "\n",
292 | " # Projection.\n",
293 | " inner_layer.proj = helpers.modify_tf_block(\n",
294 | " inner_layer.proj,\n",
295 | " np_state_dict[f\"{attn_prefix}.proj.weight\"],\n",
296 | " np_state_dict[f\"{attn_prefix}.proj.bias\"],\n",
297 | " )\n",
298 | "\n",
299 | " # MLP.\n",
300 | " elif isinstance(inner_layer, tf.keras.Model):\n",
301 | " mlp_prefix = f\"{common_prefix}.{block_idx}.mlp\"\n",
302 | " for mlp_layer in inner_layer.layers:\n",
303 | " if isinstance(mlp_layer, tf.keras.layers.Dense):\n",
304 | " mlp_layer = helpers.modify_tf_block(\n",
305 | " mlp_layer,\n",
306 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.weight\"],\n",
307 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.bias\"],\n",
308 | " )\n",
309 | " mlp_layer_idx += 1\n",
310 | "\n",
311 | " block_idx += 1\n",
312 | " return tf_block"
313 | ]
314 | },
315 | {
316 | "cell_type": "code",
317 | "execution_count": 11,
318 | "id": "d947e04d",
319 | "metadata": {},
320 | "outputs": [],
321 | "source": [
322 | "_ = modify_swin_blocks(\n",
323 | " \"layers.0\",\n",
324 | " swin_tiny_patch4_window7_224_tf.layers[2].layers,\n",
325 | ")"
326 | ]
327 | },
328 | {
329 | "cell_type": "code",
330 | "execution_count": 12,
331 | "id": "cdb9c2e1",
332 | "metadata": {},
333 | "outputs": [],
334 | "source": [
335 | "tf_block = swin_tiny_patch4_window7_224_tf.layers[2].layers\n",
336 | "pt_weights_prefix = \"layers.0\"\n",
337 | "\n",
338 | "# Patch merging.\n",
339 | "for layer in tf_block:\n",
340 | " if isinstance(layer, PatchMerging):\n",
341 | " patch_merging_idx = f\"{pt_weights_prefix}.downsample\"\n",
342 | " np.testing.assert_allclose(\n",
343 | " np_state_dict[f\"{patch_merging_idx}.reduction.weight\"].transpose(),\n",
344 | " layer.reduction.kernel.numpy(),\n",
345 | " )\n",
346 | " np.testing.assert_allclose(\n",
347 | " np_state_dict[f\"{patch_merging_idx}.norm.weight\"], layer.norm.gamma.numpy()\n",
348 | " )\n",
349 | " np.testing.assert_allclose(\n",
350 | " np_state_dict[f\"{patch_merging_idx}.norm.bias\"], layer.norm.beta.numpy()\n",
351 | " )\n",
352 | "\n",
353 | "# Swin layers.\n",
354 | "common_prefix = f\"{pt_weights_prefix}.blocks\"\n",
355 | "block_idx = 0\n",
356 | "\n",
357 | "for outer_layer in tf_block:\n",
358 | "\n",
359 | " layernorm_idx = 1\n",
360 | " mlp_layer_idx = 1\n",
361 | "\n",
362 | " if isinstance(outer_layer, SwinTransformerBlock):\n",
363 | " for inner_layer in outer_layer.layers:\n",
364 | "\n",
365 | " # Layer norm.\n",
366 | " if isinstance(inner_layer, tf.keras.layers.LayerNormalization):\n",
367 | " layer_norm_prefix = f\"{common_prefix}.{block_idx}.norm{layernorm_idx}\"\n",
368 | " np.testing.assert_allclose(\n",
369 | " np_state_dict[f\"{layer_norm_prefix}.weight\"],\n",
370 | " inner_layer.gamma.numpy(),\n",
371 | " )\n",
372 | " np.testing.assert_allclose(\n",
373 | " np_state_dict[f\"{layer_norm_prefix}.bias\"], inner_layer.beta.numpy()\n",
374 | " )\n",
375 | " layernorm_idx += 1\n",
376 | "\n",
377 | " # Windown attention.\n",
378 | " elif isinstance(inner_layer, WindowAttention):\n",
379 | " attn_prefix = f\"{common_prefix}.{block_idx}.attn\"\n",
380 | "\n",
381 | " # Relative position.\n",
382 | " np.testing.assert_allclose(\n",
383 | " np_state_dict[f\"{attn_prefix}.relative_position_bias_table\"],\n",
384 | " inner_layer.relative_position_bias_table.numpy(),\n",
385 | " )\n",
386 | "\n",
387 | " np.testing.assert_allclose(\n",
388 | " np_state_dict[f\"{attn_prefix}.relative_position_index\"],\n",
389 | " inner_layer.relative_position_index.numpy(),\n",
390 | " )\n",
391 | "\n",
392 | " # QKV.\n",
393 | " np.testing.assert_allclose(\n",
394 | " np_state_dict[f\"{attn_prefix}.qkv.weight\"].transpose(),\n",
395 | " inner_layer.qkv.kernel.numpy(),\n",
396 | " )\n",
397 | " np.testing.assert_allclose(\n",
398 | " np_state_dict[f\"{attn_prefix}.qkv.bias\"],\n",
399 | " inner_layer.qkv.bias.numpy(),\n",
400 | " )\n",
401 | "\n",
402 | " # Projection.\n",
403 | " np.testing.assert_allclose(\n",
404 | " np_state_dict[f\"{attn_prefix}.proj.weight\"].transpose(),\n",
405 | " inner_layer.proj.kernel.numpy(),\n",
406 | " )\n",
407 | " np.testing.assert_allclose(\n",
408 | " np_state_dict[f\"{attn_prefix}.proj.bias\"],\n",
409 | " inner_layer.proj.bias.numpy(),\n",
410 | " )\n",
411 | "\n",
412 | " # MLP.\n",
413 | " elif isinstance(inner_layer, tf.keras.Model):\n",
414 | " mlp_prefix = f\"{common_prefix}.{block_idx}.mlp\"\n",
415 | " for mlp_layer in inner_layer.layers:\n",
416 | " if isinstance(mlp_layer, tf.keras.layers.Dense):\n",
417 | " np.testing.assert_allclose(\n",
418 | " np_state_dict[\n",
419 | " f\"{mlp_prefix}.fc{mlp_layer_idx}.weight\"\n",
420 | " ].transpose(),\n",
421 | " mlp_layer.kernel.numpy(),\n",
422 | " )\n",
423 | " np.testing.assert_allclose(\n",
424 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.bias\"],\n",
425 | " mlp_layer.bias.numpy(),\n",
426 | " )\n",
427 | "\n",
428 | " mlp_layer_idx += 1\n",
429 | "\n",
430 | " block_idx += 1"
431 | ]
432 | },
433 | {
434 | "cell_type": "code",
435 | "execution_count": 13,
436 | "id": "9743e538",
437 | "metadata": {},
438 | "outputs": [],
439 | "source": [
440 | "for i in range(len(cfg[\"depths\"])):\n",
441 | " _ = modify_swin_blocks(\n",
442 | " f\"layers.{i}\",\n",
443 | " swin_tiny_patch4_window7_224_tf.layers[i+2].layers,\n",
444 | " )"
445 | ]
446 | },
447 | {
448 | "cell_type": "code",
449 | "execution_count": 14,
450 | "id": "71d471b3",
451 | "metadata": {},
452 | "outputs": [],
453 | "source": [
454 | "import requests\n",
455 | "from PIL import Image\n",
456 | "from io import BytesIO\n",
457 | "\n",
458 | "import matplotlib.pyplot as plt"
459 | ]
460 | },
461 | {
462 | "cell_type": "code",
463 | "execution_count": 15,
464 | "id": "672266c0",
465 | "metadata": {},
466 | "outputs": [],
467 | "source": [
468 | "input_resolution = 224\n",
469 | "\n",
470 | "crop_layer = tf.keras.layers.CenterCrop(input_resolution, input_resolution)\n",
471 | "norm_layer = tf.keras.layers.Normalization(\n",
472 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n",
473 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n",
474 | ")\n",
475 | "\n",
476 | "\n",
477 | "def preprocess_image(image, size=input_resolution):\n",
478 | " image = np.array(image)\n",
479 | " image_resized = tf.expand_dims(image, 0)\n",
480 | " resize_size = int((256 / 224) * size)\n",
481 | " image_resized = tf.image.resize(\n",
482 | " image_resized, (resize_size, resize_size), method=\"bicubic\"\n",
483 | " )\n",
484 | " image_resized = crop_layer(image_resized)\n",
485 | " return norm_layer(image_resized).numpy()\n",
486 | "\n",
487 | "\n",
488 | "def load_image_from_url(url):\n",
489 | " # Credit: Willi Gierke\n",
490 | " response = requests.get(url)\n",
491 | " image = Image.open(BytesIO(response.content))\n",
492 | " preprocessed_image = preprocess_image(image)\n",
493 | " return image, preprocessed_image"
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "execution_count": 16,
499 | "id": "c3f28a73",
500 | "metadata": {},
501 | "outputs": [],
502 | "source": [
503 | "# !wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt"
504 | ]
505 | },
506 | {
507 | "cell_type": "code",
508 | "execution_count": 17,
509 | "id": "43627c4d",
510 | "metadata": {},
511 | "outputs": [],
512 | "source": [
513 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n",
514 | " lines = f.readlines()\n",
515 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n",
516 | "\n",
517 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n",
518 | "image, preprocessed_image = load_image_from_url(img_url)"
519 | ]
520 | },
521 | {
522 | "cell_type": "code",
523 | "execution_count": 18,
524 | "id": "7adfa5e2",
525 | "metadata": {},
526 | "outputs": [],
527 | "source": [
528 | "predictions = swin_tiny_patch4_window7_224_tf.predict(preprocessed_image)\n",
529 | "logits = predictions[0]\n",
530 | "predicted_label = imagenet_int_to_str[int(np.argmax(logits))]\n",
531 | "expected_label = \"Indian_elephant, Elephas_maximus\"\n",
532 | "assert (\n",
533 | " predicted_label == expected_label\n",
534 | "), f\"Expected {expected_label} but was {predicted_label}\""
535 | ]
536 | },
537 | {
538 | "cell_type": "code",
539 | "execution_count": 19,
540 | "id": "3cb44e4e",
541 | "metadata": {},
542 | "outputs": [
543 | {
544 | "data": {
545 | "text/plain": [
546 | "dict_keys(['swin_stage_0', 'swin_stage_1', 'swin_stage_2', 'swin_stage_3'])"
547 | ]
548 | },
549 | "execution_count": 19,
550 | "metadata": {},
551 | "output_type": "execute_result"
552 | }
553 | ],
554 | "source": [
555 | "all_attn_scores = swin_tiny_patch4_window7_224_tf.get_attention_scores(\n",
556 | " preprocessed_image\n",
557 | ")\n",
558 | "all_attn_scores.keys()"
559 | ]
560 | },
561 | {
562 | "cell_type": "code",
563 | "execution_count": 20,
564 | "id": "7a1244ba",
565 | "metadata": {},
566 | "outputs": [
567 | {
568 | "data": {
569 | "text/plain": [
570 | "dict_keys(['swin_block_0', 'swin_block_1'])"
571 | ]
572 | },
573 | "execution_count": 20,
574 | "metadata": {},
575 | "output_type": "execute_result"
576 | }
577 | ],
578 | "source": [
579 | "all_attn_scores[\"swin_stage_3\"].keys()"
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": 21,
585 | "id": "a6d1b5c3",
586 | "metadata": {},
587 | "outputs": [
588 | {
589 | "data": {
590 | "text/plain": [
591 | "TensorShape([1, 24, 49, 49])"
592 | ]
593 | },
594 | "execution_count": 21,
595 | "metadata": {},
596 | "output_type": "execute_result"
597 | }
598 | ],
599 | "source": [
600 | "all_attn_scores[\"swin_stage_3\"][\"swin_block_0\"].shape"
601 | ]
602 | },
603 | {
604 | "cell_type": "code",
605 | "execution_count": 22,
606 | "id": "f69a8d93",
607 | "metadata": {},
608 | "outputs": [
609 | {
610 | "name": "stderr",
611 | "output_type": "stream",
612 | "text": [
613 | "2022-05-08 18:23:42.809960: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n",
614 | "WARNING:absl:Found untraced functions such as layer_normalization_5_layer_call_fn, layer_normalization_5_layer_call_and_return_conditional_losses, dense_4_layer_call_fn, dense_4_layer_call_and_return_conditional_losses, layer_normalization_10_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.\n"
615 | ]
616 | },
617 | {
618 | "name": "stdout",
619 | "output_type": "stream",
620 | "text": [
621 | "INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets\n"
622 | ]
623 | },
624 | {
625 | "name": "stderr",
626 | "output_type": "stream",
627 | "text": [
628 | "INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets\n"
629 | ]
630 | }
631 | ],
632 | "source": [
633 | "swin_tiny_patch4_window7_224_tf.save(\"gs://swin-tf/swin_tiny_patch4_window7_224_tf\")"
634 | ]
635 | }
636 | ],
637 | "metadata": {
638 | "kernelspec": {
639 | "display_name": "Python 3 (ipykernel)",
640 | "language": "python",
641 | "name": "python3"
642 | },
643 | "language_info": {
644 | "codemirror_mode": {
645 | "name": "ipython",
646 | "version": 3
647 | },
648 | "file_extension": ".py",
649 | "mimetype": "text/x-python",
650 | "name": "python",
651 | "nbconvert_exporter": "python",
652 | "pygments_lexer": "ipython3",
653 | "version": "3.8.2"
654 | }
655 | },
656 | "nbformat": 4,
657 | "nbformat_minor": 5
658 | }
659 |
--------------------------------------------------------------------------------
/swin-transformers-tf/requirements.txt:
--------------------------------------------------------------------------------
1 | # timm needs to be installed from source.
2 | torch==1.11.0
3 | tensorflow==2.8.0
4 | ml_collections==0.1.1
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import SwinTransformer
2 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/blocks/__init__.py:
--------------------------------------------------------------------------------
1 | from .mlp import mlp_block
2 | from .stage_block import BasicLayer
3 | from .swin_transformer_block import SwinTransformerBlock
4 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/blocks/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import tensorflow as tf
4 | from tensorflow import keras
5 | from tensorflow.keras import layers
6 |
7 |
8 | def mlp_block(dropout_rate: float, hidden_units: List[int], name: str = "mlp"):
9 | """FFN for a Transformer block."""
10 | ffn = keras.Sequential(name=name)
11 | for (idx, units) in enumerate(hidden_units):
12 | ffn.add(
13 | layers.Dense(
14 | units,
15 | activation=tf.nn.gelu if idx == 0 else None,
16 | bias_initializer=keras.initializers.RandomNormal(stddev=1e-6),
17 | )
18 | )
19 | ffn.add(layers.Dropout(dropout_rate))
20 | return ffn
21 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/blocks/stage_block.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Code copied and modified from
18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
19 | """
20 |
21 | from functools import partial
22 | from typing import Dict, Union
23 | import tensorflow as tf
24 | from tensorflow import keras
25 | from tensorflow.keras import layers as L
26 | from .swin_transformer_block import SwinTransformerBlock
27 |
28 |
29 | class BasicLayer(keras.Model):
30 | """A basic Swin Transformer layer for one stage.
31 |
32 | Args:
33 | dim (int): Number of input channels.
34 | depth (int): Number of blocks.
35 | num_heads (int): Number of attention heads.
36 | head_dim (int): Channels per head (dim // num_heads if not set)
37 | window_size (int): Local window size.
38 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
39 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
40 | drop (float, optional): Dropout rate. Default: 0.0
41 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
42 | drop_path (float | list[float], optional): Stochastic depth rate. Default: 0.0
43 | norm_layer (layers.Layer, optional): Normalization layer. Default: layers.LayerNormalization
44 | downsample (layers.Layer | None, optional): Downsample layer at the end of the layer. Default: None
45 | upsample (layers.Layer | None, optional): Upsample layer at the end of the layer. Default: None
46 | """
47 |
48 | def __init__(
49 | self,
50 | dim,
51 | out_dim,
52 | depth,
53 | num_heads=4,
54 | head_dim=None,
55 | window_size=7,
56 | mlp_ratio=4.0,
57 | qkv_bias=True,
58 | drop=0.0,
59 | attn_drop=0.0,
60 | drop_path=0.0,
61 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5),
62 | downsample=None,
63 | upsample=None,
64 | **kwargs,
65 | ):
66 |
67 | super().__init__(kwargs)
68 | self.dim = dim
69 | self.depth = depth
70 |
71 | # build blocks
72 | blocks = [
73 | SwinTransformerBlock(
74 | dim=dim,
75 | num_heads=num_heads,
76 | head_dim=head_dim,
77 | window_size=window_size,
78 | shift_size=0 if (i % 2 == 0) else window_size // 2,
79 | mlp_ratio=mlp_ratio,
80 | qkv_bias=qkv_bias,
81 | drop=drop,
82 | attn_drop=attn_drop,
83 | drop_path=drop_path[i]
84 | if isinstance(drop_path, list)
85 | else drop_path,
86 | norm_layer=norm_layer,
87 | name=f"swin_transformer_block_{i}",
88 | )
89 | for i in range(depth)
90 | ]
91 | self.blocks = blocks
92 | # patch merging layer
93 | if downsample is not None:
94 | self.downsample = downsample(
95 | dim=dim,
96 | out_dim=out_dim,
97 | norm_layer=norm_layer,
98 | )
99 | else:
100 | self.downsample = None
101 | # patch splitting layer
102 | if upsample is not None:
103 | self.upsample = upsample(
104 | dim=dim,
105 | out_dim=out_dim,
106 | norm_layer=norm_layer,
107 | )
108 | else:
109 | self.upsample = None
110 | def call(
111 | self, x, return_attns=False
112 | ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
113 | if return_attns:
114 | attention_scores = {}
115 |
116 | for i, block in enumerate(self.blocks):
117 | if not return_attns:
118 | x = block(x)
119 | else:
120 | x, attns = block(x, return_attns)
121 | attention_scores.update({f"swin_block_{i}": attns})
122 | if self.downsample is not None:
123 | x = self.downsample(x)
124 |
125 | if self.upsample is not None:
126 | x = self.upsample(x)
127 |
128 | if return_attns:
129 | return x, attention_scores
130 | else:
131 | return x
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/blocks/swin_transformer_block.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Code copied and modified from
18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
19 | """
20 |
21 | import collections.abc
22 | from functools import partial
23 | from typing import Dict, Union
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 | from tensorflow import keras
28 | from tensorflow.keras import layers as L
29 |
30 | from ..layers import StochasticDepth, WindowAttention
31 | from . import utils
32 | from .mlp import mlp_block
33 |
34 |
35 | class SwinTransformerBlock(keras.Model):
36 | """Swin Transformer Block.
37 |
38 | Args:
39 | dim (int): Number of input channels.
40 | window_size (int): Window size.
41 | num_heads (int): Number of attention heads.
42 | head_dim (int): Enforce the number of channels per head
43 | shift_size (int): Shift size for SW-MSA.
44 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
45 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
46 | drop (float, optional): Dropout rate. Default: 0.0
47 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
48 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
49 | norm_layer (layers.Layer, optional): Normalization layer. Default: layers.LayerNormalization
50 | """
51 |
52 | def __init__(
53 | self,
54 | dim,
55 | num_heads=4,
56 | head_dim=None,
57 | window_size=7,
58 | shift_size=0,
59 | mlp_ratio=4.0,
60 | qkv_bias=True,
61 | drop=0.0,
62 | attn_drop=0.0,
63 | drop_path=0.0,
64 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5),
65 | **kwargs,
66 | ):
67 | super().__init__(**kwargs)
68 | self.dim = dim
69 | self.window_size = window_size
70 | self.shift_size = shift_size
71 | self.mlp_ratio = mlp_ratio
72 | assert (
73 | 0 <= self.shift_size < self.window_size
74 | ), "shift_size must in 0-window_size"
75 | self.norm1 = norm_layer()
76 | self.attn = WindowAttention(
77 | dim,
78 | num_heads=num_heads,
79 | head_dim=head_dim,
80 | window_size=window_size
81 | if isinstance(window_size, collections.abc.Iterable)
82 | else (window_size, window_size),
83 | qkv_bias=qkv_bias,
84 | attn_drop=attn_drop,
85 | proj_drop=drop,
86 | name="window_attention",
87 | )
88 | self.drop_path = (
89 | StochasticDepth(drop_path) if drop_path > 0.0 else tf.identity
90 | )
91 | self.norm2 = norm_layer()
92 | self.mlp = mlp_block(
93 | dropout_rate=drop, hidden_units=[int(dim * mlp_ratio), dim]
94 | )
95 | self.attn_mask = None
96 |
97 | def get_img_mask(self):
98 | # calculate image mask for SW-MSA
99 | # since Tensorflow does not support item assignment, we use a
100 | # "hacky" solution. See
101 | # https://github.com/microsoft/Swin-Transformer/blob/e43ac64ce8abfe133ae582741ccaf6761eea05f7/models/swin_transformer.py#L222
102 | # for more information.
103 |
104 | H, W = self.input_resolution
105 | window_size = self.window_size
106 |
107 | mask_0 = tf.zeros((1, H-window_size, W-window_size, 1))
108 | mask_1 = tf.ones((1, H-window_size, window_size//2, 1))
109 | mask_2 = tf.ones((1, H-window_size, window_size//2, 1))
110 | mask_2 = mask_2+1
111 | mask_3 = tf.ones((1, window_size//2, W-window_size, 1))
112 | mask_3 = mask_3+2
113 | mask_4 = tf.ones((1, window_size//2, window_size//2, 1))
114 | mask_4 = mask_4+3
115 | mask_5 = tf.ones((1, window_size//2, window_size//2, 1))
116 | mask_5 = mask_5+4
117 | mask_6 = tf.ones((1, window_size//2, W-window_size, 1))
118 | mask_6 = mask_6+5
119 | mask_7 = tf.ones((1, window_size//2, window_size//2, 1))
120 | mask_7 = mask_7+6
121 | mask_8 = tf.ones((1, window_size//2, window_size//2, 1))
122 | mask_8 = mask_8+7
123 |
124 | mask_012 = tf.concat([mask_0, mask_1, mask_2], axis=2)
125 | mask_345 = tf.concat([mask_3, mask_4, mask_5], axis=2)
126 | mask_678 = tf.concat([mask_6, mask_7, mask_8], axis=2)
127 |
128 | img_mask = tf.concat([mask_012, mask_345, mask_678], axis=1)
129 | return img_mask
130 |
131 |
132 | def get_attn_mask(self):
133 | # calculate attention mask for SW-MSA
134 | mask_windows = utils.window_partition(
135 | self.img_mask, self.window_size
136 | ) # [num_win, window_size, window_size, 1]
137 | mask_windows = tf.reshape(
138 | mask_windows, (-1, self.window_size * self.window_size)
139 | )
140 | attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims(
141 | mask_windows, 2
142 | )
143 | attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask)
144 | return tf.where(attn_mask == 0, 0.0, attn_mask)
145 |
146 | def call(
147 | self, x, return_attns=False
148 | ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]:
149 |
150 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
151 | self.input_resolution = (H, W)
152 |
153 | if self.shift_size > 0:
154 | self.img_mask = tf.stop_gradient(self.get_img_mask())
155 | self.attn_mask = self.get_attn_mask()
156 |
157 | x = tf.reshape(x, (-1, H*W, C))
158 |
159 | shortcut = x
160 | x = self.norm1(x)
161 | x = tf.reshape(x, (-1, H, W, C))
162 |
163 | # cyclic shift
164 | if self.shift_size > 0:
165 | shifted_x = tf.roll(
166 | x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)
167 | )
168 | else:
169 | shifted_x = x
170 |
171 | # partition windows
172 | x_windows = utils.window_partition(
173 | shifted_x, self.window_size
174 | ) # [num_win*B, window_size, window_size, C]
175 | x_windows = tf.reshape(
176 | x_windows, (-1, self.window_size * self.window_size, C)
177 | ) # [num_win*B, window_size*window_size, C]
178 |
179 | # W-MSA/SW-MSA
180 | if not return_attns:
181 | attn_windows = self.attn(
182 | x_windows, mask=self.attn_mask
183 | ) # [num_win*B, window_size*window_size, C]
184 | else:
185 | attn_windows, attn_scores = self.attn(
186 | x_windows, mask=self.attn_mask, return_attns=True
187 | ) # [num_win*B, window_size*window_size, C]
188 | # merge windows
189 | attn_windows = tf.reshape(
190 | attn_windows, (-1, self.window_size, self.window_size, C)
191 | )
192 | shifted_x = utils.window_reverse(
193 | attn_windows, self.window_size, H, W
194 | ) # [B, H', W', C]
195 |
196 | # reverse cyclic shift
197 | if self.shift_size > 0:
198 | x = tf.roll(
199 | shifted_x,
200 | shift=(self.shift_size, self.shift_size),
201 | axis=(1, 2),
202 | )
203 | else:
204 | x = shifted_x
205 |
206 | x = tf.reshape(x, (-1, H * W, C))
207 |
208 | # FFN
209 | x = shortcut + self.drop_path(x)
210 | x = x + self.drop_path(self.mlp(self.norm2(x)))
211 |
212 | x = tf.reshape(x, (-1, H, W, C))
213 |
214 | if return_attns:
215 | return x, attn_scores
216 | else:
217 | return x
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/blocks/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Code copied and modified from
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
4 | """
5 |
6 |
7 | import tensorflow as tf
8 |
9 |
10 | def window_partition(x: tf.Tensor, window_size: int):
11 | """
12 | Args:
13 | x: (B, H, W, C)
14 | window_size (int): window size
15 | Returns:
16 | windows: (num_windows*B, window_size, window_size, C)
17 | """
18 | B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
19 |
20 | x = tf.reshape(
21 | x, (B, H // window_size, window_size, W // window_size, window_size, C)
22 | )
23 | windows = tf.transpose(x, [0, 1, 3, 2, 4, 5])
24 | windows = tf.reshape(windows, (-1, window_size, window_size, C))
25 | return windows
26 |
27 |
28 | def window_reverse(windows: tf.Tensor, window_size: int, H: int, W: int):
29 | """
30 | Args:
31 | windows: (num_windows*B, window_size, window_size, C)
32 | window_size (int): Window size
33 | H (int): Height of image
34 | W (int): Width of image
35 | Returns:
36 | x: (B, H, W, C)
37 | """
38 | B = tf.shape(windows)[0] // tf.cast(
39 | H * W / window_size / window_size, dtype="int32"
40 | )
41 |
42 | x = tf.reshape(
43 | windows,
44 | (
45 | B,
46 | H // window_size,
47 | W // window_size,
48 | window_size,
49 | window_size,
50 | -1,
51 | ),
52 | )
53 | x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
54 | return tf.reshape(x, (B, H, W, -1))
55 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .patch_merging import PatchMerging
2 | from .patch_splitting import PatchSplitting, PatchUnpack
3 | from .sd import StochasticDepth
4 | from .window_attn import WindowAttention
5 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/layers/patch_merging.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Code copied and modified from
18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
19 | """
20 |
21 | from functools import partial
22 |
23 | import tensorflow as tf
24 | from tensorflow.keras import layers as L
25 |
26 |
27 | class PatchMerging(L.Layer):
28 | """Patch Merging Layer.
29 |
30 | Args:
31 | dim (int): Number of input channels.
32 | """
33 |
34 | def __init__(
35 | self,
36 | dim,
37 | out_dim=None,
38 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5),
39 | **kwargs
40 | ):
41 | super().__init__(**kwargs)
42 | self.dim = dim
43 | self.out_dim = out_dim or 2 * dim
44 | self.norm = norm_layer()
45 | self.reduction = L.Dense(self.out_dim, use_bias=False)
46 |
47 | def call(self, x):
48 | """
49 | x: B, H, W, C
50 | """
51 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
52 |
53 | x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
54 | x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
55 | x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
56 | x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
57 | x = tf.concat([x0, x1, x2, x3], axis=-1) # [B, H/2, W/2, 4*C]
58 | x = tf.reshape(x, (-1, H//2*W//2, 4 * C)) # [B, H/2*W/2, 4*C]
59 |
60 | x = self.norm(x)
61 | x = self.reduction(x)
62 | x = tf.reshape(x, (-1, H//2, W//2, self.out_dim))
63 | return x
64 |
65 | def get_config(self):
66 | config = super().get_config()
67 | config.update(
68 | {
69 | "dim": self.dim,
70 | "out_dim": self.out_dim,
71 | "norm": self.norm,
72 | }
73 | )
74 | return config
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/layers/patch_splitting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | from functools import partial
17 |
18 | import tensorflow as tf
19 | from tensorflow.keras import layers as L
20 |
21 |
22 | class PatchSplitting(L.Layer):
23 | """Patch Splitting Layer as described in
24 | https://openreview.net/pdf?id=IDwN6xjHnK8 (section 3.1)
25 | # Patch Split = [Linear, LayerNorm, Depth-to-Space (for upsampling)]
26 | Args:
27 | dim (int): Number of input channels.
28 | """
29 |
30 | def __init__(
31 | self,
32 | dim,
33 | out_dim=None,
34 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5),
35 | **kwargs
36 | ):
37 | super().__init__(**kwargs)
38 | self.dim = dim
39 | self.out_dim = out_dim or dim
40 | self.norm = norm_layer()
41 | self.reduction = L.Dense(self.out_dim * 4, use_bias=False)
42 |
43 | def call(self, x):
44 | """
45 | x: B, H, W, C
46 | """
47 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
48 | x = tf.reshape(x, (-1, H * W, C))
49 |
50 | x = self.reduction(x)
51 | x = self.norm(x)
52 |
53 | x = tf.reshape(x, (-1, H, W, self.out_dim * 4))
54 | x = tf.nn.depth_to_space(x, 2, data_format='NHWC')
55 | x = tf.reshape(x, (-1, 2 * H, 2 * W, self.out_dim))
56 |
57 | return x
58 |
59 | def get_config(self):
60 | config = super().get_config()
61 | config.update(
62 | {
63 | "dim": self.dim,
64 | "out_dim": self.out_dim,
65 | "norm": self.norm,
66 | }
67 | )
68 | return config
69 |
70 |
71 | class PatchUnpack(L.Layer):
72 | """Patch Unpack Layer
73 | # PatchUnpack = [Linear, Depth-to-Space (for upsampling)]
74 |
75 | Key differences to PatchSplitting:
76 | - no LayerNorm
77 | - use_bias=True (self.reduction)
78 |
79 | Args:
80 | input_resolution (tuple[int]): Resolution of input feature.
81 | dim (int): Number of input channels.
82 | """
83 |
84 | def __init__(
85 | self,
86 | dim,
87 | out_dim=None,
88 | norm_layer=None,
89 | **kwargs
90 | ):
91 | super().__init__(**kwargs)
92 | self.dim = dim
93 | self.out_dim = out_dim or dim
94 | self.reduction = L.Dense(self.out_dim * 4, use_bias=True)
95 |
96 | def call(self, x):
97 | """
98 | x: B, H, W, C
99 | """
100 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
101 | x = tf.reshape(x, (-1, H * W, C))
102 |
103 | x = self.reduction(x)
104 |
105 | x = tf.reshape(x, (-1, H, W, self.out_dim * 4))
106 | x = tf.nn.depth_to_space(x, 2, data_format='NHWC')
107 | x = tf.reshape(x, (-1, 2 * H, 2 * W, self.out_dim))
108 |
109 | return x
110 |
111 | def get_config(self):
112 | config = super().get_config()
113 | config.update(
114 | {
115 | "dim": self.dim,
116 | "out_dim": self.out_dim,
117 | }
118 | )
119 | return config
120 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/layers/sd.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import layers
3 |
4 |
5 | # Referred from: github.com:rwightman/pytorch-image-models.
6 | class StochasticDepth(layers.Layer):
7 | def __init__(self, drop_prop, **kwargs):
8 | super().__init__(**kwargs)
9 | self.drop_prob = float(drop_prop)
10 |
11 | def call(self, x, training=False):
12 | if training:
13 | keep_prob = 1 - self.drop_prob
14 | shape = (tf.shape(x)[0],) + (1,) * (tf.shape(tf.shape(x)) - 1)
15 | random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)
16 | random_tensor = tf.floor(random_tensor)
17 | return (x / keep_prob) * random_tensor
18 | return x
19 |
20 | def get_config(self):
21 | config = super().get_config()
22 | config.update({"drop_prob": self.drop_prob})
23 | return config
24 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/layers/window_attn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Nikolai Körber. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """
17 | Code copied and modified from
18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
19 | """
20 |
21 |
22 | import collections.abc
23 | from typing import Tuple, Union
24 |
25 | import tensorflow as tf
26 | from tensorflow.keras import layers
27 |
28 |
29 | def get_relative_position_index(win_h, win_w):
30 | # get pair-wise relative position index for each token inside the window
31 | xx, yy = tf.meshgrid(range(win_h), range(win_w))
32 | coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww]
33 | coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww]
34 |
35 | relative_coords = (
36 | coords_flatten[:, :, None] - coords_flatten[:, None, :]
37 | ) # [2, Wh*Ww, Wh*Ww]
38 | relative_coords = tf.transpose(
39 | relative_coords, perm=[1, 2, 0]
40 | ) # [Wh*Ww, Wh*Ww, 2]
41 |
42 | xx = (relative_coords[:, :, 0] + win_h - 1) * (2 * win_w - 1)
43 | yy = relative_coords[:, :, 1] + win_w - 1
44 | relative_coords = tf.stack([xx, yy], axis=-1)
45 |
46 | return tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww]
47 |
48 |
49 | class WindowAttention(layers.Layer):
50 | """Window based multi-head self attention (W-MSA) module with relative position bias.
51 | It supports both of shifted and non-shifted window.
52 |
53 | Args:
54 | dim (int): Number of input channels.
55 | num_heads (int): Number of attention heads.
56 | head_dim (int): Number of channels per head (dim // num_heads if not set)
57 | window_size (tuple[int]): The height and width of the window.
58 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
59 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
60 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
61 | """
62 |
63 | def __init__(
64 | self,
65 | dim,
66 | num_heads,
67 | head_dim=None,
68 | window_size=7,
69 | qkv_bias=True,
70 | attn_drop=0.0,
71 | proj_drop=0.0,
72 | **kwargs,
73 | ):
74 |
75 | super().__init__(**kwargs)
76 |
77 | self.dim = dim
78 | self.window_size = (
79 | window_size
80 | if isinstance(window_size, collections.abc.Iterable)
81 | else (window_size, window_size)
82 | ) # Wh, Ww
83 | self.win_h, self.win_w = self.window_size
84 | self.window_area = self.win_h * self.win_w
85 | self.num_heads = num_heads
86 | self.head_dim = head_dim or (dim // num_heads)
87 | self.attn_dim = self.head_dim * num_heads
88 | self.scale = self.head_dim ** -0.5
89 |
90 | # get pair-wise relative position index for each token inside the window
91 | self.relative_position_index = get_relative_position_index(
92 | self.win_h, self.win_w
93 | )
94 |
95 | self.qkv = layers.Dense(
96 | self.attn_dim * 3, use_bias=qkv_bias, name="attention_qkv"
97 | )
98 | self.attn_drop = layers.Dropout(attn_drop)
99 | self.proj = layers.Dense(dim, name="attention_projection")
100 | self.proj_drop = layers.Dropout(proj_drop)
101 |
102 | def build(self, input_shape):
103 | self.relative_position_bias_table = self.add_weight(
104 | shape=((2 * self.win_h - 1) * (2 * self.win_w - 1), self.num_heads),
105 | initializer="zeros",
106 | trainable=True,
107 | name="relative_position_bias_table",
108 | )
109 | super().build(input_shape)
110 |
111 | def _get_rel_pos_bias(self) -> tf.Tensor:
112 | relative_position_bias = tf.gather(
113 | self.relative_position_bias_table,
114 | self.relative_position_index,
115 | axis=0,
116 | )
117 | return tf.transpose(relative_position_bias, [2, 0, 1])
118 |
119 | def call(
120 | self, x, mask=None, return_attns=False
121 | ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]:
122 | """
123 | Args:
124 | x: input features with shape of (num_windows*B, N, C)
125 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
126 | """
127 | B_, N, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
128 | qkv = self.qkv(x)
129 | qkv = tf.reshape(qkv, (B_, N, 3, self.num_heads, -1))
130 | qkv = tf.transpose(qkv, (2, 0, 3, 1, 4))
131 |
132 | q, k, v = tf.unstack(qkv, 3)
133 |
134 | scale = tf.cast(self.scale, dtype=qkv.dtype)
135 | q = q * scale
136 | attn = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2]))
137 | attn = attn + self._get_rel_pos_bias()
138 |
139 | if mask is not None:
140 | num_win = tf.shape(mask)[0]
141 | attn = tf.reshape(
142 | attn, (B_ // num_win, num_win, self.num_heads, N, N)
143 | )
144 | attn = attn + tf.expand_dims(mask, 1)[None, ...]
145 |
146 | attn = tf.reshape(attn, (-1, self.num_heads, N, N))
147 | attn = tf.nn.softmax(attn, -1)
148 | else:
149 | attn = tf.nn.softmax(attn, -1)
150 |
151 | attn = self.attn_drop(attn)
152 |
153 | x = tf.matmul(attn, v)
154 | x = tf.transpose(x, perm=[0, 2, 1, 3])
155 | x = tf.reshape(x, (B_, N, C))
156 |
157 | x = self.proj(x)
158 | x = self.proj_drop(x)
159 |
160 | if return_attns:
161 | return x, attn
162 | else:
163 | return x
164 |
165 | def get_config(self):
166 | config = super().get_config()
167 | config.update(
168 | {
169 | "dim": self.dim,
170 | "window_size": self.window_size,
171 | "win_h": self.win_h,
172 | "win_w": self.win_w,
173 | "num_heads": self.num_heads,
174 | "head_dim": self.head_dim,
175 | "attn_dim": self.attn_dim,
176 | "scale": self.scale,
177 | }
178 | )
179 | return config
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/model_configs.py:
--------------------------------------------------------------------------------
1 | # Take from here:
2 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
3 |
4 |
5 | def swin_base_patch4_window12_384():
6 | """Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k"""
7 | cfg = dict(
8 | img_size=384,
9 | patch_size=4,
10 | window_size=12,
11 | embed_dim=128,
12 | depths=(2, 2, 18, 2),
13 | num_heads=(4, 8, 16, 32),
14 | name="swin_base_patch4_window12_384",
15 | )
16 | return cfg
17 |
18 |
19 | def swin_base_patch4_window7_224():
20 | """Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k"""
21 | cfg = dict(
22 | patch_size=4,
23 | window_size=7,
24 | embed_dim=128,
25 | depths=(2, 2, 18, 2),
26 | num_heads=(4, 8, 16, 32),
27 | name="swin_base_patch4_window7_224",
28 | )
29 | return cfg
30 |
31 |
32 | def swin_large_patch4_window12_384():
33 | """Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k"""
34 | cfg = dict(
35 | img_size=384,
36 | patch_size=4,
37 | window_size=12,
38 | embed_dim=192,
39 | depths=(2, 2, 18, 2),
40 | num_heads=(6, 12, 24, 48),
41 | name="swin_large_patch4_window12_384",
42 | )
43 | return cfg
44 |
45 |
46 | def swin_large_patch4_window7_224():
47 | """Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k"""
48 | cfg = dict(
49 | patch_size=4,
50 | window_size=7,
51 | embed_dim=192,
52 | depths=(2, 2, 18, 2),
53 | num_heads=(6, 12, 24, 48),
54 | name="swin_large_patch4_window7_224",
55 | )
56 | return cfg
57 |
58 |
59 | def swin_small_patch4_window7_224():
60 | """Swin-S @ 224x224, trained ImageNet-1k"""
61 | cfg = dict(
62 | patch_size=4,
63 | window_size=7,
64 | embed_dim=96,
65 | depths=(2, 2, 18, 2),
66 | num_heads=(3, 6, 12, 24),
67 | name="swin_small_patch4_window7_224",
68 | )
69 | return cfg
70 |
71 |
72 | def swin_tiny_patch4_window7_224():
73 | """Swin-T @ 224x224, trained ImageNet-1k"""
74 | cfg = dict(
75 | patch_size=4,
76 | window_size=7,
77 | embed_dim=96,
78 | depths=(2, 2, 6, 2),
79 | num_heads=(3, 6, 12, 24),
80 | name="swin_tiny_patch4_window7_224",
81 | )
82 | return cfg
83 |
84 |
85 | def swin_base_patch4_window12_384_in22k():
86 | """Swin-B @ 384x384, trained ImageNet-22k"""
87 | cfg = dict(
88 | img_size=384,
89 | patch_size=4,
90 | window_size=12,
91 | embed_dim=128,
92 | depths=(2, 2, 18, 2),
93 | num_heads=(4, 8, 16, 32),
94 | name="swin_base_patch4_window12_384_in22k",
95 | num_classes=21841,
96 | )
97 | return cfg
98 |
99 |
100 | def swin_base_patch4_window7_224_in22k():
101 | """Swin-B @ 224x224, trained ImageNet-22k"""
102 | cfg = dict(
103 | patch_size=4,
104 | window_size=7,
105 | embed_dim=128,
106 | depths=(2, 2, 18, 2),
107 | num_heads=(4, 8, 16, 32),
108 | name="swin_base_patch4_window7_224_in22k",
109 | num_classes=21841,
110 | )
111 | return cfg
112 |
113 |
114 | def swin_large_patch4_window12_384_in22k():
115 | """Swin-L @ 384x384, trained ImageNet-22k"""
116 | cfg = dict(
117 | img_size=384,
118 | patch_size=4,
119 | window_size=12,
120 | embed_dim=192,
121 | depths=(2, 2, 18, 2),
122 | num_heads=(6, 12, 24, 48),
123 | name="swin_large_patch4_window12_384_in22k",
124 | num_classes=21841,
125 | )
126 | return cfg
127 |
128 |
129 | def swin_large_patch4_window7_224_in22k():
130 | """Swin-L @ 224x224, trained ImageNet-22k"""
131 | cfg = dict(
132 | patch_size=4,
133 | window_size=7,
134 | embed_dim=192,
135 | depths=(2, 2, 18, 2),
136 | num_heads=(6, 12, 24, 48),
137 | name="swin_large_patch4_window7_224_in22k",
138 | num_classes=21841,
139 | )
140 | return cfg
141 |
142 |
143 | def swin_s3_tiny_224():
144 | """Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725"""
145 | cfg = dict(
146 | patch_size=4,
147 | window_size=(7, 7, 14, 7),
148 | embed_dim=96,
149 | depths=(2, 2, 6, 2),
150 | num_heads=(3, 6, 12, 24),
151 | name="swin_s3_tiny_224",
152 | )
153 | return cfg
154 |
155 |
156 | def swin_s3_small_224():
157 | """Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725"""
158 | cfg = dict(
159 | patch_size=4,
160 | window_size=(14, 14, 14, 7),
161 | embed_dim=96,
162 | depths=(2, 2, 18, 2),
163 | num_heads=(3, 6, 12, 24),
164 | name="swin_s3_small_224",
165 | )
166 | return cfg
167 |
168 |
169 | def swin_s3_base_224():
170 | """Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725"""
171 | cfg = dict(
172 | patch_size=4,
173 | window_size=(7, 7, 14, 7),
174 | embed_dim=96,
175 | depths=(2, 2, 30, 2),
176 | num_heads=(3, 6, 12, 24),
177 | name="swin_s3_base_224",
178 | )
179 | return cfg
180 |
181 |
182 | MODEL_MAP = {
183 | "swin_base_patch4_window12_384": swin_base_patch4_window12_384,
184 | "swin_base_patch4_window7_224": swin_base_patch4_window7_224,
185 | "swin_large_patch4_window12_384": swin_large_patch4_window12_384,
186 | "swin_large_patch4_window7_224": swin_large_patch4_window7_224,
187 | "swin_small_patch4_window7_224": swin_small_patch4_window7_224,
188 | "swin_tiny_patch4_window7_224": swin_tiny_patch4_window7_224,
189 | "swin_base_patch4_window12_384_in22k": swin_base_patch4_window12_384_in22k,
190 | "swin_base_patch4_window7_224_in22k": swin_base_patch4_window7_224_in22k,
191 | "swin_large_patch4_window12_384_in22k": swin_large_patch4_window12_384_in22k,
192 | "swin_large_patch4_window7_224_in22k": swin_large_patch4_window7_224_in22k,
193 | "swin_s3_tiny_224": swin_s3_tiny_224,
194 | "swin_s3_small_224": swin_s3_small_224,
195 | "swin_s3_base_224": swin_s3_base_224,
196 | }
197 |
--------------------------------------------------------------------------------
/swin-transformers-tf/swins/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Code copied and modified from
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py
4 | """
5 |
6 | import collections.abc
7 | from functools import partial
8 | from itertools import repeat
9 | from typing import Dict
10 |
11 | import tensorflow as tf
12 | from tensorflow import keras
13 | from tensorflow.keras import layers as L
14 |
15 | from .blocks import BasicLayer
16 | from .layers import PatchMerging
17 |
18 |
19 | # https://github.com/rwightman/pytorch-image-models/blob/6d4665bb52390974e0cf9674c60c41946d2f4ee2/timm/models/layers/helpers.py#L10
20 | def to_ntuple(n):
21 | def parse(x):
22 | if isinstance(x, collections.abc.Iterable):
23 | return x
24 | return tuple(repeat(x, n))
25 |
26 | return parse
27 |
28 |
29 | class SwinTransformer(keras.Model):
30 | """Swin Transformer
31 | A TensorFlow impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
32 | https://arxiv.org/pdf/2103.14030
33 |
34 | Args:
35 | img_size (int | tuple(int)): Input image size. Default 224
36 | patch_size (int | tuple(int)): Patch size. Default: 4
37 | num_classes (int): Number of classes for classification head. Default: 1000
38 | embed_dim (int): Patch embedding dimension. Default: 96
39 | depths (tuple(int)): Depth of each Swin Transformer layer.
40 | num_heads (tuple(int)): Number of attention heads in different layers.
41 | head_dim (int, tuple(int)):
42 | window_size (int): Window size. Default: 7
43 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
44 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
45 | drop_rate (float): Dropout rate. Default: 0
46 | attn_drop_rate (float): Attention dropout rate. Default: 0
47 | drop_path_rate (float): Stochastic depth rate. Default: 0.1
48 | norm_layer (layers.Layer): Normalization layer. Default: layers.LayerNormalization.
49 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
50 | patch_norm (bool): If True, add normalization after patch embedding. Default: True
51 | pre_logits (bool): If True, return model without classification head. Default: False
52 | """
53 |
54 | def __init__(
55 | self,
56 | img_size=224,
57 | patch_size=4,
58 | num_classes=1000,
59 | global_pool="avg",
60 | embed_dim=96,
61 | depths=(2, 2, 6, 2),
62 | num_heads=(3, 6, 12, 24),
63 | head_dim=None,
64 | window_size=7,
65 | mlp_ratio=4.0,
66 | qkv_bias=True,
67 | drop_rate=0.0,
68 | attn_drop_rate=0.0,
69 | drop_path_rate=0.1,
70 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5),
71 | ape=False,
72 | patch_norm=True,
73 | pre_logits=False,
74 | **kwargs,
75 | ):
76 | super().__init__(**kwargs)
77 |
78 | self.img_size = (
79 | img_size
80 | if isinstance(img_size, collections.abc.Iterable)
81 | else (img_size, img_size)
82 | )
83 | self.patch_size = (
84 | patch_size
85 | if isinstance(patch_size, collections.abc.Iterable)
86 | else (patch_size, patch_size)
87 | )
88 |
89 | self.num_classes = num_classes
90 | self.global_pool = global_pool
91 | self.num_layers = len(depths)
92 | self.embed_dim = embed_dim
93 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
94 | self.ape = ape
95 |
96 | # split image into non-overlapping patches
97 | self.projection = keras.Sequential(
98 | [
99 | L.Conv2D(
100 | filters=embed_dim,
101 | kernel_size=(patch_size, patch_size),
102 | strides=(patch_size, patch_size),
103 | padding="VALID",
104 | name="conv_projection",
105 | kernel_initializer="lecun_normal",
106 | ),
107 | L.Reshape(
108 | target_shape=(-1, embed_dim),
109 | name="flatten_projection",
110 | ),
111 | ],
112 | name="projection",
113 | )
114 | if patch_norm:
115 | self.projection.add(norm_layer())
116 |
117 | self.patch_grid = (
118 | self.img_size[0] // self.patch_size[0],
119 | self.img_size[1] // self.patch_size[1],
120 | )
121 | self.num_patches = self.patch_grid[0] * self.patch_grid[1]
122 |
123 | # absolute position embedding
124 | if self.ape:
125 | self.absolute_pos_embed = tf.Variable(
126 | tf.zeros((1, self.num_patches, self.embed_dim)),
127 | trainable=True,
128 | name="absolute_pos_embed",
129 | )
130 | else:
131 | self.absolute_pos_embed = None
132 | self.pos_drop = L.Dropout(drop_rate)
133 |
134 | # build layers
135 | if not isinstance(self.embed_dim, (tuple, list)):
136 | self.embed_dim = [
137 | int(self.embed_dim * 2 ** i) for i in range(self.num_layers)
138 | ]
139 | embed_out_dim = self.embed_dim[1:] + [None]
140 | head_dim = to_ntuple(self.num_layers)(head_dim)
141 | window_size = to_ntuple(self.num_layers)(window_size)
142 | mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio)
143 | dpr = [
144 | float(x) for x in tf.linspace(0.0, drop_path_rate, sum(depths))
145 | ] # stochastic depth decay rule
146 |
147 | layers = [
148 | BasicLayer(
149 | dim=self.embed_dim[i],
150 | out_dim=embed_out_dim[i],
151 | input_resolution=(
152 | self.patch_grid[0] // (2 ** i),
153 | self.patch_grid[1] // (2 ** i),
154 | ),
155 | depth=depths[i],
156 | num_heads=num_heads[i],
157 | head_dim=head_dim[i],
158 | window_size=window_size[i],
159 | mlp_ratio=mlp_ratio[i],
160 | qkv_bias=qkv_bias,
161 | drop=drop_rate,
162 | attn_drop=attn_drop_rate,
163 | drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
164 | norm_layer=norm_layer,
165 | downsample=PatchMerging if (i < self.num_layers - 1) else None,
166 | name=f"basic_layer_{i}",
167 | )
168 | for i in range(self.num_layers)
169 | ]
170 | self.swin_layers = layers
171 |
172 | self.norm = norm_layer()
173 |
174 | self.pre_logits = pre_logits
175 | if not self.pre_logits:
176 | self.head = L.Dense(num_classes, name="classification_head")
177 |
178 | def forward_features(self, x):
179 | x = self.projection(x)
180 | if self.absolute_pos_embed is not None:
181 | x = x + self.absolute_pos_embed
182 | x = self.pos_drop(x)
183 |
184 | for swin_layer in self.swin_layers:
185 | x = swin_layer(x)
186 |
187 | x = self.norm(x) # [B, L, C]
188 | return x
189 |
190 | def forward_head(self, x):
191 | if self.global_pool == "avg":
192 | x = tf.reduce_mean(x, axis=1)
193 | return x if self.pre_logits else self.head(x)
194 |
195 | def call(self, x):
196 | x = self.forward_features(x)
197 | x = self.forward_head(x)
198 | return x
199 |
200 | # Thanks to Willi Gierke for this suggestion.
201 | @tf.function(
202 | input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)]
203 | )
204 | def get_attention_scores(
205 | self, x: tf.Tensor
206 | ) -> Dict[str, Dict[str, tf.Tensor]]:
207 | all_attention_scores = {}
208 |
209 | x = self.projection(x)
210 | if self.absolute_pos_embed is not None:
211 | x = x + self.absolute_pos_embed
212 | x = self.pos_drop(x)
213 |
214 | for i, swin_layer in enumerate(self.swin_layers):
215 | x, attention_scores = swin_layer(x, return_attns=True)
216 | all_attention_scores.update({f"swin_stage_{i}": attention_scores})
217 |
218 | return all_attention_scores
219 |
--------------------------------------------------------------------------------
/swin-transformers-tf/test.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from swins import SwinTransformer
4 |
5 | cfg = dict(
6 | patch_size=4,
7 | window_size=7,
8 | embed_dim=128,
9 | depths=(2, 2, 18, 2),
10 | num_heads=(4, 8, 16, 32),
11 | )
12 |
13 | swin_base_patch4_window7_224 = SwinTransformer(
14 | name="swin_base_patch4_window7_224", **cfg
15 | )
16 | print("Model instantiated, attempting predictions...")
17 | random_tensor = tf.random.normal((2, 224, 224, 3))
18 | outputs = swin_base_patch4_window7_224(random_tensor, training=False)
19 |
20 | print(outputs.shape)
21 |
22 | print(swin_base_patch4_window7_224.count_params() / 1e6)
23 |
--------------------------------------------------------------------------------
/swin-transformers-tf/utils/helpers.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | import numpy as np
4 | import tensorflow as tf
5 |
6 |
7 | def conv_transpose(w: np.ndarray) -> np.ndarray:
8 | """Transpose the weights of a PT conv layer so that it's comaptible with TF."""
9 | return w.transpose(2, 3, 1, 0)
10 |
11 |
12 | def modify_tf_block(
13 | tf_component: Union[tf.keras.layers.Layer, tf.Variable, tf.Tensor],
14 | pt_weight: np.ndarray,
15 | pt_bias: np.ndarray = None,
16 | is_attn: bool = False,
17 | ) -> Union[tf.keras.layers.Layer, tf.Variable, tf.Tensor]:
18 | """General utility for modifying PT parameters for TF compatibility.
19 | Applicable for Conv2D, Dense, tf.Variable, and LayerNormalization.
20 | """
21 | pt_weight = (
22 | conv_transpose(pt_weight)
23 | if isinstance(tf_component, tf.keras.layers.Conv2D)
24 | else pt_weight
25 | )
26 | pt_weight = (
27 | pt_weight.transpose()
28 | if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn
29 | else pt_weight
30 | )
31 |
32 | if isinstance(
33 | tf_component, (tf.keras.layers.Dense, tf.keras.layers.Conv2D)
34 | ):
35 | tf_component.kernel.assign(tf.Variable(pt_weight))
36 | if pt_bias is not None:
37 | tf_component.bias.assign(tf.Variable(pt_bias))
38 | elif isinstance(tf_component, tf.keras.layers.LayerNormalization):
39 | tf_component.gamma.assign(tf.Variable(pt_weight))
40 | tf_component.beta.assign(tf.Variable(pt_bias))
41 | elif isinstance(tf_component, (tf.Variable)):
42 | # For regular variables (tf.Variable).
43 | tf_component.assign(tf.Variable(pt_weight))
44 | else:
45 | return tf.convert_to_tensor(pt_weight)
46 |
47 | return tf_component
48 |
--------------------------------------------------------------------------------