├── LICENSE ├── README.md ├── config.py ├── res ├── doc │ ├── figures │ │ └── teaser.png │ └── notes.txt └── eval │ ├── info.txt │ ├── kodak_results.txt │ ├── kodim01_hat.png │ ├── kodim02_hat.png │ ├── kodim03_hat.png │ ├── kodim04_hat.png │ ├── kodim05_hat.png │ ├── kodim06_hat.png │ ├── kodim07_hat.png │ ├── kodim08_hat.png │ ├── kodim09_hat.png │ ├── kodim10_hat.png │ ├── kodim11_hat.png │ ├── kodim12_hat.png │ ├── kodim13_hat.png │ ├── kodim14_hat.png │ ├── kodim15.png │ ├── kodim15_hat.png │ ├── kodim16_hat.png │ ├── kodim17_hat.png │ ├── kodim18_hat.png │ ├── kodim19_hat.png │ ├── kodim20_hat.png │ ├── kodim21_hat.png │ ├── kodim22.png │ ├── kodim22_hat.png │ ├── kodim23.png │ ├── kodim23_hat.png │ └── kodim24_hat.png ├── swin-transformers-tf ├── LICENSE ├── README.md ├── changelog.txt ├── convert.py ├── convert_all_models.py ├── hub_utilities │ ├── README.md │ ├── export_for_hub.py │ └── generate_doc.py ├── in1k-eval │ ├── README.md │ ├── df.ipynb │ ├── eval-swins.ipynb │ ├── swin_224_in1k.csv │ └── swin_384_in1k.csv ├── notebooks │ ├── classification.ipynb │ ├── finetune.ipynb │ ├── ilsvrc2012_wordnet_lemmas.txt │ └── weight-porting.ipynb ├── requirements.txt ├── swins │ ├── __init__.py │ ├── blocks │ │ ├── __init__.py │ │ ├── mlp.py │ │ ├── stage_block.py │ │ ├── swin_transformer_block.py │ │ └── utils.py │ ├── layers │ │ ├── __init__.py │ │ ├── patch_merging.py │ │ ├── patch_splitting.py │ │ ├── sd.py │ │ └── window_attn.py │ ├── model_configs.py │ └── models.py ├── test.py └── utils │ └── helpers.py └── zyc2022.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SwinT-ChARM (TensorFlow 2) 2 | 3 | [](https://colab.research.google.com/drive/1w42soP9daxok4f6jvh3nen1DX0z4KSRb?usp=sharing) 4 | 5 | This repository provides a TensorFlow implementation of SwinT-ChARM based on: 6 | 7 | - [Transformer-Based Transform Coding (ICLR 2022)](https://openreview.net/pdf?id=IDwN6xjHnK8), 8 | - [Channel-wise Autoregressive Entropy Models For Learned Image Compression (ICIP 2020)](https://arxiv.org/pdf/2007.08739.pdf). 9 | 10 | ![SwinT-ChARM net arch](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/doc/figures/teaser.png) 11 | 12 | [Source](https://openreview.net/pdf?id=IDwN6xjHnK8) 13 | 14 | 15 | ## Updates 16 | 17 | ***10/06/2023*** 18 | 1. LIC-TCM (TensorFlow 2) is now available: https://github.com/Nikolai10/LIC-TCM (Liu et al. CVPR 2023 Highlight). 19 | 20 | ***09/06/2023*** 21 | 1. The high quality of this reimplementation has been confirmed in [EGIC, Section A.8.](https://arxiv.org/pdf/2309.03244v1.pdf). 22 | 23 | ***10/09/2022*** 24 | 25 | 1. The number of model parameters now corresponds exactly to the reported number (32.6 million). We thank the authors for providing us with the official DeepSpeed log files. 26 | 2. SwinT-ChARM now supports compression at different input resolutions (multiples of 256). 27 | 3. We release a pre-trained model as proof of functional correctness. 28 | 29 | ***08/17/2022*** 30 | 31 | 1. Initial release of this project (see branch *release_08/17/2022*) 32 | 33 | ## Acknowledgment 34 | This project is based on: 35 | 36 | - [TensorFlow Compression (TFC)](https://github.com/tensorflow/compression), a TF library dedicated to data compression. 37 | - [swin-transformers-tf](https://github.com/sayakpaul/swin-transformers-tf), an unofficial implementation of [Swin-Transformer](https://github.com/microsoft/Swin-Transformer). Functional correctness has been [proven](https://github.com/microsoft/Swin-Transformer/pull/206). 38 | 39 | Note that this repository builds upon the official TF implementation of [Minnen et al.](https://github.com/tensorflow/compression/blob/master/models/ms2020.py), while Zhu et al. base their work on an 40 | unknown (possibly not publicly available) PyTorch reimplementation. 41 | 42 | ## Examples 43 | 44 | The samples below are taken from the [Kodak dataset](http://r0k.us/graphics/kodak/), external to the training set: 45 | 46 | Original | SwinT-ChARM (β = 0.0003) 47 | :-------------------------:|:-------------------------: 48 | ![kodim22.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim22.png) | ![kodim22_hat.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim22_hat.png) 49 | 50 | ```python 51 | Mean squared error: 13.7772 52 | PSNR (dB): 36.74 53 | Multiscale SSIM: 0.9871 54 | Multiscale SSIM (dB): 18.88 55 | Bits per pixel: 0.9890 56 | ``` 57 | 58 | Original | SwinT-ChARM (β = 0.0003) 59 | :-------------------------:|:-------------------------: 60 | ![kodim23.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim23.png) | ![kodim23_hat.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim23_hat.png) 61 | 62 | ```python 63 | Mean squared error: 7.1963 64 | PSNR (dB): 39.56 65 | Multiscale SSIM: 0.9903 66 | Multiscale SSIM (dB): 20.13 67 | Bits per pixel: 0.3953 68 | ``` 69 | 70 | Original | SwinT-ChARM (β = 0.0003) 71 | :-------------------------:|:-------------------------: 72 | ![kodim15.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim15.png) | ![kodim15_hat.png](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodim15_hat.png) 73 | 74 | ```python 75 | Mean squared error: 10.1494 76 | PSNR (dB): 38.07 77 | Multiscale SSIM: 0.9888 78 | Multiscale SSIM (dB): 19.49 79 | Bits per pixel: 0.6525 80 | ``` 81 | More examples can be found [here](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval). 82 | 83 | ## Pretrained Models/ Performance (TFC 2.8) 84 | 85 | Our pre-trained model (β = 0.0003) achieves a [PSNR of 37.59 (db) using an average of 0.93 bpp](https://github.com/Nikolai10/SwinT-ChARM/blob/master/res/eval/kodak_results.txt) on the [Kodak dataset](http://r0k.us/graphics/kodak/), which is very close to the reported numbers (see [paper](https://openreview.net/pdf?id=IDwN6xjHnK8), Figure 3). Worth mentioning: we achieve this result despite training our model from scratch and using less than one-third of the computational resources (1M optimization steps). 86 | 87 | 88 | 89 | | Lagrangian multiplier (β) | SavedModel | Training Instructions | 90 | | ----------- | -------------------------------- | ---------------------- | 91 | | 0.0003 | [download](https://drive.google.com/drive/folders/1bUWowLEgU8ukYejvVPbhU38_7japhp7C?usp=sharing) |
`!python SwinT-ChARM/zyc2022.py -V --model_path <...> train --max_support_slices 10 --lambda 0.0003 --epochs 1000 --batchsize 16 --train_path <...>`
| 92 | 93 | ## File Structure 94 | 95 | res 96 | ├── doc/ # addtional resources 97 | ├── eval/ # sample images + reconstructions 98 | ├── train_zyc2022/ # model checkpoints + tf.summaries 99 | ├── zyc2022/ # saved model 100 | swin-transformers-tf/ # extended swin-transformers-tf implementation 101 | ├── changelog.txt # summary of changes made to the original work 102 | ├── ... 103 | config.py # model-dependent configurations 104 | zyc2022.py # core of this repo 105 | 106 | ## License 107 | [Apache License 2.0](LICENSE) 108 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """https://openreview.net/pdf?id=IDwN6xjHnK8, Appendix A2 17 | 18 | "SwinT-Hyperprior, SwinT-ChARM For both SwinT-Hyperprior and SwinT-ChARM, we use the 19 | same configurations: (wg, wh) = (8, 4), (C1, C2, C3, C4, C5, C6) = (128, 192, 256, 320, 192, 192), 20 | (d1, d2, d3, d4, d5, d6) = (2, 2, 6, 2, 5, 1) where C, d, and w are defined in Figure 13 and Figure 2. 21 | The head dim is 32 for all attention layers in SwinT-based models." 22 | 23 | """ 24 | 25 | class ConfigGa: 26 | embed_dim = [128, 192, 256, 320] 27 | embed_out_dim = [192, 256, 320, None] 28 | depths = [2, 2, 6, 2] 29 | head_dim = [32, 32, 32, 32] 30 | window_size = [8, 8, 8, 8] 31 | num_layers = len(depths) 32 | 33 | class ConfigHa: 34 | embed_dim = [192, 192] 35 | embed_out_dim = [192, None] 36 | depths = [5, 1] 37 | head_dim = [32, 32] 38 | window_size = [4, 4] 39 | num_layers = len(depths) 40 | 41 | class ConfigHs: 42 | embed_dim = [192, 192] 43 | embed_out_dim = [192, int(2*320)] 44 | depths = [1, 5] 45 | head_dim = [32, 32] 46 | window_size = [4, 4] 47 | num_layers = len(depths) 48 | 49 | class ConfigGs: 50 | embed_dim = [320, 256, 192, 128] 51 | embed_out_dim = [256, 192, 128, 3] 52 | depths = [2, 6, 2, 2] 53 | head_dim = [32, 32, 32, 32] 54 | window_size = [8, 8, 8, 8] 55 | num_layers = len(depths) 56 | 57 | """https://arxiv.org/pdf/2007.08739.pdf, Appendix A 58 | 59 | "To account for the different input depths, each CC [...] transform is 60 | programmatically defined to linearly interpolate between the input and the 61 | output depth. 62 | 63 | Note: In SwinT-ChARM a slightly different logic is used, which is why the depth 64 | is explicitly hardcoded here. 65 | 66 | For example, the tenth slice should have depths: 224, 128 and 32. 67 | 68 | import pandas as pd 69 | import numpy as np 70 | a=pd.Series([320, np.nan, np.nan, 32]) 71 | 72 | a.interpolate(method='linear') 73 | 74 | vs. 234, 117, 32 (taken from the official deepspeed logfile, 75 | which was provided by the authors). 76 | 77 | """ 78 | 79 | class ConfigChARM: 80 | depths_conv0 = [64, 64, 85, 106, 128, 149, 170, 192, 213, 234] 81 | depths_conv1 = [32, 32, 42, 53, 64, 74, 85, 96, 106, 117] -------------------------------------------------------------------------------- /res/doc/figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/doc/figures/teaser.png -------------------------------------------------------------------------------- /res/doc/notes.txt: -------------------------------------------------------------------------------- 1 | ## Technical Differences 2 | 3 | IMPORTANT: only affects branch release_08/17/2022 (the differences have been resolved!) 4 | 5 | In comparison to Zhu et al.: 6 | 7 | - we keep LPR (lrp_transforms) -> no motivation given for removal. 8 | - for CC, LPR we keep the approach described in Minnen et al., Appendix A.; i.e. 9 | the input to slice_1 is 320, input to slice_2 is 352 etc. rather than 32, 64, 10 | ... as described in Zhu et al. Appendix A, Figure 12. 11 | 12 | Further technical deviations are possible, as we base our work on the official 13 | TF implementation of Minnen et al., while Zhu et al. base their work on an 14 | unknown (possibly not publicly available) PyTorch reimplementation. -------------------------------------------------------------------------------- /res/eval/info.txt: -------------------------------------------------------------------------------- 1 | Kodak 2 | The Kodak data set is a collection of 24 images with resolution 768x512 (or 512x768). 3 | The images are available as PNG files here: http://r0k.us/graphics/kodak 4 | 5 | @misc{kodak, 6 | title = "Kodak Lossless True Color Image Suite ({PhotoCD PCD0992})", 7 | author = "Eastman Kodak", 8 | url = "http://r0k.us/graphics/kodak", 9 | } -------------------------------------------------------------------------------- /res/eval/kodak_results.txt: -------------------------------------------------------------------------------- 1 | The following values were generated by the official evaluation scripts provided by TFC (Minnen): 2 | https://github.com/tensorflow/compression/tree/master/results/image_compression 3 | 4 | SwinT-ChARM (beta=0.0003) 5 | 6 | ------------------------------ 7 | kodim1 8 | ------------------------------ 9 | Mean squared error: 14.7817 10 | PSNR (dB): 36.43 11 | Multiscale SSIM: 0.9941 12 | Multiscale SSIM (dB): 22.31 13 | Bits per pixel: 1.5605 14 | 15 | ------------------------------ 16 | kodim2 17 | ------------------------------ 18 | Mean squared error: 11.9409 19 | PSNR (dB): 37.36 20 | Multiscale SSIM: 0.9828 21 | Multiscale SSIM (dB): 17.63 22 | Bits per pixel: 0.6922 23 | 24 | ------------------------------ 25 | kodim3 26 | ------------------------------ 27 | Mean squared error: 6.6958 28 | PSNR (dB): 39.87 29 | Multiscale SSIM: 0.9911 30 | Multiscale SSIM (dB): 20.49 31 | Bits per pixel: 0.4716 32 | 33 | ------------------------------ 34 | kodim4 35 | ------------------------------ 36 | Mean squared error: 10.5122 37 | PSNR (dB): 37.91 38 | Multiscale SSIM: 0.9878 39 | Multiscale SSIM (dB): 19.15 40 | Bits per pixel: 0.7019 41 | 42 | ------------------------------ 43 | kodim5 44 | ------------------------------ 45 | Mean squared error: 14.5078 46 | PSNR (dB): 36.51 47 | Multiscale SSIM: 0.9950 48 | Multiscale SSIM (dB): 22.98 49 | Bits per pixel: 1.4607 50 | 51 | ------------------------------ 52 | kodim6 53 | ------------------------------ 54 | Mean squared error: 11.9955 55 | PSNR (dB): 37.34 56 | Multiscale SSIM: 0.9911 57 | Multiscale SSIM (dB): 20.48 58 | Bits per pixel: 1.1177 59 | 60 | ------------------------------ 61 | kodim7 62 | ------------------------------ 63 | Mean squared error: 7.3828 64 | PSNR (dB): 39.45 65 | Multiscale SSIM: 0.9943 66 | Multiscale SSIM (dB): 22.42 67 | Bits per pixel: 0.5966 68 | 69 | ------------------------------ 70 | kodim8 71 | ------------------------------ 72 | Mean squared error: 17.2197 73 | PSNR (dB): 35.77 74 | Multiscale SSIM: 0.9944 75 | Multiscale SSIM (dB): 22.54 76 | Bits per pixel: 1.5817 77 | 78 | ------------------------------ 79 | kodim9 80 | ------------------------------ 81 | Mean squared error: 8.4156 82 | PSNR (dB): 38.88 83 | Multiscale SSIM: 0.9896 84 | Multiscale SSIM (dB): 19.83 85 | Bits per pixel: 0.4888 86 | 87 | ------------------------------ 88 | kodim10 89 | ------------------------------ 90 | Mean squared error: 8.8622 91 | PSNR (dB): 38.66 92 | Multiscale SSIM: 0.9895 93 | Multiscale SSIM (dB): 19.78 94 | Bits per pixel: 0.5364 95 | 96 | ------------------------------ 97 | kodim11 98 | ------------------------------ 99 | Mean squared error: 12.1560 100 | PSNR (dB): 37.28 101 | Multiscale SSIM: 0.9900 102 | Multiscale SSIM (dB): 20.00 103 | Bits per pixel: 0.9873 104 | 105 | ------------------------------ 106 | kodim12 107 | ------------------------------ 108 | Mean squared error: 8.8393 109 | PSNR (dB): 38.67 110 | Multiscale SSIM: 0.9868 111 | Multiscale SSIM (dB): 18.79 112 | Bits per pixel: 0.5531 113 | 114 | ------------------------------ 115 | kodim13 116 | ------------------------------ 117 | Mean squared error: 21.1706 118 | PSNR (dB): 34.87 119 | Multiscale SSIM: 0.9932 120 | Multiscale SSIM (dB): 21.68 121 | Bits per pixel: 2.0560 122 | 123 | ------------------------------ 124 | kodim14 125 | ------------------------------ 126 | Mean squared error: 14.9357 127 | PSNR (dB): 36.39 128 | Multiscale SSIM: 0.9919 129 | Multiscale SSIM (dB): 20.90 130 | Bits per pixel: 1.2558 131 | 132 | ------------------------------ 133 | kodim15 134 | ------------------------------ 135 | Mean squared error: 10.1494 136 | PSNR (dB): 38.07 137 | Multiscale SSIM: 0.9888 138 | Multiscale SSIM (dB): 19.49 139 | Bits per pixel: 0.6525 140 | 141 | ------------------------------ 142 | kodim16 143 | ------------------------------ 144 | Mean squared error: 9.3178 145 | PSNR (dB): 38.44 146 | Multiscale SSIM: 0.9902 147 | Multiscale SSIM (dB): 20.09 148 | Bits per pixel: 0.7519 149 | 150 | ------------------------------ 151 | kodim17 152 | ------------------------------ 153 | Mean squared error: 9.6518 154 | PSNR (dB): 38.28 155 | Multiscale SSIM: 0.9910 156 | Multiscale SSIM (dB): 20.44 157 | Bits per pixel: 0.6467 158 | 159 | ------------------------------ 160 | kodim18 161 | ------------------------------ 162 | Mean squared error: 17.0509 163 | PSNR (dB): 35.81 164 | Multiscale SSIM: 0.9891 165 | Multiscale SSIM (dB): 19.61 166 | Bits per pixel: 1.2358 167 | 168 | ------------------------------ 169 | kodim19 170 | ------------------------------ 171 | Mean squared error: 11.6530 172 | PSNR (dB): 37.47 173 | Multiscale SSIM: 0.9891 174 | Multiscale SSIM (dB): 19.63 175 | Bits per pixel: 0.8815 176 | 177 | ------------------------------ 178 | kodim20 179 | ------------------------------ 180 | Mean squared error: 8.7219 181 | PSNR (dB): 38.72 182 | Multiscale SSIM: 0.9906 183 | Multiscale SSIM (dB): 20.28 184 | Bits per pixel: 0.5912 185 | 186 | ------------------------------ 187 | kodim21 188 | ------------------------------ 189 | Mean squared error: 11.8035 190 | PSNR (dB): 37.41 191 | Multiscale SSIM: 0.9899 192 | Multiscale SSIM (dB): 19.94 193 | Bits per pixel: 0.9675 194 | 195 | ------------------------------ 196 | kodim22 197 | ------------------------------ 198 | Mean squared error: 13.7772 199 | PSNR (dB): 36.74 200 | Multiscale SSIM: 0.9871 201 | Multiscale SSIM (dB): 18.88 202 | Bits per pixel: 0.9890 203 | 204 | ------------------------------ 205 | kodim23 206 | ------------------------------ 207 | Mean squared error: 7.1963 208 | PSNR (dB): 39.56 209 | Multiscale SSIM: 0.9903 210 | Multiscale SSIM (dB): 20.13 211 | Bits per pixel: 0.3953 212 | 213 | ------------------------------ 214 | kodim24 215 | ------------------------------ 216 | Mean squared error: 15.0891 217 | PSNR (dB): 36.34 218 | Multiscale SSIM: 0.9928 219 | Multiscale SSIM (dB): 21.42 220 | Bits per pixel: 1.1902 221 | 222 | 223 | STATS (mean across kodak) 224 | 225 | bpps = [1.5605, 0.6922, 0.4716, 0.7019, 1.4607, 1.1177, 0.5966, 1.5817, 0.4888, 0.5364, 0.9873, 0.5531, 2.0560, 1.2558, 0.6525, 0.7519, 0.6467, 1.2358, 0.8815, 0.5912, 0.9675, 0.9890, 0.3953, 1.1902] 226 | np.mean(bpps) -> 0.9317458333333334 227 | 228 | psnrs = [36.43, 37.36, 39.87, 37.91, 36.51, 37.34, 39.45, 35.77, 38.88, 38.66, 37.28, 38.67, 34.87, 36.39, 38.07, 38.44, 38.28, 35.81, 37.47, 38.72, 37.41, 36.74, 39.56, 36.34] 229 | np.mean(psnrs) -> 37.59291666666667 -------------------------------------------------------------------------------- /res/eval/kodim01_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim01_hat.png -------------------------------------------------------------------------------- /res/eval/kodim02_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim02_hat.png -------------------------------------------------------------------------------- /res/eval/kodim03_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim03_hat.png -------------------------------------------------------------------------------- /res/eval/kodim04_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim04_hat.png -------------------------------------------------------------------------------- /res/eval/kodim05_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim05_hat.png -------------------------------------------------------------------------------- /res/eval/kodim06_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim06_hat.png -------------------------------------------------------------------------------- /res/eval/kodim07_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim07_hat.png -------------------------------------------------------------------------------- /res/eval/kodim08_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim08_hat.png -------------------------------------------------------------------------------- /res/eval/kodim09_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim09_hat.png -------------------------------------------------------------------------------- /res/eval/kodim10_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim10_hat.png -------------------------------------------------------------------------------- /res/eval/kodim11_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim11_hat.png -------------------------------------------------------------------------------- /res/eval/kodim12_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim12_hat.png -------------------------------------------------------------------------------- /res/eval/kodim13_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim13_hat.png -------------------------------------------------------------------------------- /res/eval/kodim14_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim14_hat.png -------------------------------------------------------------------------------- /res/eval/kodim15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim15.png -------------------------------------------------------------------------------- /res/eval/kodim15_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim15_hat.png -------------------------------------------------------------------------------- /res/eval/kodim16_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim16_hat.png -------------------------------------------------------------------------------- /res/eval/kodim17_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim17_hat.png -------------------------------------------------------------------------------- /res/eval/kodim18_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim18_hat.png -------------------------------------------------------------------------------- /res/eval/kodim19_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim19_hat.png -------------------------------------------------------------------------------- /res/eval/kodim20_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim20_hat.png -------------------------------------------------------------------------------- /res/eval/kodim21_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim21_hat.png -------------------------------------------------------------------------------- /res/eval/kodim22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim22.png -------------------------------------------------------------------------------- /res/eval/kodim22_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim22_hat.png -------------------------------------------------------------------------------- /res/eval/kodim23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim23.png -------------------------------------------------------------------------------- /res/eval/kodim23_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim23_hat.png -------------------------------------------------------------------------------- /res/eval/kodim24_hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nikolai10/SwinT-ChARM/fe989c1c75a4b8513b989465bca2f9c8a2606651/res/eval/kodim24_hat.png -------------------------------------------------------------------------------- /swin-transformers-tf/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /swin-transformers-tf/README.md: -------------------------------------------------------------------------------- 1 | # Swin for win! 2 | 3 | [![TensorFlow 2.8](https://img.shields.io/badge/TensorFlow-2.8-FF6F00?logo=tensorflow)](https://github.com/tensorflow/tensorflow/releases/tag/v2.8.0) 4 | [![Models on TF-Hub](https://img.shields.io/badge/TF--Hub-Models%20on%20TF--Hub-orange)](https://tfhub.dev/sayakpaul/collections/swin/1) 5 | 6 | This repository provides TensorFlow / Keras implementations of different Swin Transformer 7 | [1, 2] variants by Liu et al. and Chen et al. It also provides the TensorFlow / Keras models 8 | that have been populated with the original Swin pre-trained params available from [3, 4]. These 9 | models are not blackbox SavedModels i.e., they can be fully expanded into `tf.keras.Model` 10 | objects and one can call all the utility functions on them (example: `.summary()`). 11 | 12 | Refer to the ["Using the models"](https://github.com/sayakpaul/swin-transformers-tf#using-the-models) 13 | section to get started. 14 | 15 | I find Swin Transformers interesting because they induce a sense of hierarchies by using ***s***hifted ***win***dows. Multi-scale 16 | representations like that are crucial to get good performance in tasks like object detection and segmentation. 17 | ![teaser](https://github.com/microsoft/Swin-Transformer/raw/main/figures/teaser.png) 18 | Source 19 | 20 | "Swin for win!" however doesn't portray my architecture bias -- I found it cool and hence kept it. 21 | 22 | ## Table of contents 23 | 24 | * [Conversion](https://github.com/sayakpaul/swin-transformers-tf#conversion) 25 | * [Collection of pre-trained models (converted from PyTorch to TensorFlow)](https://github.com/sayakpaul/swin-transformers-tf#models) 26 | * [Results of the converted models](https://github.com/sayakpaul/swin-transformers-tf#results) 27 | * [How to use the models?](https://github.com/sayakpaul/swin-transformers-tf#using-the-models) 28 | * [References](https://github.com/sayakpaul/swin-transformers-tf#references) 29 | * [Acknowledgements](https://github.com/sayakpaul/swin-transformers-tf#acknowledgements) 30 | 31 | ## Conversion 32 | 33 | TensorFlow / Keras implementations are available in `swins/models.py`. All model configurations 34 | are in `swins/model_configs.py`. Conversion utilities are in `convert.py`. To run the conversion 35 | utilities, first install all the dependencies listed in `requirements.txt`. Additionally, 36 | nnstall `timm` from source: 37 | 38 | ```sh 39 | pip install -q git+https://github.com/rwightman/pytorch-image-models 40 | ``` 41 | 42 | ## Models 43 | 44 | Find the models on TF-Hub here: https://tfhub.dev/sayakpaul/collections/swin/1. You can fully inspect the 45 | architecture of the TF-Hub models like so: 46 | 47 | ```py 48 | import tensorflow as tf 49 | 50 | model_gcs_path = "gs://tfhub-modules/sayakpaul/swin_tiny_patch4_window7_224/1/uncompressed" 51 | model = tf.keras.models.load_model(model_gcs_path) 52 | 53 | dummy_inputs = tf.ones((2, 224, 224, 3)) 54 | _ = model(dummy_inputs) 55 | print(model.summary(expand_nested=True)) 56 | ``` 57 | 58 | ## Results 59 | 60 | The table below provides a performance summary (ImageNet-1k validation set): 61 | 62 | | model_name | top1_acc(%) | top5_acc(%) | orig_top1_acc(%) | 63 | |:------------------------------:|:-------------:|:-------------:|:------------------:| 64 | | swin_base_patch4_window7_224 | 85.134 | 97.48 | 85.2 | 65 | | swin_large_patch4_window7_224 | 86.252 | 97.878 | 86.3 | 66 | | swin_s3_base_224 | 83.958 | 96.532 | 84 | 67 | | swin_s3_small_224 | 83.648 | 96.358 | 83.7 | 68 | | swin_s3_tiny_224 | 82.034 | 95.864 | 82.1 | 69 | | swin_small_patch4_window7_224 | 83.178 | 96.24 | 83.2 | 70 | | swin_tiny_patch4_window7_224 | 81.184 | 95.512 | 81.2 | 71 | | swin_base_patch4_window12_384 | 86.428 | 98.042 | 86.4 | 72 | | swin_large_patch4_window12_384 | 87.272 | 98.242 | 87.3 | 73 | 74 | 75 | The `base` and `large` models were first pre-trained on the ImageNet-22k dataset and then fine-tuned 76 | on the ImageNet-1k dataset. 77 | 78 | [`in1k-eval` directory](https://github.com/sayakpaul/swin-transformers-tf/tree/main/in1k-eval) provides details 79 | on how these numbers were generated. Original scores for all the models except for the `s3` ones were 80 | gathered from [here](https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md). Scores 81 | for the `s3` model were gathered from [here](https://github.com/microsoft/Cream/tree/main/AutoFormerV2#model-zoo). 82 | 83 | ## Using the models 84 | 85 | **Pre-trained models**: 86 | 87 | * Off-the-shelf classification: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/classification.ipynb) 88 | * Fine-tuning: [Colab Notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/finetune.ipynb) 89 | 90 | When doing transfer learning try using the models that were pre-trained on the ImageNet-22k dataset. All the 91 | `base` and `large` models listed here were pre-trained on the ImageNet-22k dataset. Refer to the 92 | [model collection page on TF-Hub](https://tfhub.dev/sayakpaul/collections/swin/1) to know more. 93 | 94 | These models also output attention weights from each of the Transformer blocks. 95 | Refer to [this notebook](https://colab.research.google.com/github/sayakpaul/swin-transformers-tf/blob/main/notebooks/classification.ipynb) 96 | for more details. Additionally, the notebook shows how to obtain the attention maps for a given image. 97 | 98 | 99 | **Randomly initialized models**: 100 | 101 | ```py 102 | import tensorflow as tf 103 | 104 | from swins import SwinTransformer 105 | 106 | cfg = dict( 107 | patch_size=4, 108 | window_size=7, 109 | embed_dim=128, 110 | depths=(2, 2, 18, 2), 111 | num_heads=(4, 8, 16, 32), 112 | ) 113 | 114 | swin_base_patch4_window7_224 = SwinTransformer( 115 | name="swin_base_patch4_window7_224", **cfg 116 | ) 117 | print("Model instantiated, attempting predictions...") 118 | random_tensor = tf.random.normal((2, 224, 224, 3)) 119 | outputs = swin_base_patch4_window7_224(random_tensor, training=False) 120 | 121 | print(outputs.shape) 122 | 123 | print(swin_base_patch4_window7_224.count_params() / 1e6) 124 | ``` 125 | 126 | To initialize a network with say, 5 classes do: 127 | 128 | ```py 129 | cfg = dict( 130 | patch_size=4, 131 | window_size=7, 132 | embed_dim=128, 133 | depths=(2, 2, 18, 2), 134 | num_heads=(4, 8, 16, 32), 135 | num_classes=5, 136 | ) 137 | 138 | swin_base_patch4_window7_224 = SwinTransformer( 139 | name="swin_base_patch4_window7_224", **cfg 140 | ) 141 | ``` 142 | 143 | To view different model configurations, refer to `swins/model_configs.py`. 144 | 145 | ## References 146 | 147 | [1] [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.](https://arxiv.org/abs/2103.14030) 148 | 149 | [2] [Searching the Search Space of Vision Transformer by Chen et al.](https://arxiv.org/abs/2111.14725) 150 | 151 | [3] [Swin Transformers GitHub](https://github.com/microsoft/Swin-Transformer) 152 | 153 | [4] [AutoFormerV2 GitHub](https://github.com/silent-chen/AutoFormerV2-model-zoo) 154 | 155 | ## Acknowledgements 156 | 157 | * [`timm` library source code](https://github.com/rwightman/pytorch-image-models) 158 | for the awesome codebase. I've copy-pasted and modified a huge chunk of code from there. 159 | I've also mentioned it from the respective scripts. 160 | * [Willi Gierke](https://ch.linkedin.com/in/willi-gierke) for helping with a non-trivial model serialization hack. 161 | * [ML-GDE program](https://developers.google.com/programs/experts/) for 162 | providing GCP credits that supported my experiments. 163 | -------------------------------------------------------------------------------- /swin-transformers-tf/changelog.txt: -------------------------------------------------------------------------------- 1 | Changelog to https://github.com/sayakpaul/swin-transformers-tf 2 | 3 | - swin-transformers-tf/swins/layers: patch_splitting.py added 4 | - swin-transformers-tf/swins/layers: patch_merging.py modified 5 | - swin_transformers-tf/swins/blocks: stage_block.py modified 6 | - swin_transformers-tf/swins/blocks: swin_transformer_block.py modified 7 | - swin_transformers-tf/swins/layers: __init__.py modified 8 | - swin_transformers-tf/swins/layers: window_attn.py modified 9 | 10 | General comments: 11 | - enabled support of arbitrary input shapes (multiples of 256). 12 | - replaced numpy-based attention mask calculation with a native TensorFlow equivalent. -------------------------------------------------------------------------------- /swin-transformers-tf/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some code is copied from here: 3 | https://github.com/sayakpaul/cait-tf/blob/main/convert.py 4 | """ 5 | 6 | import argparse 7 | import os 8 | import sys 9 | from typing import Dict, List 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | import timm 14 | 15 | sys.path.append("..") 16 | 17 | 18 | from swins import SwinTransformer, model_configs 19 | from swins.blocks import * 20 | from swins.layers import * 21 | from utils import helpers 22 | 23 | TF_MODEL_ROOT = "gs://swin-tf" 24 | 25 | NUM_CLASSES = {"in1k": 1000, "in21k": 21841} 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser( 30 | description="Conversion of the PyTorch pre-trained Swin weights to TensorFlow." 31 | ) 32 | parser.add_argument( 33 | "-m", 34 | "--model-name", 35 | default="swin_tiny_patch4_window7_224", 36 | type=str, 37 | choices=model_configs.MODEL_MAP.keys(), 38 | help="Name of the Swin model variant.", 39 | ) 40 | parser.add_argument( 41 | "-d", 42 | "--dataset", 43 | default="in1k", 44 | choices=["in1k", "in21k"], 45 | type=str, 46 | ) 47 | parser.add_argument( 48 | "-pl", 49 | "--pre-logits", 50 | action="store_true", 51 | help="If we don't need the classification outputs.", 52 | ) 53 | return parser.parse_args() 54 | 55 | 56 | def modify_swin_blocks( 57 | np_state_dict: Dict[str, np.ndarray], 58 | pt_weights_prefix: str, 59 | tf_block: List[tf.keras.layers.Layer], 60 | ) -> List[tf.keras.layers.Layer]: 61 | """Main utility to convert params of a swin block.""" 62 | # Patch merging. 63 | for layer in tf_block: 64 | if isinstance(layer, PatchMerging): 65 | patch_merging_idx = f"{pt_weights_prefix}.downsample" 66 | 67 | layer.reduction = helpers.modify_tf_block( 68 | layer.reduction, 69 | np_state_dict[f"{patch_merging_idx}.reduction.weight"], 70 | ) 71 | layer.norm = helpers.modify_tf_block( 72 | layer.norm, 73 | np_state_dict[f"{patch_merging_idx}.norm.weight"], 74 | np_state_dict[f"{patch_merging_idx}.norm.bias"], 75 | ) 76 | 77 | # Swin layers. 78 | common_prefix = f"{pt_weights_prefix}.blocks" 79 | block_idx = 0 80 | 81 | for outer_layer in tf_block: 82 | 83 | layernorm_idx = 1 84 | mlp_layer_idx = 1 85 | 86 | if isinstance(outer_layer, SwinTransformerBlock): 87 | for inner_layer in outer_layer.layers: 88 | 89 | # Layer norm. 90 | if isinstance(inner_layer, tf.keras.layers.LayerNormalization): 91 | layer_norm_prefix = ( 92 | f"{common_prefix}.{block_idx}.norm{layernorm_idx}" 93 | ) 94 | inner_layer.gamma.assign( 95 | tf.Variable( 96 | np_state_dict[f"{layer_norm_prefix}.weight"] 97 | ) 98 | ) 99 | inner_layer.beta.assign( 100 | tf.Variable(np_state_dict[f"{layer_norm_prefix}.bias"]) 101 | ) 102 | layernorm_idx += 1 103 | 104 | # Windown attention. 105 | elif isinstance(inner_layer, WindowAttention): 106 | attn_prefix = f"{common_prefix}.{block_idx}.attn" 107 | 108 | # Relative position. 109 | inner_layer.relative_position_bias_table = ( 110 | helpers.modify_tf_block( 111 | inner_layer.relative_position_bias_table, 112 | np_state_dict[ 113 | f"{attn_prefix}.relative_position_bias_table" 114 | ], 115 | ) 116 | ) 117 | inner_layer.relative_position_index = ( 118 | helpers.modify_tf_block( 119 | inner_layer.relative_position_index, 120 | np_state_dict[ 121 | f"{attn_prefix}.relative_position_index" 122 | ], 123 | ) 124 | ) 125 | 126 | # QKV. 127 | inner_layer.qkv = helpers.modify_tf_block( 128 | inner_layer.qkv, 129 | np_state_dict[f"{attn_prefix}.qkv.weight"], 130 | np_state_dict[f"{attn_prefix}.qkv.bias"], 131 | ) 132 | 133 | # Projection. 134 | inner_layer.proj = helpers.modify_tf_block( 135 | inner_layer.proj, 136 | np_state_dict[f"{attn_prefix}.proj.weight"], 137 | np_state_dict[f"{attn_prefix}.proj.bias"], 138 | ) 139 | 140 | # MLP. 141 | elif isinstance(inner_layer, tf.keras.Model): 142 | mlp_prefix = f"{common_prefix}.{block_idx}.mlp" 143 | for mlp_layer in inner_layer.layers: 144 | if isinstance(mlp_layer, tf.keras.layers.Dense): 145 | mlp_layer = helpers.modify_tf_block( 146 | mlp_layer, 147 | np_state_dict[ 148 | f"{mlp_prefix}.fc{mlp_layer_idx}.weight" 149 | ], 150 | np_state_dict[ 151 | f"{mlp_prefix}.fc{mlp_layer_idx}.bias" 152 | ], 153 | ) 154 | mlp_layer_idx += 1 155 | 156 | block_idx += 1 157 | return tf_block 158 | 159 | 160 | def main(args): 161 | if args.pre_logits: 162 | print(f"Converting {args.model_name} for feature extraction...") 163 | else: 164 | print(f"Converting {args.model_name}...") 165 | 166 | print("Instantiating PyTorch model...") 167 | pt_model = timm.create_model(model_name=args.model_name, pretrained=True) 168 | pt_model.eval() 169 | 170 | print("Instantiating TF model...") 171 | cfg_method = model_configs.MODEL_MAP[args.model_name] 172 | cfg = cfg_method() 173 | tf_model = SwinTransformer(**cfg, pre_logits=args.pre_logits) 174 | 175 | image_size = cfg.get("img_size", 224) 176 | dummy_inputs = tf.ones((2, image_size, image_size, 3)) 177 | _ = tf_model(dummy_inputs) 178 | 179 | if not args.pre_logits: 180 | assert tf_model.count_params() == sum( 181 | p.numel() for p in pt_model.parameters() 182 | ) 183 | 184 | # Load the PT params. 185 | np_state_dict = pt_model.state_dict() 186 | np_state_dict = {k: np_state_dict[k].numpy() for k in np_state_dict} 187 | 188 | print("Beginning parameter porting process...") 189 | 190 | # Projection. 191 | tf_model.projection.layers[0] = helpers.modify_tf_block( 192 | tf_model.projection.layers[0], 193 | np_state_dict["patch_embed.proj.weight"], 194 | np_state_dict["patch_embed.proj.bias"], 195 | ) 196 | tf_model.projection.layers[2] = helpers.modify_tf_block( 197 | tf_model.projection.layers[2], 198 | np_state_dict["patch_embed.norm.weight"], 199 | np_state_dict["patch_embed.norm.bias"], 200 | ) 201 | 202 | # Layer norm layers. 203 | ln_idx = -2 204 | tf_model.layers[ln_idx] = helpers.modify_tf_block( 205 | tf_model.layers[ln_idx], 206 | np_state_dict["norm.weight"], 207 | np_state_dict["norm.bias"], 208 | ) 209 | 210 | # Head layers. 211 | if not args.pre_logits: 212 | head_layer = tf_model.get_layer("classification_head") 213 | tf_model.layers[-1] = helpers.modify_tf_block( 214 | head_layer, 215 | np_state_dict["head.weight"], 216 | np_state_dict["head.bias"], 217 | ) 218 | 219 | # Swin layers. 220 | for i in range(len(cfg["depths"])): 221 | _ = modify_swin_blocks( 222 | np_state_dict, 223 | f"layers.{i}", 224 | tf_model.layers[i + 2].layers, 225 | ) 226 | 227 | print("Porting successful, serializing TensorFlow model...") 228 | save_path = os.path.join(TF_MODEL_ROOT, args.model_name) 229 | save_path = f"{save_path}_fe" if args.pre_logits else save_path 230 | tf_model.save(save_path) 231 | print(f"TensorFlow model serialized to: {save_path}...") 232 | 233 | 234 | if __name__ == "__main__": 235 | args = parse_args() 236 | main(args) 237 | -------------------------------------------------------------------------------- /swin-transformers-tf/convert_all_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("..") 5 | 6 | from swins import model_configs 7 | 8 | 9 | def main(): 10 | for model_name in model_configs.MODEL_MAP: 11 | if "in22k" in model_name: 12 | dataset = "in21k" 13 | else: 14 | dataset = "in1k" 15 | 16 | for i in range(2): 17 | command = f"python convert.py -m {model_name} -d {dataset}" 18 | if i == 1: 19 | command += " -pl" 20 | os.system(command) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /swin-transformers-tf/hub_utilities/README.md: -------------------------------------------------------------------------------- 1 | The scripts contained in this directory are somewhat closely related to TF-Hub. Following utilities are supported: 2 | 3 | * `export_for_hub.py`: Exports a bulk of SavedModels as `tar.gz` archives needed by TF-Hub. 4 | * `generate_doc.py`: Generates documentation for a bulk of models. 5 | -------------------------------------------------------------------------------- /swin-transformers-tf/hub_utilities/export_for_hub.py: -------------------------------------------------------------------------------- 1 | """Generates .tar.gz archives from SavedModels and serializes them.""" 2 | 3 | 4 | import os 5 | from typing import List 6 | 7 | import tensorflow as tf 8 | 9 | TF_MODEL_ROOT = "gs://swin-tf" 10 | TAR_ARCHIVES = os.path.join(TF_MODEL_ROOT, "tars/") 11 | 12 | 13 | def prepare_archive(model_name: str) -> None: 14 | """Prepares a tar archive.""" 15 | archive_name = f"{model_name}.tar.gz" 16 | print(f"Archiving to {archive_name}.") 17 | archive_command = f"cd {model_name} && tar -czvf ../{archive_name} *" 18 | os.system(archive_command) 19 | os.system(f"rm -rf {model_name}") 20 | 21 | 22 | def save_to_gcs(model_paths: List[str]) -> None: 23 | """Prepares tar archives and saves them inside a GCS bucket.""" 24 | for path in model_paths: 25 | print(f"Preparing model: {path}.") 26 | model_name = path.strip("/") 27 | abs_model_path = os.path.join(TF_MODEL_ROOT, model_name) 28 | 29 | print(f"Copying from {abs_model_path}.") 30 | os.system(f"gsutil cp -r {abs_model_path} .") 31 | prepare_archive(model_name) 32 | 33 | os.system(f"gsutil -m cp -r *.tar.gz {TAR_ARCHIVES}") 34 | os.system("rm -rf *.tar.gz") 35 | 36 | 37 | model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT) 38 | print(f"Total models: {len(model_paths)}.") 39 | 40 | print("Preparing archives for the classification and feature extractor models.") 41 | save_to_gcs(model_paths) 42 | tar_paths = tf.io.gfile.listdir(TAR_ARCHIVES) 43 | print(f"Total tars: {len(tar_paths)}.") 44 | -------------------------------------------------------------------------------- /swin-transformers-tf/hub_utilities/generate_doc.py: -------------------------------------------------------------------------------- 1 | """Generates model documentation for Swin-TF models. 2 | 3 | Credits: Willi Gierke 4 | """ 5 | 6 | import os 7 | from string import Template 8 | 9 | import attr 10 | 11 | template = Template( 12 | """# Module $HANDLE 13 | 14 | Fine-tunable Swin Transformer model pre-trained on the $DATASET_DESCRIPTION. 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | ## Overview 25 | 26 | This model is a Swin Transformer [1] pre-trained on the $DATASET_DESCRIPTION. You can find the complete 27 | collection of Swin models on TF-Hub on [this page](https://tfhub.dev/sayakpaul/collections/swin/1). 28 | 29 | You can use this model for feature extraction and fine-tuning. Please refer to 30 | the Colab Notebook linked on this page for more details. 31 | 32 | ## Notes 33 | 34 | * The original model weights are provided from [2]. There were ported to Keras models 35 | (`tf.keras.Model`) and then serialized as TensorFlow SavedModels. The porting 36 | steps are available in [3]. 37 | * If the model handle contains `s3` then please refer to [4] for more details on the architecture. It's 38 | original weights are available in [5]. 39 | * The model can be unrolled into a standard Keras model and you can inspect its topology. 40 | To do so, first download the model from TF-Hub and then load it using `tf.keras.models.load_model` 41 | providing the path to the downloaded model folder. 42 | 43 | ## References 44 | 45 | [1] [Swin Transformer: Hierarchical Vision Transformer using Shifted Windows Liu et al.](https://arxiv.org/abs/2103.14030) 46 | 47 | [2] [Swin Transformers GitHub](https://github.com/microsoft/Swin-Transformer) 48 | 49 | [3] [Swin-TF GitHub](https://github.com/sayakpaul/swin-transformers-tf) 50 | 51 | [4] [Searching the Search Space of Vision Transformer by Chen et al.](https://arxiv.org/abs/2111.14725) 52 | 53 | [5] [AutoFormerV2 GitHub](https://github.com/silent-chen/AutoFormerV2-model-zoo) 54 | 55 | ## Acknowledgements 56 | 57 | * [Willi](https://ch.linkedin.com/in/willi-gierke) 58 | * [ML-GDE program](https://developers.google.com/programs/experts/) 59 | 60 | """ 61 | ) 62 | 63 | 64 | @attr.s 65 | class Config: 66 | size = attr.ib(type=str) 67 | patch_size = attr.ib(type=int) 68 | window_size = attr.ib(type=int) 69 | single_resolution = attr.ib(type=int) 70 | dataset = attr.ib(type=str) 71 | type = attr.ib(type=str, default="swin") 72 | 73 | def two_d_resolution(self): 74 | return f"{self.single_resolution}x{self.single_resolution}" 75 | 76 | def gcs_folder_name(self): 77 | if self.dataset == "in22k": 78 | return f"swin_{self.size}_patch{self.patch_size}_window{self.window_size}_{self.single_resolution}_{self.dataset}_fe" 79 | elif self.type == "autoformer": 80 | return f"swin_s3_{self.size}_{self.single_resolution}_fe" 81 | else: 82 | return f"swin_{self.size}_patch{self.patch_size}_window{self.window_size}_{self.single_resolution}_fe" 83 | 84 | def handle(self): 85 | return f"sayakpaul/{self.gcs_folder_name()}/1" 86 | 87 | def rel_doc_file_path(self): 88 | """Relative to the tfhub.dev directory.""" 89 | return f"assets/docs/{self.handle()}.md" 90 | 91 | 92 | # swin_base_patch4_window12_384, swin_base_patch4_window12_384_in22k 93 | for c in [ 94 | Config("tiny", 4, 7, 224, "in1k"), 95 | Config("small", 4, 7, 224, "in1k"), 96 | Config("base", 4, 7, 224, "in1k"), 97 | Config("base", 4, 12, 384, "in1k"), 98 | Config("large", 4, 7, 224, "in1k"), 99 | Config("large", 4, 12, 384, "in1k"), 100 | Config("base", 4, 7, 224, "in22k"), 101 | Config("base", 4, 12, 384, "in22k"), 102 | Config("large", 4, 7, 224, "in22k"), 103 | Config("large", 4, 12, 384, "in22k"), 104 | Config("tiny", 0, 0, 224, "in1k", "autoformer"), 105 | Config("small", 0, 0, 224, "in1k", "autoformer"), 106 | Config("base", 0, 0, 224, "in1k", "autoformer"), 107 | ]: 108 | if c.dataset == "in1k" and not ("large" in c.size or "base" in c.size): 109 | dataset_text = "ImageNet-1k dataset" 110 | elif c.dataset == "in22k": 111 | dataset_text = "ImageNet-22k dataset" 112 | elif c.dataset == "in1k" and ("large" in c.size or "base" in c.size): 113 | dataset_text = ( 114 | "ImageNet-22k" 115 | " dataset and" 116 | " was then " 117 | "fine-tuned " 118 | "on the " 119 | "ImageNet-1k " 120 | "dataset" 121 | ) 122 | 123 | save_path = os.path.join( 124 | "/Users/sayakpaul/Downloads/", "tfhub.dev", c.rel_doc_file_path() 125 | ) 126 | model_folder = save_path.split("/")[-2] 127 | model_abs_path = "/".join(save_path.split("/")[:-1]) 128 | 129 | if not os.path.exists(model_abs_path): 130 | os.makedirs(model_abs_path, exist_ok=True) 131 | 132 | with open(save_path, "w") as f: 133 | f.write( 134 | template.substitute( 135 | HANDLE=c.handle(), 136 | DATASET_DESCRIPTION=dataset_text, 137 | INPUT_RESOLUTION=c.two_d_resolution(), 138 | ARCHIVE_NAME=c.gcs_folder_name(), 139 | ) 140 | ) 141 | -------------------------------------------------------------------------------- /swin-transformers-tf/in1k-eval/README.md: -------------------------------------------------------------------------------- 1 | This directory provides a notebook to run 2 | evaluation on the ImageNet-1k `val` split using the TF/Keras converted Swin 3 | models. The notebook assumes the following files are present in your working 4 | directory and the dependencies specified in `../requirements.txt` are installed: 5 | 6 | * The `val` split directory of ImageNet-1k. -------------------------------------------------------------------------------- /swin-transformers-tf/in1k-eval/df.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/html": [ 20 | "
\n", 21 | "\n", 34 | "\n", 35 | " \n", 36 | " \n", 37 | " \n", 38 | " \n", 39 | " \n", 40 | " \n", 41 | " \n", 42 | " \n", 43 | " \n", 44 | " \n", 45 | " \n", 46 | " \n", 47 | " \n", 48 | " \n", 49 | " \n", 50 | " \n", 51 | " \n", 52 | " \n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | "
model_nametop1_acc(%)top5_acc(%)
0swin_base_patch4_window7_22485.13497.480
1swin_large_patch4_window7_22486.25297.878
2swin_s3_base_22483.95896.532
3swin_s3_small_22483.64896.358
4swin_s3_tiny_22482.03495.864
\n", 76 | "
" 77 | ], 78 | "text/plain": [ 79 | " model_name top1_acc(%) top5_acc(%)\n", 80 | "0 swin_base_patch4_window7_224 85.134 97.480\n", 81 | "1 swin_large_patch4_window7_224 86.252 97.878\n", 82 | "2 swin_s3_base_224 83.958 96.532\n", 83 | "3 swin_s3_small_224 83.648 96.358\n", 84 | "4 swin_s3_tiny_224 82.034 95.864" 85 | ] 86 | }, 87 | "execution_count": 2, 88 | "metadata": {}, 89 | "output_type": "execute_result" 90 | } 91 | ], 92 | "source": [ 93 | "df_224 = pd.read_csv(\"swin_224_in1k.csv\")\n", 94 | "df_334 = pd.read_csv(\"swin_384_in1k.csv\")\n", 95 | "\n", 96 | "df = pd.concat([df_224, df_334])\n", 97 | "df[\"model_name\"] = df[\"model_name\"].apply(lambda x: x.strip(\"/\"))\n", 98 | "df.head()" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/html": [ 109 | "
\n", 110 | "\n", 123 | "\n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | "
model_nametop1_acc(%)top5_acc(%)
0swin_base_patch4_window7_22485.13497.480
1swin_large_patch4_window7_22486.25297.878
2swin_s3_base_22483.95896.532
3swin_s3_small_22483.64896.358
4swin_s3_tiny_22482.03495.864
5swin_small_patch4_window7_22483.17896.240
6swin_tiny_patch4_window7_22481.18495.512
0swin_base_patch4_window12_38486.42898.042
1swin_large_patch4_window12_38487.27298.242
\n", 189 | "
" 190 | ], 191 | "text/plain": [ 192 | " model_name top1_acc(%) top5_acc(%)\n", 193 | "0 swin_base_patch4_window7_224 85.134 97.480\n", 194 | "1 swin_large_patch4_window7_224 86.252 97.878\n", 195 | "2 swin_s3_base_224 83.958 96.532\n", 196 | "3 swin_s3_small_224 83.648 96.358\n", 197 | "4 swin_s3_tiny_224 82.034 95.864\n", 198 | "5 swin_small_patch4_window7_224 83.178 96.240\n", 199 | "6 swin_tiny_patch4_window7_224 81.184 95.512\n", 200 | "0 swin_base_patch4_window12_384 86.428 98.042\n", 201 | "1 swin_large_patch4_window12_384 87.272 98.242" 202 | ] 203 | }, 204 | "execution_count": 3, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | } 208 | ], 209 | "source": [ 210 | "df" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 4, 216 | "metadata": {}, 217 | "outputs": [ 218 | { 219 | "data": { 220 | "text/html": [ 221 | "
\n", 222 | "\n", 235 | "\n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | "
model_nametop1_acc(%)top5_acc(%)orig_top1_acc(%)
0swin_base_patch4_window7_22485.13497.48085.2
1swin_large_patch4_window7_22486.25297.87886.3
2swin_s3_base_22483.95896.53284.0
3swin_s3_small_22483.64896.35883.7
4swin_s3_tiny_22482.03495.86482.1
5swin_small_patch4_window7_22483.17896.24083.2
6swin_tiny_patch4_window7_22481.18495.51281.2
7swin_base_patch4_window12_38486.42898.04286.4
8swin_large_patch4_window12_38487.27298.24287.3
\n", 311 | "
" 312 | ], 313 | "text/plain": [ 314 | " model_name top1_acc(%) top5_acc(%) orig_top1_acc(%)\n", 315 | "0 swin_base_patch4_window7_224 85.134 97.480 85.2\n", 316 | "1 swin_large_patch4_window7_224 86.252 97.878 86.3\n", 317 | "2 swin_s3_base_224 83.958 96.532 84.0\n", 318 | "3 swin_s3_small_224 83.648 96.358 83.7\n", 319 | "4 swin_s3_tiny_224 82.034 95.864 82.1\n", 320 | "5 swin_small_patch4_window7_224 83.178 96.240 83.2\n", 321 | "6 swin_tiny_patch4_window7_224 81.184 95.512 81.2\n", 322 | "7 swin_base_patch4_window12_384 86.428 98.042 86.4\n", 323 | "8 swin_large_patch4_window12_384 87.272 98.242 87.3" 324 | ] 325 | }, 326 | "execution_count": 4, 327 | "metadata": {}, 328 | "output_type": "execute_result" 329 | } 330 | ], 331 | "source": [ 332 | "# All models except for s3s: https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md\n", 333 | "# s3s: https://github.com/microsoft/Cream/tree/main/AutoFormerV2#model-zoo\n", 334 | "\n", 335 | "orig_1 = [\"85.2\", \"86.3\", \"84.0\", \"83.7\", \"82.1\", \"83.2\", \"81.2\", \"86.4\", \"87.3\"]\n", 336 | "df[\"orig_top1_acc(%)\"] = orig_1\n", 337 | "df = df.reset_index(drop=True)\n", 338 | "df" 339 | ] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "execution_count": 5, 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "name": "stdout", 348 | "output_type": "stream", 349 | "text": [ 350 | "| | model_name | top1_acc(%) | top5_acc(%) | orig_top1_acc(%) |\n", 351 | "|---:|:-------------------------------|--------------:|--------------:|-------------------:|\n", 352 | "| 0 | swin_base_patch4_window7_224 | 85.134 | 97.48 | 85.2 |\n", 353 | "| 1 | swin_large_patch4_window7_224 | 86.252 | 97.878 | 86.3 |\n", 354 | "| 2 | swin_s3_base_224 | 83.958 | 96.532 | 84 |\n", 355 | "| 3 | swin_s3_small_224 | 83.648 | 96.358 | 83.7 |\n", 356 | "| 4 | swin_s3_tiny_224 | 82.034 | 95.864 | 82.1 |\n", 357 | "| 5 | swin_small_patch4_window7_224 | 83.178 | 96.24 | 83.2 |\n", 358 | "| 6 | swin_tiny_patch4_window7_224 | 81.184 | 95.512 | 81.2 |\n", 359 | "| 7 | swin_base_patch4_window12_384 | 86.428 | 98.042 | 86.4 |\n", 360 | "| 8 | swin_large_patch4_window12_384 | 87.272 | 98.242 | 87.3 |\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "print(df.to_markdown())" 366 | ] 367 | } 368 | ], 369 | "metadata": { 370 | "interpreter": { 371 | "hash": "2eeebd8186f946aafebcc49bceba063d6e659c945a9dcc5253ac24fa5b4e04cc" 372 | }, 373 | "kernelspec": { 374 | "display_name": "Python 3.8.2 ('pytorch')", 375 | "language": "python", 376 | "name": "python3" 377 | }, 378 | "language_info": { 379 | "codemirror_mode": { 380 | "name": "ipython", 381 | "version": 3 382 | }, 383 | "file_extension": ".py", 384 | "mimetype": "text/x-python", 385 | "name": "python", 386 | "nbconvert_exporter": "python", 387 | "pygments_lexer": "ipython3", 388 | "version": "3.8.2" 389 | }, 390 | "orig_nbformat": 4 391 | }, 392 | "nbformat": 4, 393 | "nbformat_minor": 2 394 | } 395 | -------------------------------------------------------------------------------- /swin-transformers-tf/in1k-eval/eval-swins.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "926bae46", 6 | "metadata": {}, 7 | "source": [ 8 | "## Imports" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "e038ddf1", 14 | "metadata": {}, 15 | "source": [ 16 | "Suppress TensorFlow warnings." 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "9ec05d94-4361-49d2-b680-ce41d0376299", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "# Copied from:\n", 27 | "# https://weepingfish.github.io/2020/07/22/0722-suppress-tensorflow-warnings/\n", 28 | "\n", 29 | "# Filter tensorflow version warnings\n", 30 | "import os\n", 31 | "\n", 32 | "# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709\n", 33 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # or any {'0', '1', '2'}\n", 34 | "import warnings\n", 35 | "\n", 36 | "# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning\n", 37 | "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", 38 | "warnings.simplefilter(action=\"ignore\", category=Warning)\n", 39 | "import tensorflow as tf\n", 40 | "\n", 41 | "tf.get_logger().setLevel(\"INFO\")\n", 42 | "tf.autograph.set_verbosity(0)\n", 43 | "import logging\n", 44 | "\n", 45 | "tf.get_logger().setLevel(logging.ERROR)" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "bae636ae-24d1-4523-9997-696731318a81", 52 | "metadata": {}, 53 | "outputs": [], 54 | "source": [ 55 | "from tensorflow import keras\n", 56 | "\n", 57 | "from torchvision.datasets import ImageFolder\n", 58 | "from torchvision import transforms\n", 59 | "from torch.utils.data import DataLoader\n", 60 | "\n", 61 | "from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "96c4f0a2", 67 | "metadata": {}, 68 | "source": [ 69 | "## Constants" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "f8238055-08bf-44e1-8f3b-98e7768f1603", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Change batch size accordingly in case of OOM.\n", 80 | "# Change the image size to 384 wheb evaluation's done on 224.\n", 81 | "BATCH_SIZE = 256\n", 82 | "IMAGE_SIZE = 224\n", 83 | "TF_MODEL_ROOT = \"gs://swin-tf\"" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "ea42076a", 89 | "metadata": {}, 90 | "source": [ 91 | "## Swin models " 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": null, 97 | "id": "13a7e46e-31b2-48b9-9a57-2873fe27397a", 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "model_paths = tf.io.gfile.listdir(TF_MODEL_ROOT)\n", 102 | "model_paths = [p for p in model_paths if str(IMAGE_SIZE) in p and \"fe\" not in p and \"22k\" not in p]\n", 103 | "print(model_paths)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "id": "60c39720", 109 | "metadata": {}, 110 | "source": [ 111 | "## Image loader" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "b066a571", 117 | "metadata": {}, 118 | "source": [ 119 | "To have an apples-to-apples comparison with the original PyTorch models for evaluation, it's important to ensure we use the same transformations." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "id": "5cf8eb2e-de82-48af-9292-c4917c237fa8", 126 | "metadata": {}, 127 | "outputs": [], 128 | "source": [ 129 | "# Transformations from:\n", 130 | "# (1) https://github.com/microsoft/Swin-Transformer\n", 131 | "# (2) https://github.com/microsoft/Swin-Transformer/tree/main/data\n", 132 | "\n", 133 | "if IMAGE_SIZE == 224:\n", 134 | " size = int((256 / 224) * IMAGE_SIZE)\n", 135 | " transform_chain = transforms.Compose(\n", 136 | " [\n", 137 | " transforms.Resize(size, interpolation=3),\n", 138 | " transforms.CenterCrop(IMAGE_SIZE),\n", 139 | " transforms.ToTensor(),\n", 140 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n", 141 | " ]\n", 142 | " )\n", 143 | "else:\n", 144 | " transform_chain = transforms.Compose(\n", 145 | " [\n", 146 | " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=3),\n", 147 | " transforms.ToTensor(),\n", 148 | " transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),\n", 149 | " ]\n", 150 | " )" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "51ae139c-786b-47b3-9840-655c624f86b7", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "dataset = ImageFolder(\"val\", transform=transform_chain)\n", 161 | "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=6)\n", 162 | "\n", 163 | "batch = next(iter(dataloader))\n", 164 | "print(batch[0].shape)" 165 | ] 166 | }, 167 | { 168 | "cell_type": "markdown", 169 | "id": "2b7e8e68", 170 | "metadata": {}, 171 | "source": [ 172 | "## Run evaluation" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "63a3da22-a60f-48b8-a0b0-02e54b2d012f", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "def get_model(model_url):\n", 183 | " model = keras.models.load_model(model_url)\n", 184 | " return model" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "cbf8eabf-df6d-4ee5-900c-48b8761329ed", 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "# Copied and modified from:\n", 195 | "# https://github.com/sebastian-sz/resnet-rs-keras/blob/main/imagenet_evaluation/main.py\n", 196 | "\n", 197 | "log_file = f\"swin_{IMAGE_SIZE}_in1k.csv\"\n", 198 | "\n", 199 | "if not os.path.exists(log_file):\n", 200 | " with open(log_file, 'w') as f:\n", 201 | " f.write(\n", 202 | " 'model_name,top1_acc(%),top5_acc(%)\\n'\n", 203 | " )\n", 204 | "\n", 205 | "for path in model_paths:\n", 206 | " print(f\"Evaluating {path}.\")\n", 207 | " model = get_model(f\"{TF_MODEL_ROOT}/{path.strip('/')}\")\n", 208 | "\n", 209 | " top1 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=1, name=\"top1\")\n", 210 | " top5 = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5, name=\"top5\")\n", 211 | " progbar = tf.keras.utils.Progbar(target=len(dataset) // BATCH_SIZE)\n", 212 | "\n", 213 | " for idx, (images, y_true) in enumerate(dataloader):\n", 214 | " images = images.numpy().transpose(0, 2, 3, 1)\n", 215 | " y_true = y_true.numpy()\n", 216 | " y_pred = model.predict(images)\n", 217 | "\n", 218 | " top1.update_state(y_true=y_true, y_pred=y_pred)\n", 219 | " top5.update_state(y_true=y_true, y_pred=y_pred)\n", 220 | "\n", 221 | " progbar.update(\n", 222 | " idx, [(\"top1\", top1.result().numpy()), (\"top5\", top5.result().numpy())]\n", 223 | " )\n", 224 | "\n", 225 | " print()\n", 226 | " print(f\"TOP1: {top1.result().numpy()}. TOP5: {top5.result().numpy()}\")\n", 227 | " \n", 228 | " top_1 = top1.result().numpy() * 100.\n", 229 | " top_5 = top5.result().numpy() * 100.\n", 230 | " with open(log_file, 'a') as f:\n", 231 | " f.write(\"%s,%0.3f,%0.3f\\n\" % (path, top_1, top_5))" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "c845cdcb-1036-43e2-93df-4d47cce020ec", 238 | "metadata": {}, 239 | "outputs": [], 240 | "source": [ 241 | "!sudo shutdown now" 242 | ] 243 | } 244 | ], 245 | "metadata": { 246 | "environment": { 247 | "kernel": "python3", 248 | "name": "tf2-gpu.2-7.m87", 249 | "type": "gcloud", 250 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87" 251 | }, 252 | "kernelspec": { 253 | "display_name": "Python 3", 254 | "language": "python", 255 | "name": "python3" 256 | }, 257 | "language_info": { 258 | "codemirror_mode": { 259 | "name": "ipython", 260 | "version": 3 261 | }, 262 | "file_extension": ".py", 263 | "mimetype": "text/x-python", 264 | "name": "python", 265 | "nbconvert_exporter": "python", 266 | "pygments_lexer": "ipython3", 267 | "version": "3.7.12" 268 | } 269 | }, 270 | "nbformat": 4, 271 | "nbformat_minor": 5 272 | } 273 | -------------------------------------------------------------------------------- /swin-transformers-tf/in1k-eval/swin_224_in1k.csv: -------------------------------------------------------------------------------- 1 | model_name,top1_acc(%),top5_acc(%) 2 | swin_base_patch4_window7_224/,85.134,97.480 3 | swin_large_patch4_window7_224/,86.252,97.878 4 | swin_s3_base_224/,83.958,96.532 5 | swin_s3_small_224/,83.648,96.358 6 | swin_s3_tiny_224/,82.034,95.864 7 | swin_small_patch4_window7_224/,83.178,96.240 8 | swin_tiny_patch4_window7_224/,81.184,95.512 9 | -------------------------------------------------------------------------------- /swin-transformers-tf/in1k-eval/swin_384_in1k.csv: -------------------------------------------------------------------------------- 1 | model_name,top1_acc(%),top5_acc(%) 2 | swin_base_patch4_window12_384/,86.428,98.042 3 | swin_large_patch4_window12_384/,87.272,98.242 4 | -------------------------------------------------------------------------------- /swin-transformers-tf/notebooks/classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "zpiEiO2BUeO5", 6 | "metadata": { 7 | "id": "zpiEiO2BUeO5" 8 | }, 9 | "source": [ 10 | "# Off-the-shelf image classification with Swin Transformers on TF-Hub\n", 11 | "\n", 12 | "\n", 13 | " \n", 16 | " \n", 19 | " \n", 22 | "
\n", 14 | " Run in Google Colab\n", 15 | " \n", 17 | " View on GitHub\n", 18 | " \n", 20 | " See TF Hub models\n", 21 | "
" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "661e6538", 28 | "metadata": { 29 | "id": "661e6538" 30 | }, 31 | "source": [ 32 | "## Setup" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5", 39 | "metadata": { 40 | "id": "f2b73e50-6538-4af5-9878-ed99489409f5" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "!wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9", 51 | "metadata": { 52 | "id": "43974820-4eeb-4b3a-90b4-9ddfa00d1cb9" 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "import tensorflow as tf\n", 57 | "import tensorflow_hub as hub\n", 58 | "from tensorflow import keras\n", 59 | "\n", 60 | "\n", 61 | "from PIL import Image\n", 62 | "from io import BytesIO\n", 63 | "\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "import numpy as np\n", 66 | "import requests\n", 67 | "import cv2" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "id": "z5l1cRpiSavW", 73 | "metadata": { 74 | "id": "z5l1cRpiSavW" 75 | }, 76 | "source": [ 77 | "## Select a [Swin](https://arxiv.org/abs/2103.14030) ImageNet-1k model\n", 78 | "\n", 79 | "Find the entire collection [here](https://tfhub.dev/sayakpaul/collections/swin/1). For inferring with the ImageNet-22k models, please refer [here](https://tfhub.dev/google/bit/m-r50x1/imagenet21k_classification/1#usage)." 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "id": "a0wM8idaSaOq", 86 | "metadata": { 87 | "id": "a0wM8idaSaOq" 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "model_name = \"swin_tiny_patch4_window7_224\" #@param [\"swin_tiny_patch4_window7_224\", \"swin_small_patch4_window7_224\", \"swin_base_patch4_window7_224\", \"swin_base_patch4_window12_384\", \"swin_large_patch4_window7_224\", \"swin_large_patch4_window7_384\", \"swin_s3_tiny_224\", \"swin_s3_small_224\", \"swin_s3_base_224\"]\n", 92 | "\n", 93 | "model_handle_map ={\n", 94 | " \"swin_tiny_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224/1\",\n", 95 | " \"swin_small_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_small_patch4_window7_224/1\",\n", 96 | " \"swin_base_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_base_patch4_window7_224/1\",\n", 97 | " \"swin_base_patch4_window12_384\": \"https://tfhub.dev/sayakpaul/swin_base_patch4_window12_384/1\",\n", 98 | " \"swin_large_patch4_window7_224\": \"https://tfhub.dev/sayakpaul/swin_large_patch4_window7_224/1\",\n", 99 | " \"swin_large_patch4_window7_384\": \"https://tfhub.dev/sayakpaul/swin_large_patch4_window7_384/1\",\n", 100 | " \"swin_s3_tiny_224\": \"https://tfhub.dev/sayakpaul/swin_s3_tiny_224/1\",\n", 101 | " \"swin_s3_small_224\": \"https://tfhub.dev/sayakpaul/swin_s3_small_224/1\",\n", 102 | " \"swin_s3_base_224\": \"https://tfhub.dev/sayakpaul/swin_s3_base_224/1\",\n", 103 | "}\n", 104 | "\n", 105 | "input_resolution = int(model_name.split(\"_\")[-1])\n", 106 | "model_handle = model_handle_map[model_name]\n", 107 | "print(f\"Input resolution: {input_resolution} x {input_resolution} x 3.\")\n", 108 | "print(f\"TF-Hub handle: {model_handle}.\")" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "441b5361", 114 | "metadata": { 115 | "id": "441b5361" 116 | }, 117 | "source": [ 118 | "## Image preprocessing utilities " 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b", 125 | "metadata": { 126 | "id": "63e76ff1-e1e0-4c6a-91b2-4114aad60e5b" 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "crop_layer = keras.layers.CenterCrop(input_resolution, input_resolution)\n", 131 | "norm_layer = keras.layers.Normalization(\n", 132 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n", 133 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n", 134 | ")\n", 135 | "\n", 136 | "\n", 137 | "def preprocess_image(image, size=input_resolution):\n", 138 | " image = np.array(image)\n", 139 | " image_resized = tf.expand_dims(image, 0)\n", 140 | " \n", 141 | " if size == 224:\n", 142 | " resize_size = int((256 / 224) * size)\n", 143 | " image_resized = tf.image.resize(image_resized, (resize_size, resize_size), method=\"bicubic\")\n", 144 | " image_resized = crop_layer(image_resized)\n", 145 | " else:\n", 146 | " image_resized = tf.image.resize(image_resized, (size, size), method=\"bicubic\")\n", 147 | " \n", 148 | " return norm_layer(image_resized).numpy()\n", 149 | " \n", 150 | "\n", 151 | "def load_image_from_url(url):\n", 152 | " # Credit: Willi Gierke\n", 153 | " response = requests.get(url)\n", 154 | " image = Image.open(BytesIO(response.content))\n", 155 | " preprocessed_image = preprocess_image(image)\n", 156 | " return image, preprocessed_image" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "8b961e14", 162 | "metadata": { 163 | "id": "8b961e14" 164 | }, 165 | "source": [ 166 | "## Load ImageNet-1k labels and a demo image" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b", 173 | "metadata": { 174 | "id": "8dc9250a-5eb6-4547-8893-dd4c746ab53b" 175 | }, 176 | "outputs": [], 177 | "source": [ 178 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n", 179 | " lines = f.readlines()\n", 180 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n", 181 | "\n", 182 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n", 183 | "image, preprocessed_image = load_image_from_url(img_url)\n", 184 | "\n", 185 | "plt.imshow(image)\n", 186 | "plt.show()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "9006a643", 192 | "metadata": { 193 | "id": "9006a643" 194 | }, 195 | "source": [ 196 | "## Run inference" 197 | ] 198 | }, 199 | { 200 | "cell_type": "code", 201 | "execution_count": null, 202 | "id": "bHnCyJtAf9el", 203 | "metadata": { 204 | "id": "bHnCyJtAf9el" 205 | }, 206 | "outputs": [], 207 | "source": [ 208 | "def get_model(model_url: str) -> tf.keras.Model:\n", 209 | " inputs = tf.keras.Input((input_resolution, input_resolution, 3))\n", 210 | " hub_module = hub.KerasLayer(model_url)\n", 211 | "\n", 212 | " outputs = hub_module(inputs)\n", 213 | "\n", 214 | " return tf.keras.Model(inputs, outputs)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": null, 220 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908", 221 | "metadata": { 222 | "id": "8dfd2c7d-e454-48da-a40b-cd5d6f6c4908" 223 | }, 224 | "outputs": [], 225 | "source": [ 226 | "classification_model = get_model(model_handle)\n", 227 | "predictions = classification_model.predict(preprocessed_image)\n", 228 | "predicted_label = imagenet_int_to_str[int(np.argmax(predictions))]\n", 229 | "print(predicted_label)" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "source": [ 235 | "## Obtain attention scores" 236 | ], 237 | "metadata": { 238 | "id": "wPisHE9lMmaN" 239 | }, 240 | "id": "wPisHE9lMmaN" 241 | }, 242 | { 243 | "cell_type": "code", 244 | "source": [ 245 | "swin_tiny_patch4_window7_224_tf = tf.keras.models.load_model(\n", 246 | " f\"gs://tfhub-modules/sayakpaul/{model_name}/1/uncompressed\"\n", 247 | ")\n", 248 | "all_attn_scores = swin_tiny_patch4_window7_224_tf.get_attention_scores(\n", 249 | " preprocessed_image\n", 250 | ")\n", 251 | "print(all_attn_scores.keys())" 252 | ], 253 | "metadata": { 254 | "id": "cRO5v-yPMoEO" 255 | }, 256 | "id": "cRO5v-yPMoEO", 257 | "execution_count": null, 258 | "outputs": [] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "source": [ 263 | "# attn score dimensions:\n", 264 | "# (batch_size, nb_attention_heads, seq_length, seq_length)\n", 265 | "print(all_attn_scores[\"swin_stage_3\"].keys()), print(all_attn_scores[\"swin_stage_3\"][\"swin_block_0\"].shape)" 266 | ], 267 | "metadata": { 268 | "id": "GwRtcNR2NLrC" 269 | }, 270 | "id": "GwRtcNR2NLrC", 271 | "execution_count": null, 272 | "outputs": [] 273 | } 274 | ], 275 | "metadata": { 276 | "accelerator": "GPU", 277 | "colab": { 278 | "machine_shape": "hm", 279 | "name": "classification.ipynb", 280 | "provenance": [] 281 | }, 282 | "environment": { 283 | "kernel": "python3", 284 | "name": "tf2-gpu.2-7.m87", 285 | "type": "gcloud", 286 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87" 287 | }, 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.7.12" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 5 308 | } -------------------------------------------------------------------------------- /swin-transformers-tf/notebooks/finetune.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "89B27-TGiDNB" 17 | }, 18 | "source": [ 19 | "## Imports" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "metadata": { 26 | "id": "9u3d4Z7uQsmp" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "from tensorflow import keras\n", 31 | "import tensorflow as tf\n", 32 | "import tensorflow_hub as hub\n", 33 | "import tensorflow_datasets as tfds\n", 34 | "\n", 35 | "tfds.disable_progress_bar()\n", 36 | "\n", 37 | "import os\n", 38 | "import sys\n", 39 | "import math\n", 40 | "import numpy as np\n", 41 | "import pandas as pd\n", 42 | "import matplotlib.pyplot as plt" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": { 49 | "id": "mdUDP5qIiG05" 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "# Copied from:\n", 54 | "# https://weepingfish.github.io/2020/07/22/0722-suppress-tensorflow-warnings/\n", 55 | "\n", 56 | "# Filter tensorflow version warnings\n", 57 | "\n", 58 | "# https://stackoverflow.com/questions/40426502/is-there-a-way-to-suppress-the-messages-tensorflow-prints/40426709\n", 59 | "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\" # or any {'0', '1', '2'}\n", 60 | "import warnings\n", 61 | "\n", 62 | "# https://stackoverflow.com/questions/15777951/how-to-suppress-pandas-future-warning\n", 63 | "warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", 64 | "warnings.simplefilter(action=\"ignore\", category=Warning)\n", 65 | "\n", 66 | "tf.get_logger().setLevel(\"INFO\")\n", 67 | "tf.autograph.set_verbosity(0)\n", 68 | "import logging\n", 69 | "\n", 70 | "tf.get_logger().setLevel(logging.ERROR)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": { 76 | "id": "mPo10cahZXXQ" 77 | }, 78 | "source": [ 79 | "## GPUs" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": { 86 | "id": "FpvUOuC3j27n" 87 | }, 88 | "outputs": [], 89 | "source": [ 90 | "try: # detect TPUs\n", 91 | " tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection\n", 92 | " strategy = tf.distribute.TPUStrategy(tpu)\n", 93 | "except ValueError: # detect GPUs\n", 94 | " tpu = False\n", 95 | " strategy = (\n", 96 | " tf.distribute.get_strategy()\n", 97 | " ) # default strategy that works on CPU and single GPU\n", 98 | "print(\"Number of Accelerators: \", strategy.num_replicas_in_sync)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": { 104 | "id": "w9S3uKC_iXY5" 105 | }, 106 | "source": [ 107 | "## Configuration\n", 108 | "\n", 109 | "Find the list of all fine-tunable models [here](https://tfhub.dev/sayakpaul/collections/swin/1)." 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": { 116 | "id": "kCc6tdUGnD4C" 117 | }, 118 | "outputs": [], 119 | "source": [ 120 | "# Model\n", 121 | "IMAGE_SIZE = [224, 224] # Change this accordingly. \n", 122 | "MODEL_PATH = \"https://tfhub.dev/sayakpaul/swin_tiny_patch4_window7_224_fe\" \n", 123 | "\n", 124 | "# TPU\n", 125 | "if tpu:\n", 126 | " BATCH_SIZE = (\n", 127 | " 16 * strategy.num_replicas_in_sync\n", 128 | " ) # a TPU has 8 cores so this will be 128\n", 129 | "else:\n", 130 | " BATCH_SIZE = 128 # on Colab/GPU, a higher batch size may throw OOM\n", 131 | "\n", 132 | "# Dataset\n", 133 | "CLASSES = [\n", 134 | " \"dandelion\",\n", 135 | " \"daisy\",\n", 136 | " \"tulips\",\n", 137 | " \"sunflowers\",\n", 138 | " \"roses\",\n", 139 | "] # don't change the order\n", 140 | "\n", 141 | "# Other constants\n", 142 | "MEAN = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255]) # imagenet mean\n", 143 | "STD = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255]) # imagenet std\n", 144 | "AUTO = tf.data.AUTOTUNE" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": { 150 | "id": "9iTImGI5qMQT" 151 | }, 152 | "source": [ 153 | "# Data Pipeline\n", 154 | "\n", 155 | "[DeiT authors](https://arxiv.org/abs/2012.12877) use a separate preprocessing pipeline for fine-tuning. But for keeping this walkthrough short and simple, we can just perform the basic ones." 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "h29TLx7gqN_7" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "def make_dataset(dataset: tf.data.Dataset, train: bool, image_size: int = IMAGE_SIZE):\n", 167 | " def preprocess(image, label):\n", 168 | " # for training, do augmentation\n", 169 | " if train:\n", 170 | " if tf.random.uniform(shape=[]) > 0.5:\n", 171 | " image = tf.image.flip_left_right(image)\n", 172 | " image = tf.image.resize(image, size=image_size, method=\"bicubic\")\n", 173 | " image = (image - MEAN) / STD # normalization\n", 174 | " return image, label\n", 175 | "\n", 176 | " if train:\n", 177 | " dataset = dataset.shuffle(BATCH_SIZE * 10)\n", 178 | "\n", 179 | " return dataset.map(preprocess, AUTO).batch(BATCH_SIZE).prefetch(AUTO)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": { 185 | "id": "AMQ3Qs9_pddU" 186 | }, 187 | "source": [ 188 | "# Flower Dataset" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": { 195 | "id": "M3G-2aUBQJ-H" 196 | }, 197 | "outputs": [], 198 | "source": [ 199 | "train_dataset, val_dataset = tfds.load(\n", 200 | " \"tf_flowers\",\n", 201 | " split=[\"train[:90%]\", \"train[90%:]\"],\n", 202 | " as_supervised=True,\n", 203 | " try_gcs=False, # gcs_path is necessary for tpu,\n", 204 | ")\n", 205 | "\n", 206 | "num_train = tf.data.experimental.cardinality(train_dataset)\n", 207 | "num_val = tf.data.experimental.cardinality(val_dataset)\n", 208 | "print(f\"Number of training examples: {num_train}\")\n", 209 | "print(f\"Number of validation examples: {num_val}\")" 210 | ] 211 | }, 212 | { 213 | "cell_type": "markdown", 214 | "metadata": { 215 | "id": "l2X7sE3oRLXN" 216 | }, 217 | "source": [ 218 | "## Prepare dataset" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": { 225 | "id": "oftrfYw1qXei" 226 | }, 227 | "outputs": [], 228 | "source": [ 229 | "train_dataset = make_dataset(train_dataset, True)\n", 230 | "val_dataset = make_dataset(val_dataset, False)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "id": "kNyCCM6PRM8I" 237 | }, 238 | "source": [ 239 | "## Visualize" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "id": "IaGzFUUVqjaC" 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "sample_images, sample_labels = next(iter(train_dataset))\n", 251 | "\n", 252 | "plt.figure(figsize=(5 * 3, 3 * 3))\n", 253 | "for n in range(15):\n", 254 | " ax = plt.subplot(3, 5, n + 1)\n", 255 | " image = (sample_images[n] * STD + MEAN).numpy()\n", 256 | " image = (image - image.min()) / (\n", 257 | " image.max() - image.min()\n", 258 | " ) # convert to [0, 1] for avoiding matplotlib warning\n", 259 | " plt.imshow(image)\n", 260 | " plt.title(CLASSES[sample_labels[n]])\n", 261 | " plt.axis(\"off\")\n", 262 | "plt.tight_layout()\n", 263 | "plt.show()" 264 | ] 265 | }, 266 | { 267 | "cell_type": "markdown", 268 | "metadata": { 269 | "id": "Qf6u_7tt8BYy" 270 | }, 271 | "source": [ 272 | "# LR Scheduler Utility" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "metadata": { 279 | "id": "oVTbnkJL79T_" 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "# Reference:\n", 284 | "# https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2\n", 285 | "\n", 286 | "\n", 287 | "class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):\n", 288 | " def __init__(\n", 289 | " self, learning_rate_base, total_steps, warmup_learning_rate, warmup_steps\n", 290 | " ):\n", 291 | " super(WarmUpCosine, self).__init__()\n", 292 | "\n", 293 | " self.learning_rate_base = learning_rate_base\n", 294 | " self.total_steps = total_steps\n", 295 | " self.warmup_learning_rate = warmup_learning_rate\n", 296 | " self.warmup_steps = warmup_steps\n", 297 | " self.pi = tf.constant(np.pi)\n", 298 | "\n", 299 | " def __call__(self, step):\n", 300 | " if self.total_steps < self.warmup_steps:\n", 301 | " raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n", 302 | " learning_rate = (\n", 303 | " 0.5\n", 304 | " * self.learning_rate_base\n", 305 | " * (\n", 306 | " 1\n", 307 | " + tf.cos(\n", 308 | " self.pi\n", 309 | " * (tf.cast(step, tf.float32) - self.warmup_steps)\n", 310 | " / float(self.total_steps - self.warmup_steps)\n", 311 | " )\n", 312 | " )\n", 313 | " )\n", 314 | "\n", 315 | " if self.warmup_steps > 0:\n", 316 | " if self.learning_rate_base < self.warmup_learning_rate:\n", 317 | " raise ValueError(\n", 318 | " \"Learning_rate_base must be larger or equal to \"\n", 319 | " \"warmup_learning_rate.\"\n", 320 | " )\n", 321 | " slope = (\n", 322 | " self.learning_rate_base - self.warmup_learning_rate\n", 323 | " ) / self.warmup_steps\n", 324 | " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n", 325 | " learning_rate = tf.where(\n", 326 | " step < self.warmup_steps, warmup_rate, learning_rate\n", 327 | " )\n", 328 | " return tf.where(\n", 329 | " step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n", 330 | " )" 331 | ] 332 | }, 333 | { 334 | "cell_type": "markdown", 335 | "metadata": { 336 | "id": "ALtRUlxhw8Vt" 337 | }, 338 | "source": [ 339 | "# Model Utility" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": null, 345 | "metadata": { 346 | "id": "JD9SI_Q9JdAB" 347 | }, 348 | "outputs": [], 349 | "source": [ 350 | "def get_model(model_url: str, res: int = IMAGE_SIZE[0], num_classes: int = 5) -> tf.keras.Model:\n", 351 | " inputs = tf.keras.Input((res, res, 3))\n", 352 | " hub_module = hub.KerasLayer(model_url, trainable=True)\n", 353 | "\n", 354 | " x = hub_module(inputs, training=False) \n", 355 | " outputs = keras.layers.Dense(num_classes, activation=\"softmax\")(x)\n", 356 | "\n", 357 | " return tf.keras.Model(inputs, outputs)" 358 | ] 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": null, 363 | "metadata": { 364 | "id": "wpZApp9u9_Y-" 365 | }, 366 | "outputs": [], 367 | "source": [ 368 | "get_model(MODEL_PATH).summary()" 369 | ] 370 | }, 371 | { 372 | "cell_type": "markdown", 373 | "metadata": { 374 | "id": "dMfenMQcxAAb" 375 | }, 376 | "source": [ 377 | "# Training Hyperparameters" 378 | ] 379 | }, 380 | { 381 | "cell_type": "code", 382 | "execution_count": null, 383 | "metadata": { 384 | "id": "1D7Iu7oD8WzX" 385 | }, 386 | "outputs": [], 387 | "source": [ 388 | "EPOCHS = 10\n", 389 | "WARMUP_STEPS = 10\n", 390 | "INIT_LR = 0.03\n", 391 | "WAMRUP_LR = 0.006\n", 392 | "\n", 393 | "TOTAL_STEPS = int((num_train / BATCH_SIZE) * EPOCHS)\n", 394 | "\n", 395 | "scheduled_lrs = WarmUpCosine(\n", 396 | " learning_rate_base=INIT_LR,\n", 397 | " total_steps=TOTAL_STEPS,\n", 398 | " warmup_learning_rate=WAMRUP_LR,\n", 399 | " warmup_steps=WARMUP_STEPS,\n", 400 | ")" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": null, 406 | "metadata": { 407 | "id": "M-ID7vP5mIKs" 408 | }, 409 | "outputs": [], 410 | "source": [ 411 | "optimizer = keras.optimizers.SGD(scheduled_lrs)\n", 412 | "loss = keras.losses.SparseCategoricalCrossentropy()" 413 | ] 414 | }, 415 | { 416 | "cell_type": "markdown", 417 | "metadata": { 418 | "id": "E9p4ymNh9y7d" 419 | }, 420 | "source": [ 421 | "# Training & Validation" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": { 428 | "id": "VnZTSd8K90Mq" 429 | }, 430 | "outputs": [], 431 | "source": [ 432 | "with strategy.scope(): # this line is all that is needed to run on TPU (or multi-GPU, ...)\n", 433 | " model = get_model(MODEL_PATH)\n", 434 | " model.compile(loss=loss, optimizer=optimizer, metrics=[\"accuracy\"])\n", 435 | "\n", 436 | "history = model.fit(train_dataset, validation_data=val_dataset, epochs=EPOCHS)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": null, 442 | "metadata": { 443 | "id": "jc7LMVz5Cbx6" 444 | }, 445 | "outputs": [], 446 | "source": [ 447 | "result = pd.DataFrame(history.history)\n", 448 | "fig, ax = plt.subplots(2, 1, figsize=(10, 10))\n", 449 | "result[[\"accuracy\", \"val_accuracy\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[0])\n", 450 | "result[[\"loss\", \"val_loss\"]].plot(xlabel=\"epoch\", ylabel=\"score\", ax=ax[1])" 451 | ] 452 | }, 453 | { 454 | "cell_type": "markdown", 455 | "metadata": { 456 | "id": "MKFMWzh0Yxsq" 457 | }, 458 | "source": [ 459 | "# Predictions" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": null, 465 | "metadata": { 466 | "id": "yMEsR851VDZb" 467 | }, 468 | "outputs": [], 469 | "source": [ 470 | "sample_images, sample_labels = next(iter(val_dataset))\n", 471 | "\n", 472 | "predictions = model.predict(sample_images, batch_size=16).argmax(axis=-1)\n", 473 | "evaluations = model.evaluate(sample_images, sample_labels, batch_size=16)\n", 474 | "\n", 475 | "print(\"[val_loss, val_acc]\", evaluations)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "metadata": { 482 | "id": "qzCCDL1CZFx6" 483 | }, 484 | "outputs": [], 485 | "source": [ 486 | "plt.figure(figsize=(5 * 3, 3 * 3))\n", 487 | "for n in range(15):\n", 488 | " ax = plt.subplot(3, 5, n + 1)\n", 489 | " image = (sample_images[n] * STD + MEAN).numpy()\n", 490 | " image = (image - image.min()) / (\n", 491 | " image.max() - image.min()\n", 492 | " ) # convert to [0, 1] for avoiding matplotlib warning\n", 493 | " plt.imshow(image)\n", 494 | " target = CLASSES[sample_labels[n]]\n", 495 | " pred = CLASSES[predictions[n]]\n", 496 | " plt.title(\"{} ({})\".format(target, pred))\n", 497 | " plt.axis(\"off\")\n", 498 | "\n", 499 | "plt.tight_layout()\n", 500 | "plt.show()" 501 | ] 502 | }, 503 | { 504 | "cell_type": "markdown", 505 | "metadata": { 506 | "id": "2e5oy9zmNNID" 507 | }, 508 | "source": [ 509 | "# Reference\n", 510 | "* [Keras Flowers on TPU (solution)](https://colab.research.google.com/github/GoogleCloudPlatform/training-data-analyst/blob/master/courses/fast-and-lean-data-science/07_Keras_Flowers_TPU_solution.ipynb)\n", 511 | "\n", 512 | "\n", 513 | "This notebook is copied and modified from [here](https://github.com/sayakpaul/ConvNeXt-TF/blob/main/notebooks/finetune.ipynb). I'm thankful to [awsaf49](https://github.com/awsaf49) who originally worked on that notebook. " 514 | ] 515 | } 516 | ], 517 | "metadata": { 518 | "accelerator": "GPU", 519 | "colab": { 520 | "collapsed_sections": [], 521 | "machine_shape": "hm", 522 | "name": "finetune", 523 | "provenance": [], 524 | "toc_visible": true, 525 | "include_colab_link": true 526 | }, 527 | "environment": { 528 | "kernel": "python3", 529 | "name": "tf2-gpu.2-7.m87", 530 | "type": "gcloud", 531 | "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-7:m87" 532 | }, 533 | "kernelspec": { 534 | "display_name": "Python 3", 535 | "language": "python", 536 | "name": "python3" 537 | }, 538 | "language_info": { 539 | "codemirror_mode": { 540 | "name": "ipython", 541 | "version": 3 542 | }, 543 | "file_extension": ".py", 544 | "mimetype": "text/x-python", 545 | "name": "python", 546 | "nbconvert_exporter": "python", 547 | "pygments_lexer": "ipython3", 548 | "version": "3.7.12" 549 | } 550 | }, 551 | "nbformat": 4, 552 | "nbformat_minor": 0 553 | } -------------------------------------------------------------------------------- /swin-transformers-tf/notebooks/ilsvrc2012_wordnet_lemmas.txt: -------------------------------------------------------------------------------- 1 | tench, Tinca_tinca 2 | goldfish, Carassius_auratus 3 | great_white_shark, white_shark, man-eater, man-eating_shark, Carcharodon_carcharias 4 | tiger_shark, Galeocerdo_cuvieri 5 | hammerhead, hammerhead_shark 6 | electric_ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio_camelus 11 | brambling, Fringilla_montifringilla 12 | goldfinch, Carduelis_carduelis 13 | house_finch, linnet, Carpodacus_mexicanus 14 | junco, snowbird 15 | indigo_bunting, indigo_finch, indigo_bird, Passerina_cyanea 16 | robin, American_robin, Turdus_migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water_ouzel, dipper 22 | kite 23 | bald_eagle, American_eagle, Haliaeetus_leucocephalus 24 | vulture 25 | great_grey_owl, great_gray_owl, Strix_nebulosa 26 | European_fire_salamander, Salamandra_salamandra 27 | common_newt, Triturus_vulgaris 28 | eft 29 | spotted_salamander, Ambystoma_maculatum 30 | axolotl, mud_puppy, Ambystoma_mexicanum 31 | bullfrog, Rana_catesbeiana 32 | tree_frog, tree-frog 33 | tailed_frog, bell_toad, ribbed_toad, tailed_toad, Ascaphus_trui 34 | loggerhead, loggerhead_turtle, Caretta_caretta 35 | leatherback_turtle, leatherback, leathery_turtle, Dermochelys_coriacea 36 | mud_turtle 37 | terrapin 38 | box_turtle, box_tortoise 39 | banded_gecko 40 | common_iguana, iguana, Iguana_iguana 41 | American_chameleon, anole, Anolis_carolinensis 42 | whiptail, whiptail_lizard 43 | agama 44 | frilled_lizard, Chlamydosaurus_kingi 45 | alligator_lizard 46 | Gila_monster, Heloderma_suspectum 47 | green_lizard, Lacerta_viridis 48 | African_chameleon, Chamaeleo_chamaeleon 49 | Komodo_dragon, Komodo_lizard, dragon_lizard, giant_lizard, Varanus_komodoensis 50 | African_crocodile, Nile_crocodile, Crocodylus_niloticus 51 | American_alligator, Alligator_mississipiensis 52 | triceratops 53 | thunder_snake, worm_snake, Carphophis_amoenus 54 | ringneck_snake, ring-necked_snake, ring_snake 55 | hognose_snake, puff_adder, sand_viper 56 | green_snake, grass_snake 57 | king_snake, kingsnake 58 | garter_snake, grass_snake 59 | water_snake 60 | vine_snake 61 | night_snake, Hypsiglena_torquata 62 | boa_constrictor, Constrictor_constrictor 63 | rock_python, rock_snake, Python_sebae 64 | Indian_cobra, Naja_naja 65 | green_mamba 66 | sea_snake 67 | horned_viper, cerastes, sand_viper, horned_asp, Cerastes_cornutus 68 | diamondback, diamondback_rattlesnake, Crotalus_adamanteus 69 | sidewinder, horned_rattlesnake, Crotalus_cerastes 70 | trilobite 71 | harvestman, daddy_longlegs, Phalangium_opilio 72 | scorpion 73 | black_and_gold_garden_spider, Argiope_aurantia 74 | barn_spider, Araneus_cavaticus 75 | garden_spider, Aranea_diademata 76 | black_widow, Latrodectus_mactans 77 | tarantula 78 | wolf_spider, hunting_spider 79 | tick 80 | centipede 81 | black_grouse 82 | ptarmigan 83 | ruffed_grouse, partridge, Bonasa_umbellus 84 | prairie_chicken, prairie_grouse, prairie_fowl 85 | peacock 86 | quail 87 | partridge 88 | African_grey, African_gray, Psittacus_erithacus 89 | macaw 90 | sulphur-crested_cockatoo, Kakatoe_galerita, Cacatua_galerita 91 | lorikeet 92 | coucal 93 | bee_eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted_merganser, Mergus_serrator 100 | goose 101 | black_swan, Cygnus_atratus 102 | tusker 103 | echidna, spiny_anteater, anteater 104 | platypus, duckbill, duckbilled_platypus, duck-billed_platypus, Ornithorhynchus_anatinus 105 | wallaby, brush_kangaroo 106 | koala, koala_bear, kangaroo_bear, native_bear, Phascolarctos_cinereus 107 | wombat 108 | jellyfish 109 | sea_anemone, anemone 110 | brain_coral 111 | flatworm, platyhelminth 112 | nematode, nematode_worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea_slug, nudibranch 117 | chiton, coat-of-mail_shell, sea_cradle, polyplacophore 118 | chambered_nautilus, pearly_nautilus, nautilus 119 | Dungeness_crab, Cancer_magister 120 | rock_crab, Cancer_irroratus 121 | fiddler_crab 122 | king_crab, Alaska_crab, Alaskan_king_crab, Alaska_king_crab, Paralithodes_camtschatica 123 | American_lobster, Northern_lobster, Maine_lobster, Homarus_americanus 124 | spiny_lobster, langouste, rock_lobster, crawfish, crayfish, sea_crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit_crab 127 | isopod 128 | white_stork, Ciconia_ciconia 129 | black_stork, Ciconia_nigra 130 | spoonbill 131 | flamingo 132 | little_blue_heron, Egretta_caerulea 133 | American_egret, great_white_heron, Egretta_albus 134 | bittern 135 | crane 136 | limpkin, Aramus_pictus 137 | European_gallinule, Porphyrio_porphyrio 138 | American_coot, marsh_hen, mud_hen, water_hen, Fulica_americana 139 | bustard 140 | ruddy_turnstone, Arenaria_interpres 141 | red-backed_sandpiper, dunlin, Erolia_alpina 142 | redshank, Tringa_totanus 143 | dowitcher 144 | oystercatcher, oyster_catcher 145 | pelican 146 | king_penguin, Aptenodytes_patagonica 147 | albatross, mollymawk 148 | grey_whale, gray_whale, devilfish, Eschrichtius_gibbosus, Eschrichtius_robustus 149 | killer_whale, killer, orca, grampus, sea_wolf, Orcinus_orca 150 | dugong, Dugong_dugon 151 | sea_lion 152 | Chihuahua 153 | Japanese_spaniel 154 | Maltese_dog, Maltese_terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim_spaniel 158 | papillon 159 | toy_terrier 160 | Rhodesian_ridgeback 161 | Afghan_hound, Afghan 162 | basset, basset_hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan_coonhound 167 | Walker_hound, Walker_foxhound 168 | English_foxhound 169 | redbone 170 | borzoi, Russian_wolfhound 171 | Irish_wolfhound 172 | Italian_greyhound 173 | whippet 174 | Ibizan_hound, Ibizan_Podenco 175 | Norwegian_elkhound, elkhound 176 | otterhound, otter_hound 177 | Saluki, gazelle_hound 178 | Scottish_deerhound, deerhound 179 | Weimaraner 180 | Staffordshire_bullterrier, Staffordshire_bull_terrier 181 | American_Staffordshire_terrier, Staffordshire_terrier, American_pit_bull_terrier, pit_bull_terrier 182 | Bedlington_terrier 183 | Border_terrier 184 | Kerry_blue_terrier 185 | Irish_terrier 186 | Norfolk_terrier 187 | Norwich_terrier 188 | Yorkshire_terrier 189 | wire-haired_fox_terrier 190 | Lakeland_terrier 191 | Sealyham_terrier, Sealyham 192 | Airedale, Airedale_terrier 193 | cairn, cairn_terrier 194 | Australian_terrier 195 | Dandie_Dinmont, Dandie_Dinmont_terrier 196 | Boston_bull, Boston_terrier 197 | miniature_schnauzer 198 | giant_schnauzer 199 | standard_schnauzer 200 | Scotch_terrier, Scottish_terrier, Scottie 201 | Tibetan_terrier, chrysanthemum_dog 202 | silky_terrier, Sydney_silky 203 | soft-coated_wheaten_terrier 204 | West_Highland_white_terrier 205 | Lhasa, Lhasa_apso 206 | flat-coated_retriever 207 | curly-coated_retriever 208 | golden_retriever 209 | Labrador_retriever 210 | Chesapeake_Bay_retriever 211 | German_short-haired_pointer 212 | vizsla, Hungarian_pointer 213 | English_setter 214 | Irish_setter, red_setter 215 | Gordon_setter 216 | Brittany_spaniel 217 | clumber, clumber_spaniel 218 | English_springer, English_springer_spaniel 219 | Welsh_springer_spaniel 220 | cocker_spaniel, English_cocker_spaniel, cocker 221 | Sussex_spaniel 222 | Irish_water_spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old_English_sheepdog, bobtail 231 | Shetland_sheepdog, Shetland_sheep_dog, Shetland 232 | collie 233 | Border_collie 234 | Bouvier_des_Flandres, Bouviers_des_Flandres 235 | Rottweiler 236 | German_shepherd, German_shepherd_dog, German_police_dog, alsatian 237 | Doberman, Doberman_pinscher 238 | miniature_pinscher 239 | Greater_Swiss_Mountain_dog 240 | Bernese_mountain_dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull_mastiff 245 | Tibetan_mastiff 246 | French_bulldog 247 | Great_Dane 248 | Saint_Bernard, St_Bernard 249 | Eskimo_dog, husky 250 | malamute, malemute, Alaskan_malamute 251 | Siberian_husky 252 | dalmatian, coach_dog, carriage_dog 253 | affenpinscher, monkey_pinscher, monkey_dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland_dog 258 | Great_Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow_chow 262 | keeshond 263 | Brabancon_griffon 264 | Pembroke, Pembroke_Welsh_corgi 265 | Cardigan, Cardigan_Welsh_corgi 266 | toy_poodle 267 | miniature_poodle 268 | standard_poodle 269 | Mexican_hairless 270 | timber_wolf, grey_wolf, gray_wolf, Canis_lupus 271 | white_wolf, Arctic_wolf, Canis_lupus_tundrarum 272 | red_wolf, maned_wolf, Canis_rufus, Canis_niger 273 | coyote, prairie_wolf, brush_wolf, Canis_latrans 274 | dingo, warrigal, warragal, Canis_dingo 275 | dhole, Cuon_alpinus 276 | African_hunting_dog, hyena_dog, Cape_hunting_dog, Lycaon_pictus 277 | hyena, hyaena 278 | red_fox, Vulpes_vulpes 279 | kit_fox, Vulpes_macrotis 280 | Arctic_fox, white_fox, Alopex_lagopus 281 | grey_fox, gray_fox, Urocyon_cinereoargenteus 282 | tabby, tabby_cat 283 | tiger_cat 284 | Persian_cat 285 | Siamese_cat, Siamese 286 | Egyptian_cat 287 | cougar, puma, catamount, mountain_lion, painter, panther, Felis_concolor 288 | lynx, catamount 289 | leopard, Panthera_pardus 290 | snow_leopard, ounce, Panthera_uncia 291 | jaguar, panther, Panthera_onca, Felis_onca 292 | lion, king_of_beasts, Panthera_leo 293 | tiger, Panthera_tigris 294 | cheetah, chetah, Acinonyx_jubatus 295 | brown_bear, bruin, Ursus_arctos 296 | American_black_bear, black_bear, Ursus_americanus, Euarctos_americanus 297 | ice_bear, polar_bear, Ursus_Maritimus, Thalarctos_maritimus 298 | sloth_bear, Melursus_ursinus, Ursus_ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger_beetle 302 | ladybug, ladybeetle, lady_beetle, ladybird, ladybird_beetle 303 | ground_beetle, carabid_beetle 304 | long-horned_beetle, longicorn, longicorn_beetle 305 | leaf_beetle, chrysomelid 306 | dung_beetle 307 | rhinoceros_beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking_stick, walkingstick, stick_insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing_fly 320 | dragonfly, darning_needle, devil's_darning_needle, sewing_needle, snake_feeder, snake_doctor, mosquito_hawk, skeeter_hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet_butterfly 324 | monarch, monarch_butterfly, milkweed_butterfly, Danaus_plexippus 325 | cabbage_butterfly 326 | sulphur_butterfly, sulfur_butterfly 327 | lycaenid, lycaenid_butterfly 328 | starfish, sea_star 329 | sea_urchin 330 | sea_cucumber, holothurian 331 | wood_rabbit, cottontail, cottontail_rabbit 332 | hare 333 | Angora, Angora_rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox_squirrel, eastern_fox_squirrel, Sciurus_niger 337 | marmot 338 | beaver 339 | guinea_pig, Cavia_cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus_scrofa 343 | wild_boar, boar, Sus_scrofa 344 | warthog 345 | hippopotamus, hippo, river_horse, Hippopotamus_amphibius 346 | ox 347 | water_buffalo, water_ox, Asiatic_buffalo, Bubalus_bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn_sheep, cimarron, Rocky_Mountain_bighorn, Rocky_Mountain_sheep, Ovis_canadensis 351 | ibex, Capra_ibex 352 | hartebeest 353 | impala, Aepyceros_melampus 354 | gazelle 355 | Arabian_camel, dromedary, Camelus_dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela_putorius 360 | black-footed_ferret, ferret, Mustela_nigripes 361 | otter 362 | skunk, polecat, wood_pussy 363 | badger 364 | armadillo 365 | three-toed_sloth, ai, Bradypus_tridactylus 366 | orangutan, orang, orangutang, Pongo_pygmaeus 367 | gorilla, Gorilla_gorilla 368 | chimpanzee, chimp, Pan_troglodytes 369 | gibbon, Hylobates_lar 370 | siamang, Hylobates_syndactylus, Symphalangus_syndactylus 371 | guenon, guenon_monkey 372 | patas, hussar_monkey, Erythrocebus_patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus_monkey 377 | proboscis_monkey, Nasalis_larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus_capucinus 380 | howler_monkey, howler 381 | titi, titi_monkey 382 | spider_monkey, Ateles_geoffroyi 383 | squirrel_monkey, Saimiri_sciureus 384 | Madagascar_cat, ring-tailed_lemur, Lemur_catta 385 | indri, indris, Indri_indri, Indri_brevicaudatus 386 | Indian_elephant, Elephas_maximus 387 | African_elephant, Loxodonta_africana 388 | lesser_panda, red_panda, panda, bear_cat, cat_bear, Ailurus_fulgens 389 | giant_panda, panda, panda_bear, coon_bear, Ailuropoda_melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho_salmon, blue_jack, silver_salmon, Oncorhynchus_kisutch 393 | rock_beauty, Holocanthus_tricolor 394 | anemone_fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus_osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic_gown, academic_robe, judge's_robe 402 | accordion, piano_accordion, squeeze_box 403 | acoustic_guitar 404 | aircraft_carrier, carrier, flattop, attack_aircraft_carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious_vehicle 410 | analog_clock 411 | apiary, bee_house 412 | apron 413 | ashcan, trash_can, garbage_can, wastebin, ash_bin, ash-bin, ashbin, dustbin, trash_barrel, trash_bin 414 | assault_rifle, assault_gun 415 | backpack, back_pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance_beam, beam 418 | balloon 419 | ballpoint, ballpoint_pen, ballpen, Biro 420 | Band_Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber_chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden_cart, lawn_cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing_cap, swimming_cap 435 | bath_towel 436 | bathtub, bathing_tub, bath, tub 437 | beach_wagon, station_wagon, wagon, estate_car, beach_waggon, station_waggon, waggon 438 | beacon, lighthouse, beacon_light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer_bottle 442 | beer_glass 443 | bell_cote, bell_cot 444 | bib 445 | bicycle-built-for-two, tandem_bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field_glasses, opera_glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo_tie, bolo, bola_tie, bola 453 | bonnet, poke_bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow_tie, bow-tie, bowtie 459 | brass, memorial_tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof_vest 467 | bullet_train, bullet 468 | butcher_shop, meat_market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax_light 472 | cannon 473 | canoe 474 | can_opener, tin_opener 475 | cardigan 476 | car_mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's_kit, tool_kit 479 | carton 480 | car_wheel 481 | cash_machine, cash_dispenser, automated_teller_machine, automatic_teller_machine, automated_teller, automatic_teller, ATM 482 | cassette 483 | cassette_player 484 | castle 485 | catamaran 486 | CD_player 487 | cello, violoncello 488 | cellular_telephone, cellular_phone, cellphone, cell, mobile_phone 489 | chain 490 | chainlink_fence 491 | chain_mail, ring_mail, mail, chain_armor, chain_armour, ring_armor, ring_armour 492 | chain_saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china_cabinet, china_closet 497 | Christmas_stocking 498 | church, church_building 499 | cinema, movie_theater, movie_theatre, movie_house, picture_palace 500 | cleaver, meat_cleaver, chopper 501 | cliff_dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail_shaker 505 | coffee_mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination_lock 509 | computer_keyboard, keypad 510 | confectionery, confectionary, candy_store 511 | container_ship, containership, container_vessel 512 | convertible 513 | corkscrew, bottle_screw 514 | cornet, horn, trumpet, trump 515 | cowboy_boot 516 | cowboy_hat, ten-gallon_hat 517 | cradle 518 | crane 519 | crash_helmet 520 | crate 521 | crib, cot 522 | Crock_Pot 523 | croquet_ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop_computer 529 | dial_telephone, dial_phone 530 | diaper, nappy, napkin 531 | digital_clock 532 | digital_watch 533 | dining_table, board 534 | dishrag, dishcloth 535 | dishwasher, dish_washer, dishwashing_machine 536 | disk_brake, disc_brake 537 | dock, dockage, docking_facility 538 | dogsled, dog_sled, dog_sleigh 539 | dome 540 | doormat, welcome_mat 541 | drilling_platform, offshore_rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch_oven 546 | electric_fan, blower 547 | electric_guitar 548 | electric_locomotive 549 | entertainment_center 550 | envelope 551 | espresso_maker 552 | face_powder 553 | feather_boa, boa 554 | file, file_cabinet, filing_cabinet 555 | fireboat 556 | fire_engine, fire_truck 557 | fire_screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse_flute 560 | folding_chair 561 | football_helmet 562 | forklift 563 | fountain 564 | fountain_pen 565 | four-poster 566 | freight_car 567 | French_horn, horn 568 | frying_pan, frypan, skillet 569 | fur_coat 570 | garbage_truck, dustcart 571 | gasmask, respirator, gas_helmet 572 | gas_pump, gasoline_pump, petrol_pump, island_dispenser 573 | goblet 574 | go-kart 575 | golf_ball 576 | golfcart, golf_cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand_piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator_grille 583 | grocery_store, grocery, food_market, market 584 | guillotine 585 | hair_slide 586 | hair_spray 587 | half_track 588 | hammer 589 | hamper 590 | hand_blower, blow_dryer, blow_drier, hair_dryer, hair_drier 591 | hand-held_computer, hand-held_microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard_disc, hard_disk, fixed_disk 594 | harmonica, mouth_organ, harp, mouth_harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home_theater, home_theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal_bar, high_bar 604 | horse_cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing_iron 608 | jack-o'-lantern 609 | jean, blue_jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee_shirt 612 | jigsaw_puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee_pad 617 | knot 618 | lab_coat, laboratory_coat 619 | ladle 620 | lampshade, lamp_shade 621 | laptop, laptop_computer 622 | lawn_mower, mower 623 | lens_cap, lens_cover 624 | letter_opener, paper_knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean_liner 630 | lipstick, lip_rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker_unit, loudspeaker_system, speaker_system 634 | loupe, jeweler's_loupe 635 | lumbermill, sawmill 636 | magnetic_compass 637 | mailbag, postbag 638 | mailbox, letter_box 639 | maillot 640 | maillot, tank_suit 641 | manhole_cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring_cup 649 | medicine_chest, medicine_cabinet 650 | megalith, megalithic_structure 651 | microphone, mike 652 | microwave, microwave_oven 653 | military_uniform 654 | milk_can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing_bowl 661 | mobile_home, manufactured_home 662 | Model_T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito_net 671 | motor_scooter, scooter 672 | mountain_bike, all-terrain_bike, off-roader 673 | mountain_tent 674 | mouse, computer_mouse 675 | mousetrap 676 | moving_van 677 | muzzle 678 | nail 679 | neck_brace 680 | necklace 681 | nipple 682 | notebook, notebook_computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet_potato 686 | odometer, hodometer, mileometer, milometer 687 | oil_filter 688 | organ, pipe_organ 689 | oscilloscope, scope, cathode-ray_oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen_mask 693 | packet 694 | paddle, boat_paddle 695 | paddlewheel, paddle_wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean_pipe, syrinx 701 | paper_towel 702 | parachute, chute 703 | parallel_bars, bars 704 | park_bench 705 | parking_meter 706 | passenger_car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil_box, pencil_case 711 | pencil_sharpener 712 | perfume, essence 713 | Petri_dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket_fence, paling 718 | pickup, pickup_truck 719 | pier 720 | piggy_bank, penny_bank 721 | pill_bottle 722 | pillow 723 | ping-pong_ball 724 | pinwheel 725 | pirate, pirate_ship 726 | pitcher, ewer 727 | plane, carpenter's_plane, woodworking_plane 728 | planetarium 729 | plastic_bag 730 | plate_rack 731 | plow, plough 732 | plunger, plumber's_helper 733 | Polaroid_camera, Polaroid_Land_camera 734 | pole 735 | police_van, police_wagon, paddy_wagon, patrol_wagon, wagon, black_Maria 736 | poncho 737 | pool_table, billiard_table, snooker_table 738 | pop_bottle, soda_bottle 739 | pot, flowerpot 740 | potter's_wheel 741 | power_drill 742 | prayer_rug, prayer_mat 743 | printer 744 | prison, prison_house 745 | projectile, missile 746 | projector 747 | puck, hockey_puck 748 | punching_bag, punch_bag, punching_ball, punchball 749 | purse 750 | quill, quill_pen 751 | quilt, comforter, comfort, puff 752 | racer, race_car, racing_car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio_telescope, radio_reflector 757 | rain_barrel 758 | recreational_vehicle, RV, R.V. 759 | reel 760 | reflex_camera 761 | refrigerator, icebox 762 | remote_control, remote 763 | restaurant, eating_house, eating_place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking_chair, rocker 767 | rotisserie 768 | rubber_eraser, rubber, pencil_eraser 769 | rugby_ball 770 | rule, ruler 771 | running_shoe 772 | safe 773 | safety_pin 774 | saltshaker, salt_shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing_machine 780 | school_bus 781 | schooner 782 | scoreboard 783 | screen, CRT_screen 784 | screw 785 | screwdriver 786 | seat_belt, seatbelt 787 | sewing_machine 788 | shield, buckler 789 | shoe_shop, shoe-shop, shoe_store 790 | shoji 791 | shopping_basket 792 | shopping_cart 793 | shovel 794 | shower_cap 795 | shower_curtain 796 | ski 797 | ski_mask 798 | sleeping_bag 799 | slide_rule, slipstick 800 | sliding_door 801 | slot, one-armed_bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap_dispenser 806 | soccer_ball 807 | sock 808 | solar_dish, solar_collector, solar_furnace 809 | sombrero 810 | soup_bowl 811 | space_bar 812 | space_heater 813 | space_shuttle 814 | spatula 815 | speedboat 816 | spider_web, spider's_web 817 | spindle 818 | sports_car, sport_car 819 | spotlight, spot 820 | stage 821 | steam_locomotive 822 | steel_arch_bridge 823 | steel_drum 824 | stethoscope 825 | stole 826 | stone_wall 827 | stopwatch, stop_watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley_car 831 | stretcher 832 | studio_couch, day_bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit_of_clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark_glasses, shades 839 | sunscreen, sunblock, sun_blocker 840 | suspension_bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming_trunks, bathing_trunks 844 | swing 845 | switch, electric_switch, electrical_switch 846 | syringe 847 | table_lamp 848 | tank, army_tank, armored_combat_vehicle, armoured_combat_vehicle 849 | tape_player 850 | teapot 851 | teddy, teddy_bear 852 | television, television_system 853 | tennis_ball 854 | thatch, thatched_roof 855 | theater_curtain, theatre_curtain 856 | thimble 857 | thresher, thrasher, threshing_machine 858 | throne 859 | tile_roof 860 | toaster 861 | tobacco_shop, tobacconist_shop, tobacconist 862 | toilet_seat 863 | torch 864 | totem_pole 865 | tow_truck, tow_car, wrecker 866 | toyshop 867 | tractor 868 | trailer_truck, tractor_trailer, trucking_rig, rig, articulated_lorry, semi 869 | tray 870 | trench_coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal_arch 875 | trolleybus, trolley_coach, trackless_trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter_keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright_piano 883 | vacuum, vacuum_cleaner 884 | vase 885 | vault 886 | velvet 887 | vending_machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle_iron 893 | wall_clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military_plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand_basin 898 | washer, automatic_washer, washing_machine 899 | water_bottle 900 | water_jug 901 | water_tower 902 | whiskey_jug 903 | whistle 904 | wig 905 | window_screen 906 | window_shade 907 | Windsor_tie 908 | wine_bottle 909 | wing 910 | wok 911 | wooden_spoon 912 | wool, woolen, woollen 913 | worm_fence, snake_fence, snake-rail_fence, Virginia_fence 914 | wreck 915 | yawl 916 | yurt 917 | web_site, website, internet_site, site 918 | comic_book 919 | crossword_puzzle, crossword 920 | street_sign 921 | traffic_light, traffic_signal, stoplight 922 | book_jacket, dust_cover, dust_jacket, dust_wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot_pot, hotpot 928 | trifle 929 | ice_cream, icecream 930 | ice_lolly, lolly, lollipop, popsicle 931 | French_loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot_dog, red_hot 936 | mashed_potato 937 | head_cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti_squash 942 | acorn_squash 943 | butternut_squash 944 | cucumber, cuke 945 | artichoke, globe_artichoke 946 | bell_pepper 947 | cardoon 948 | mushroom 949 | Granny_Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard_apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate_sauce, chocolate_syrup 962 | dough 963 | meat_loaf, meatloaf 964 | pizza, pizza_pie 965 | potpie 966 | burrito 967 | red_wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral_reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand_bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball_player 983 | groom, bridegroom 984 | scuba_diver 985 | rapeseed 986 | daisy 987 | yellow_lady's_slipper, yellow_lady-slipper, Cypripedium_calceolus, Cypripedium_parviflorum 988 | corn 989 | acorn 990 | hip, rose_hip, rosehip 991 | buckeye, horse_chestnut, conker 992 | coral_fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion_fungus 996 | earthstar 997 | hen-of-the-woods, hen_of_the_woods, Polyporus_frondosus, Grifola_frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet_tissue, toilet_paper, bathroom_tissue 1001 | -------------------------------------------------------------------------------- /swin-transformers-tf/notebooks/weight-porting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "40180afb", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import timm\n", 11 | "\n", 12 | "import tensorflow as tf\n", 13 | "import numpy as np" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 2, 19 | "id": "50ef790b", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import sys\n", 24 | "\n", 25 | "sys.path.append(\"..\")\n", 26 | "\n", 27 | "from swins import SwinTransformer\n", 28 | "from swins.layers import *\n", 29 | "from swins.blocks import *\n", 30 | "from utils import helpers" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 3, 36 | "id": "3d2ec87f", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "cfg = dict(\n", 41 | " patch_size=4,\n", 42 | " window_size=7,\n", 43 | " embed_dim=96,\n", 44 | " depths=(2, 2, 6, 2),\n", 45 | " num_heads=(3, 6, 12, 24),\n", 46 | ")" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 4, 52 | "id": "59542c5f", 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "2022-05-08 18:23:21.505079: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", 60 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" 61 | ] 62 | }, 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "Swin TF model created.\n" 68 | ] 69 | } 70 | ], 71 | "source": [ 72 | "swin_tiny_patch4_window7_224_tf = SwinTransformer(\n", 73 | " name=\"swin_tiny_patch4_window7_224\", **cfg\n", 74 | ")\n", 75 | "random_tensor = tf.random.normal((2, 224, 224, 3))\n", 76 | "outputs = swin_tiny_patch4_window7_224_tf(random_tensor, training=False)\n", 77 | "print(\"Swin TF model created.\")" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 5, 83 | "id": "d29d6661", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stderr", 88 | "output_type": "stream", 89 | "text": [ 90 | "/Users/sayakpaul/.local/bin/.virtualenvs/pytorch/lib/python3.8/site-packages/torch/functional.py:445: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:2157.)\n", 91 | " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" 92 | ] 93 | }, 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Swin PT model created.\n", 99 | "Number of parameters:\n", 100 | "28.288354\n" 101 | ] 102 | } 103 | ], 104 | "source": [ 105 | "swin_tiny_patch4_window7_224_pt = timm.create_model(\n", 106 | " model_name=\"swin_tiny_patch4_window7_224\", pretrained=True\n", 107 | ")\n", 108 | "print(\"Swin PT model created.\")\n", 109 | "print(\"Number of parameters:\")\n", 110 | "num_params = sum(p.numel() for p in swin_tiny_patch4_window7_224_pt.parameters())\n", 111 | "print(num_params / 1e6)\n", 112 | "\n", 113 | "assert swin_tiny_patch4_window7_224_tf.count_params() == num_params" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 6, 119 | "id": "50ee5556", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "state_dict = swin_tiny_patch4_window7_224_pt.state_dict()\n", 124 | "np_state_dict = {k: state_dict[k].numpy() for k in state_dict}" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 7, 130 | "id": "cbbc3080", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# Projection.\n", 135 | "swin_tiny_patch4_window7_224_tf.projection.layers[0] = helpers.modify_tf_block(\n", 136 | " swin_tiny_patch4_window7_224_tf.projection.layers[0],\n", 137 | " np_state_dict[\"patch_embed.proj.weight\"],\n", 138 | " np_state_dict[\"patch_embed.proj.bias\"],\n", 139 | ")\n", 140 | "swin_tiny_patch4_window7_224_tf.projection.layers[2] = helpers.modify_tf_block(\n", 141 | " swin_tiny_patch4_window7_224_tf.projection.layers[2],\n", 142 | " np_state_dict[\"patch_embed.norm.weight\"],\n", 143 | " np_state_dict[\"patch_embed.norm.bias\"],\n", 144 | ")" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 8, 150 | "id": "e80ad36b", 151 | "metadata": {}, 152 | "outputs": [], 153 | "source": [ 154 | "# Layer norm layers.\n", 155 | "ln_idx = -2\n", 156 | "swin_tiny_patch4_window7_224_tf.layers[ln_idx] = helpers.modify_tf_block(\n", 157 | " swin_tiny_patch4_window7_224_tf.layers[ln_idx],\n", 158 | " np_state_dict[\"norm.weight\"],\n", 159 | " np_state_dict[\"norm.bias\"],\n", 160 | ")\n", 161 | "\n", 162 | "# Head layers.\n", 163 | "head_layer = swin_tiny_patch4_window7_224_tf.get_layer(\"classification_head\")\n", 164 | "swin_tiny_patch4_window7_224_tf.layers[-1] = helpers.modify_tf_block(\n", 165 | " head_layer,\n", 166 | " np_state_dict[\"head.weight\"],\n", 167 | " np_state_dict[\"head.bias\"],\n", 168 | ")" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 9, 174 | "id": "03ba0496", 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "data": { 179 | "text/plain": [ 180 | "['layers.0.blocks.0.norm1.weight',\n", 181 | " 'layers.0.blocks.0.norm1.bias',\n", 182 | " 'layers.0.blocks.0.attn.relative_position_bias_table',\n", 183 | " 'layers.0.blocks.0.attn.relative_position_index',\n", 184 | " 'layers.0.blocks.0.attn.qkv.weight',\n", 185 | " 'layers.0.blocks.0.attn.qkv.bias',\n", 186 | " 'layers.0.blocks.0.attn.proj.weight',\n", 187 | " 'layers.0.blocks.0.attn.proj.bias',\n", 188 | " 'layers.0.blocks.0.norm2.weight',\n", 189 | " 'layers.0.blocks.0.norm2.bias',\n", 190 | " 'layers.0.blocks.0.mlp.fc1.weight',\n", 191 | " 'layers.0.blocks.0.mlp.fc1.bias',\n", 192 | " 'layers.0.blocks.0.mlp.fc2.weight',\n", 193 | " 'layers.0.blocks.0.mlp.fc2.bias',\n", 194 | " 'layers.0.blocks.1.attn_mask',\n", 195 | " 'layers.0.blocks.1.norm1.weight',\n", 196 | " 'layers.0.blocks.1.norm1.bias',\n", 197 | " 'layers.0.blocks.1.attn.relative_position_bias_table',\n", 198 | " 'layers.0.blocks.1.attn.relative_position_index',\n", 199 | " 'layers.0.blocks.1.attn.qkv.weight',\n", 200 | " 'layers.0.blocks.1.attn.qkv.bias',\n", 201 | " 'layers.0.blocks.1.attn.proj.weight',\n", 202 | " 'layers.0.blocks.1.attn.proj.bias',\n", 203 | " 'layers.0.blocks.1.norm2.weight',\n", 204 | " 'layers.0.blocks.1.norm2.bias',\n", 205 | " 'layers.0.blocks.1.mlp.fc1.weight',\n", 206 | " 'layers.0.blocks.1.mlp.fc1.bias',\n", 207 | " 'layers.0.blocks.1.mlp.fc2.weight',\n", 208 | " 'layers.0.blocks.1.mlp.fc2.bias',\n", 209 | " 'layers.0.downsample.reduction.weight',\n", 210 | " 'layers.0.downsample.norm.weight',\n", 211 | " 'layers.0.downsample.norm.bias']" 212 | ] 213 | }, 214 | "execution_count": 9, 215 | "metadata": {}, 216 | "output_type": "execute_result" 217 | } 218 | ], 219 | "source": [ 220 | "list(filter(lambda x: \"layers.0\" in x, np_state_dict.keys()))" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 10, 226 | "id": "5c7af809", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "def modify_swin_blocks(pt_weights_prefix, tf_block):\n", 231 | " # Patch merging.\n", 232 | " for layer in tf_block:\n", 233 | " if isinstance(layer, PatchMerging):\n", 234 | " patch_merging_idx = f\"{pt_weights_prefix}.downsample\"\n", 235 | "\n", 236 | " layer.reduction = helpers.modify_tf_block(\n", 237 | " layer.reduction,\n", 238 | " np_state_dict[f\"{patch_merging_idx}.reduction.weight\"],\n", 239 | " )\n", 240 | " layer.norm = helpers.modify_tf_block(\n", 241 | " layer.norm,\n", 242 | " np_state_dict[f\"{patch_merging_idx}.norm.weight\"],\n", 243 | " np_state_dict[f\"{patch_merging_idx}.norm.bias\"],\n", 244 | " )\n", 245 | "\n", 246 | " # Swin layers.\n", 247 | " common_prefix = f\"{pt_weights_prefix}.blocks\"\n", 248 | " block_idx = 0\n", 249 | "\n", 250 | " for outer_layer in tf_block:\n", 251 | "\n", 252 | " layernorm_idx = 1\n", 253 | " mlp_layer_idx = 1\n", 254 | "\n", 255 | " if isinstance(outer_layer, SwinTransformerBlock):\n", 256 | " for inner_layer in outer_layer.layers:\n", 257 | "\n", 258 | " # Layer norm.\n", 259 | " if isinstance(inner_layer, tf.keras.layers.LayerNormalization):\n", 260 | " layer_norm_prefix = (\n", 261 | " f\"{common_prefix}.{block_idx}.norm{layernorm_idx}\"\n", 262 | " )\n", 263 | " inner_layer.gamma.assign(\n", 264 | " tf.Variable(np_state_dict[f\"{layer_norm_prefix}.weight\"])\n", 265 | " )\n", 266 | " inner_layer.beta.assign(\n", 267 | " tf.Variable(np_state_dict[f\"{layer_norm_prefix}.bias\"])\n", 268 | " )\n", 269 | " layernorm_idx += 1\n", 270 | "\n", 271 | " # Windown attention.\n", 272 | " elif isinstance(inner_layer, WindowAttention):\n", 273 | " attn_prefix = f\"{common_prefix}.{block_idx}.attn\"\n", 274 | "\n", 275 | " # Relative position.\n", 276 | " inner_layer.relative_position_bias_table = helpers.modify_tf_block(\n", 277 | " inner_layer.relative_position_bias_table,\n", 278 | " np_state_dict[f\"{attn_prefix}.relative_position_bias_table\"],\n", 279 | " )\n", 280 | " inner_layer.relative_position_index = helpers.modify_tf_block(\n", 281 | " inner_layer.relative_position_index,\n", 282 | " np_state_dict[f\"{attn_prefix}.relative_position_index\"],\n", 283 | " )\n", 284 | "\n", 285 | " # QKV.\n", 286 | " inner_layer.qkv = helpers.modify_tf_block(\n", 287 | " inner_layer.qkv,\n", 288 | " np_state_dict[f\"{attn_prefix}.qkv.weight\"],\n", 289 | " np_state_dict[f\"{attn_prefix}.qkv.bias\"],\n", 290 | " )\n", 291 | "\n", 292 | " # Projection.\n", 293 | " inner_layer.proj = helpers.modify_tf_block(\n", 294 | " inner_layer.proj,\n", 295 | " np_state_dict[f\"{attn_prefix}.proj.weight\"],\n", 296 | " np_state_dict[f\"{attn_prefix}.proj.bias\"],\n", 297 | " )\n", 298 | "\n", 299 | " # MLP.\n", 300 | " elif isinstance(inner_layer, tf.keras.Model):\n", 301 | " mlp_prefix = f\"{common_prefix}.{block_idx}.mlp\"\n", 302 | " for mlp_layer in inner_layer.layers:\n", 303 | " if isinstance(mlp_layer, tf.keras.layers.Dense):\n", 304 | " mlp_layer = helpers.modify_tf_block(\n", 305 | " mlp_layer,\n", 306 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.weight\"],\n", 307 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.bias\"],\n", 308 | " )\n", 309 | " mlp_layer_idx += 1\n", 310 | "\n", 311 | " block_idx += 1\n", 312 | " return tf_block" 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 11, 318 | "id": "d947e04d", 319 | "metadata": {}, 320 | "outputs": [], 321 | "source": [ 322 | "_ = modify_swin_blocks(\n", 323 | " \"layers.0\",\n", 324 | " swin_tiny_patch4_window7_224_tf.layers[2].layers,\n", 325 | ")" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 12, 331 | "id": "cdb9c2e1", 332 | "metadata": {}, 333 | "outputs": [], 334 | "source": [ 335 | "tf_block = swin_tiny_patch4_window7_224_tf.layers[2].layers\n", 336 | "pt_weights_prefix = \"layers.0\"\n", 337 | "\n", 338 | "# Patch merging.\n", 339 | "for layer in tf_block:\n", 340 | " if isinstance(layer, PatchMerging):\n", 341 | " patch_merging_idx = f\"{pt_weights_prefix}.downsample\"\n", 342 | " np.testing.assert_allclose(\n", 343 | " np_state_dict[f\"{patch_merging_idx}.reduction.weight\"].transpose(),\n", 344 | " layer.reduction.kernel.numpy(),\n", 345 | " )\n", 346 | " np.testing.assert_allclose(\n", 347 | " np_state_dict[f\"{patch_merging_idx}.norm.weight\"], layer.norm.gamma.numpy()\n", 348 | " )\n", 349 | " np.testing.assert_allclose(\n", 350 | " np_state_dict[f\"{patch_merging_idx}.norm.bias\"], layer.norm.beta.numpy()\n", 351 | " )\n", 352 | "\n", 353 | "# Swin layers.\n", 354 | "common_prefix = f\"{pt_weights_prefix}.blocks\"\n", 355 | "block_idx = 0\n", 356 | "\n", 357 | "for outer_layer in tf_block:\n", 358 | "\n", 359 | " layernorm_idx = 1\n", 360 | " mlp_layer_idx = 1\n", 361 | "\n", 362 | " if isinstance(outer_layer, SwinTransformerBlock):\n", 363 | " for inner_layer in outer_layer.layers:\n", 364 | "\n", 365 | " # Layer norm.\n", 366 | " if isinstance(inner_layer, tf.keras.layers.LayerNormalization):\n", 367 | " layer_norm_prefix = f\"{common_prefix}.{block_idx}.norm{layernorm_idx}\"\n", 368 | " np.testing.assert_allclose(\n", 369 | " np_state_dict[f\"{layer_norm_prefix}.weight\"],\n", 370 | " inner_layer.gamma.numpy(),\n", 371 | " )\n", 372 | " np.testing.assert_allclose(\n", 373 | " np_state_dict[f\"{layer_norm_prefix}.bias\"], inner_layer.beta.numpy()\n", 374 | " )\n", 375 | " layernorm_idx += 1\n", 376 | "\n", 377 | " # Windown attention.\n", 378 | " elif isinstance(inner_layer, WindowAttention):\n", 379 | " attn_prefix = f\"{common_prefix}.{block_idx}.attn\"\n", 380 | "\n", 381 | " # Relative position.\n", 382 | " np.testing.assert_allclose(\n", 383 | " np_state_dict[f\"{attn_prefix}.relative_position_bias_table\"],\n", 384 | " inner_layer.relative_position_bias_table.numpy(),\n", 385 | " )\n", 386 | "\n", 387 | " np.testing.assert_allclose(\n", 388 | " np_state_dict[f\"{attn_prefix}.relative_position_index\"],\n", 389 | " inner_layer.relative_position_index.numpy(),\n", 390 | " )\n", 391 | "\n", 392 | " # QKV.\n", 393 | " np.testing.assert_allclose(\n", 394 | " np_state_dict[f\"{attn_prefix}.qkv.weight\"].transpose(),\n", 395 | " inner_layer.qkv.kernel.numpy(),\n", 396 | " )\n", 397 | " np.testing.assert_allclose(\n", 398 | " np_state_dict[f\"{attn_prefix}.qkv.bias\"],\n", 399 | " inner_layer.qkv.bias.numpy(),\n", 400 | " )\n", 401 | "\n", 402 | " # Projection.\n", 403 | " np.testing.assert_allclose(\n", 404 | " np_state_dict[f\"{attn_prefix}.proj.weight\"].transpose(),\n", 405 | " inner_layer.proj.kernel.numpy(),\n", 406 | " )\n", 407 | " np.testing.assert_allclose(\n", 408 | " np_state_dict[f\"{attn_prefix}.proj.bias\"],\n", 409 | " inner_layer.proj.bias.numpy(),\n", 410 | " )\n", 411 | "\n", 412 | " # MLP.\n", 413 | " elif isinstance(inner_layer, tf.keras.Model):\n", 414 | " mlp_prefix = f\"{common_prefix}.{block_idx}.mlp\"\n", 415 | " for mlp_layer in inner_layer.layers:\n", 416 | " if isinstance(mlp_layer, tf.keras.layers.Dense):\n", 417 | " np.testing.assert_allclose(\n", 418 | " np_state_dict[\n", 419 | " f\"{mlp_prefix}.fc{mlp_layer_idx}.weight\"\n", 420 | " ].transpose(),\n", 421 | " mlp_layer.kernel.numpy(),\n", 422 | " )\n", 423 | " np.testing.assert_allclose(\n", 424 | " np_state_dict[f\"{mlp_prefix}.fc{mlp_layer_idx}.bias\"],\n", 425 | " mlp_layer.bias.numpy(),\n", 426 | " )\n", 427 | "\n", 428 | " mlp_layer_idx += 1\n", 429 | "\n", 430 | " block_idx += 1" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": 13, 436 | "id": "9743e538", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [ 440 | "for i in range(len(cfg[\"depths\"])):\n", 441 | " _ = modify_swin_blocks(\n", 442 | " f\"layers.{i}\",\n", 443 | " swin_tiny_patch4_window7_224_tf.layers[i+2].layers,\n", 444 | " )" 445 | ] 446 | }, 447 | { 448 | "cell_type": "code", 449 | "execution_count": 14, 450 | "id": "71d471b3", 451 | "metadata": {}, 452 | "outputs": [], 453 | "source": [ 454 | "import requests\n", 455 | "from PIL import Image\n", 456 | "from io import BytesIO\n", 457 | "\n", 458 | "import matplotlib.pyplot as plt" 459 | ] 460 | }, 461 | { 462 | "cell_type": "code", 463 | "execution_count": 15, 464 | "id": "672266c0", 465 | "metadata": {}, 466 | "outputs": [], 467 | "source": [ 468 | "input_resolution = 224\n", 469 | "\n", 470 | "crop_layer = tf.keras.layers.CenterCrop(input_resolution, input_resolution)\n", 471 | "norm_layer = tf.keras.layers.Normalization(\n", 472 | " mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n", 473 | " variance=[(0.229 * 255) ** 2, (0.224 * 255) ** 2, (0.225 * 255) ** 2],\n", 474 | ")\n", 475 | "\n", 476 | "\n", 477 | "def preprocess_image(image, size=input_resolution):\n", 478 | " image = np.array(image)\n", 479 | " image_resized = tf.expand_dims(image, 0)\n", 480 | " resize_size = int((256 / 224) * size)\n", 481 | " image_resized = tf.image.resize(\n", 482 | " image_resized, (resize_size, resize_size), method=\"bicubic\"\n", 483 | " )\n", 484 | " image_resized = crop_layer(image_resized)\n", 485 | " return norm_layer(image_resized).numpy()\n", 486 | "\n", 487 | "\n", 488 | "def load_image_from_url(url):\n", 489 | " # Credit: Willi Gierke\n", 490 | " response = requests.get(url)\n", 491 | " image = Image.open(BytesIO(response.content))\n", 492 | " preprocessed_image = preprocess_image(image)\n", 493 | " return image, preprocessed_image" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "execution_count": 16, 499 | "id": "c3f28a73", 500 | "metadata": {}, 501 | "outputs": [], 502 | "source": [ 503 | "# !wget https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt -O ilsvrc2012_wordnet_lemmas.txt" 504 | ] 505 | }, 506 | { 507 | "cell_type": "code", 508 | "execution_count": 17, 509 | "id": "43627c4d", 510 | "metadata": {}, 511 | "outputs": [], 512 | "source": [ 513 | "with open(\"ilsvrc2012_wordnet_lemmas.txt\", \"r\") as f:\n", 514 | " lines = f.readlines()\n", 515 | "imagenet_int_to_str = [line.rstrip() for line in lines]\n", 516 | "\n", 517 | "img_url = \"https://p0.pikrepo.com/preview/853/907/close-up-photo-of-gray-elephant.jpg\"\n", 518 | "image, preprocessed_image = load_image_from_url(img_url)" 519 | ] 520 | }, 521 | { 522 | "cell_type": "code", 523 | "execution_count": 18, 524 | "id": "7adfa5e2", 525 | "metadata": {}, 526 | "outputs": [], 527 | "source": [ 528 | "predictions = swin_tiny_patch4_window7_224_tf.predict(preprocessed_image)\n", 529 | "logits = predictions[0]\n", 530 | "predicted_label = imagenet_int_to_str[int(np.argmax(logits))]\n", 531 | "expected_label = \"Indian_elephant, Elephas_maximus\"\n", 532 | "assert (\n", 533 | " predicted_label == expected_label\n", 534 | "), f\"Expected {expected_label} but was {predicted_label}\"" 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": 19, 540 | "id": "3cb44e4e", 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "data": { 545 | "text/plain": [ 546 | "dict_keys(['swin_stage_0', 'swin_stage_1', 'swin_stage_2', 'swin_stage_3'])" 547 | ] 548 | }, 549 | "execution_count": 19, 550 | "metadata": {}, 551 | "output_type": "execute_result" 552 | } 553 | ], 554 | "source": [ 555 | "all_attn_scores = swin_tiny_patch4_window7_224_tf.get_attention_scores(\n", 556 | " preprocessed_image\n", 557 | ")\n", 558 | "all_attn_scores.keys()" 559 | ] 560 | }, 561 | { 562 | "cell_type": "code", 563 | "execution_count": 20, 564 | "id": "7a1244ba", 565 | "metadata": {}, 566 | "outputs": [ 567 | { 568 | "data": { 569 | "text/plain": [ 570 | "dict_keys(['swin_block_0', 'swin_block_1'])" 571 | ] 572 | }, 573 | "execution_count": 20, 574 | "metadata": {}, 575 | "output_type": "execute_result" 576 | } 577 | ], 578 | "source": [ 579 | "all_attn_scores[\"swin_stage_3\"].keys()" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": 21, 585 | "id": "a6d1b5c3", 586 | "metadata": {}, 587 | "outputs": [ 588 | { 589 | "data": { 590 | "text/plain": [ 591 | "TensorShape([1, 24, 49, 49])" 592 | ] 593 | }, 594 | "execution_count": 21, 595 | "metadata": {}, 596 | "output_type": "execute_result" 597 | } 598 | ], 599 | "source": [ 600 | "all_attn_scores[\"swin_stage_3\"][\"swin_block_0\"].shape" 601 | ] 602 | }, 603 | { 604 | "cell_type": "code", 605 | "execution_count": 22, 606 | "id": "f69a8d93", 607 | "metadata": {}, 608 | "outputs": [ 609 | { 610 | "name": "stderr", 611 | "output_type": "stream", 612 | "text": [ 613 | "2022-05-08 18:23:42.809960: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n", 614 | "WARNING:absl:Found untraced functions such as layer_normalization_5_layer_call_fn, layer_normalization_5_layer_call_and_return_conditional_losses, dense_4_layer_call_fn, dense_4_layer_call_and_return_conditional_losses, layer_normalization_10_layer_call_fn while saving (showing 5 of 108). These functions will not be directly callable after loading.\n" 615 | ] 616 | }, 617 | { 618 | "name": "stdout", 619 | "output_type": "stream", 620 | "text": [ 621 | "INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets\n" 622 | ] 623 | }, 624 | { 625 | "name": "stderr", 626 | "output_type": "stream", 627 | "text": [ 628 | "INFO:tensorflow:Assets written to: gs://swin-tf/swin_tiny_patch4_window7_224_tf/assets\n" 629 | ] 630 | } 631 | ], 632 | "source": [ 633 | "swin_tiny_patch4_window7_224_tf.save(\"gs://swin-tf/swin_tiny_patch4_window7_224_tf\")" 634 | ] 635 | } 636 | ], 637 | "metadata": { 638 | "kernelspec": { 639 | "display_name": "Python 3 (ipykernel)", 640 | "language": "python", 641 | "name": "python3" 642 | }, 643 | "language_info": { 644 | "codemirror_mode": { 645 | "name": "ipython", 646 | "version": 3 647 | }, 648 | "file_extension": ".py", 649 | "mimetype": "text/x-python", 650 | "name": "python", 651 | "nbconvert_exporter": "python", 652 | "pygments_lexer": "ipython3", 653 | "version": "3.8.2" 654 | } 655 | }, 656 | "nbformat": 4, 657 | "nbformat_minor": 5 658 | } 659 | -------------------------------------------------------------------------------- /swin-transformers-tf/requirements.txt: -------------------------------------------------------------------------------- 1 | # timm needs to be installed from source. 2 | torch==1.11.0 3 | tensorflow==2.8.0 4 | ml_collections==0.1.1 -------------------------------------------------------------------------------- /swin-transformers-tf/swins/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import SwinTransformer 2 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/blocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import mlp_block 2 | from .stage_block import BasicLayer 3 | from .swin_transformer_block import SwinTransformerBlock 4 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/blocks/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | from tensorflow.keras import layers 6 | 7 | 8 | def mlp_block(dropout_rate: float, hidden_units: List[int], name: str = "mlp"): 9 | """FFN for a Transformer block.""" 10 | ffn = keras.Sequential(name=name) 11 | for (idx, units) in enumerate(hidden_units): 12 | ffn.add( 13 | layers.Dense( 14 | units, 15 | activation=tf.nn.gelu if idx == 0 else None, 16 | bias_initializer=keras.initializers.RandomNormal(stddev=1e-6), 17 | ) 18 | ) 19 | ffn.add(layers.Dropout(dropout_rate)) 20 | return ffn 21 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/blocks/stage_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Code copied and modified from 18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 19 | """ 20 | 21 | from functools import partial 22 | from typing import Dict, Union 23 | import tensorflow as tf 24 | from tensorflow import keras 25 | from tensorflow.keras import layers as L 26 | from .swin_transformer_block import SwinTransformerBlock 27 | 28 | 29 | class BasicLayer(keras.Model): 30 | """A basic Swin Transformer layer for one stage. 31 | 32 | Args: 33 | dim (int): Number of input channels. 34 | depth (int): Number of blocks. 35 | num_heads (int): Number of attention heads. 36 | head_dim (int): Channels per head (dim // num_heads if not set) 37 | window_size (int): Local window size. 38 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 39 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 40 | drop (float, optional): Dropout rate. Default: 0.0 41 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 42 | drop_path (float | list[float], optional): Stochastic depth rate. Default: 0.0 43 | norm_layer (layers.Layer, optional): Normalization layer. Default: layers.LayerNormalization 44 | downsample (layers.Layer | None, optional): Downsample layer at the end of the layer. Default: None 45 | upsample (layers.Layer | None, optional): Upsample layer at the end of the layer. Default: None 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dim, 51 | out_dim, 52 | depth, 53 | num_heads=4, 54 | head_dim=None, 55 | window_size=7, 56 | mlp_ratio=4.0, 57 | qkv_bias=True, 58 | drop=0.0, 59 | attn_drop=0.0, 60 | drop_path=0.0, 61 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5), 62 | downsample=None, 63 | upsample=None, 64 | **kwargs, 65 | ): 66 | 67 | super().__init__(kwargs) 68 | self.dim = dim 69 | self.depth = depth 70 | 71 | # build blocks 72 | blocks = [ 73 | SwinTransformerBlock( 74 | dim=dim, 75 | num_heads=num_heads, 76 | head_dim=head_dim, 77 | window_size=window_size, 78 | shift_size=0 if (i % 2 == 0) else window_size // 2, 79 | mlp_ratio=mlp_ratio, 80 | qkv_bias=qkv_bias, 81 | drop=drop, 82 | attn_drop=attn_drop, 83 | drop_path=drop_path[i] 84 | if isinstance(drop_path, list) 85 | else drop_path, 86 | norm_layer=norm_layer, 87 | name=f"swin_transformer_block_{i}", 88 | ) 89 | for i in range(depth) 90 | ] 91 | self.blocks = blocks 92 | # patch merging layer 93 | if downsample is not None: 94 | self.downsample = downsample( 95 | dim=dim, 96 | out_dim=out_dim, 97 | norm_layer=norm_layer, 98 | ) 99 | else: 100 | self.downsample = None 101 | # patch splitting layer 102 | if upsample is not None: 103 | self.upsample = upsample( 104 | dim=dim, 105 | out_dim=out_dim, 106 | norm_layer=norm_layer, 107 | ) 108 | else: 109 | self.upsample = None 110 | def call( 111 | self, x, return_attns=False 112 | ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]: 113 | if return_attns: 114 | attention_scores = {} 115 | 116 | for i, block in enumerate(self.blocks): 117 | if not return_attns: 118 | x = block(x) 119 | else: 120 | x, attns = block(x, return_attns) 121 | attention_scores.update({f"swin_block_{i}": attns}) 122 | if self.downsample is not None: 123 | x = self.downsample(x) 124 | 125 | if self.upsample is not None: 126 | x = self.upsample(x) 127 | 128 | if return_attns: 129 | return x, attention_scores 130 | else: 131 | return x -------------------------------------------------------------------------------- /swin-transformers-tf/swins/blocks/swin_transformer_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Code copied and modified from 18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 19 | """ 20 | 21 | import collections.abc 22 | from functools import partial 23 | from typing import Dict, Union 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | from tensorflow import keras 28 | from tensorflow.keras import layers as L 29 | 30 | from ..layers import StochasticDepth, WindowAttention 31 | from . import utils 32 | from .mlp import mlp_block 33 | 34 | 35 | class SwinTransformerBlock(keras.Model): 36 | """Swin Transformer Block. 37 | 38 | Args: 39 | dim (int): Number of input channels. 40 | window_size (int): Window size. 41 | num_heads (int): Number of attention heads. 42 | head_dim (int): Enforce the number of channels per head 43 | shift_size (int): Shift size for SW-MSA. 44 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 45 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 46 | drop (float, optional): Dropout rate. Default: 0.0 47 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 48 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 49 | norm_layer (layers.Layer, optional): Normalization layer. Default: layers.LayerNormalization 50 | """ 51 | 52 | def __init__( 53 | self, 54 | dim, 55 | num_heads=4, 56 | head_dim=None, 57 | window_size=7, 58 | shift_size=0, 59 | mlp_ratio=4.0, 60 | qkv_bias=True, 61 | drop=0.0, 62 | attn_drop=0.0, 63 | drop_path=0.0, 64 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5), 65 | **kwargs, 66 | ): 67 | super().__init__(**kwargs) 68 | self.dim = dim 69 | self.window_size = window_size 70 | self.shift_size = shift_size 71 | self.mlp_ratio = mlp_ratio 72 | assert ( 73 | 0 <= self.shift_size < self.window_size 74 | ), "shift_size must in 0-window_size" 75 | self.norm1 = norm_layer() 76 | self.attn = WindowAttention( 77 | dim, 78 | num_heads=num_heads, 79 | head_dim=head_dim, 80 | window_size=window_size 81 | if isinstance(window_size, collections.abc.Iterable) 82 | else (window_size, window_size), 83 | qkv_bias=qkv_bias, 84 | attn_drop=attn_drop, 85 | proj_drop=drop, 86 | name="window_attention", 87 | ) 88 | self.drop_path = ( 89 | StochasticDepth(drop_path) if drop_path > 0.0 else tf.identity 90 | ) 91 | self.norm2 = norm_layer() 92 | self.mlp = mlp_block( 93 | dropout_rate=drop, hidden_units=[int(dim * mlp_ratio), dim] 94 | ) 95 | self.attn_mask = None 96 | 97 | def get_img_mask(self): 98 | # calculate image mask for SW-MSA 99 | # since Tensorflow does not support item assignment, we use a 100 | # "hacky" solution. See 101 | # https://github.com/microsoft/Swin-Transformer/blob/e43ac64ce8abfe133ae582741ccaf6761eea05f7/models/swin_transformer.py#L222 102 | # for more information. 103 | 104 | H, W = self.input_resolution 105 | window_size = self.window_size 106 | 107 | mask_0 = tf.zeros((1, H-window_size, W-window_size, 1)) 108 | mask_1 = tf.ones((1, H-window_size, window_size//2, 1)) 109 | mask_2 = tf.ones((1, H-window_size, window_size//2, 1)) 110 | mask_2 = mask_2+1 111 | mask_3 = tf.ones((1, window_size//2, W-window_size, 1)) 112 | mask_3 = mask_3+2 113 | mask_4 = tf.ones((1, window_size//2, window_size//2, 1)) 114 | mask_4 = mask_4+3 115 | mask_5 = tf.ones((1, window_size//2, window_size//2, 1)) 116 | mask_5 = mask_5+4 117 | mask_6 = tf.ones((1, window_size//2, W-window_size, 1)) 118 | mask_6 = mask_6+5 119 | mask_7 = tf.ones((1, window_size//2, window_size//2, 1)) 120 | mask_7 = mask_7+6 121 | mask_8 = tf.ones((1, window_size//2, window_size//2, 1)) 122 | mask_8 = mask_8+7 123 | 124 | mask_012 = tf.concat([mask_0, mask_1, mask_2], axis=2) 125 | mask_345 = tf.concat([mask_3, mask_4, mask_5], axis=2) 126 | mask_678 = tf.concat([mask_6, mask_7, mask_8], axis=2) 127 | 128 | img_mask = tf.concat([mask_012, mask_345, mask_678], axis=1) 129 | return img_mask 130 | 131 | 132 | def get_attn_mask(self): 133 | # calculate attention mask for SW-MSA 134 | mask_windows = utils.window_partition( 135 | self.img_mask, self.window_size 136 | ) # [num_win, window_size, window_size, 1] 137 | mask_windows = tf.reshape( 138 | mask_windows, (-1, self.window_size * self.window_size) 139 | ) 140 | attn_mask = tf.expand_dims(mask_windows, 1) - tf.expand_dims( 141 | mask_windows, 2 142 | ) 143 | attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) 144 | return tf.where(attn_mask == 0, 0.0, attn_mask) 145 | 146 | def call( 147 | self, x, return_attns=False 148 | ) -> Union[tf.Tensor, Dict[str, tf.Tensor]]: 149 | 150 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] 151 | self.input_resolution = (H, W) 152 | 153 | if self.shift_size > 0: 154 | self.img_mask = tf.stop_gradient(self.get_img_mask()) 155 | self.attn_mask = self.get_attn_mask() 156 | 157 | x = tf.reshape(x, (-1, H*W, C)) 158 | 159 | shortcut = x 160 | x = self.norm1(x) 161 | x = tf.reshape(x, (-1, H, W, C)) 162 | 163 | # cyclic shift 164 | if self.shift_size > 0: 165 | shifted_x = tf.roll( 166 | x, shift=(-self.shift_size, -self.shift_size), axis=(1, 2) 167 | ) 168 | else: 169 | shifted_x = x 170 | 171 | # partition windows 172 | x_windows = utils.window_partition( 173 | shifted_x, self.window_size 174 | ) # [num_win*B, window_size, window_size, C] 175 | x_windows = tf.reshape( 176 | x_windows, (-1, self.window_size * self.window_size, C) 177 | ) # [num_win*B, window_size*window_size, C] 178 | 179 | # W-MSA/SW-MSA 180 | if not return_attns: 181 | attn_windows = self.attn( 182 | x_windows, mask=self.attn_mask 183 | ) # [num_win*B, window_size*window_size, C] 184 | else: 185 | attn_windows, attn_scores = self.attn( 186 | x_windows, mask=self.attn_mask, return_attns=True 187 | ) # [num_win*B, window_size*window_size, C] 188 | # merge windows 189 | attn_windows = tf.reshape( 190 | attn_windows, (-1, self.window_size, self.window_size, C) 191 | ) 192 | shifted_x = utils.window_reverse( 193 | attn_windows, self.window_size, H, W 194 | ) # [B, H', W', C] 195 | 196 | # reverse cyclic shift 197 | if self.shift_size > 0: 198 | x = tf.roll( 199 | shifted_x, 200 | shift=(self.shift_size, self.shift_size), 201 | axis=(1, 2), 202 | ) 203 | else: 204 | x = shifted_x 205 | 206 | x = tf.reshape(x, (-1, H * W, C)) 207 | 208 | # FFN 209 | x = shortcut + self.drop_path(x) 210 | x = x + self.drop_path(self.mlp(self.norm2(x))) 211 | 212 | x = tf.reshape(x, (-1, H, W, C)) 213 | 214 | if return_attns: 215 | return x, attn_scores 216 | else: 217 | return x -------------------------------------------------------------------------------- /swin-transformers-tf/swins/blocks/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code copied and modified from 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 4 | """ 5 | 6 | 7 | import tensorflow as tf 8 | 9 | 10 | def window_partition(x: tf.Tensor, window_size: int): 11 | """ 12 | Args: 13 | x: (B, H, W, C) 14 | window_size (int): window size 15 | Returns: 16 | windows: (num_windows*B, window_size, window_size, C) 17 | """ 18 | B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] 19 | 20 | x = tf.reshape( 21 | x, (B, H // window_size, window_size, W // window_size, window_size, C) 22 | ) 23 | windows = tf.transpose(x, [0, 1, 3, 2, 4, 5]) 24 | windows = tf.reshape(windows, (-1, window_size, window_size, C)) 25 | return windows 26 | 27 | 28 | def window_reverse(windows: tf.Tensor, window_size: int, H: int, W: int): 29 | """ 30 | Args: 31 | windows: (num_windows*B, window_size, window_size, C) 32 | window_size (int): Window size 33 | H (int): Height of image 34 | W (int): Width of image 35 | Returns: 36 | x: (B, H, W, C) 37 | """ 38 | B = tf.shape(windows)[0] // tf.cast( 39 | H * W / window_size / window_size, dtype="int32" 40 | ) 41 | 42 | x = tf.reshape( 43 | windows, 44 | ( 45 | B, 46 | H // window_size, 47 | W // window_size, 48 | window_size, 49 | window_size, 50 | -1, 51 | ), 52 | ) 53 | x = tf.transpose(x, [0, 1, 3, 2, 4, 5]) 54 | return tf.reshape(x, (B, H, W, -1)) 55 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .patch_merging import PatchMerging 2 | from .patch_splitting import PatchSplitting, PatchUnpack 3 | from .sd import StochasticDepth 4 | from .window_attn import WindowAttention 5 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/layers/patch_merging.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Code copied and modified from 18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 19 | """ 20 | 21 | from functools import partial 22 | 23 | import tensorflow as tf 24 | from tensorflow.keras import layers as L 25 | 26 | 27 | class PatchMerging(L.Layer): 28 | """Patch Merging Layer. 29 | 30 | Args: 31 | dim (int): Number of input channels. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | dim, 37 | out_dim=None, 38 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5), 39 | **kwargs 40 | ): 41 | super().__init__(**kwargs) 42 | self.dim = dim 43 | self.out_dim = out_dim or 2 * dim 44 | self.norm = norm_layer() 45 | self.reduction = L.Dense(self.out_dim, use_bias=False) 46 | 47 | def call(self, x): 48 | """ 49 | x: B, H, W, C 50 | """ 51 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] 52 | 53 | x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] 54 | x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] 55 | x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] 56 | x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] 57 | x = tf.concat([x0, x1, x2, x3], axis=-1) # [B, H/2, W/2, 4*C] 58 | x = tf.reshape(x, (-1, H//2*W//2, 4 * C)) # [B, H/2*W/2, 4*C] 59 | 60 | x = self.norm(x) 61 | x = self.reduction(x) 62 | x = tf.reshape(x, (-1, H//2, W//2, self.out_dim)) 63 | return x 64 | 65 | def get_config(self): 66 | config = super().get_config() 67 | config.update( 68 | { 69 | "dim": self.dim, 70 | "out_dim": self.out_dim, 71 | "norm": self.norm, 72 | } 73 | ) 74 | return config -------------------------------------------------------------------------------- /swin-transformers-tf/swins/layers/patch_splitting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from functools import partial 17 | 18 | import tensorflow as tf 19 | from tensorflow.keras import layers as L 20 | 21 | 22 | class PatchSplitting(L.Layer): 23 | """Patch Splitting Layer as described in 24 | https://openreview.net/pdf?id=IDwN6xjHnK8 (section 3.1) 25 | # Patch Split = [Linear, LayerNorm, Depth-to-Space (for upsampling)] 26 | Args: 27 | dim (int): Number of input channels. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | dim, 33 | out_dim=None, 34 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5), 35 | **kwargs 36 | ): 37 | super().__init__(**kwargs) 38 | self.dim = dim 39 | self.out_dim = out_dim or dim 40 | self.norm = norm_layer() 41 | self.reduction = L.Dense(self.out_dim * 4, use_bias=False) 42 | 43 | def call(self, x): 44 | """ 45 | x: B, H, W, C 46 | """ 47 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] 48 | x = tf.reshape(x, (-1, H * W, C)) 49 | 50 | x = self.reduction(x) 51 | x = self.norm(x) 52 | 53 | x = tf.reshape(x, (-1, H, W, self.out_dim * 4)) 54 | x = tf.nn.depth_to_space(x, 2, data_format='NHWC') 55 | x = tf.reshape(x, (-1, 2 * H, 2 * W, self.out_dim)) 56 | 57 | return x 58 | 59 | def get_config(self): 60 | config = super().get_config() 61 | config.update( 62 | { 63 | "dim": self.dim, 64 | "out_dim": self.out_dim, 65 | "norm": self.norm, 66 | } 67 | ) 68 | return config 69 | 70 | 71 | class PatchUnpack(L.Layer): 72 | """Patch Unpack Layer 73 | # PatchUnpack = [Linear, Depth-to-Space (for upsampling)] 74 | 75 | Key differences to PatchSplitting: 76 | - no LayerNorm 77 | - use_bias=True (self.reduction) 78 | 79 | Args: 80 | input_resolution (tuple[int]): Resolution of input feature. 81 | dim (int): Number of input channels. 82 | """ 83 | 84 | def __init__( 85 | self, 86 | dim, 87 | out_dim=None, 88 | norm_layer=None, 89 | **kwargs 90 | ): 91 | super().__init__(**kwargs) 92 | self.dim = dim 93 | self.out_dim = out_dim or dim 94 | self.reduction = L.Dense(self.out_dim * 4, use_bias=True) 95 | 96 | def call(self, x): 97 | """ 98 | x: B, H, W, C 99 | """ 100 | H, W, C = tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3] 101 | x = tf.reshape(x, (-1, H * W, C)) 102 | 103 | x = self.reduction(x) 104 | 105 | x = tf.reshape(x, (-1, H, W, self.out_dim * 4)) 106 | x = tf.nn.depth_to_space(x, 2, data_format='NHWC') 107 | x = tf.reshape(x, (-1, 2 * H, 2 * W, self.out_dim)) 108 | 109 | return x 110 | 111 | def get_config(self): 112 | config = super().get_config() 113 | config.update( 114 | { 115 | "dim": self.dim, 116 | "out_dim": self.out_dim, 117 | } 118 | ) 119 | return config 120 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/layers/sd.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import layers 3 | 4 | 5 | # Referred from: github.com:rwightman/pytorch-image-models. 6 | class StochasticDepth(layers.Layer): 7 | def __init__(self, drop_prop, **kwargs): 8 | super().__init__(**kwargs) 9 | self.drop_prob = float(drop_prop) 10 | 11 | def call(self, x, training=False): 12 | if training: 13 | keep_prob = 1 - self.drop_prob 14 | shape = (tf.shape(x)[0],) + (1,) * (tf.shape(tf.shape(x)) - 1) 15 | random_tensor = keep_prob + tf.random.uniform(shape, 0, 1) 16 | random_tensor = tf.floor(random_tensor) 17 | return (x / keep_prob) * random_tensor 18 | return x 19 | 20 | def get_config(self): 21 | config = super().get_config() 22 | config.update({"drop_prob": self.drop_prob}) 23 | return config 24 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/layers/window_attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Nikolai Körber. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Code copied and modified from 18 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 19 | """ 20 | 21 | 22 | import collections.abc 23 | from typing import Tuple, Union 24 | 25 | import tensorflow as tf 26 | from tensorflow.keras import layers 27 | 28 | 29 | def get_relative_position_index(win_h, win_w): 30 | # get pair-wise relative position index for each token inside the window 31 | xx, yy = tf.meshgrid(range(win_h), range(win_w)) 32 | coords = tf.stack([yy, xx], axis=0) # [2, Wh, Ww] 33 | coords_flatten = tf.reshape(coords, [2, -1]) # [2, Wh*Ww] 34 | 35 | relative_coords = ( 36 | coords_flatten[:, :, None] - coords_flatten[:, None, :] 37 | ) # [2, Wh*Ww, Wh*Ww] 38 | relative_coords = tf.transpose( 39 | relative_coords, perm=[1, 2, 0] 40 | ) # [Wh*Ww, Wh*Ww, 2] 41 | 42 | xx = (relative_coords[:, :, 0] + win_h - 1) * (2 * win_w - 1) 43 | yy = relative_coords[:, :, 1] + win_w - 1 44 | relative_coords = tf.stack([xx, yy], axis=-1) 45 | 46 | return tf.reduce_sum(relative_coords, axis=-1) # [Wh*Ww, Wh*Ww] 47 | 48 | 49 | class WindowAttention(layers.Layer): 50 | """Window based multi-head self attention (W-MSA) module with relative position bias. 51 | It supports both of shifted and non-shifted window. 52 | 53 | Args: 54 | dim (int): Number of input channels. 55 | num_heads (int): Number of attention heads. 56 | head_dim (int): Number of channels per head (dim // num_heads if not set) 57 | window_size (tuple[int]): The height and width of the window. 58 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 59 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 60 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 61 | """ 62 | 63 | def __init__( 64 | self, 65 | dim, 66 | num_heads, 67 | head_dim=None, 68 | window_size=7, 69 | qkv_bias=True, 70 | attn_drop=0.0, 71 | proj_drop=0.0, 72 | **kwargs, 73 | ): 74 | 75 | super().__init__(**kwargs) 76 | 77 | self.dim = dim 78 | self.window_size = ( 79 | window_size 80 | if isinstance(window_size, collections.abc.Iterable) 81 | else (window_size, window_size) 82 | ) # Wh, Ww 83 | self.win_h, self.win_w = self.window_size 84 | self.window_area = self.win_h * self.win_w 85 | self.num_heads = num_heads 86 | self.head_dim = head_dim or (dim // num_heads) 87 | self.attn_dim = self.head_dim * num_heads 88 | self.scale = self.head_dim ** -0.5 89 | 90 | # get pair-wise relative position index for each token inside the window 91 | self.relative_position_index = get_relative_position_index( 92 | self.win_h, self.win_w 93 | ) 94 | 95 | self.qkv = layers.Dense( 96 | self.attn_dim * 3, use_bias=qkv_bias, name="attention_qkv" 97 | ) 98 | self.attn_drop = layers.Dropout(attn_drop) 99 | self.proj = layers.Dense(dim, name="attention_projection") 100 | self.proj_drop = layers.Dropout(proj_drop) 101 | 102 | def build(self, input_shape): 103 | self.relative_position_bias_table = self.add_weight( 104 | shape=((2 * self.win_h - 1) * (2 * self.win_w - 1), self.num_heads), 105 | initializer="zeros", 106 | trainable=True, 107 | name="relative_position_bias_table", 108 | ) 109 | super().build(input_shape) 110 | 111 | def _get_rel_pos_bias(self) -> tf.Tensor: 112 | relative_position_bias = tf.gather( 113 | self.relative_position_bias_table, 114 | self.relative_position_index, 115 | axis=0, 116 | ) 117 | return tf.transpose(relative_position_bias, [2, 0, 1]) 118 | 119 | def call( 120 | self, x, mask=None, return_attns=False 121 | ) -> Union[tf.Tensor, Tuple[tf.Tensor, tf.Tensor]]: 122 | """ 123 | Args: 124 | x: input features with shape of (num_windows*B, N, C) 125 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 126 | """ 127 | B_, N, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2] 128 | qkv = self.qkv(x) 129 | qkv = tf.reshape(qkv, (B_, N, 3, self.num_heads, -1)) 130 | qkv = tf.transpose(qkv, (2, 0, 3, 1, 4)) 131 | 132 | q, k, v = tf.unstack(qkv, 3) 133 | 134 | scale = tf.cast(self.scale, dtype=qkv.dtype) 135 | q = q * scale 136 | attn = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2])) 137 | attn = attn + self._get_rel_pos_bias() 138 | 139 | if mask is not None: 140 | num_win = tf.shape(mask)[0] 141 | attn = tf.reshape( 142 | attn, (B_ // num_win, num_win, self.num_heads, N, N) 143 | ) 144 | attn = attn + tf.expand_dims(mask, 1)[None, ...] 145 | 146 | attn = tf.reshape(attn, (-1, self.num_heads, N, N)) 147 | attn = tf.nn.softmax(attn, -1) 148 | else: 149 | attn = tf.nn.softmax(attn, -1) 150 | 151 | attn = self.attn_drop(attn) 152 | 153 | x = tf.matmul(attn, v) 154 | x = tf.transpose(x, perm=[0, 2, 1, 3]) 155 | x = tf.reshape(x, (B_, N, C)) 156 | 157 | x = self.proj(x) 158 | x = self.proj_drop(x) 159 | 160 | if return_attns: 161 | return x, attn 162 | else: 163 | return x 164 | 165 | def get_config(self): 166 | config = super().get_config() 167 | config.update( 168 | { 169 | "dim": self.dim, 170 | "window_size": self.window_size, 171 | "win_h": self.win_h, 172 | "win_w": self.win_w, 173 | "num_heads": self.num_heads, 174 | "head_dim": self.head_dim, 175 | "attn_dim": self.attn_dim, 176 | "scale": self.scale, 177 | } 178 | ) 179 | return config -------------------------------------------------------------------------------- /swin-transformers-tf/swins/model_configs.py: -------------------------------------------------------------------------------- 1 | # Take from here: 2 | # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 3 | 4 | 5 | def swin_base_patch4_window12_384(): 6 | """Swin-B @ 384x384, pretrained ImageNet-22k, fine tune 1k""" 7 | cfg = dict( 8 | img_size=384, 9 | patch_size=4, 10 | window_size=12, 11 | embed_dim=128, 12 | depths=(2, 2, 18, 2), 13 | num_heads=(4, 8, 16, 32), 14 | name="swin_base_patch4_window12_384", 15 | ) 16 | return cfg 17 | 18 | 19 | def swin_base_patch4_window7_224(): 20 | """Swin-B @ 224x224, pretrained ImageNet-22k, fine tune 1k""" 21 | cfg = dict( 22 | patch_size=4, 23 | window_size=7, 24 | embed_dim=128, 25 | depths=(2, 2, 18, 2), 26 | num_heads=(4, 8, 16, 32), 27 | name="swin_base_patch4_window7_224", 28 | ) 29 | return cfg 30 | 31 | 32 | def swin_large_patch4_window12_384(): 33 | """Swin-L @ 384x384, pretrained ImageNet-22k, fine tune 1k""" 34 | cfg = dict( 35 | img_size=384, 36 | patch_size=4, 37 | window_size=12, 38 | embed_dim=192, 39 | depths=(2, 2, 18, 2), 40 | num_heads=(6, 12, 24, 48), 41 | name="swin_large_patch4_window12_384", 42 | ) 43 | return cfg 44 | 45 | 46 | def swin_large_patch4_window7_224(): 47 | """Swin-L @ 224x224, pretrained ImageNet-22k, fine tune 1k""" 48 | cfg = dict( 49 | patch_size=4, 50 | window_size=7, 51 | embed_dim=192, 52 | depths=(2, 2, 18, 2), 53 | num_heads=(6, 12, 24, 48), 54 | name="swin_large_patch4_window7_224", 55 | ) 56 | return cfg 57 | 58 | 59 | def swin_small_patch4_window7_224(): 60 | """Swin-S @ 224x224, trained ImageNet-1k""" 61 | cfg = dict( 62 | patch_size=4, 63 | window_size=7, 64 | embed_dim=96, 65 | depths=(2, 2, 18, 2), 66 | num_heads=(3, 6, 12, 24), 67 | name="swin_small_patch4_window7_224", 68 | ) 69 | return cfg 70 | 71 | 72 | def swin_tiny_patch4_window7_224(): 73 | """Swin-T @ 224x224, trained ImageNet-1k""" 74 | cfg = dict( 75 | patch_size=4, 76 | window_size=7, 77 | embed_dim=96, 78 | depths=(2, 2, 6, 2), 79 | num_heads=(3, 6, 12, 24), 80 | name="swin_tiny_patch4_window7_224", 81 | ) 82 | return cfg 83 | 84 | 85 | def swin_base_patch4_window12_384_in22k(): 86 | """Swin-B @ 384x384, trained ImageNet-22k""" 87 | cfg = dict( 88 | img_size=384, 89 | patch_size=4, 90 | window_size=12, 91 | embed_dim=128, 92 | depths=(2, 2, 18, 2), 93 | num_heads=(4, 8, 16, 32), 94 | name="swin_base_patch4_window12_384_in22k", 95 | num_classes=21841, 96 | ) 97 | return cfg 98 | 99 | 100 | def swin_base_patch4_window7_224_in22k(): 101 | """Swin-B @ 224x224, trained ImageNet-22k""" 102 | cfg = dict( 103 | patch_size=4, 104 | window_size=7, 105 | embed_dim=128, 106 | depths=(2, 2, 18, 2), 107 | num_heads=(4, 8, 16, 32), 108 | name="swin_base_patch4_window7_224_in22k", 109 | num_classes=21841, 110 | ) 111 | return cfg 112 | 113 | 114 | def swin_large_patch4_window12_384_in22k(): 115 | """Swin-L @ 384x384, trained ImageNet-22k""" 116 | cfg = dict( 117 | img_size=384, 118 | patch_size=4, 119 | window_size=12, 120 | embed_dim=192, 121 | depths=(2, 2, 18, 2), 122 | num_heads=(6, 12, 24, 48), 123 | name="swin_large_patch4_window12_384_in22k", 124 | num_classes=21841, 125 | ) 126 | return cfg 127 | 128 | 129 | def swin_large_patch4_window7_224_in22k(): 130 | """Swin-L @ 224x224, trained ImageNet-22k""" 131 | cfg = dict( 132 | patch_size=4, 133 | window_size=7, 134 | embed_dim=192, 135 | depths=(2, 2, 18, 2), 136 | num_heads=(6, 12, 24, 48), 137 | name="swin_large_patch4_window7_224_in22k", 138 | num_classes=21841, 139 | ) 140 | return cfg 141 | 142 | 143 | def swin_s3_tiny_224(): 144 | """Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725""" 145 | cfg = dict( 146 | patch_size=4, 147 | window_size=(7, 7, 14, 7), 148 | embed_dim=96, 149 | depths=(2, 2, 6, 2), 150 | num_heads=(3, 6, 12, 24), 151 | name="swin_s3_tiny_224", 152 | ) 153 | return cfg 154 | 155 | 156 | def swin_s3_small_224(): 157 | """Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725""" 158 | cfg = dict( 159 | patch_size=4, 160 | window_size=(14, 14, 14, 7), 161 | embed_dim=96, 162 | depths=(2, 2, 18, 2), 163 | num_heads=(3, 6, 12, 24), 164 | name="swin_s3_small_224", 165 | ) 166 | return cfg 167 | 168 | 169 | def swin_s3_base_224(): 170 | """Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725""" 171 | cfg = dict( 172 | patch_size=4, 173 | window_size=(7, 7, 14, 7), 174 | embed_dim=96, 175 | depths=(2, 2, 30, 2), 176 | num_heads=(3, 6, 12, 24), 177 | name="swin_s3_base_224", 178 | ) 179 | return cfg 180 | 181 | 182 | MODEL_MAP = { 183 | "swin_base_patch4_window12_384": swin_base_patch4_window12_384, 184 | "swin_base_patch4_window7_224": swin_base_patch4_window7_224, 185 | "swin_large_patch4_window12_384": swin_large_patch4_window12_384, 186 | "swin_large_patch4_window7_224": swin_large_patch4_window7_224, 187 | "swin_small_patch4_window7_224": swin_small_patch4_window7_224, 188 | "swin_tiny_patch4_window7_224": swin_tiny_patch4_window7_224, 189 | "swin_base_patch4_window12_384_in22k": swin_base_patch4_window12_384_in22k, 190 | "swin_base_patch4_window7_224_in22k": swin_base_patch4_window7_224_in22k, 191 | "swin_large_patch4_window12_384_in22k": swin_large_patch4_window12_384_in22k, 192 | "swin_large_patch4_window7_224_in22k": swin_large_patch4_window7_224_in22k, 193 | "swin_s3_tiny_224": swin_s3_tiny_224, 194 | "swin_s3_small_224": swin_s3_small_224, 195 | "swin_s3_base_224": swin_s3_base_224, 196 | } 197 | -------------------------------------------------------------------------------- /swin-transformers-tf/swins/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code copied and modified from 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/swin_transformer.py 4 | """ 5 | 6 | import collections.abc 7 | from functools import partial 8 | from itertools import repeat 9 | from typing import Dict 10 | 11 | import tensorflow as tf 12 | from tensorflow import keras 13 | from tensorflow.keras import layers as L 14 | 15 | from .blocks import BasicLayer 16 | from .layers import PatchMerging 17 | 18 | 19 | # https://github.com/rwightman/pytorch-image-models/blob/6d4665bb52390974e0cf9674c60c41946d2f4ee2/timm/models/layers/helpers.py#L10 20 | def to_ntuple(n): 21 | def parse(x): 22 | if isinstance(x, collections.abc.Iterable): 23 | return x 24 | return tuple(repeat(x, n)) 25 | 26 | return parse 27 | 28 | 29 | class SwinTransformer(keras.Model): 30 | """Swin Transformer 31 | A TensorFlow impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 32 | https://arxiv.org/pdf/2103.14030 33 | 34 | Args: 35 | img_size (int | tuple(int)): Input image size. Default 224 36 | patch_size (int | tuple(int)): Patch size. Default: 4 37 | num_classes (int): Number of classes for classification head. Default: 1000 38 | embed_dim (int): Patch embedding dimension. Default: 96 39 | depths (tuple(int)): Depth of each Swin Transformer layer. 40 | num_heads (tuple(int)): Number of attention heads in different layers. 41 | head_dim (int, tuple(int)): 42 | window_size (int): Window size. Default: 7 43 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 44 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 45 | drop_rate (float): Dropout rate. Default: 0 46 | attn_drop_rate (float): Attention dropout rate. Default: 0 47 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 48 | norm_layer (layers.Layer): Normalization layer. Default: layers.LayerNormalization. 49 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 50 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 51 | pre_logits (bool): If True, return model without classification head. Default: False 52 | """ 53 | 54 | def __init__( 55 | self, 56 | img_size=224, 57 | patch_size=4, 58 | num_classes=1000, 59 | global_pool="avg", 60 | embed_dim=96, 61 | depths=(2, 2, 6, 2), 62 | num_heads=(3, 6, 12, 24), 63 | head_dim=None, 64 | window_size=7, 65 | mlp_ratio=4.0, 66 | qkv_bias=True, 67 | drop_rate=0.0, 68 | attn_drop_rate=0.0, 69 | drop_path_rate=0.1, 70 | norm_layer=partial(L.LayerNormalization, epsilon=1e-5), 71 | ape=False, 72 | patch_norm=True, 73 | pre_logits=False, 74 | **kwargs, 75 | ): 76 | super().__init__(**kwargs) 77 | 78 | self.img_size = ( 79 | img_size 80 | if isinstance(img_size, collections.abc.Iterable) 81 | else (img_size, img_size) 82 | ) 83 | self.patch_size = ( 84 | patch_size 85 | if isinstance(patch_size, collections.abc.Iterable) 86 | else (patch_size, patch_size) 87 | ) 88 | 89 | self.num_classes = num_classes 90 | self.global_pool = global_pool 91 | self.num_layers = len(depths) 92 | self.embed_dim = embed_dim 93 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 94 | self.ape = ape 95 | 96 | # split image into non-overlapping patches 97 | self.projection = keras.Sequential( 98 | [ 99 | L.Conv2D( 100 | filters=embed_dim, 101 | kernel_size=(patch_size, patch_size), 102 | strides=(patch_size, patch_size), 103 | padding="VALID", 104 | name="conv_projection", 105 | kernel_initializer="lecun_normal", 106 | ), 107 | L.Reshape( 108 | target_shape=(-1, embed_dim), 109 | name="flatten_projection", 110 | ), 111 | ], 112 | name="projection", 113 | ) 114 | if patch_norm: 115 | self.projection.add(norm_layer()) 116 | 117 | self.patch_grid = ( 118 | self.img_size[0] // self.patch_size[0], 119 | self.img_size[1] // self.patch_size[1], 120 | ) 121 | self.num_patches = self.patch_grid[0] * self.patch_grid[1] 122 | 123 | # absolute position embedding 124 | if self.ape: 125 | self.absolute_pos_embed = tf.Variable( 126 | tf.zeros((1, self.num_patches, self.embed_dim)), 127 | trainable=True, 128 | name="absolute_pos_embed", 129 | ) 130 | else: 131 | self.absolute_pos_embed = None 132 | self.pos_drop = L.Dropout(drop_rate) 133 | 134 | # build layers 135 | if not isinstance(self.embed_dim, (tuple, list)): 136 | self.embed_dim = [ 137 | int(self.embed_dim * 2 ** i) for i in range(self.num_layers) 138 | ] 139 | embed_out_dim = self.embed_dim[1:] + [None] 140 | head_dim = to_ntuple(self.num_layers)(head_dim) 141 | window_size = to_ntuple(self.num_layers)(window_size) 142 | mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) 143 | dpr = [ 144 | float(x) for x in tf.linspace(0.0, drop_path_rate, sum(depths)) 145 | ] # stochastic depth decay rule 146 | 147 | layers = [ 148 | BasicLayer( 149 | dim=self.embed_dim[i], 150 | out_dim=embed_out_dim[i], 151 | input_resolution=( 152 | self.patch_grid[0] // (2 ** i), 153 | self.patch_grid[1] // (2 ** i), 154 | ), 155 | depth=depths[i], 156 | num_heads=num_heads[i], 157 | head_dim=head_dim[i], 158 | window_size=window_size[i], 159 | mlp_ratio=mlp_ratio[i], 160 | qkv_bias=qkv_bias, 161 | drop=drop_rate, 162 | attn_drop=attn_drop_rate, 163 | drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])], 164 | norm_layer=norm_layer, 165 | downsample=PatchMerging if (i < self.num_layers - 1) else None, 166 | name=f"basic_layer_{i}", 167 | ) 168 | for i in range(self.num_layers) 169 | ] 170 | self.swin_layers = layers 171 | 172 | self.norm = norm_layer() 173 | 174 | self.pre_logits = pre_logits 175 | if not self.pre_logits: 176 | self.head = L.Dense(num_classes, name="classification_head") 177 | 178 | def forward_features(self, x): 179 | x = self.projection(x) 180 | if self.absolute_pos_embed is not None: 181 | x = x + self.absolute_pos_embed 182 | x = self.pos_drop(x) 183 | 184 | for swin_layer in self.swin_layers: 185 | x = swin_layer(x) 186 | 187 | x = self.norm(x) # [B, L, C] 188 | return x 189 | 190 | def forward_head(self, x): 191 | if self.global_pool == "avg": 192 | x = tf.reduce_mean(x, axis=1) 193 | return x if self.pre_logits else self.head(x) 194 | 195 | def call(self, x): 196 | x = self.forward_features(x) 197 | x = self.forward_head(x) 198 | return x 199 | 200 | # Thanks to Willi Gierke for this suggestion. 201 | @tf.function( 202 | input_signature=[tf.TensorSpec([None, None, None, 3], tf.float32)] 203 | ) 204 | def get_attention_scores( 205 | self, x: tf.Tensor 206 | ) -> Dict[str, Dict[str, tf.Tensor]]: 207 | all_attention_scores = {} 208 | 209 | x = self.projection(x) 210 | if self.absolute_pos_embed is not None: 211 | x = x + self.absolute_pos_embed 212 | x = self.pos_drop(x) 213 | 214 | for i, swin_layer in enumerate(self.swin_layers): 215 | x, attention_scores = swin_layer(x, return_attns=True) 216 | all_attention_scores.update({f"swin_stage_{i}": attention_scores}) 217 | 218 | return all_attention_scores 219 | -------------------------------------------------------------------------------- /swin-transformers-tf/test.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from swins import SwinTransformer 4 | 5 | cfg = dict( 6 | patch_size=4, 7 | window_size=7, 8 | embed_dim=128, 9 | depths=(2, 2, 18, 2), 10 | num_heads=(4, 8, 16, 32), 11 | ) 12 | 13 | swin_base_patch4_window7_224 = SwinTransformer( 14 | name="swin_base_patch4_window7_224", **cfg 15 | ) 16 | print("Model instantiated, attempting predictions...") 17 | random_tensor = tf.random.normal((2, 224, 224, 3)) 18 | outputs = swin_base_patch4_window7_224(random_tensor, training=False) 19 | 20 | print(outputs.shape) 21 | 22 | print(swin_base_patch4_window7_224.count_params() / 1e6) 23 | -------------------------------------------------------------------------------- /swin-transformers-tf/utils/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def conv_transpose(w: np.ndarray) -> np.ndarray: 8 | """Transpose the weights of a PT conv layer so that it's comaptible with TF.""" 9 | return w.transpose(2, 3, 1, 0) 10 | 11 | 12 | def modify_tf_block( 13 | tf_component: Union[tf.keras.layers.Layer, tf.Variable, tf.Tensor], 14 | pt_weight: np.ndarray, 15 | pt_bias: np.ndarray = None, 16 | is_attn: bool = False, 17 | ) -> Union[tf.keras.layers.Layer, tf.Variable, tf.Tensor]: 18 | """General utility for modifying PT parameters for TF compatibility. 19 | Applicable for Conv2D, Dense, tf.Variable, and LayerNormalization. 20 | """ 21 | pt_weight = ( 22 | conv_transpose(pt_weight) 23 | if isinstance(tf_component, tf.keras.layers.Conv2D) 24 | else pt_weight 25 | ) 26 | pt_weight = ( 27 | pt_weight.transpose() 28 | if isinstance(tf_component, tf.keras.layers.Dense) and not is_attn 29 | else pt_weight 30 | ) 31 | 32 | if isinstance( 33 | tf_component, (tf.keras.layers.Dense, tf.keras.layers.Conv2D) 34 | ): 35 | tf_component.kernel.assign(tf.Variable(pt_weight)) 36 | if pt_bias is not None: 37 | tf_component.bias.assign(tf.Variable(pt_bias)) 38 | elif isinstance(tf_component, tf.keras.layers.LayerNormalization): 39 | tf_component.gamma.assign(tf.Variable(pt_weight)) 40 | tf_component.beta.assign(tf.Variable(pt_bias)) 41 | elif isinstance(tf_component, (tf.Variable)): 42 | # For regular variables (tf.Variable). 43 | tf_component.assign(tf.Variable(pt_weight)) 44 | else: 45 | return tf.convert_to_tensor(pt_weight) 46 | 47 | return tf_component 48 | --------------------------------------------------------------------------------