├── .gitignore ├── LICENSE ├── README.md ├── docker ├── pt_git.Dockerfile ├── pt_pip.Dockerfile ├── tf_git.Dockerfile ├── tf_git_ampere.Dockerfile └── tf_pip.Dockerfile ├── export_pytorch.py ├── jeffnet ├── __init__.py ├── common │ ├── __init__.py │ ├── arch_defs.py │ ├── block_utils.py │ ├── builder.py │ ├── constants.py │ ├── io.py │ ├── loss.py │ ├── lr_schedule.py │ ├── metrics.py │ ├── model_cfgs.py │ ├── model_zoo.py │ ├── optim │ │ ├── __init__.py │ │ ├── helpers.py │ │ ├── lars.py │ │ └── optim_factory.py │ └── padding.py ├── data │ ├── tf_autoaugment.py │ ├── tf_image_ops.py │ ├── tf_imagenet_data.py │ ├── tf_input_pipeline.py │ └── tf_simclr_aug.py ├── linen │ ├── __init__.py │ ├── blocks_linen.py │ ├── efficientnet_linen.py │ ├── ema_state.py │ ├── helpers.py │ └── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── linear.py │ │ ├── mixed_conv.py │ │ ├── normalization.py │ │ └── stochastic.py ├── objax │ ├── __init__.py │ ├── blocks_objax.py │ ├── efficientnet_objax.py │ ├── helpers.py │ └── layers │ │ ├── __init__.py │ │ ├── activations.py │ │ ├── drop_path.py │ │ ├── linear.py │ │ ├── mixed_conv.py │ │ └── normalization.py └── utils │ └── to_tuple.py ├── pt_linen_validate.py ├── pt_objax_validate.py ├── requirements.txt ├── tf_linen_train.py ├── tf_linen_validate.py ├── tf_objax_validate.py └── train_configs ├── default.py ├── pt_efficientnet_b3-tpu_x8.py └── tf_efficientnet_b0-gpu_24gb_x2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Ross Wightman 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientNet JAX - Flax Linen and Objax 2 | 3 | ## Acknowledgements 4 | 5 | Verification of training code was made possible with Cloud TPUs via Google's TPU Research Cloud (TRC) (https://www.tensorflow.org/tfrc) 6 | 7 | ## Intro 8 | This is very much a giant steaming work in progress. Jax, jaxlib, and the NN libraries I'm using are shifting week to week. 9 | 10 | This code base currently supports: 11 | * Flax Linen (https://github.com/google/flax/tree/master/flax/linen) -- for models, validation w/ pretrained weights, and training from scratch 12 | * Objax (https://github.com/google/objax) -- for model and model validation with pretrained weights 13 | 14 | This is essentially an adaptation of my PyTorch EfficienNet generator code (https://github.com/rwightman/gen-efficientnet-pytorch and also found in https://github.com/rwightman/pytorch-image-models) to JAX. 15 | 16 | I started this to 17 | * learn JAX by working with familiar code / models as a starting point, 18 | * figure out which JAX modelling interface libraries ('frameworks') I liked, 19 | * compare the training / inference runtime traits of non-trivial models across combinations of PyTorch, JAX, GPU and TPU in order to drive cost optimizations for scaling up of future projects 20 | 21 | Where are we at: 22 | * Training works on single node, multi-GPU and TPU v3-8 for Flax Linen variants w/ Tensorflow Datasets based pipeline 23 | * The Objax and Flax Linen (nn.compact) variants of models are working (for inference) 24 | * Weights are ported from PyTorch (my timm training) and Tensorflow (original paper author releases) and are organized in zoo of sorts (borrowed PyTorch code) 25 | * Tensorflow and PyTorch data pipeline based validation scripts work with models and weights. For PT pipeline with PT models and TF pipeline with TF models the results are pretty much exact. 26 | 27 | TODO: 28 | - [x] Fix model weight inits (working for Flax Linen variants) 29 | - [x] Fix dropout/drop path impl and other training specifics (verified for Flax Linen variants) 30 | - [ ] Add more instructions / help in the README on how to get an optimal environment with JAX up and running (with GPU support) 31 | - [x] Add basic training code. The main point of this is to scale up training. 32 | - [ ] Add more advance data augmentation pipeline 33 | - [ ] Training on lots of GPUs 34 | - [ ] Training on lots of TPUs 35 | 36 | Some odd things: 37 | * Objax layers are reimplemented to make my initial work easier, scratch some itches, make more consistent with PyTorch (because why not?) 38 | * Flax Linen layers are by default fairly consistent with Tensorflow (left as is) 39 | * I use wrappers around Flax Linen layers for some argument consistency and reduced visual noise (no redundant tuples) 40 | * I made a 'LIKE' padding mode, sort of like 'SAME' but different, hence the name. It calculates symmetric padding for PyTorch models. 41 | * Models with Tensorflow 'SAME' padding and TF origin weights are prefixed with `tf_`. Models with PyTorch trained weights and symmetric PyTorch style padding ('LIKE' here) are prefixed with `pt_` 42 | * I use `pt` and `tf` to refer to PyTorch and Tensorflow for both the models and environments. These two do not need to be used together. `pt` models with 'LIKE' padding will work fine running in a Tensorflow based environment and vice versa. I did this to show the full flexibility here, that one can use JAX models with PyTorch data pipelines and datasets or with Tensorflow based data pipelines and TFDS. 43 | 44 | ## Models 45 | 46 | Supported models and their paper's 47 | * EfficientNet NoisyStudent (B0-B7, L2) - https://arxiv.org/abs/1911.04252 48 | * EfficientNet AdvProp (B0-B8) - https://arxiv.org/abs/1911.09665 49 | * EfficientNet (B0-B7) - https://arxiv.org/abs/1905.11946 50 | * EfficientNet-EdgeTPU (S, M, L) - https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html 51 | * MixNet - https://arxiv.org/abs/1907.09595 52 | * MobileNet-V3 - https://arxiv.org/abs/1905.02244 53 | * MobileNet-V2 - https://arxiv.org/abs/1801.04381 54 | * MNASNet B1, A1 (Squeeze-Excite), and Small - https://arxiv.org/abs/1807.11626 55 | * Single-Path NAS - https://arxiv.org/abs/1904.02877 56 | * FBNet-C - https://arxiv.org/abs/1812.03443 57 | 58 | Models by their config name w/ valid pretrained weights that should be working here: 59 | ``` 60 | pt_mnasnet_100 61 | pt_semnasnet_100 62 | pt_mobilenetv2_100 63 | pt_mobilenetv2_110d 64 | pt_mobilenetv2_120d 65 | pt_mobilenetv2_140 66 | pt_fbnetc_100 67 | pt_spnasnet_100 68 | pt_efficientnet_b0 69 | pt_efficientnet_b1 70 | pt_efficientnet_b2 71 | pt_efficientnet_b3 72 | tf_efficientnet_b0 73 | tf_efficientnet_b1 74 | tf_efficientnet_b2 75 | tf_efficientnet_b3 76 | tf_efficientnet_b4 77 | tf_efficientnet_b5 78 | tf_efficientnet_b6 79 | tf_efficientnet_b7 80 | tf_efficientnet_b8 81 | tf_efficientnet_b0_ap 82 | tf_efficientnet_b1_ap 83 | tf_efficientnet_b2_ap 84 | tf_efficientnet_b3_ap 85 | tf_efficientnet_b4_ap 86 | tf_efficientnet_b5_ap 87 | tf_efficientnet_b6_ap 88 | tf_efficientnet_b7_ap 89 | tf_efficientnet_b8_ap 90 | tf_efficientnet_b0_ns 91 | tf_efficientnet_b1_ns 92 | tf_efficientnet_b2_ns 93 | tf_efficientnet_b3_ns 94 | tf_efficientnet_b4_ns 95 | tf_efficientnet_b5_ns 96 | tf_efficientnet_b6_ns 97 | tf_efficientnet_b7_ns 98 | tf_efficientnet_l2_ns_475 99 | tf_efficientnet_l2_ns 100 | pt_efficientnet_es 101 | pt_efficientnet_em 102 | tf_efficientnet_es 103 | tf_efficientnet_em 104 | tf_efficientnet_el 105 | pt_efficientnet_lite0 106 | tf_efficientnet_lite0 107 | tf_efficientnet_lite1 108 | tf_efficientnet_lite2 109 | tf_efficientnet_lite3 110 | tf_efficientnet_lite4 111 | pt_mixnet_s 112 | pt_mixnet_m 113 | pt_mixnet_l 114 | pt_mixnet_xl 115 | tf_mixnet_s 116 | tf_mixnet_m 117 | tf_mixnet_l 118 | pt_mobilenetv3_large_100 119 | tf_mobilenetv3_large_075 120 | tf_mobilenetv3_large_100 121 | tf_mobilenetv3_large_minimal_100 122 | tf_mobilenetv3_small_075 123 | tf_mobilenetv3_small_100 124 | tf_mobilenetv3_small_minimal_100 125 | ``` 126 | 127 | ## Environment 128 | 129 | Working with JAX I've found the best approach for having a working GPU compatible environment that performs well is to use Docker containers based on the latest NVIDIA NGC releases. I've found it challenging or flaky getting local conda/pip venvs or Tensorflow docker containers working well with good GPU performance, proper NCCL distributed support, etc. I use CPU JAX install in conda env for dev/debugging. 130 | 131 | ### Dockerfiles 132 | 133 | There are several container definitions in `docker/`. They use NGC containers as their parent image so you'll need to be setup to pull NGC containers: https://www.nvidia.com/en-us/gpu-cloud/containers/ . I'm currently using recent NGC containers w/ CUDA 11.1 support, the host system will need a very recent NVIDIA driver to support this but doesn't need a matching CUDA 11.1 / cuDNN 8 install. 134 | 135 | Current dockerfiles: 136 | * `pt_git.Dockerfile` - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax. 137 | * `pt_pip.Dockerfile` - PyTorch 20.12 NGC as parent, CUDA 11.1, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax. 138 | * `tf_git.Dockerfile` - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. git (source install) of jaxlib, jax, objax, and flax. 139 | * `tf_pip.Dockerfile` - Tensorflow 2 21.02 NGC as parent, CUDA 11.2, cuDNN 8. pip (latest ver) install of jaxlib, jax, objax, and flax. 140 | 141 | The 'git' containers take some time to build jaxlib, they pull the masters of all respective repos so are up to the bleeding edge but more likely to have possible regression or incompatibilities that go with that. The pip install containers are quite a bit quicker to get up and running, based on the latest pip versions of all repos. 142 | 143 | ### Docker Usage (GPU) 144 | 145 | 1. Make sure you have a recent version of docker and the NVIDIA Container Toolkit setup (https://github.com/NVIDIA/nvidia-docker) 146 | 2. Build the container `docker build -f docker/tf_pip.Dockerfile -t jax_tf_pip .` 147 | 3. Run the container, ideally map jeffnet and datasets (ImageNet) into the container 148 | * For tf containers, `docker run --gpus all -it -v /path/to/tfds/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_tf_pip` 149 | * For pt containers, `docker run --gpus all -it -v /path/to/imagenet/root:/data/ -v /path/to/efficientnet-jax/:/workspace/jeffnet --rm --ipc=host jax_pt_pip` 150 | 4. Model validation w/ pretrained weights (once inside running container): 151 | * For tf, in `worskpace/jeffnet`, `python tf_linen_validate.py /data/ --model tf_efficientnet_b0_ns` 152 | * For pt, in `worskpace/jeffnet`, `python pt_objax_validate.py /data/validation --model pt_efficientnet_b0` 153 | 5. Training (within container) 154 | * In `worskpace/jeffnet`, `tf_linen_train.py --config train_configs/tf_efficientnet_b0-gpu_24gb_x2.py --config.data_dir /data` 155 | 156 | ### TPU 157 | 158 | I've successfully used this codebase on TPU VM environments as is. Any of the `tpu_x8` training configs should work out of the box on a v3-8 TPU. I have not tackled training with TPU Pods. 159 | -------------------------------------------------------------------------------- /docker/pt_git.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.12-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN pip install --upgrade pip 6 | 7 | RUN git clone https://github.com/google/jax &&\ 8 | cd jax &&\ 9 | python build/build.py --enable_cuda &&\ 10 | pip install dist/*.whl &&\ 11 | pip install -e . &&\ 12 | rm -rf /root/.cache/bazel && \ 13 | cd .. 14 | 15 | RUN git clone https://github.com/google/flax &&\ 16 | cd flax &&\ 17 | pip install -e . &&\ 18 | cd .. 19 | 20 | RUN git clone https://github.com/google/objax &&\ 21 | cd objax &&\ 22 | pip install -e . &&\ 23 | cd .. 24 | 25 | # install timm for PyTorch data pipeline / helpers that I'm familiar with, reinstall SIMD Pillow since 26 | # it never stays installed due to dep issues 27 | RUN pip install timm &&\ 28 | pip uninstall -y pillow &&\ 29 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd -------------------------------------------------------------------------------- /docker/pt_pip.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:20.12-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN pip install --upgrade pip 6 | 7 | ENV CUDA_VERSION=11.1 8 | 9 | RUN pip install jaxlib && \ 10 | pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g` && \ 11 | pip install jax 12 | 13 | RUN pip install objax 14 | 15 | RUN pip install flax 16 | 17 | # install timm for PyTorch data pipeline / helpers that I'm familiar with, reinstall SIMD Pillow since 18 | # it never stays installed due to dep issues 19 | RUN pip install timm &&\ 20 | pip uninstall -y pillow &&\ 21 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 22 | -------------------------------------------------------------------------------- /docker/tf_git.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorflow:21.02-tf2-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN pip install --upgrade pip 6 | 7 | # Default in build.py script 8 | # ENV CUDA_COMPUTE="3.5,5.2,6.0,6.1,7.0" 9 | 10 | # Pascal, Volta, Turing 11 | ENV CUDA_COMPUTE="6.1,7.0,7.5" 12 | 13 | # If you're lucky and you know it... Ampere (A100, RTX 3000) 14 | # ENV CUDA_COMPUTE="8.0,8.6" 15 | 16 | RUN git clone https://github.com/google/jax &&\ 17 | cd jax &&\ 18 | python build/build.py --enable_cuda --cuda_compute_capabilities=$CUDA_COMPUTE &&\ 19 | pip install dist/*.whl &&\ 20 | pip install -e . &&\ 21 | rm -rf /root/.cache/bazel &&\ 22 | cd .. 23 | 24 | RUN git clone https://github.com/google/flax &&\ 25 | cd flax &&\ 26 | pip install -e . &&\ 27 | cd .. 28 | 29 | RUN git clone https://github.com/google/objax &&\ 30 | cd objax &&\ 31 | pip install -e . &&\ 32 | cd .. 33 | 34 | RUN pip install ml_collections 35 | 36 | 37 | -------------------------------------------------------------------------------- /docker/tf_git_ampere.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorflow:21.02-tf2-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN pip install --upgrade pip 6 | 7 | # If you're lucky and you know it... Ampere (A100, RTX 3000) 8 | ENV CUDA_COMPUTE="8.0,8.6" 9 | 10 | RUN git clone https://github.com/google/jax &&\ 11 | cd jax &&\ 12 | python build/build.py --enable_cuda --cuda_compute_capabilities=$CUDA_COMPUTE &&\ 13 | pip install dist/*.whl &&\ 14 | pip install -e . &&\ 15 | rm -rf /root/.cache/bazel &&\ 16 | cd .. 17 | 18 | RUN git clone https://github.com/google/flax &&\ 19 | cd flax &&\ 20 | pip install -e . &&\ 21 | cd .. 22 | 23 | RUN git clone https://github.com/google/objax &&\ 24 | cd objax &&\ 25 | pip install -e . &&\ 26 | cd .. 27 | 28 | RUN pip install ml_collections 29 | 30 | 31 | -------------------------------------------------------------------------------- /docker/tf_pip.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tensorflow:21.02-tf2-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN pip install --upgrade pip 6 | 7 | ENV CUDA_VERSION=11.2 8 | 9 | RUN pip install jaxlib && \ 10 | pip install --upgrade -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g` && \ 11 | pip install jax 12 | 13 | RUN pip install objax flax ml_collections 14 | -------------------------------------------------------------------------------- /export_pytorch.py: -------------------------------------------------------------------------------- 1 | """ timm -> generic npz export script 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import os 5 | import time 6 | import argparse 7 | import shutil 8 | import tempfile 9 | import hashlib 10 | 11 | import numpy as np 12 | import torch 13 | import timm 14 | from jeffnet.common import list_models 15 | 16 | 17 | parser = argparse.ArgumentParser(description='') 18 | parser.add_argument('--model', '-m', metavar='MODEL', default='efficientnet_b0', 19 | help='model architecture (default: efficientnet_b0)') 20 | parser.add_argument('--output', '-o', metavar='DIR', default=None, 21 | help='') 22 | 23 | 24 | def remap_module(module_type, k, v): 25 | # remappings for layers I've renamed since PyTorch impl 26 | # objax/flax specific naming reqs are handled on load 27 | if module_type == "ConvBnAct": 28 | k = k.replace('bn1.', 'bn.') 29 | elif module_type == "InvertedResidual": 30 | k = k.replace('conv_pw.', 'conv_exp.') 31 | k = k.replace('bn1.', 'bn_exp.') 32 | k = k.replace('bn2.', 'bn_dw.') 33 | k = k.replace('bn3.', 'bn_pwl.') 34 | elif module_type == "EdgeResidual": 35 | k = k.replace('bn1.', 'bn_exp.') 36 | k = k.replace('bn2.', 'bn_pwl.') 37 | elif module_type == 'DepthwiseSeparableConv': 38 | k = k.replace('bn1.', 'bn_dw.') 39 | k = k.replace('bn2.', 'bn_pw.') 40 | elif module_type == 'SqueezeExcite': 41 | k = k.replace('conv_reduce.', 'reduce.') 42 | k = k.replace('conv_expand.', 'expand.') 43 | elif module_type == 'EdgeResidual': 44 | k = k.replace('bn1.', 'bn_exp.') 45 | k = k.replace('bn3.', 'bn_pwl.') 46 | elif module_type == 'EfficientNet': 47 | k = k.replace('conv_stem.', 'stem.conv.') 48 | k = k.replace('bn1.', 'stem.bn.') 49 | k = k.replace('conv_head.', 'head.conv_pw.') 50 | k = k.replace('bn2.', 'head.bn.') 51 | k = k.replace('classifier.', 'head.classifier.') 52 | elif module_type == "MobileNetV3": 53 | k = k.replace('conv_stem.', 'stem.conv.') 54 | k = k.replace('bn1.', 'stem.bn.') 55 | k = k.replace('conv_head.', 'head.conv_pw.') 56 | k = k.replace('bn2.', 'head.bn.') 57 | k = k.replace('classifier.', 'head.classifier.') 58 | return k, v 59 | 60 | 61 | def export_model(model_name, output_dir=''): 62 | timm_model_name = model_name.replace('pt_', '') 63 | m = timm.create_model(timm_model_name, pretrained=True) 64 | d = dict(m.named_modules()) 65 | 66 | data = {} 67 | names = [] 68 | types = [] 69 | for k, v in m.state_dict().items(): 70 | if 'num_batches' in k: 71 | continue 72 | 73 | k_split = k.split('.') 74 | layer_name = '.'.join(k_split[:-1]) 75 | parent_name = '.'.join(k_split[:-2]) 76 | parent_module = d[parent_name] 77 | parent_type = type(parent_module).__name__ 78 | if 'MixedConv' in parent_type: 79 | # need to step back another level in hierarchy to get parent block 80 | parent_name = '.'.join(k_split[:-3]) 81 | parent_module = d[parent_name] 82 | parent_type = type(parent_module).__name__ 83 | k, v = remap_module(parent_type, k, v) 84 | 85 | type_str = '' 86 | if layer_name in d: 87 | type_str = type(d[layer_name]).__name__ 88 | if type_str == 'Conv2dSame': 89 | type_str = 'Conv2d' 90 | types.append(type_str) 91 | 92 | print(k, type_str, v.shape) 93 | data[str(len(data))] = v.numpy() 94 | names.append(k) 95 | 96 | # write as npz 97 | tempf = tempfile.NamedTemporaryFile(delete=False, dir='./') 98 | np.savez(tempf, names=np.array(names), types=types, **data) 99 | tempf.close() 100 | 101 | # verify by reading and hashing 102 | with open(tempf.name, 'rb') as f: 103 | sha_hash = hashlib.sha256(f.read()).hexdigest() 104 | 105 | # move to proper name / location 106 | if output_dir: 107 | assert os.path.isdir(output_dir) 108 | else: 109 | output_dir = './' 110 | final_filename = '-'.join([model_name, sha_hash[:8]]) + '.npz' 111 | shutil.move(tempf.name, os.path.join(output_dir, final_filename)) 112 | 113 | 114 | def main(): 115 | args = parser.parse_args() 116 | 117 | all_models = list_models(pretrained=True) 118 | if args.model == 'all': 119 | for model_name in all_models: 120 | export_model(model_name, args.output) 121 | else: 122 | export_model(args.model, args.output) 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /jeffnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rwightman/efficientnet-jax/06d7ed6cdb252d191d262e0f50464451ee952200/jeffnet/__init__.py -------------------------------------------------------------------------------- /jeffnet/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .arch_defs import decode_arch_def 2 | from .block_utils import make_divisible, round_features 3 | from .builder import EfficientNetBuilder 4 | from .constants import IMAGENET_MEAN, IMAGENET_STD, INCEPTION_MEAN, INCEPTION_STD, get_bn_args_tf, get_bn_args_pt 5 | from .io import load_state_dict, split_state_dict, get_outdir 6 | from .loss import cross_entropy_loss, weighted_cross_entropy_loss 7 | from .lr_schedule import create_lr_schedule, create_lr_schedule_epochs 8 | from .metrics import AverageMeter, correct_topk, acc_topk 9 | from .model_cfgs import get_model_cfg, list_models 10 | from .model_zoo import load_state_dict_from_url 11 | from .padding import get_like_padding 12 | -------------------------------------------------------------------------------- /jeffnet/common/block_utils.py: -------------------------------------------------------------------------------- 1 | def make_divisible(v, divisor=8, min_value=None): 2 | min_value = min_value or divisor 3 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 4 | # Make sure that round down does not go down by more than 10%. 5 | if new_v < 0.9 * v: 6 | new_v += divisor 7 | return new_v 8 | 9 | 10 | def round_features(features, multiplier=1.0, divisor=8, feat_min=None): 11 | """Round number of filters based on depth multiplier.""" 12 | if not multiplier: 13 | return features 14 | features *= multiplier 15 | return make_divisible(features, divisor, feat_min) 16 | -------------------------------------------------------------------------------- /jeffnet/common/builder.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet, MobileNetV3, etc Builder 2 | 3 | Assembles EfficieNet and related network feature blocks from string definitions. 4 | Handles stride, dilation calculations, and selects feature extraction points. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman 7 | """ 8 | 9 | import logging 10 | 11 | from .block_utils import round_features 12 | 13 | 14 | __all__ = ['EfficientNetBuilder'] 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | def _log_info_if(msg, condition): 20 | if condition: 21 | _logger.info(msg) 22 | 23 | 24 | class EfficientNetBuilder: 25 | """ Build Trunk Blocks 26 | 27 | This ended up being somewhat of a cross between 28 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py 29 | and 30 | https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py 31 | 32 | """ 33 | def __init__(self, in_chs, block_defs, block_factory, 34 | feat_multiplier=1.0, feat_divisor=8, feat_min=None, 35 | output_stride=32, pad_type='', conv_layer=None, norm_layer=None, se_layer=None, 36 | act_fn=None, drop_path_rate=0., feature_location='', verbose=False): 37 | assert output_stride in (32, 16, 8, 4, 2) 38 | self.in_chs = in_chs # num input ch from stem 39 | self.block_defs = block_defs # block types, arguments w/ structure 40 | self.block_factory = block_factory # factory to build framework specific blocks 41 | self.feat_multiplier = feat_multiplier 42 | self.feat_divisor = feat_divisor 43 | self.feat_min = feat_min 44 | self.output_stride = output_stride 45 | self.act_fn = act_fn 46 | self.drop_path_rate = drop_path_rate 47 | self.default_args = dict( 48 | pad_type=pad_type, 49 | conv_layer=conv_layer, 50 | norm_layer=norm_layer, 51 | se_layer=se_layer, 52 | ) 53 | self.feature_location = feature_location 54 | assert feature_location in ('bottleneck', 'expansion', '') 55 | self.verbose = verbose 56 | 57 | self.features = [] # information about feature maps, constructed during build 58 | 59 | def _round_channels(self, chs): 60 | return round_features(chs, self.feat_multiplier, self.feat_divisor, self.feat_min) 61 | 62 | def _make_block(self, block_type, block_args, stage_idx, block_idx, flat_idx, block_count): 63 | drop_path_rate = self.drop_path_rate * flat_idx / block_count 64 | # NOTE: block act fn overrides the model default 65 | act_fn = block_args['act_fn'] if block_args['act_fn'] is not None else self.act_fn 66 | act_fn = self.block_factory.get_act_fn(act_fn) # map string acts to functions 67 | ba_overlay = dict( 68 | in_chs=self.in_chs, out_chs=self._round_channels(block_args['out_chs']), act_fn=act_fn, 69 | drop_path_rate=drop_path_rate, **self.default_args) 70 | block_args.update(ba_overlay) 71 | if 'fake_in_chs' in block_args and block_args['fake_in_chs']: 72 | # FIXME this is a hack to work around mismatch in origin impl input filters 73 | block_args['fake_in_chs'] = self._round_channels(block_args['fake_in_chs']) 74 | assert block_args['act_fn'] is not None 75 | 76 | _log_info_if(f' {block_type.upper()} {block_idx}, Args: {str(block_args)}', self.verbose) 77 | if block_type == 'ir': 78 | block = self.block_factory.InvertedResidual(stage_idx, block_idx, **block_args) 79 | elif block_type == 'ds' or block_type == 'dsa': 80 | block = self.block_factory.DepthwiseSeparable(stage_idx, block_idx, **block_args) 81 | elif block_type == 'er': 82 | block = self.block_factory.EdgeResidual(stage_idx, block_idx, **block_args) 83 | elif block_type == 'cn': 84 | block = self.block_factory.ConvBnAct(stage_idx, block_idx, **block_args) 85 | else: 86 | assert False, 'Uknkown block type (%s) while building model.' % block_type 87 | self.in_chs = block_args['out_chs'] # update in_chs for arg of next block 88 | 89 | return block 90 | 91 | def __call__(self): 92 | """ Build the blocks 93 | Return: 94 | List of stages (each stage being a list of blocks) 95 | """ 96 | _log_info_if('Building model trunk with %d stages...' % len(self.block_defs), self.verbose) 97 | num_blocks = sum([len(x) for x in self.block_defs]) 98 | flat_idx = 0 99 | current_stride = 2 100 | current_dilation = 1 101 | stages = [] 102 | # if self.block_args[0][0]['stride'] > 1: 103 | # # if the first block starts with a stride, we need to extract first level feat from stem 104 | # self.features.append(dict( 105 | # module='act1', num_chs=self.in_chs, stage=0, reduction=current_stride, 106 | # hook_type='forward' if self.feature_location != 'bottleneck' else '')) 107 | 108 | # outer list of block_args defines the stacks 109 | for stage_idx, stage_defs in enumerate(self.block_defs): 110 | _log_info_if('Stack: {}'.format(stage_idx), self.verbose) 111 | 112 | blocks = [] 113 | # each stage contains a list of block types and arguments 114 | for block_idx, block_def in enumerate(stage_defs): 115 | _log_info_if(' Block: {}'.format(block_idx), self.verbose) 116 | last_block = block_idx + 1 == len(stage_defs) 117 | block_type, block_args = block_def 118 | block_args = dict(**block_args) 119 | 120 | assert block_args['stride'] in (1, 2) 121 | if block_idx >= 1: # only the first block in any stack can have a stride > 1 122 | block_args['stride'] = 1 123 | 124 | # extract_features = False 125 | # if last_block: 126 | # next_stage_idx = stage_idx + 1 127 | # extract_features = next_stage_idx >= len(self.block_defs) or \ 128 | # self.block_defs[next_stage_idx][0]['stride'] > 1 129 | 130 | next_dilation = current_dilation 131 | if block_args['stride'] > 1: 132 | next_output_stride = current_stride * block_args['stride'] 133 | if next_output_stride > self.output_stride: 134 | next_dilation = current_dilation * block_args['stride'] 135 | block_args['stride'] = 1 136 | _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format( 137 | self.output_stride), self.verbose) 138 | else: 139 | current_stride = next_output_stride 140 | block_args['dilation'] = current_dilation 141 | if next_dilation != current_dilation: 142 | current_dilation = next_dilation 143 | 144 | # create the block 145 | blocks.append(self._make_block(block_type, block_args, stage_idx, block_idx, flat_idx, num_blocks)) 146 | 147 | # stash feature module name and channel info for model feature extraction 148 | # if extract_features: 149 | # feature_info = dict(stage=stage_idx + 1, reduction=current_stride) 150 | # module_name = f'blocks.{stage_idx}.{block_idx}' 151 | # leaf_name = feature_info.get('module', '') 152 | # feature_info['module'] = '.'.join([module_name, leaf_name]) if leaf_name else module_name 153 | # self.features.append(feature_info) 154 | 155 | flat_idx += 1 # incr flattened block idx (across all stages) 156 | stages.append(blocks) 157 | return stages 158 | -------------------------------------------------------------------------------- /jeffnet/common/constants.py: -------------------------------------------------------------------------------- 1 | DEFAULT_CROP_PCT = 0.875 2 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 3 | IMAGENET_STD = (0.229, 0.224, 0.225) 4 | INCEPTION_MEAN = (0.5, 0.5, 0.5) 5 | INCEPTION_STD = (0.5, 0.5, 0.5) 6 | 7 | BN_MOM_TF_DEFAULT = 0.99 8 | BN_EPS_TF_DEFAULT = 1e-3 9 | _BN_ARGS_TF = dict(momentum=BN_MOM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) 10 | 11 | BN_MOM_PT_DEFAULT = .9 12 | BN_EPS_PT_DEFAULT = 1e-5 13 | _BN_ARGS_PT = dict(momentum=BN_MOM_PT_DEFAULT, eps=BN_EPS_PT_DEFAULT) 14 | 15 | 16 | def get_bn_args_tf(): 17 | return _BN_ARGS_TF.copy() 18 | 19 | 20 | def get_bn_args_pt(): 21 | return _BN_ARGS_PT.copy() -------------------------------------------------------------------------------- /jeffnet/common/io.py: -------------------------------------------------------------------------------- 1 | """ Numpy State Dict Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | import os 6 | import numpy as np 7 | import jax.numpy as jnp 8 | 9 | 10 | def load_state_dict(filename, include_type_map=False, transpose=False): 11 | np_weights = np.load(filename) 12 | var_names = np_weights['names'] 13 | var_types = [] 14 | if 'types' in np_weights: 15 | var_types = np_weights['types'] 16 | var_values = [np_weights[str(i)] for i in range(len(var_names))] 17 | jax_state_dict = {} 18 | type_map = {} 19 | for i, (k, v) in enumerate(zip(var_names, var_values)): 20 | if transpose: 21 | # FIXME this is narrowly defined and currently only robust to conv2d, linear, typical norm layers 22 | assert len(v.shape) in (1, 2, 4) 23 | if len(v.shape) == 4: 24 | v = v.transpose((2, 3, 1, 0)) # OIHW -> HWIO 25 | elif len(v.shape) == 2: 26 | v = v.transpose() # OI -> IO 27 | jax_state_dict[k] = jnp.array(v) 28 | if include_type_map and len(var_types) == len(var_names): 29 | t = var_types[i] 30 | type_map[k] = t.lower() 31 | 32 | if len(type_map): 33 | return jax_state_dict, type_map 34 | else: 35 | return jax_state_dict 36 | 37 | 38 | _STATE_NAMES = ('running_mean', 'running_var', 'moving_mean', 'moving_variance') 39 | 40 | 41 | def split_state_dict(state_dict): 42 | """ split a state_dict into params and other state 43 | FIXME currently other state is assumed to be norm running state 44 | """ 45 | out_params = {} 46 | out_state = {} 47 | for k, v in state_dict.items(): 48 | if any(n in k for n in _STATE_NAMES): 49 | out_state[k] = v 50 | else: 51 | out_params[k] = v 52 | return out_params, out_state 53 | 54 | 55 | def get_outdir(path, *paths, retry_inc=False): 56 | outdir = os.path.join(path, *paths) 57 | if not os.path.exists(outdir): 58 | os.makedirs(outdir) 59 | elif retry_inc: 60 | count = 1 61 | outdir_inc = outdir + '-' + str(count) 62 | while os.path.exists(outdir_inc): 63 | count = count + 1 64 | outdir_inc = outdir + '-' + str(count) 65 | assert count < 100 66 | outdir = outdir_inc 67 | os.makedirs(outdir) 68 | return outdir 69 | -------------------------------------------------------------------------------- /jeffnet/common/loss.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp, lax 3 | 4 | 5 | # FIXME ended up with multiple cross entropy loss def here while experimenting 6 | # with diff numeric stability issues... will cleanup someday. 7 | 8 | 9 | def cross_entropy_loss(logits, labels, label_smoothing=0., dtype=jnp.float32): 10 | """Compute cross entropy for logits and labels w/ label smoothing 11 | Args: 12 | logits: [batch, length, num_classes] float array. 13 | labels: categorical labels [batch, length] int array. 14 | label_smoothing: label smoothing constant, used to determine the on and off values. 15 | dtype: dtype to perform loss calcs in, including log_softmax 16 | """ 17 | num_classes = logits.shape[-1] 18 | labels = jax.nn.one_hot(labels, num_classes, dtype=dtype) 19 | if label_smoothing > 0: 20 | labels = labels * (1 - label_smoothing) + label_smoothing / num_classes 21 | logp = jax.nn.log_softmax(logits.astype(dtype)) 22 | return -jnp.mean(jnp.sum(logp * labels, axis=-1)) 23 | 24 | 25 | def onehot(labels, num_classes, on_value=1.0, off_value=0.0, dtype=jnp.float32): 26 | x = (labels[..., None] == jnp.arange(num_classes)[None]) 27 | x = lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value)) 28 | return x.astype(dtype) 29 | 30 | 31 | def weighted_cross_entropy_loss(logits, labels, weights=None, label_smoothing=0.0, dtype=jnp.float32): 32 | """Compute weighted cross entropy for logits and labels w/ label smoothing. 33 | Args: 34 | logits: [batch, length, num_classes] float array. 35 | labels: categorical labels [batch, length] int array. 36 | weights: None or array of shape [batch, length]. 37 | label_smoothing: label smoothing constant, used to determine the on and off values. 38 | dtype: dtype to perform loss calcs in, including log_softmax 39 | Returns: 40 | Tuple of scalar loss and batch normalizing factor. 41 | """ 42 | if logits.ndim != labels.ndim + 1: 43 | raise ValueError(f'Incorrect shapes. Got shape {logits.shape} logits and {labels.shape} targets') 44 | num_classes = logits.shape[-1] 45 | off_value = label_smoothing / num_classes 46 | on_value = 1. - label_smoothing + off_value 47 | soft_targets = onehot(labels, num_classes, on_value=on_value, off_value=off_value, dtype=dtype) 48 | logp = jax.nn.log_softmax(logits.astype(dtype)) 49 | loss = jnp.sum(logp * soft_targets, axis=-1) 50 | if weights is not None: 51 | loss = loss * weights 52 | return -loss.mean() 53 | -------------------------------------------------------------------------------- /jeffnet/common/lr_schedule.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | 5 | def create_lr_schedule_epochs( 6 | base_lr, 7 | decay_type, 8 | steps_per_epoch, 9 | total_epochs, 10 | decay_rate=0.1, 11 | decay_epochs=0, 12 | warmup_epochs=5., 13 | power=1.0, 14 | min_lr=1e-5): 15 | total_steps = int(total_epochs * steps_per_epoch) 16 | decay_steps = int(decay_epochs * steps_per_epoch) 17 | warmup_steps = int(warmup_epochs * steps_per_epoch) 18 | 19 | return create_lr_schedule( 20 | base_lr=base_lr, 21 | decay_type=decay_type, 22 | total_steps=total_steps, 23 | decay_rate=decay_rate, 24 | decay_steps=decay_steps, 25 | warmup_steps=warmup_steps, 26 | power=power, 27 | min_lr=min_lr, 28 | ) 29 | 30 | 31 | def create_lr_schedule( 32 | base_lr, 33 | decay_type, 34 | total_steps, 35 | decay_rate=0.1, 36 | decay_steps=0, 37 | warmup_steps=0, 38 | power=1.0, 39 | min_lr=1e-5): 40 | """Creates learning rate schedule. 41 | 42 | Currently only warmup + {linear,cosine} but will be a proper mini-language 43 | like preprocessing one in the future. 44 | 45 | Args: 46 | total_steps: The total number of steps to run. 47 | base_lr: The starting learning-rate (without warmup). 48 | decay_type: One of 'cosine', 'step', 'poly', 'exponential', 'constant' 49 | decay_rate: Decay fraction for step / exponential schedules 50 | decay_steps: Number of steps for each application of decay_rate 51 | warmup_steps: how many steps to warm up for. 52 | min_lr: Minimum learning rate. 53 | 54 | Returns: 55 | A function learning_rate(step): float -> {"learning_rate": float}. 56 | """ 57 | 58 | def step_fn(step): 59 | """Step to learning rate function.""" 60 | lr = base_lr 61 | step_mwu = jnp.maximum(0., step - warmup_steps) 62 | step_pct = jnp.clip(step_mwu / float(total_steps - warmup_steps), 0.0, 1.0) 63 | 64 | if decay_type == 'cosine': 65 | lr = min_lr + lr * 0.5 * (1. + jnp.cos(jnp.pi * step_pct)) 66 | elif decay_type == 'step': 67 | assert decay_steps > 0 68 | lr = lr * decay_rate ** (step_mwu // decay_steps) 69 | elif decay_type.startswith('poly'): 70 | lr = min_lr + (lr - min_lr) * (1. - step_pct) ** power 71 | elif decay_type.startswith('exp'): 72 | assert decay_steps > 0 73 | lr = lr * decay_rate ** (step_mwu / decay_steps) 74 | elif not decay_type or decay_type.startswith('const'): 75 | lr = lr 76 | else: 77 | raise ValueError(f'Unknown lr type {decay_type}') 78 | 79 | lr = jnp.maximum(min_lr, lr) 80 | if warmup_steps: 81 | lr = lr * jnp.minimum(1., step / warmup_steps) 82 | 83 | return jnp.asarray(lr, dtype=jnp.float32) 84 | 85 | return step_fn 86 | -------------------------------------------------------------------------------- /jeffnet/common/metrics.py: -------------------------------------------------------------------------------- 1 | from jax import lax as lax 2 | 3 | 4 | class AverageMeter: 5 | """Computes and stores the average and current value""" 6 | def __init__(self): 7 | self.reset() 8 | 9 | def reset(self): 10 | self.val = 0 11 | self.avg = 0 12 | self.sum = 0 13 | self.count = 0 14 | 15 | def update(self, val, n=1): 16 | self.val = val 17 | self.sum += val * n 18 | self.count += n 19 | self.avg = self.sum / self.count 20 | 21 | 22 | def correct_topk(logits, labels, topk=(1,)): 23 | top = lax.top_k(logits, max(topk))[1].transpose() 24 | correct = top == labels.reshape(1, -1) 25 | return [correct[:k].reshape(-1).sum(axis=0) for k in topk] 26 | 27 | 28 | def acc_topk(logits, labels, topk=(1,)): 29 | top = lax.top_k(logits, max(topk))[1].transpose() 30 | correct = top == labels.reshape(1, -1) 31 | return [correct[:k].reshape(-1).sum(axis=0) * 100 / labels.shape[0] for k in topk] -------------------------------------------------------------------------------- /jeffnet/common/model_zoo.py: -------------------------------------------------------------------------------- 1 | """ Model Zoo Functionality 2 | 3 | This code has been ripped from PyTorch' zoo/hub functionality. 4 | PyTorch is Copyright Facebook Inc, and Authors. Released under BSD-3 compatible license. 5 | 6 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 7 | """ 8 | 9 | import os 10 | import sys 11 | import errno 12 | import re 13 | import shutil 14 | import hashlib 15 | import tempfile 16 | import warnings 17 | from urllib.request import urlopen, Request 18 | from urllib.parse import urlparse # noqa: F401 19 | 20 | 21 | from .io import load_state_dict 22 | 23 | try: 24 | from tqdm.auto import tqdm # automatically select proper tqdm submodule if available 25 | except ImportError: 26 | try: 27 | from tqdm import tqdm 28 | except ImportError: 29 | # fake tqdm if it's not installed 30 | class tqdm(object): # type: ignore 31 | 32 | def __init__(self, total=None, disable=False, 33 | unit=None, unit_scale=None, unit_divisor=None): 34 | self.total = total 35 | self.disable = disable 36 | self.n = 0 37 | # ignore unit, unit_scale, unit_divisor; they're just for real tqdm 38 | 39 | def update(self, n): 40 | if self.disable: 41 | return 42 | 43 | self.n += n 44 | if self.total is None: 45 | sys.stderr.write("\r{0:.1f} bytes".format(self.n)) 46 | else: 47 | sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) 48 | sys.stderr.flush() 49 | 50 | def __enter__(self): 51 | return self 52 | 53 | def __exit__(self, exc_type, exc_val, exc_tb): 54 | if self.disable: 55 | return 56 | 57 | sys.stderr.write('\n') 58 | 59 | # matches bfd8deac from resnet18-bfd8deac.pth 60 | HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') 61 | ENV_JAX_HOME = 'JAX_HOME' 62 | ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' 63 | DEFAULT_CACHE_DIR = '~/.cache' 64 | READ_DATA_CHUNK = 8192 65 | 66 | 67 | def get_jax_dir(): 68 | jax_home = os.path.expanduser(os.getenv(ENV_JAX_HOME, 69 | os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'jax'))) 70 | return jax_home 71 | 72 | 73 | def download_url_to_file(url, dst, hash_prefix=None, progress=True): 74 | r"""Download object at the given URL to a local path. 75 | 76 | Args: 77 | url (string): URL of the object to download 78 | dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` 79 | hash_prefix (string, optional): If not None, the SHA256 downloaded file should start with `hash_prefix`. 80 | Default: None 81 | progress (bool, optional): whether or not to display a progress bar to stderr 82 | Default: True 83 | 84 | """ 85 | file_size = None 86 | # We use a different API for python2 since urllib(2) doesn't recognize the CA 87 | # certificates in older Python 88 | req = Request(url, headers={"User-Agent": "jax.zoo"}) 89 | u = urlopen(req) 90 | meta = u.info() 91 | if hasattr(meta, 'getheaders'): 92 | content_length = meta.getheaders("Content-Length") 93 | else: 94 | content_length = meta.get_all("Content-Length") 95 | if content_length is not None and len(content_length) > 0: 96 | file_size = int(content_length[0]) 97 | 98 | # We deliberately save it in a temp file and move it after 99 | # download is complete. This prevents a local working checkpoint 100 | # being overridden by a broken download. 101 | dst = os.path.expanduser(dst) 102 | dst_dir = os.path.dirname(dst) 103 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 104 | 105 | try: 106 | if hash_prefix is not None: 107 | sha256 = hashlib.sha256() 108 | with tqdm(total=file_size, disable=not progress, 109 | unit='B', unit_scale=True, unit_divisor=1024) as pbar: 110 | while True: 111 | buffer = u.read(8192) 112 | if len(buffer) == 0: 113 | break 114 | f.write(buffer) 115 | if hash_prefix is not None: 116 | sha256.update(buffer) 117 | pbar.update(len(buffer)) 118 | 119 | f.close() 120 | if hash_prefix is not None: 121 | digest = sha256.hexdigest() 122 | if digest[:len(hash_prefix)] != hash_prefix: 123 | raise RuntimeError('invalid hash value (expected "{}", got "{}")' 124 | .format(hash_prefix, digest)) 125 | shutil.move(f.name, dst) 126 | finally: 127 | f.close() 128 | if os.path.exists(f.name): 129 | os.remove(f.name) 130 | 131 | 132 | def load_state_dict_from_url(url, model_dir=None, transpose=False, progress=True, check_hash=True, file_name=None): 133 | r"""Loads the serialized npz object at the given URL. 134 | 135 | If the object is already present in `model_dir`, it's deserialized and returned. 136 | 137 | The default value of `model_dir` is ``/checkpoints`` where 138 | `jax_dir` is the directory returned by :func:`get_dir`. 139 | 140 | Args: 141 | url (string): URL of the object to download 142 | model_dir (string, optional): directory in which to save the object 143 | transpose (bool): transpose the weights from PyTorch (ie OIHW) style layouts to Tensorflow (ie HWIO conv2d) 144 | progress (bool, optional): whether or not to display a progress bar to stderr. Default: True 145 | check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention 146 | ``filename-.ext`` where ```` is the first eight or more 147 | digits of the SHA256 hash of the contents of the file. The hash is used to 148 | ensure unique names and to verify the contents of the file. Default: False 149 | file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set. 150 | 151 | """ 152 | # Issue warning to move data if old env is set 153 | if os.getenv('TORCH_MODEL_ZOO'): 154 | warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') 155 | 156 | if model_dir is None: 157 | model_dir = os.path.join(get_jax_dir(), 'checkpoints') 158 | 159 | try: 160 | os.makedirs(model_dir) 161 | except OSError as e: 162 | if e.errno == errno.EEXIST: 163 | # Directory already exists, ignore. 164 | pass 165 | else: 166 | # Unexpected OSError, re-raise. 167 | raise 168 | 169 | parts = urlparse(url) 170 | filename = os.path.basename(parts.path) 171 | if file_name is not None: 172 | filename = file_name 173 | cached_file = os.path.join(model_dir, filename) 174 | if not os.path.exists(cached_file): 175 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 176 | hash_prefix = HASH_REGEX.search(filename).group(1) if check_hash else None 177 | download_url_to_file(url, cached_file, hash_prefix, progress=progress) 178 | 179 | return load_state_dict(cached_file, transpose=transpose) 180 | -------------------------------------------------------------------------------- /jeffnet/common/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .optim_factory import create_optax_optim -------------------------------------------------------------------------------- /jeffnet/common/optim/helpers.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple, Any, Union 2 | 3 | import jax 4 | import optax 5 | from jax import numpy as jnp 6 | 7 | ScalarOrSchedule = Union[float, Callable] 8 | 9 | 10 | def scale_by_learning_rate(learning_rate: ScalarOrSchedule): 11 | if callable(learning_rate): 12 | return optax.scale_by_schedule(lambda count: -learning_rate(count)) 13 | return optax.scale(-learning_rate) 14 | 15 | 16 | def update_moment(updates, moments, decay, order): 17 | return jax.tree_multimap(lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) 18 | 19 | 20 | FilterFn = Callable[[Tuple[Any], jnp.ndarray], jnp.ndarray] 21 | 22 | 23 | def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.ndarray: 24 | """Filter to exclude biaises and normalizations weights.""" 25 | del val 26 | if path[-1] == "bias" or path[-1] == 'scale': 27 | return False 28 | return True 29 | -------------------------------------------------------------------------------- /jeffnet/common/optim/lars.py: -------------------------------------------------------------------------------- 1 | """ Optax LARS Optimizer 2 | 3 | Taken from https://github.com/deepmind/deepmind-research/blob/master/byol/utils/optimizers.py 4 | Apache 2.0 License, copyright below 5 | 6 | """ 7 | # Copyright 2020 DeepMind Technologies Limited. 8 | # 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # https://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | 22 | """Implementation of LARS Optimizer with optax.""" 23 | 24 | from typing import List, NamedTuple, Optional, Tuple 25 | 26 | import jax 27 | import jax.numpy as jnp 28 | import optax 29 | import tree as nest 30 | 31 | # A filter function takes a path and a value as input and outputs True for 32 | # variable to apply update and False not to apply the update 33 | from .helpers import FilterFn, ScalarOrSchedule, scale_by_learning_rate, exclude_bias_and_norm 34 | 35 | 36 | def _partial_update(updates: optax.Updates, 37 | new_updates: optax.Updates, 38 | params: optax.Params, 39 | filter_fn: Optional[FilterFn] = None) -> optax.Updates: 40 | """Returns new_update for params which filter_fn is True else updates.""" 41 | 42 | if filter_fn is None: 43 | return new_updates 44 | 45 | wrapped_filter_fn = lambda x, y: jnp.array(filter_fn(x, y)) 46 | params_to_filter = nest.map_structure_with_path(wrapped_filter_fn, params) 47 | 48 | def _update_fn(g: jnp.ndarray, t: jnp.ndarray, m: jnp.ndarray) -> jnp.ndarray: 49 | m = m.astype(g.dtype) 50 | return g * (1. - m) + t * m 51 | 52 | return jax.tree_multimap(_update_fn, updates, new_updates, params_to_filter) 53 | 54 | 55 | class ScaleByLarsState(NamedTuple): 56 | mu: jnp.ndarray 57 | 58 | 59 | def scale_by_lars( 60 | momentum: float = 0.9, 61 | eta: float = 0.001, 62 | filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: 63 | """Rescales updates according to the LARS algorithm. 64 | 65 | Does not include weight decay. 66 | References: 67 | [You et al, 2017](https://arxiv.org/abs/1708.03888) 68 | 69 | Args: 70 | momentum: momentum coeficient. 71 | eta: LARS coefficient. 72 | filter_fn: an optional filter function. 73 | 74 | Returns: 75 | An (init_fn, update_fn) tuple. 76 | """ 77 | 78 | def init_fn(params: optax.Params) -> ScaleByLarsState: 79 | mu = jax.tree_multimap(jnp.zeros_like, params) # momentum 80 | return ScaleByLarsState(mu=mu) 81 | 82 | def update_fn(updates: optax.Updates, state: ScaleByLarsState, 83 | params: optax.Params) -> Tuple[optax.Updates, ScaleByLarsState]: 84 | def lars_adaptation( 85 | update: jnp.ndarray, 86 | param: jnp.ndarray, 87 | ) -> jnp.ndarray: 88 | param_norm = jnp.linalg.norm(param) 89 | update_norm = jnp.linalg.norm(update) 90 | return update * jnp.where( 91 | param_norm > 0., 92 | jnp.where(update_norm > 0, (eta * param_norm / update_norm), 1.0), 1.0) 93 | 94 | adapted_updates = jax.tree_multimap(lars_adaptation, updates, params) 95 | adapted_updates = _partial_update(updates, adapted_updates, params, filter_fn) 96 | mu = jax.tree_multimap(lambda g, t: momentum * g + t, state.mu, adapted_updates) 97 | return mu, ScaleByLarsState(mu=mu) 98 | 99 | return optax.GradientTransformation(init_fn, update_fn) 100 | 101 | 102 | class AddWeightDecayState(NamedTuple): 103 | """Stateless transformation.""" 104 | 105 | 106 | def add_weight_decay( 107 | weight_decay: float, 108 | filter_fn: Optional[FilterFn] = None) -> optax.GradientTransformation: 109 | """Adds a weight decay to the update. 110 | 111 | Args: 112 | weight_decay: weight_decay coeficient. 113 | filter_fn: an optional filter function. 114 | 115 | Returns: 116 | An (init_fn, update_fn) tuple. 117 | """ 118 | 119 | def init_fn(_) -> AddWeightDecayState: 120 | return AddWeightDecayState() 121 | 122 | def update_fn( 123 | updates: optax.Updates, 124 | state: AddWeightDecayState, 125 | params: optax.Params, 126 | ) -> Tuple[optax.Updates, AddWeightDecayState]: 127 | new_updates = jax.tree_multimap(lambda g, p: g + weight_decay * p, updates, params) 128 | new_updates = _partial_update(updates, new_updates, params, filter_fn) 129 | return new_updates, state 130 | 131 | return optax.GradientTransformation(init_fn, update_fn) 132 | 133 | 134 | LarsState = List # Type for the lars optimizer 135 | 136 | 137 | def lars( 138 | learning_rate: ScalarOrSchedule, 139 | weight_decay: float = 0., 140 | momentum: float = 0.9, 141 | eta: float = 0.001, 142 | weight_decay_filter: Optional[FilterFn] = None, 143 | lars_adaptation_filter: Optional[FilterFn] = None, 144 | ) -> optax.GradientTransformation: 145 | """Creates lars optimizer with weight decay. 146 | 147 | References: 148 | [You et al, 2017](https://arxiv.org/abs/1708.03888) 149 | 150 | Args: 151 | learning_rate: learning rate coefficient. 152 | weight_decay: weight decay coefficient. 153 | momentum: momentum coefficient. 154 | eta: LARS coefficient. 155 | weight_decay_filter: optional filter function to only apply the weight 156 | decay on a subset of parameters. The filter function takes as input the 157 | parameter path (as a tuple) and its associated update, and return a True 158 | for params to apply the weight decay and False for params to not apply 159 | the weight decay. When weight_decay_filter is set to None, the weight 160 | decay is not applied to the bias, i.e. when the variable name is 'b', and 161 | the weight decay is not applied to nornalization params, i.e. the 162 | penultimate path contains 'norm'. 163 | lars_adaptation_filter: similar to weight decay filter but for lars adaptation 164 | 165 | Returns: 166 | An optax.GradientTransformation, i.e. a (init_fn, update_fn) tuple. 167 | """ 168 | 169 | if weight_decay_filter is None: 170 | # FIXME always pass in exclusions or assume defaults? 171 | weight_decay_filter = exclude_bias_and_norm 172 | #weight_decay_filter = lambda *_: True 173 | if lars_adaptation_filter is None: 174 | # FIXME always pass in exclusions or assume defaults? 175 | lars_adaptation_filter = exclude_bias_and_norm 176 | #lars_adaptation_filter = lambda *_: True 177 | 178 | return optax.chain( 179 | add_weight_decay(weight_decay=weight_decay, filter_fn=weight_decay_filter), 180 | scale_by_lars(momentum=momentum, eta=eta, filter_fn=lars_adaptation_filter), 181 | scale_by_learning_rate(learning_rate), 182 | ) 183 | -------------------------------------------------------------------------------- /jeffnet/common/optim/optim_factory.py: -------------------------------------------------------------------------------- 1 | import optax 2 | 3 | from .lars import lars 4 | 5 | 6 | def _rename(kwargs, originals, new): 7 | for o, n in zip(originals, new): 8 | o = kwargs.pop(o, None) 9 | if o is not None: 10 | kwargs[n] = o 11 | 12 | 13 | def _erase(kwargs, names): 14 | for u in names: 15 | kwargs.pop(u, None) 16 | 17 | 18 | def create_optax_optim(name, learning_rate=None, momentum=0.9, weight_decay=0, **kwargs): 19 | """ Optimizer Factory 20 | 21 | Args: 22 | learning_rate (float): specify learning rate or leave up to scheduler / optim if None 23 | weight_decay (float): weight decay to apply to all params, not applied if 0 24 | **kwargs: optional / optimizer specific params that override defaults 25 | 26 | With regards to the kwargs, I've tried to keep the param naming incoming via kwargs from 27 | config file more consistent so there is less variation. Names of common args such as eps, 28 | beta1, beta2 etc will be remapped where possible (even if optimizer impl uses a diff name) 29 | and removed when not needed. A list of some common params to use in config files as named: 30 | eps (float): default stability / regularization epsilon value 31 | beta1 (float): moving average / momentum coefficient for gradient 32 | beta2 (float): moving average / momentum coefficient for gradient magnitude (squared grad) 33 | """ 34 | name = name.lower() 35 | opt_args = dict(learning_rate=learning_rate, **kwargs) 36 | _rename(opt_args, ('beta1', 'beta2'), ('b1', 'b2')) 37 | if name == 'sgd' or name == 'momentum' or name == 'nesterov': 38 | _erase(opt_args, ('eps',)) 39 | if name == 'momentum': 40 | optimizer = optax.sgd(momentum=momentum, **opt_args) 41 | elif name == 'nesterov': 42 | optimizer = optax.sgd(momentum=momentum, nesterov=True) 43 | else: 44 | assert name == 'sgd' 45 | optimizer = optax.sgd(momentum=0, **opt_args) 46 | elif name == 'adabelief': 47 | optimizer = optax.adabelief(**opt_args) 48 | elif name == 'adam' or name == 'adamw': 49 | if name == 'adamw': 50 | optimizer = optax.adamw(weight_decay=weight_decay, **opt_args) 51 | else: 52 | optimizer = optax.adam(**opt_args) 53 | elif name == 'lamb': 54 | optimizer = optax.lamb(weight_decay=weight_decay, **opt_args) 55 | elif name == 'lars': 56 | optimizer = lars(weight_decay=weight_decay, **opt_args) 57 | elif name == 'rmsprop': 58 | optimizer = optax.rmsprop(momentum=momentum, **opt_args) 59 | elif name == 'rmsproptf': 60 | optimizer = optax.rmsprop(momentum=momentum, initial_scale=1.0, **opt_args) 61 | else: 62 | assert False, f"Invalid optimizer name specified ({name})" 63 | 64 | return optimizer 65 | -------------------------------------------------------------------------------- /jeffnet/common/padding.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # calculate SAME-like symmetric padding for a convolution 4 | def get_like_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int: 5 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 6 | return padding 7 | -------------------------------------------------------------------------------- /jeffnet/data/tf_autoaugment.py: -------------------------------------------------------------------------------- 1 | """ RandAudgment and AutoAugment for TF data pipeline. 2 | 3 | This code is a mish mash of various RA and AA impl for TF, including bits and pieces from: 4 | * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 5 | * https://github.com/google-research/fixmatch/tree/master/imagenet/augment 6 | 7 | AutoAugment Reference: https://arxiv.org/abs/1805.09501 8 | RandAugment Reference: https://arxiv.org/abs/1909.13719 9 | """ 10 | 11 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 12 | # 13 | # Licensed under the Apache License, Version 2.0 (the "License"); 14 | # you may not use this file except in compliance with the License. 15 | # You may obtain a copy of the License at 16 | # 17 | # http://www.apache.org/licenses/LICENSE-2.0 18 | # 19 | # Unless required by applicable law or agreed to in writing, software 20 | # distributed under the License is distributed on an "AS IS" BASIS, 21 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 22 | # See the License for the specific language governing permissions and 23 | # limitations under the License. 24 | # ============================================================================== 25 | import tensorflow as tf 26 | 27 | # This signifies the max integer that the controller RNN could predict for the 28 | # augmentation scheme. 29 | from jeffnet.data.tf_image_ops import cutout, solarize, solarize_add, color, contrast, brightness, posterize, rotate, \ 30 | translate_x, translate_y, shear_x, shear_y, autocontrast, sharpness, equalize, invert, autocontrast_or_tone 31 | 32 | _MAX_LEVEL = 10. 33 | 34 | IMAGENET_AUG_OPS = [ 35 | 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize', 36 | 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 37 | 'TranslateX', 'TranslateY', 'SolarizeAdd', 'Identity', 38 | ] 39 | 40 | NAME_TO_FUNC = { 41 | 'AutoContrast': autocontrast_or_tone, 42 | 'Equalize': equalize, 43 | 'Invert': invert, 44 | 'Rotate': rotate, 45 | 'Posterize': posterize, 46 | 'Solarize': solarize, 47 | 'SolarizeAdd': solarize_add, 48 | 'Color': color, 49 | 'Contrast': contrast, 50 | 'Brightness': brightness, 51 | 'Sharpness': sharpness, 52 | 'ShearX': shear_x, 53 | 'ShearY': shear_y, 54 | 'TranslateX': translate_x, 55 | 'TranslateY': translate_y, 56 | 'Cutout': cutout, 57 | } 58 | 59 | 60 | def _randomly_negate_tensor(tensor): 61 | """With 50% prob turn the tensor negative.""" 62 | should_flip = tf.cast(tf.floor(tf.random.uniform([]) + 0.5), tf.bool) 63 | final_tensor = tf.cond(should_flip, lambda: tensor, lambda: -tensor) 64 | return final_tensor 65 | 66 | 67 | def _rotate_level(level): 68 | level = (level / _MAX_LEVEL) * 30. 69 | level = _randomly_negate_tensor(level) 70 | return level, 71 | 72 | 73 | def _shrink_level(level): 74 | """Converts level to ratio by which we shrink the image content.""" 75 | if level == 0: 76 | return 1.0, # if level is zero, do not shrink the image 77 | # Maximum shrinking ratio is 2.9. 78 | level = 2. / (_MAX_LEVEL / level) + 0.9 79 | return level, 80 | 81 | 82 | def _enhance_level(level): 83 | # NOTE original level range doesn't make sense 84 | # M0 -> all degenerate 85 | # M0-M5 -> interpolation 86 | # M5 -> no op 87 | # M5+ -> extrapolation 88 | #level = (level / _MAX_LEVEL) * 1.8 + 0.1 89 | 90 | # this will randomly flip between interpolate and extrapolate with increasing amount based on level 91 | # FIXME what to limit range to? typically makes sense 0. - 2., but larger range possible 92 | level = (level / _MAX_LEVEL) * .9 93 | level = 1.0 + _randomly_negate_tensor(level) 94 | level = tf.clip_by_value(level, 0., 3.) 95 | return level, 96 | 97 | 98 | def _shear_level(level): 99 | level = (level / _MAX_LEVEL) * 0.3 100 | # Flip level to negative with 50% chance. 101 | level = _randomly_negate_tensor(level) 102 | return level, 103 | 104 | 105 | def _translate_level(level, translate_const): 106 | level = (level / _MAX_LEVEL) * float(translate_const) 107 | # Flip level to negative with 50% chance. 108 | level = _randomly_negate_tensor(level) 109 | return level, 110 | 111 | 112 | def _get_args_fn(hparams): 113 | return { 114 | 'AutoContrast': lambda level: (), 115 | 'Equalize': lambda level: (), 116 | 'Invert': lambda level: (), 117 | 'Rotate': lambda level: _rotate_level(level) + (hparams['fill_value'],), 118 | # FIXME fix posterize/solarize scale as per timm 119 | 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), 120 | 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), 121 | 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), 122 | 'Color': _enhance_level, 123 | 'Contrast': _enhance_level, 124 | 'Brightness': _enhance_level, 125 | 'Sharpness': _enhance_level, 126 | 'ShearX': lambda level: _shear_level(level) + (hparams['fill_value'],), 127 | 'ShearY': lambda level: _shear_level(level) + (hparams['fill_value'],), 128 | # pylint:disable=g-long-lambda 129 | 'TranslateX': lambda level: _translate_level(level, hparams['translate_const']) + (hparams['fill_value'],), 130 | 'TranslateY': lambda level: _translate_level(level, hparams['translate_const']) + (hparams['fill_value'],), 131 | # FIXME relative translate as per timm 132 | # pylint:enable=g-long-lambda 133 | 'Cutout': lambda level: (), 134 | } 135 | 136 | 137 | class RandAugment: 138 | """Random augment with fixed magnitude. 139 | FIXME this is a class based impl or RA from fixmatch, it needs some changes before using 140 | """ 141 | 142 | def __init__(self, 143 | num_layers=2, 144 | prob_to_apply=None, 145 | magnitude=None, 146 | num_levels=10, 147 | ): 148 | """Initialized rand augment. 149 | Args: 150 | num_layers: number of augmentation layers, i.e. how many times to do augmentation. 151 | prob_to_apply: probability to apply on each layer. If None then always apply. 152 | magnitude: default magnitude in range [0, 1], if None then magnitude will be chosen randomly. 153 | num_levels: number of levels for quantization of the magnitude. 154 | """ 155 | self.num_layers = num_layers 156 | self.prob_to_apply = float(prob_to_apply) if prob_to_apply is not None else None 157 | self.num_levels = int(num_levels) if num_levels else None 158 | self.level = float(magnitude) if magnitude is not None else None 159 | self.augmentation_hparams = dict( 160 | translate_rel=0.4, 161 | translate_const=100) 162 | 163 | def _get_level(self): 164 | if self.level is not None: 165 | return tf.convert_to_tensor(self.level) 166 | if self.num_levels is None: 167 | return tf.random.uniform(shape=[], dtype=tf.float32) 168 | else: 169 | level = tf.random.uniform(shape=[], maxval=self.num_levels + 1, dtype=tf.int32) 170 | return tf.cast(level, tf.float32) / self.num_levels 171 | 172 | def _apply_one_layer(self, image): 173 | """Applies one level of augmentation to the image.""" 174 | level = self._get_level() 175 | branch_fns = [] 176 | for augment_op_name in IMAGENET_AUG_OPS: 177 | augment_fn = NAME_TO_FUNC[augment_op_name] 178 | args_fn = _get_args_fn(self.augmentation_hparams)[augment_op_name] 179 | 180 | def _branch_fn(image=image, augment_fn=augment_fn, args_fn=args_fn): 181 | args = [image] + list(args_fn(level)) 182 | return augment_fn(*args) 183 | 184 | branch_fns.append(_branch_fn) 185 | 186 | branch_index = tf.random.uniform(shape=[], maxval=len(branch_fns), dtype=tf.int32) 187 | aug_image = tf.switch_case(branch_index, branch_fns, default=lambda: image) 188 | if self.prob_to_apply is not None: 189 | return tf.cond( 190 | tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply, 191 | lambda: aug_image, 192 | lambda: image) 193 | else: 194 | return aug_image 195 | 196 | def __call__(self, image, aug_image_key='image'): 197 | output_dict = {} 198 | 199 | if aug_image_key is not None: 200 | aug_image = image 201 | for _ in range(self.num_layers): 202 | aug_image = self._apply_one_layer(aug_image) 203 | output_dict[aug_image_key] = aug_image 204 | 205 | if aug_image_key != 'image': 206 | output_dict['image'] = image 207 | 208 | return output_dict 209 | 210 | 211 | def _parse_policy_info(name, prob, level, augmentation_hparams): 212 | """Return the function that corresponds to `name` and update `level` param.""" 213 | func = NAME_TO_FUNC[name] 214 | args = _get_args_fn(augmentation_hparams)[name](level) 215 | 216 | return func, prob, args 217 | 218 | 219 | def _apply_func_with_prob(func, image, args, prob): 220 | """Apply `func` to image w/ `args` as input with probability `prob`.""" 221 | assert isinstance(args, tuple) 222 | 223 | # Apply the function with probability `prob`. 224 | should_apply_op = tf.cast(tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool) 225 | augmented_image = tf.cond( 226 | should_apply_op, 227 | lambda: func(image, *args), 228 | lambda: image) 229 | return augmented_image 230 | 231 | 232 | def select_and_apply_random_policy(policies, image): 233 | """Select a random policy from `policies` and apply it to `image`.""" 234 | policy_to_select = tf.random.uniform([], maxval=len(policies), dtype=tf.int32) 235 | # Note that using tf.case instead of tf.conds would result in significantly 236 | # larger graphs and would even break export for some larger policies. 237 | for i, policy in enumerate(policies): 238 | image = tf.cond( 239 | tf.equal(i, policy_to_select), 240 | lambda selected_policy=policy: selected_policy(image), 241 | lambda: image) 242 | return image 243 | 244 | 245 | def distort_image_with_randaugment(image, num_layers, magnitude, fill_value=(128, 128, 128)): 246 | """Applies the RandAugment policy to `image`. 247 | 248 | RandAugment is from the paper https://arxiv.org/abs/1909.13719, 249 | 250 | Args: 251 | image: `Tensor` of shape [height, width, 3] representing an image. 252 | num_layers: Integer, the number of augmentation transformations to apply 253 | sequentially to an image. Represented as (N) in the paper. Usually best 254 | values will be in the range [1, 3]. 255 | magnitude: Integer, shared magnitude across all augmentation operations. 256 | Represented as (M) in the paper. Usually best values are in the range [5, 30]. 257 | 258 | Returns: 259 | The augmented version of `image`. 260 | """ 261 | augmentation_hparams = dict( 262 | translate_rel=0.4, 263 | translate_const=100, 264 | fill_value=fill_value, 265 | ) 266 | available_ops = [ 267 | 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 268 | 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', 269 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'SolarizeAdd'] 270 | 271 | for layer_num in range(num_layers): 272 | op_to_select = tf.random.uniform([], maxval=len(available_ops), dtype=tf.int32) 273 | random_magnitude = float(magnitude) 274 | with tf.name_scope('randaug_layer_{}'.format(layer_num)): 275 | for (i, op_name) in enumerate(available_ops): 276 | prob = tf.random.uniform([], minval=0.2, maxval=0.8, dtype=tf.float32) 277 | func, _, args = _parse_policy_info(op_name, prob, random_magnitude, augmentation_hparams) 278 | image = tf.cond( 279 | tf.equal(i, op_to_select), 280 | # pylint:disable=g-long-lambda 281 | lambda selected_func=func, selected_args=args: selected_func(image, *selected_args), 282 | # pylint:enable=g-long-lambda 283 | lambda: image) 284 | return image 285 | -------------------------------------------------------------------------------- /jeffnet/data/tf_image_ops.py: -------------------------------------------------------------------------------- 1 | """ TF image operations 2 | 3 | A collection of image operations from a variety of sources, intended for use by 4 | RandAug / AutoAug / SimCLR policies 5 | 6 | Original sources 7 | * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py 8 | * https://github.com/google-research/fixmatch/tree/master/imagenet/augment 9 | * https://github.com/tensorflow/addons/tree/v0.12.0/tensorflow_addons/image 10 | """ 11 | import math 12 | import tensorflow as tf 13 | from tensorflow_addons import image as tfi 14 | from tensorflow_addons.image.utils import unwrap, wrap 15 | from tensorflow_addons.utils.types import TensorLike, Number 16 | 17 | 18 | equalize = tfi.equalize 19 | cutout = tfi.random_cutout 20 | 21 | 22 | def to_float(x: TensorLike): 23 | # convert int image dtype to float, WITHOUT rescale to [0, 1) 24 | return tf.cast(x, dtype=tf.float32) 25 | 26 | 27 | def to_uint8(x: TensorLike, saturate=True): 28 | return tf.saturate_cast(x, tf.uint8) if saturate else tf.cast(x, tf.uint8) 29 | 30 | 31 | def blend(image1, image2, factor): 32 | """Blend image1 and image2 using 'factor'. 33 | 34 | A value of factor 0.0 means only image1 is used. 35 | A value of 1.0 means only image2 is used. A value between 0.0 and 36 | 1.0 means we linearly interpolate the pixel values between the two 37 | images. A value greater than 1.0 "extrapolates" the difference 38 | between the two pixel values, and we clip the results to values 39 | between 0 and 255. 40 | 41 | Args: 42 | image1: An image Tensor. 43 | image2: An image Tensor. 44 | factor: A floating point value above 0.0. 45 | 46 | Returns: 47 | A blended image Tensor. 48 | """ 49 | image1 = tf.cast(image1, tf.float32) 50 | image2 = tf.cast(image2, tf.float32) 51 | return to_uint8(image1 + factor * (image2 - image1)) 52 | 53 | 54 | def invert(image): 55 | """Inverts the image pixels.""" 56 | return 255 - image 57 | 58 | 59 | def solarize(image, threshold=128): 60 | # For each pixel in the image, select the pixel 61 | # if the value is less than the threshold. 62 | # Otherwise, subtract 255 from the pixel. 63 | return tf.where(image < threshold, image, invert(image)) 64 | 65 | 66 | def solarize_add(image, addition=0, threshold=128): 67 | # For each pixel in the image less than threshold 68 | # we add 'addition' amount to it and then clip the 69 | # pixel value to be between 0 and 255. The value 70 | # of 'addition' is between -128 and 128. 71 | added_image = tf.cast(image, tf.int64) + addition 72 | added_image = to_uint8(added_image) 73 | return tf.where(image < threshold, added_image, image) 74 | 75 | 76 | def color(image, factor): 77 | """Equivalent of PIL Color.""" 78 | degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) 79 | return blend(degenerate, image, factor) 80 | 81 | 82 | def contrast(image, factor): 83 | """Equivalent of PIL Contrast.""" 84 | degenerate = tf.image.rgb_to_grayscale(image) 85 | # Cast before calling tf.histogram. 86 | degenerate = tf.cast(degenerate, tf.int32) 87 | 88 | # Compute the grayscale histogram, then compute the mean pixel value, 89 | # and create a constant image size of that value. Use that as the 90 | # blending degenerate target of the original image. 91 | hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) 92 | mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 93 | degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean 94 | degenerate = tf.image.grayscale_to_rgb(to_uint8(degenerate)) 95 | return blend(degenerate, image, factor) 96 | 97 | 98 | def brightness(image, factor): 99 | """Equivalent of PIL Brightness.""" 100 | degenerate = tf.zeros_like(image) 101 | return blend(degenerate, image, factor) 102 | 103 | 104 | def sharpness(image, factor): 105 | """Implements Sharpness function from PIL using TF ops.""" 106 | orig_image = image 107 | image = to_float(image) 108 | image_channels = image.shape[-1] 109 | # Make image 4D for conv operation. 110 | image = tf.expand_dims(image, 0) 111 | # SMOOTH PIL Kernel. 112 | kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]) / 13. 113 | # Tile across channel dimension. 114 | kernel = tf.tile(kernel, [1, 1, image_channels, 1]) 115 | degenerate = tf.nn.depthwise_conv2d(image, kernel, strides=[1, 1, 1, 1], padding='VALID') 116 | degenerate = tf.squeeze(to_uint8(degenerate), [0]) 117 | 118 | # For the borders of the resulting image, fill in the values of the 119 | # original image. 120 | mask = tf.ones_like(degenerate) 121 | padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) 122 | padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) 123 | result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) 124 | 125 | # Blend the final result. 126 | return blend(result, orig_image, factor) 127 | 128 | 129 | def posterize(image, num_bits): 130 | """Equivalent of PIL Posterize.""" 131 | shift = 8 - num_bits 132 | return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) 133 | 134 | 135 | def posterize2(image, num_bits): 136 | """Reduces the number of bits used to represent an `image` 137 | for each color channel. 138 | Args: 139 | image: An int or float tensor of shape `[height, width, num_channels]`. 140 | num_bits: A 0-D int tensor or integer value representing number of bits. 141 | Returns: 142 | A tensor with same shape and type as that of `image`. 143 | """ 144 | num_bits = tf.cast(num_bits, tf.int32) 145 | mask = tf.cast(2 ** (8 - num_bits) - 1, tf.uint8) 146 | mask = tf.bitwise.invert(mask) 147 | 148 | posterized_image = tf.bitwise.bitwise_and(image, mask) 149 | return posterized_image 150 | 151 | 152 | def rotate(image, degrees, fill_value, interpolation='BILINEAR'): 153 | """Rotates the image by degrees either clockwise or counterclockwise. 154 | 155 | Args: 156 | image: An image Tensor of type uint8. 157 | degrees: Float, a scalar angle in degrees to rotate all images by. If degrees is positive the image 158 | will be rotated clockwise otherwise it will be rotated counterclockwise. 159 | fill_value: A one or three value 1D tensor to fill empty pixels caused by the rotate operation. 160 | interpolation: Interpolation method 161 | Returns: 162 | The rotated version of image. 163 | """ 164 | # Convert from degrees to radians. 165 | degrees_to_radians = math.pi / 180.0 166 | radians = degrees * degrees_to_radians 167 | 168 | # In practice, we should randomize the rotation degrees by flipping 169 | # it negatively half the time, but that's done on 'degrees' outside 170 | # of the function. 171 | image = tfi.rotate(wrap(image), radians, interpolation=interpolation) 172 | return unwrap(image, fill_value) 173 | 174 | 175 | def translate_x(image, pixels, fill_value): 176 | """Equivalent of PIL Translate in X dimension.""" 177 | image = tfi.translate(wrap(image), [-pixels, 0]) 178 | return unwrap(image, fill_value) 179 | 180 | 181 | def translate_y(image, pixels, fill_value): 182 | """Equivalent of PIL Translate in Y dimension.""" 183 | image = tfi.translate(wrap(image), [0, -pixels]) 184 | return unwrap(image, fill_value) 185 | 186 | 187 | def shear_x(image, level, fill_value): 188 | """Equivalent of PIL Shearing in X dimension.""" 189 | # Shear parallel to x axis is a projective transform 190 | # with a matrix form of: 191 | # [1 level 192 | # 0 1]. 193 | image = tfi.transform(wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) 194 | return unwrap(image, fill_value) 195 | 196 | 197 | def shear_y(image, level, fill_value): 198 | """Equivalent of PIL Shearing in Y dimension.""" 199 | # Shear parallel to y axis is a projective transform 200 | # with a matrix form of: 201 | # [1 0 202 | # level 1]. 203 | image = tfi.transform(wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) 204 | return unwrap(image, fill_value) 205 | 206 | 207 | def autotone(image): 208 | """Implements autotone (per channel autocontrast) 209 | Args: 210 | image: A 3D uint8 tensor. 211 | Returns: 212 | The image after it has had autocontrast applied to it and will be of type uint8. 213 | """ 214 | 215 | def scale_channel(chan): 216 | """Scale the 2D image using the autocontrast rule.""" 217 | # A possibly cheaper version can be done using cumsum/unique_with_counts 218 | # over the histogram values, rather than iterating over the entire image. 219 | # to compute mins and maxes. 220 | lo = to_float(tf.reduce_min(chan)) 221 | hi = to_float(tf.reduce_max(chan)) 222 | 223 | # Scale the image, making the lowest value 0 and the highest value 255. 224 | def scale_values(im): 225 | scale = 255.0 / (hi - lo) 226 | offset = -lo * scale 227 | im = to_float(im) * scale + offset 228 | return to_uint8(im) 229 | 230 | result = tf.cond(hi > lo, lambda: scale_values(chan), lambda: chan) 231 | return result 232 | 233 | # Assumes RGB for now. Scales each channel independently and then stacks the result. 234 | s1 = scale_channel(image[:, :, 0]) 235 | s2 = scale_channel(image[:, :, 1]) 236 | s3 = scale_channel(image[:, :, 2]) 237 | image = tf.stack([s1, s2, s3], 2) 238 | return image 239 | 240 | 241 | def autocontrast(image): 242 | """ Normalizes `image` contrast by remapping the `image` histogram such 243 | that the brightest pixel becomes 1.0 (float) / 255 (unsigned int) and 244 | darkest pixel becomes 0. 245 | Args: 246 | image: An int or float tensor of shape `[height, width, num_channels]`. 247 | Returns: 248 | A tensor with same shape and type as that of `image`. 249 | """ 250 | min_val = tf.reduce_min(image, axis=[0, 1]) 251 | max_val = tf.reduce_max(image, axis=[0, 1]) 252 | norm_image = to_float(image - min_val) / to_float(max_val - min_val) 253 | return to_uint8(norm_image) 254 | 255 | 256 | def autocontrast_or_tone(image): 257 | """ 258 | """ 259 | choice = tf.cast(tf.floor(tf.random.uniform([], dtype=tf.float32) + 0.5), tf.bool) 260 | augmented_image = tf.cond( 261 | choice, 262 | lambda: autocontrast(image), 263 | lambda: autotone(image)) 264 | return augmented_image 265 | -------------------------------------------------------------------------------- /jeffnet/data/tf_imagenet_data.py: -------------------------------------------------------------------------------- 1 | """ Tensorflow ImageNet Data Pipeline 2 | 3 | This code is based on an earlier version of 4 | https://github.com/google/objax/blob/master/examples/image_classification/imagenet_resnet50_data.py 5 | 6 | Original copyrights below. Modifications by Ross Wightman. 7 | """ 8 | 9 | # Copyright 2020 Google LLC 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # https://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | 23 | """Imagenet dataset reader. 24 | 25 | Code based on https://github.com/deepmind/dm-haiku/blob/master/examples/imagenet/dataset.py 26 | """ 27 | 28 | import enum 29 | from typing import Optional, Sequence, Tuple 30 | 31 | import jax 32 | import numpy as np 33 | import tensorflow.compat.v1 as tf 34 | import tensorflow_datasets as tfds 35 | 36 | MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) 37 | STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255) 38 | 39 | IMAGE_SIZE = 224 40 | IMAGE_PADDING_FOR_CROP = 32 41 | 42 | 43 | class Split(enum.Enum): 44 | """Imagenet dataset split.""" 45 | TRAIN = 1 46 | TEST = 2 47 | 48 | @property 49 | def num_examples(self): 50 | return {Split.TRAIN: 1281167, Split.TEST: 50000}[self] 51 | 52 | 53 | def _to_tfds_split(split: Split) -> tfds.Split: 54 | """Returns the TFDS split appropriately sharded.""" 55 | if split == Split.TRAIN: 56 | return tfds.Split.TRAIN 57 | else: 58 | assert split == Split.TEST 59 | return tfds.Split.VALIDATION 60 | 61 | 62 | def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int, int]: 63 | """Returns [start, end) for the given shard index.""" 64 | assert shard_index < num_shards 65 | arange = np.arange(split.num_examples) 66 | shard_range = np.array_split(arange, num_shards)[shard_index] 67 | start, end = shard_range[0], (shard_range[-1] + 1) 68 | return start, end 69 | 70 | 71 | def load( 72 | split: Split, 73 | is_training: bool, 74 | batch_dims: Sequence[int], 75 | image_size: int = IMAGE_SIZE, 76 | chw: bool = False, 77 | dataset_name='imagenet2012:5.0.0', 78 | mean: Optional[Tuple[float]] = None, 79 | std: Optional[Tuple[float]] = None, 80 | interpolation: str = 'bicubic', 81 | tfds_data_dir: Optional[str] = None, 82 | ): 83 | mean = MEAN_RGB if mean is None else mean 84 | std = STDDEV_RGB if std is None else std 85 | """Loads the given split of the dataset.""" 86 | if is_training: 87 | start, end = _shard(split, jax.host_id(), jax.host_count()) 88 | else: 89 | start, end = _shard(split, 0, 1) 90 | tfds_split = tfds.core.ReadInstruction(_to_tfds_split(split), from_=start, to=end, unit='abs') 91 | ds = tfds.load( 92 | dataset_name, 93 | split=tfds_split, 94 | decoders={'image': tfds.decode.SkipDecoding()}, 95 | data_dir=tfds_data_dir) 96 | 97 | total_batch_size = np.prod(batch_dims) 98 | 99 | options = ds.options() 100 | options.experimental_threading.private_threadpool_size = 16 101 | options.experimental_threading.max_intra_op_parallelism = 1 102 | if is_training: 103 | options.experimental_deterministic = False 104 | 105 | if is_training: 106 | ds = ds.repeat() 107 | ds = ds.shuffle(buffer_size=10 * total_batch_size, seed=0) 108 | else: 109 | if split.num_examples % total_batch_size != 0: 110 | raise ValueError(f'Test set size must be divisible by {total_batch_size}') 111 | num_batches = split.num_examples // total_batch_size 112 | 113 | interpolation = tf.image.ResizeMethod.BILINEAR if 'bilinear' in interpolation else tf.image.ResizeMethod.BICUBIC 114 | 115 | def preprocess(example): 116 | image = _preprocess_image( 117 | example['image'], is_training, image_size=image_size, mean=mean, std=std, interpolation=interpolation) 118 | if chw: 119 | image = tf.transpose(image, (2, 0, 1)) # transpose HWC image to CHW format 120 | label = tf.cast(example['label'], tf.int32) 121 | return {'images': image, 'labels': label} 122 | 123 | ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) 124 | 125 | for batch_size in reversed(batch_dims): 126 | ds = ds.batch(batch_size) 127 | 128 | ds = ds.prefetch(tf.data.experimental.AUTOTUNE) 129 | 130 | return tfds.as_numpy(ds), num_batches 131 | 132 | 133 | def normalize_image_for_view(image, mean=MEAN_RGB, std=STDDEV_RGB): 134 | """Normalizes dataset image into the format for viewing.""" 135 | image *= np.reshape(mean, (3, 1, 1)) 136 | image += np.reshape(std, (3, 1, 1)) 137 | image = np.transpose(image, (1, 2, 0)) 138 | return image.clip(0, 255).round().astype('uint8') 139 | 140 | 141 | def _preprocess_image( 142 | image_bytes: tf.Tensor, 143 | is_training: bool, 144 | image_size: int = IMAGE_SIZE, 145 | mean=MEAN_RGB, 146 | std=STDDEV_RGB, 147 | interpolation=tf.image.ResizeMethod.BICUBIC, 148 | ) -> tf.Tensor: 149 | """Returns processed and resized images.""" 150 | if is_training: 151 | image = _decode_and_random_crop(image_bytes, image_size=image_size) 152 | image = tf.image.random_flip_left_right(image) 153 | else: 154 | image = _decode_and_center_crop(image_bytes, image_size=image_size) 155 | assert image.dtype == tf.uint8 156 | # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without 157 | # clamping overshoots. This means values returned will be outside the range 158 | # [0.0, 255.0]. 159 | image = tf.image.resize(image, [image_size, image_size], interpolation) 160 | image = _normalize_image(image, mean=mean, std=std) 161 | return image 162 | 163 | 164 | def _normalize_image(image: tf.Tensor, mean=MEAN_RGB, std=STDDEV_RGB) -> tf.Tensor: 165 | """Normalize the image to zero mean and unit variance.""" 166 | image -= tf.constant(mean, shape=[1, 1, 3], dtype=image.dtype) 167 | image /= tf.constant(std, shape=[1, 1, 3], dtype=image.dtype) 168 | return image 169 | 170 | 171 | def _distorted_bounding_box_crop( 172 | image_bytes: tf.Tensor, 173 | jpeg_shape: tf.Tensor, 174 | bbox: tf.Tensor, 175 | min_object_covered: float, 176 | aspect_ratio_range: Tuple[float, float], 177 | area_range: Tuple[float, float], 178 | max_attempts: int, 179 | ) -> tf.Tensor: 180 | """Generates cropped_image using one of the bboxes randomly distorted.""" 181 | bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box( 182 | jpeg_shape, 183 | bounding_boxes=bbox, 184 | min_object_covered=min_object_covered, 185 | aspect_ratio_range=aspect_ratio_range, 186 | area_range=area_range, 187 | max_attempts=max_attempts, 188 | use_image_if_no_bounding_boxes=True) 189 | 190 | # Crop the image to the specified bounding box. 191 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 192 | target_height, target_width, _ = tf.unstack(bbox_size) 193 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 194 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 195 | return image 196 | 197 | 198 | def _decode_and_random_crop(image_bytes: tf.Tensor, image_size: int = 224) -> tf.Tensor: 199 | """Make a random crop of image.""" 200 | jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) 201 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 202 | image = _distorted_bounding_box_crop( 203 | image_bytes, 204 | jpeg_shape=jpeg_shape, 205 | bbox=bbox, 206 | min_object_covered=0.1, 207 | aspect_ratio_range=(3 / 4, 4 / 3), 208 | area_range=(0.08, 1.0), 209 | max_attempts=10) 210 | if tf.reduce_all(tf.equal(jpeg_shape, tf.shape(image))): 211 | # If the random crop failed fall back to center crop. 212 | image = _decode_and_center_crop(image_bytes, image_size=image_size, jpeg_shape=jpeg_shape) 213 | return image 214 | 215 | 216 | def _decode_and_center_crop( 217 | image_bytes: tf.Tensor, 218 | image_size: int = 224, 219 | jpeg_shape: Optional[tf.Tensor] = None, 220 | ) -> tf.Tensor: 221 | """Crops to center of image with padding then scales.""" 222 | if jpeg_shape is None: 223 | jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) 224 | image_height = jpeg_shape[0] 225 | image_width = jpeg_shape[1] 226 | 227 | padded_center_crop_size = tf.cast( 228 | (tf.cast(image_size / (image_size + IMAGE_PADDING_FOR_CROP), tf.float32) * 229 | tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) 230 | 231 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 232 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 233 | crop_window = tf.stack([offset_height, offset_width, 234 | padded_center_crop_size, padded_center_crop_size]) 235 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 236 | return image 237 | -------------------------------------------------------------------------------- /jeffnet/data/tf_input_pipeline.py: -------------------------------------------------------------------------------- 1 | """ Tensorflow ImageNet Data Pipeline 2 | 3 | ImageNet data pipeline adapted from Flax examples. This is mostly redundant wrt 4 | tf_imagenet_data.py. However I wanted to get off the ground quickly with the 5 | Flax training example scripts and there are a few changes needed to use one in 6 | each use case. 7 | 8 | Eventually there will be one Tensorflow image pipeline + dataset factor that will 9 | support RandAug/AutoAug/AugMix and other datasets. 10 | 11 | Original copyrights below. Modifications by Ross Wightman. 12 | """ 13 | 14 | # Copyright 2020 The Flax Authors. 15 | # 16 | # Licensed under the Apache License, Version 2.0 (the "License"); 17 | # you may not use this file except in compliance with the License. 18 | # You may obtain a copy of the License at 19 | # 20 | # http://www.apache.org/licenses/LICENSE-2.0 21 | # 22 | # Unless required by applicable law or agreed to in writing, software 23 | # distributed under the License is distributed on an "AS IS" BASIS, 24 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 25 | # See the License for the specific language governing permissions and 26 | # limitations under the License. 27 | 28 | """ImageNet input pipeline. 29 | """ 30 | from typing import Optional, Tuple 31 | 32 | import jax 33 | import tensorflow as tf 34 | import tensorflow_datasets as tfds 35 | from absl import logging 36 | 37 | from .tf_autoaugment import distort_image_with_randaugment 38 | from .tf_image_ops import to_float, to_uint8 39 | 40 | IMAGE_SIZE = 224 41 | CROP_PADDING = 32 42 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 43 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] 44 | 45 | 46 | def distorted_bounding_box_crop( 47 | image_bytes, 48 | bbox, 49 | min_object_covered=0.1, 50 | aspect_ratio_range=(0.75, 1.33), 51 | area_range=(0.05, 1.0), 52 | max_attempts=100): 53 | """Generates cropped_image using one of the bboxes randomly distorted. 54 | 55 | See `tf.image.sample_distorted_bounding_box` for more documentation. 56 | 57 | Args: 58 | image_bytes: `Tensor` of binary image data. 59 | bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]` 60 | where each coordinate is [0, 1) and the coordinates are arranged 61 | as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole 62 | image. 63 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 64 | area of the image must contain at least this fraction of any bounding 65 | box supplied. 66 | aspect_ratio_range: An optional list of `float`s. The cropped area of the 67 | image must have an aspect ratio = width / height within this range. 68 | area_range: An optional list of `float`s. The cropped area of the image 69 | must contain a fraction of the supplied image within in this range. 70 | max_attempts: An optional `int`. Number of attempts at generating a cropped 71 | region of the image of the specified constraints. After `max_attempts` 72 | failures, return the entire image. 73 | Returns: 74 | cropped image `Tensor` 75 | """ 76 | shape = tf.io.extract_jpeg_shape(image_bytes) 77 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 78 | shape, 79 | bounding_boxes=bbox, 80 | min_object_covered=min_object_covered, 81 | aspect_ratio_range=aspect_ratio_range, 82 | area_range=area_range, 83 | max_attempts=max_attempts, 84 | use_image_if_no_bounding_boxes=True) 85 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 86 | 87 | # Crop the image to the specified bounding box. 88 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 89 | target_height, target_width, _ = tf.unstack(bbox_size) 90 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 91 | image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 92 | return image 93 | 94 | 95 | def resize(image, image_size, interpolation=tf.image.ResizeMethod.BICUBIC, antialias=True): 96 | return tf.image.resize([image], [image_size, image_size], method=interpolation, antialias=antialias)[0] 97 | 98 | 99 | def at_least_x_are_equal(a, b, x): 100 | """At least `x` of `a` and `b` `Tensors` are equal.""" 101 | match = tf.equal(a, b) 102 | match = tf.cast(match, tf.int32) 103 | return tf.greater_equal(tf.reduce_sum(match), x) 104 | 105 | 106 | def decode_and_random_crop(image_bytes, image_size, interpolation): 107 | """Make a random crop of image_size.""" 108 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 109 | image = distorted_bounding_box_crop( 110 | image_bytes, 111 | bbox, 112 | min_object_covered=0.1, 113 | aspect_ratio_range=(3. / 4, 4. / 3.), 114 | area_range=(0.08, 1.0), 115 | max_attempts=10) 116 | original_shape = tf.io.extract_jpeg_shape(image_bytes) 117 | bad = at_least_x_are_equal(original_shape, tf.shape(image), 3) 118 | 119 | image = tf.cond( 120 | bad, 121 | lambda: decode_and_center_crop(image_bytes, image_size, interpolation), 122 | lambda: resize(image, image_size, interpolation)) 123 | 124 | return image 125 | 126 | 127 | def decode_and_center_crop(image_bytes, image_size, interpolation): 128 | """Crops to center of image with padding then scales image_size.""" 129 | shape = tf.io.extract_jpeg_shape(image_bytes) 130 | image_height = shape[0] 131 | image_width = shape[1] 132 | 133 | padded_center_crop_size = tf.cast( 134 | (image_size / (image_size + CROP_PADDING)) * tf.cast(tf.minimum(image_height, image_width), tf.float32), 135 | tf.int32) 136 | 137 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 138 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 139 | crop_window = tf.stack([offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]) 140 | image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window) 141 | image = resize(image, image_size, interpolation) 142 | 143 | return image 144 | 145 | 146 | def normalize_image(image, mean=MEAN_RGB, std=STDDEV_RGB): 147 | image -= tf.constant(mean, shape=[1, 1, 3], dtype=image.dtype) 148 | image /= tf.constant(std, shape=[1, 1, 3], dtype=image.dtype) 149 | return image 150 | 151 | 152 | def preprocess_for_train( 153 | image_bytes, 154 | dtype=tf.float32, 155 | image_size=IMAGE_SIZE, 156 | mean=MEAN_RGB, 157 | std=STDDEV_RGB, 158 | interpolation=tf.image.ResizeMethod.BICUBIC, 159 | augment_name=None, 160 | randaug_num_layers=None, 161 | randaug_magnitude=None, 162 | ): 163 | """Preprocesses the given image for training. 164 | 165 | Args: 166 | image_bytes: `Tensor` representing an image binary of arbitrary size. 167 | dtype: data type of the image. 168 | image_size: image size. 169 | 170 | Returns: 171 | A preprocessed image `Tensor`. 172 | """ 173 | image = decode_and_random_crop(image_bytes, image_size, interpolation) 174 | image = tf.image.random_flip_left_right(image) 175 | image = tf.reshape(image, [image_size, image_size, 3]) 176 | 177 | if augment_name: 178 | logging.info('Apply AutoAugment policy %s', augment_name) 179 | fill_value = [int(round(v)) for v in MEAN_RGB] 180 | # if augment_name == 'autoaugment': 181 | # logging.info('Apply AutoAugment policy %s', augment_name) 182 | # image = distort_image_with_autoaugment(image, 'v0') 183 | image = to_uint8(image, saturate=False) 184 | if augment_name == 'randaugment': 185 | image = distort_image_with_randaugment( 186 | image, randaug_num_layers, randaug_magnitude, fill_value=fill_value) 187 | else: 188 | raise ValueError('Invalid value for augment_name: %s' % augment_name) 189 | image = to_float(image) # float32, [0., 255.) 190 | 191 | image = normalize_image(image, mean=mean, std=std) 192 | image = tf.image.convert_image_dtype(image, dtype=dtype) 193 | return image 194 | 195 | 196 | def preprocess_for_eval( 197 | image_bytes, 198 | dtype=tf.float32, 199 | image_size=IMAGE_SIZE, 200 | mean=MEAN_RGB, 201 | std=STDDEV_RGB, 202 | interpolation=tf.image.ResizeMethod.BICUBIC, 203 | ): 204 | """Preprocesses the given image for evaluation. 205 | 206 | Args: 207 | image_bytes: `Tensor` representing an image binary of arbitrary size. 208 | dtype: data type of the image. 209 | image_size: image size. 210 | 211 | Returns: 212 | A preprocessed image `Tensor`. 213 | """ 214 | image = decode_and_center_crop(image_bytes, image_size, interpolation) 215 | image = tf.reshape(image, [image_size, image_size, 3]) 216 | image = normalize_image(image, mean=mean, std=std) 217 | image = tf.image.convert_image_dtype(image, dtype=dtype) 218 | return image 219 | 220 | 221 | def create_split( 222 | dataset_builder: tfds.core.DatasetBuilder, 223 | batch_size: int, 224 | train: bool = True, 225 | half_precision: bool = False, 226 | image_size: int = IMAGE_SIZE, 227 | mean: Optional[Tuple[float]] = None, 228 | std: Optional[Tuple[float]] = None, 229 | interpolation: str = 'bicubic', 230 | augment_name: Optional[str] = None, 231 | randaug_num_layers: Optional[int] = None, 232 | randaug_magnitude: Optional[int] = None, 233 | cache: bool = False, 234 | no_repeat: bool = False, 235 | ): 236 | """Creates a split from the ImageNet dataset using TensorFlow Datasets. 237 | 238 | Args: 239 | dataset_builder: TFDS dataset builder for ImageNet. 240 | batch_size: the batch size returned by the data pipeline. 241 | train: Whether to load the train or evaluation split. 242 | half_precision: convert image datatype to half-precision 243 | image_size: The target size of the images (default: 224). 244 | mean: image dataset mean 245 | std: image dataset std-dev 246 | interpolation: interpolation method to use for image resize (default: 'bicubic') 247 | cache: Whether to cache the dataset (default: False). 248 | no_repeat: disable repeat iter for evaluation 249 | Returns: 250 | A `tf.data.Dataset`. 251 | """ 252 | mean = mean or MEAN_RGB 253 | std = std or STDDEV_RGB 254 | interpolation = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR 255 | platform = jax.local_devices()[0].platform 256 | if half_precision: 257 | if platform == 'tpu': 258 | input_dtype = tf.bfloat16 259 | else: 260 | input_dtype = tf.float16 261 | else: 262 | input_dtype = tf.float32 263 | 264 | if train: 265 | data_size = dataset_builder.info.splits['train'].num_examples 266 | split = 'train' 267 | else: 268 | data_size = dataset_builder.info.splits['validation'].num_examples 269 | split = 'validation' 270 | split_size = data_size // jax.host_count() 271 | start = jax.host_id() * split_size 272 | split = split + '[{}:{}]'.format(start, start + split_size) 273 | 274 | def _decode_example(example): 275 | if train: 276 | image = preprocess_for_train( 277 | example['image'], input_dtype, image_size, mean, std, interpolation, 278 | augment_name=augment_name, 279 | randaug_num_layers=randaug_num_layers, 280 | randaug_magnitude=randaug_magnitude) 281 | else: 282 | image = preprocess_for_eval(example['image'], input_dtype, image_size, mean, std, interpolation) 283 | return {'image': image, 'label': example['label']} 284 | 285 | ds = dataset_builder.as_dataset( 286 | split=split, 287 | decoders={ 288 | 'image': tfds.decode.SkipDecoding() 289 | } 290 | ) 291 | ds.options().experimental_threading.private_threadpool_size = 16 292 | ds.options().experimental_threading.max_intra_op_parallelism = 1 293 | 294 | if cache: 295 | ds = ds.cache() 296 | 297 | if train: 298 | ds = ds.repeat() 299 | ds = ds.shuffle(16 * batch_size, seed=0) 300 | 301 | ds = ds.map(_decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE) 302 | ds = ds.batch(batch_size, drop_remainder=True) 303 | 304 | if not train and not no_repeat: 305 | ds = ds.repeat() 306 | 307 | ds = ds.prefetch(10) 308 | 309 | return ds 310 | -------------------------------------------------------------------------------- /jeffnet/linen/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficientnet_linen import EfficientNet, create_model 2 | from .ema_state import EmaState 3 | -------------------------------------------------------------------------------- /jeffnet/linen/blocks_linen.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet, MobileNetV3, etc Blocks for Flax Linen 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from typing import Any, Callable, Union, Optional 6 | 7 | from flax import linen as nn 8 | import jax.numpy as jnp 9 | 10 | from jeffnet.common.block_utils import * 11 | from .layers import conv2d, batchnorm2d, get_act_fn, linear, DropPath, Dropout, MixedConv 12 | 13 | Dtype = Any 14 | ModuleDef = Any 15 | 16 | 17 | def create_conv(features, kernel_size, conv_layer=None, **kwargs): 18 | """ Select a convolution implementation based on arguments 19 | Creates and returns one of Conv, MixedConv, or CondConv (TODO) 20 | """ 21 | conv_layer = conv2d if conv_layer is None else conv_layer 22 | if isinstance(kernel_size, list): 23 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 24 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 25 | # We're going to use only lists for defining the MixedConv2d kernel groups, 26 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 27 | m = MixedConv(features, kernel_size, conv_layer=conv_layer, **kwargs) 28 | else: 29 | depthwise = kwargs.pop('depthwise', False) 30 | groups = features if depthwise else kwargs.pop('groups', 1) 31 | # if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 32 | # m = CondConv(features, kernel_size, groups=groups, conv_layer=conv_layer, **kwargs) 33 | # else: 34 | m = conv_layer(features, kernel_size, groups=groups, **kwargs) 35 | return m 36 | 37 | 38 | class SqueezeExcite(nn.Module): 39 | num_features: int # features at input to containing block 40 | block_features: int = None # input feature count of containing block 41 | se_ratio: float = 0.25 42 | divisor: int = 1 43 | reduce_from_block: bool = True # calc se reduction from containing block's input features 44 | 45 | dtype: Dtype = jnp.float32 46 | conv_layer: ModuleDef = conv2d 47 | act_fn: Callable = nn.relu 48 | bound_act_fn: Optional[Callable] = None # override the passed in act_fn from parent with a bound fn 49 | gate_fn: Callable = nn.sigmoid 50 | 51 | @nn.compact 52 | def __call__(self, x): 53 | x_se = jnp.asarray(x, jnp.float32) 54 | x_se = x_se.mean((1, 2), keepdims=True) 55 | x_se = jnp.asarray(x_se, self.dtype) 56 | base_features = self.block_features if self.block_features and self.reduce_from_block else self.num_features 57 | reduce_features: int = make_divisible(base_features * self.se_ratio, self.divisor) 58 | act_fn = self.bound_act_fn if self.bound_act_fn is not None else self.act_fn 59 | x_se = self.conv_layer(reduce_features, 1, stride=1, bias=True, name='reduce')(x_se) 60 | x_se = act_fn(x_se) 61 | x_se = self.conv_layer(self.num_features, 1, stride=1, bias=True, name='expand')(x_se) 62 | return x * self.gate_fn(x_se) 63 | 64 | 65 | class ConvBnAct(nn.Module): 66 | out_features: int 67 | in_features: int = None # note used, currently for generic args support 68 | kernel_size: int = 3 69 | stride: int = 1 70 | dilation: int = 1 71 | pad_type: str = 'LIKE' 72 | 73 | conv_layer: ModuleDef = conv2d 74 | norm_layer: ModuleDef = batchnorm2d 75 | act_fn: Callable = nn.relu 76 | 77 | @nn.compact 78 | def __call__(self, x, training: bool): 79 | x = self.conv_layer( 80 | self.out_features, self.kernel_size, stride=self.stride, 81 | dilation=self.dilation, padding=self.pad_type, name='conv')(x) 82 | x = self.norm_layer(name='bn')(x, training=training) 83 | x = self.act_fn(x) 84 | return x 85 | 86 | 87 | class DepthwiseSeparable(nn.Module): 88 | """ DepthwiseSeparable block 89 | Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion 90 | (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. 91 | """ 92 | 93 | in_features: int 94 | out_features: int 95 | dw_kernel_size: int = 3 96 | pw_kernel_size: int = 1 97 | stride: int = 1 98 | dilation: int = 1 99 | pad_type: str = 'LIKE' 100 | noskip: bool = False 101 | pw_act: bool = False 102 | se_ratio: float = 0. 103 | drop_path_rate: float = 0. 104 | 105 | conv_layer: ModuleDef = conv2d 106 | norm_layer: ModuleDef = batchnorm2d 107 | se_layer: ModuleDef = SqueezeExcite 108 | act_fn: Callable = nn.relu 109 | 110 | @nn.compact 111 | def __call__(self, x, training: bool): 112 | shortcut = x 113 | 114 | x = create_conv( 115 | self.in_features, self.dw_kernel_size, stride=self.stride, dilation=self.dilation, 116 | padding=self.pad_type, depthwise=True, conv_layer=self.conv_layer, name='conv_dw')(x) 117 | x = self.norm_layer(name='bn_dw')(x, training=training) 118 | x = self.act_fn(x) 119 | 120 | if self.se_layer is not None and self.se_ratio > 0: 121 | x = self.se_layer( 122 | num_features=self.in_features, se_ratio=self.se_ratio, 123 | conv_layer=self.conv_layer, act_fn=self.act_fn, name='se')(x) 124 | 125 | x = create_conv( 126 | self.out_features, self.pw_kernel_size, padding=self.pad_type, 127 | conv_layer=self.conv_layer, name='conv_pw')(x) 128 | x = self.norm_layer(name='bn_pw')(x, training=training) 129 | if self.pw_act: 130 | x = self.act_fn(x) 131 | 132 | if (self.stride == 1 and self.in_features == self.out_features) and not self.noskip: 133 | x = DropPath(self.drop_path_rate)(x, training=training) 134 | x = x + shortcut 135 | return x 136 | 137 | 138 | class InvertedResidual(nn.Module): 139 | """ Inverted residual block w/ optional SE and CondConv routing""" 140 | 141 | in_features: int 142 | out_features: int 143 | exp_kernel_size: int = 1 144 | dw_kernel_size: int = 3 145 | pw_kernel_size: int = 1 146 | stride: int = 1 147 | dilation: int = 1 148 | pad_type: str = 'LIKE' 149 | noskip: bool = False 150 | exp_ratio: float = 1.0 151 | se_ratio: float = 0. 152 | drop_path_rate: float = 0. 153 | 154 | conv_layer: ModuleDef = conv2d 155 | norm_layer: ModuleDef = batchnorm2d 156 | se_layer: ModuleDef = SqueezeExcite 157 | act_fn: Callable = nn.relu 158 | 159 | @nn.compact 160 | def __call__(self, x, training: bool): 161 | shortcut = x 162 | 163 | features = make_divisible(self.in_features * self.exp_ratio) 164 | 165 | # Point-wise expansion 166 | if self.exp_ratio > 1.: 167 | x = create_conv( 168 | features, self.exp_kernel_size, padding=self.pad_type, conv_layer=self.conv_layer, name='conv_exp')(x) 169 | x = self.norm_layer(name='bn_exp')(x, training=training) 170 | x = self.act_fn(x) 171 | 172 | x = create_conv( 173 | features, self.dw_kernel_size, stride=self.stride, dilation=self.dilation, 174 | padding=self.pad_type, depthwise=True, conv_layer=self.conv_layer, name='conv_dw')(x) 175 | x = self.norm_layer(name='bn_dw')(x, training=training) 176 | x = self.act_fn(x) 177 | 178 | if self.se_layer is not None and self.se_ratio > 0: 179 | x = self.se_layer( 180 | num_features=features, block_features=self.in_features, se_ratio=self.se_ratio, 181 | conv_layer=self.conv_layer, act_fn=self.act_fn, name='se')(x) 182 | 183 | x = create_conv( 184 | self.out_features, self.pw_kernel_size, padding=self.pad_type, 185 | conv_layer=self.conv_layer, name='conv_pwl')(x) 186 | x = self.norm_layer(name='bn_pwl')(x, training=training) 187 | 188 | if (self.stride == 1 and self.in_features == self.out_features) and not self.noskip: 189 | x = DropPath(self.drop_path_rate)(x, training=training) 190 | x = x + shortcut 191 | return x 192 | 193 | 194 | class EdgeResidual(nn.Module): 195 | """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" 196 | in_features: int 197 | out_features: int 198 | exp_kernel_size: int = 1 199 | dw_kernel_size: int = 3 200 | pw_kernel_size: int = 1 201 | stride: int = 1 202 | dilation: int = 1 203 | pad_type: str = 'LIKE' 204 | noskip: bool = False 205 | exp_ratio: float = 1.0 206 | se_ratio: float = 0. 207 | drop_path_rate: float = 0. 208 | 209 | conv_layer: ModuleDef = conv2d 210 | norm_layer: ModuleDef = batchnorm2d 211 | se_layer: ModuleDef = SqueezeExcite 212 | act_fn: Callable = nn.relu 213 | 214 | @nn.compact 215 | def __call__(self, x, training: bool): 216 | shortcut = x 217 | 218 | # Unlike other blocks, not using the arch def for in_features since it's not reliable for Edge 219 | features = make_divisible(x.shape[-1] * self.exp_ratio) 220 | 221 | # Point-wise expansion 222 | x = create_conv( 223 | features, self.exp_kernel_size, padding=self.pad_type, conv_layer=self.conv_layer, name='conv_exp')(x) 224 | x = self.norm_layer(name='bn_exp')(x, training=training) 225 | x = self.act_fn(x) 226 | 227 | if self.se_layer is not None and self.se_ratio > 0: 228 | x = self.se_layer( 229 | num_features=features, block_features=self.in_features, se_ratio=self.se_ratio, 230 | conv_layer=self.conv_layer, act_fn=self.act_fn, name='se')(x) 231 | 232 | x = create_conv( 233 | self.out_features, self.pw_kernel_size, stride=self.stride, dilation=self.dilation, 234 | padding=self.pad_type, conv_layer=self.conv_layer, name='conv_pwl')(x) 235 | x = self.norm_layer(name='bn_pwl')(x, training=training) 236 | 237 | if (self.stride == 1 and self.in_features == self.out_features) and not self.noskip: 238 | x = DropPath(self.drop_path_rate)(x, training=training) 239 | x = x + shortcut 240 | return x 241 | 242 | 243 | class Head(nn.Module): 244 | """ Standard Head from EfficientNet, MixNet, MNasNet, MobileNetV2, etc. """ 245 | num_features: int 246 | num_classes: int = 1000 247 | global_pool: str = 'avg' # FIXME support diff pooling 248 | drop_rate: float = 0. 249 | 250 | dtype: Dtype = jnp.float32 251 | conv_layer: ModuleDef = conv2d 252 | norm_layer: ModuleDef = batchnorm2d 253 | linear_layer: ModuleDef = linear 254 | act_fn: Callable = nn.relu 255 | 256 | @nn.compact 257 | def __call__(self, x, training: bool): 258 | x = self.conv_layer(self.num_features, 1, name='conv_pw')(x) 259 | x = self.norm_layer(name='bn')(x, training=training) 260 | x = self.act_fn(x) 261 | if self.global_pool == 'avg': 262 | x = jnp.asarray(x, jnp.float32) 263 | x = x.mean((1, 2)) 264 | x = jnp.asarray(x, self.dtype) 265 | x = Dropout(rate=self.drop_rate)(x, training=training) 266 | if self.num_classes > 0: 267 | x = self.linear_layer(self.num_classes, bias=True, name='classifier')(x) 268 | return x 269 | 270 | 271 | class EfficientHead(nn.Module): 272 | """ EfficientHead for MobileNetV3. """ 273 | num_features: int 274 | num_classes: int = 1000 275 | global_pool: str = 'avg' # FIXME support diff pooling 276 | drop_rate: float = 0. 277 | 278 | dtype: Dtype = jnp.float32 279 | conv_layer: ModuleDef = conv2d 280 | norm_layer: ModuleDef = None # ignored, to keep calling code clean 281 | linear_layer: ModuleDef = linear 282 | act_fn: Callable = nn.relu 283 | 284 | @nn.compact 285 | def __call__(self, x, training: bool): 286 | if self.global_pool == 'avg': 287 | x = jnp.asarray(x, jnp.float32) 288 | x = x.mean((1, 2), keepdims=True) 289 | x = jnp.asarray(x, self.dtype) 290 | x = self.conv_layer(self.num_features, 1, bias=True, name='conv_pw')(x) 291 | x = self.act_fn(x) 292 | x = Dropout(rate=self.drop_rate)(x, training=training) 293 | if self.num_classes > 0: 294 | x = self.linear_layer(self.num_classes, bias=True, name='classifier')(x) 295 | return x 296 | 297 | 298 | def chan_to_features(kwargs): 299 | in_chs = kwargs.pop('in_chs', None) 300 | if in_chs is not None: 301 | kwargs['in_features'] = in_chs 302 | out_chs = kwargs.pop('out_chs', None) 303 | if out_chs is not None: 304 | kwargs['out_features'] = out_chs 305 | return kwargs 306 | 307 | 308 | class BlockFactory: 309 | 310 | @staticmethod 311 | def CondConv(stage_idx, block_idx, **block_args): 312 | assert False, "Not currently impl" 313 | 314 | @staticmethod 315 | def InvertedResidual(stage_idx, block_idx, **block_args): 316 | block_args = chan_to_features(block_args) 317 | return InvertedResidual(**block_args, name=f'blocks_{stage_idx}_{block_idx}') 318 | 319 | @staticmethod 320 | def DepthwiseSeparable(stage_idx, block_idx, **block_args): 321 | block_args = chan_to_features(block_args) 322 | return DepthwiseSeparable(**block_args, name=f'blocks_{stage_idx}_{block_idx}') 323 | 324 | @staticmethod 325 | def EdgeResidual(stage_idx, block_idx, **block_args): 326 | block_args = chan_to_features(block_args) 327 | block_args.pop('fake_in_chs') # not needed for Linen @nn.compact defs, we can access the real in_features 328 | return EdgeResidual(**block_args, name=f'blocks_{stage_idx}_{block_idx}') 329 | 330 | @staticmethod 331 | def ConvBnAct(stage_idx, block_idx, **block_args): 332 | block_args.pop('drop_path_rate', None) 333 | block_args.pop('se_layer', None) 334 | block_args = chan_to_features(block_args) 335 | return ConvBnAct(**block_args, name=f'blocks_{stage_idx}_{block_idx}') 336 | 337 | @staticmethod 338 | def get_act_fn(act_fn: Union[str, Callable]): 339 | return get_act_fn(act_fn) if isinstance(act_fn, str) else act_fn 340 | -------------------------------------------------------------------------------- /jeffnet/linen/efficientnet_linen.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet (Flax Linen) Model and Factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | import re 6 | from typing import Any, Callable, Sequence, Dict 7 | from functools import partial 8 | 9 | from flax import linen as nn 10 | import flax.linen.initializers as initializers 11 | import jax 12 | import jax.numpy as jnp 13 | 14 | from jeffnet.common import round_features, get_model_cfg, EfficientNetBuilder 15 | from .helpers import load_pretrained 16 | from .layers import conv2d, linear, batchnorm2d, get_act_fn 17 | from .blocks_linen import ConvBnAct, SqueezeExcite, BlockFactory, Head, EfficientHead 18 | 19 | ModuleDef = Any 20 | Dtype = Any 21 | effnet_normal = partial(initializers.variance_scaling, 2.0, "fan_out", "normal") 22 | effnet_uniform = partial(initializers.variance_scaling, 1.0/3, "fan_out", "uniform") 23 | 24 | 25 | class EfficientNet(nn.Module): 26 | """ EfficientNet (and other MBConvNets) 27 | * EfficientNet B0-B8, L2 28 | * EfficientNet-EdgeTPU 29 | * EfficientNet-Lite 30 | * MixNet S, M, L, XL 31 | * MobileNetV3 32 | * MobileNetV2 33 | * MnasNet A1, B1, and small 34 | * FBNet C 35 | * Single-Path NAS Pixel1 36 | """ 37 | 38 | # model config 39 | block_defs: Sequence[Sequence[Dict]] 40 | stem_size: int = 32 41 | feat_multiplier: float = 1.0 42 | feat_divisor: int = 8 43 | feat_min: int = None 44 | fix_stem: bool = False 45 | pad_type: str = 'LIKE' 46 | output_stride: int = 32 47 | 48 | # classifier / head config 49 | efficient_head: bool = False 50 | num_classes: int = 1000 51 | num_features: int = 1280 52 | global_pool: str = 'avg' 53 | 54 | # pretrained / data config 55 | default_cfg: Dict = None 56 | 57 | # regularization 58 | drop_rate: float = 0. 59 | drop_path_rate: float = 0. 60 | 61 | dtype: Dtype = jnp.float32 62 | conv_layer: ModuleDef = conv2d 63 | norm_layer: ModuleDef = batchnorm2d 64 | se_layer: ModuleDef = SqueezeExcite 65 | act_fn: Callable = nn.relu 66 | 67 | @nn.compact 68 | def __call__(self, x, training: bool): 69 | # add dtype binding to layers 70 | # FIXME is there better way to handle dtype? Passing dtype to all child Modules also seems messy... 71 | lkwargs = dict( 72 | conv_layer=partial(self.conv_layer, dtype=self.dtype, kernel_init=effnet_normal()), 73 | norm_layer=partial(self.norm_layer, dtype=self.dtype), 74 | act_fn=self.act_fn) 75 | se_layer = partial(self.se_layer, dtype=self.dtype) 76 | linear_layer = partial(linear, dtype=self.dtype, kernel_init=effnet_uniform()) 77 | 78 | stem_features = self.stem_size 79 | if not self.fix_stem: 80 | stem_features = round_features(self.stem_size, self.feat_multiplier, self.feat_divisor, self.feat_min) 81 | x = ConvBnAct( 82 | out_features=stem_features, kernel_size=3, stride=2, pad_type=self.pad_type, 83 | **lkwargs, name='stem')(x, training=training) 84 | 85 | blocks = EfficientNetBuilder( 86 | stem_features, self.block_defs, BlockFactory(), 87 | feat_multiplier=self.feat_multiplier, feat_divisor=self.feat_divisor, feat_min=self.feat_min, 88 | output_stride=self.output_stride, pad_type=self.pad_type, se_layer=se_layer, **lkwargs, 89 | drop_path_rate=self.drop_path_rate)() 90 | for stage in blocks: 91 | for block in stage: 92 | x = block(x, training=training) 93 | 94 | head_layer = EfficientHead if self.efficient_head else Head 95 | x = head_layer( 96 | num_features=self.num_features, num_classes=self.num_classes, drop_rate=self.drop_rate, 97 | **lkwargs, dtype=self.dtype, linear_layer=linear_layer, name='head')(x, training=training) 98 | return x 99 | 100 | 101 | def _filter(state_dict): 102 | """ convert state dict keys from pytorch style origins to flax linen """ 103 | out = {} 104 | p_blocks = re.compile(r'blocks\.(\d)\.(\d)') 105 | p_bn_scale = re.compile(r'bn(\w*)\.weight') 106 | for k, v in state_dict.items(): 107 | k = p_blocks.sub(r'blocks_\1_\2', k) 108 | k = p_bn_scale.sub(r'bn\1.scale', k) 109 | k = k.replace('running_mean', 'mean') 110 | k = k.replace('running_var', 'var') 111 | k = k.replace('.weight', '.kernel') 112 | out[k] = v 113 | return out 114 | 115 | 116 | def create_model(variant, pretrained=False, rng=None, input_shape=None, dtype=jnp.float32, **kwargs): 117 | model_cfg = get_model_cfg(variant) 118 | model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg']) 119 | model_args.update(kwargs) 120 | 121 | # resolve some special layers and their arguments 122 | se_args = model_args.pop('se_cfg', {}) # not consumable by model 123 | if 'se_layer' not in model_args: 124 | if 'bound_act_fn' in se_args: 125 | se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn']) 126 | if 'gate_fn' in se_args: 127 | se_args['gate_fn'] = get_act_fn(se_args['gate_fn']) 128 | model_args['se_layer'] = partial(SqueezeExcite, **se_args) 129 | 130 | bn_args = model_args.pop('bn_cfg') # not consumable by model 131 | if 'norm_layer' not in model_args: 132 | model_args['norm_layer'] = partial(batchnorm2d, **bn_args) 133 | 134 | model_args['act_fn'] = get_act_fn(model_args.pop('act_fn', 'relu')) # convert str -> fn 135 | 136 | model = EfficientNet(dtype=dtype, default_cfg=model_cfg['default_cfg'], **model_args) 137 | 138 | rng = jax.random.PRNGKey(0) if rng is None else rng 139 | params_rng, dropout_rng = jax.random.split(rng) 140 | input_shape = model_cfg['default_cfg']['input_size'] if input_shape is None else input_shape 141 | input_shape = (1, input_shape[1], input_shape[2], input_shape[0]) # CHW -> HWC by default 142 | 143 | # FIXME is jiting the init worthwhile for my usage? 144 | # @jax.jit 145 | # def init(*args): 146 | # return model.init(*args, training=True) 147 | 148 | variables = model.init( 149 | {'params': params_rng, 'dropout': dropout_rng}, 150 | jnp.ones(input_shape, dtype=dtype), 151 | training=False) 152 | 153 | if pretrained: 154 | variables = load_pretrained(variables, default_cfg=model.default_cfg, filter_fn=_filter) 155 | 156 | return model, variables 157 | -------------------------------------------------------------------------------- /jeffnet/linen/ema_state.py: -------------------------------------------------------------------------------- 1 | import flax 2 | import flax.serialization as serialization 3 | import flax.struct as struct 4 | import jax 5 | 6 | from typing import Any 7 | 8 | 9 | @struct.dataclass 10 | class EmaState: 11 | decay: float = struct.field(pytree_node=False, default=0.) 12 | variables: flax.core.FrozenDict[str, Any] = None 13 | 14 | @staticmethod 15 | def create(decay, variables): 16 | """Initialize ema state""" 17 | if decay == 0.: 18 | # default state == disabled 19 | return EmaState() 20 | ema_variables = jax.tree_map(lambda x: x, variables) 21 | return EmaState(decay, ema_variables) 22 | 23 | def update(self, new_variables): 24 | if self.decay == 0.: 25 | return self.replace(variables=None) 26 | new_ema_variables = jax.tree_multimap( 27 | lambda ema, p: ema * self.decay + (1. - self.decay) * p, self.variables, new_variables) 28 | return self.replace(variables=new_ema_variables) 29 | -------------------------------------------------------------------------------- /jeffnet/linen/helpers.py: -------------------------------------------------------------------------------- 1 | """ Pretrained State Dict Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from flax.core import FrozenDict, freeze, unfreeze 6 | from flax.traverse_util import flatten_dict, unflatten_dict 7 | from jeffnet.common import load_state_dict_from_url, split_state_dict, load_state_dict 8 | 9 | 10 | def load_pretrained(variables, url='', default_cfg=None, filter_fn=None): 11 | if not url: 12 | assert default_cfg is not None and default_cfg['url'] 13 | url = default_cfg['url'] 14 | state_dict = load_state_dict_from_url(url, transpose=True) 15 | 16 | source_params, source_state = split_state_dict(state_dict) 17 | if filter_fn is not None: 18 | # filter after split as we may have modified the split criteria (ie bn running vars) 19 | source_params = filter_fn(source_params) 20 | source_state = filter_fn(source_state) 21 | 22 | # FIXME better way to do this? 23 | var_unfrozen = unfreeze(variables) 24 | missing_keys = [] 25 | flat_params = flatten_dict(var_unfrozen['params']) 26 | flat_param_keys = set() 27 | for k, v in flat_params.items(): 28 | flat_k = '.'.join(k) 29 | if flat_k in source_params: 30 | assert flat_params[k].shape == v.shape 31 | flat_params[k] = source_params[flat_k] 32 | else: 33 | missing_keys.append(flat_k) 34 | flat_param_keys.add(flat_k) 35 | unexpected_keys = list(set(source_params.keys()).difference(flat_param_keys)) 36 | params = freeze(unflatten_dict(flat_params)) 37 | 38 | flat_state = flatten_dict(var_unfrozen['batch_stats']) 39 | flat_state_keys = set() 40 | for k, v in flat_state.items(): 41 | flat_k = '.'.join(k) 42 | if flat_k in source_state: 43 | assert flat_state[k].shape == v.shape 44 | flat_state[k] = source_state[flat_k] 45 | else: 46 | missing_keys.append(flat_k) 47 | flat_state_keys.add(flat_k) 48 | unexpected_keys.extend(list(set(source_state.keys()).difference(flat_state_keys))) 49 | batch_stats = freeze(unflatten_dict(flat_state)) 50 | 51 | if missing_keys: 52 | print(f' WARNING: {len(missing_keys)} keys missing while loading state_dict. {str(missing_keys)}') 53 | if unexpected_keys: 54 | print(f' WARNING: {len(unexpected_keys)} unexpected keys found while loading state_dict. {str(unexpected_keys)}') 55 | 56 | return dict(params=params, batch_stats=batch_stats) 57 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import get_act_fn 2 | from .linear import conv2d, linear 3 | from .mixed_conv import MixedConv 4 | from .normalization import batchnorm2d 5 | from .stochastic import Dropout, DropPath, drop_path 6 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import flax.linen as nn 5 | import jax.nn as jnn 6 | from functools import partial 7 | 8 | _ACT_FN = dict( 9 | relu=nn.relu, 10 | relu6=jnn.relu6, 11 | leaky_relu=nn.leaky_relu, 12 | gelu=nn.gelu, 13 | elu=nn.elu, 14 | softplus=nn.softplus, 15 | silu=nn.swish, 16 | swish=nn.swish, 17 | sigmoid=nn.sigmoid, 18 | tanh=nn.tanh, 19 | hard_silu=jnn.hard_silu, 20 | hard_swish=jnn.hard_silu, 21 | hard_sigmoid=jnn.hard_sigmoid, 22 | hard_tanh=jnn.hard_tanh, 23 | ) 24 | 25 | 26 | def get_act_fn(name='relu', **kwargs): 27 | name = name.lower() 28 | assert name in _ACT_FN 29 | act_fn = _ACT_FN[name] 30 | if kwargs: 31 | act_fn = partial(act_fn, **kwargs) 32 | return act_fn 33 | 34 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear / Conv Layer Wrappers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | 6 | from typing import Any, Callable, Sequence, Optional, Tuple, Union 7 | 8 | import flax.linen as nn 9 | import flax.linen.initializers as initializers 10 | 11 | import jax.numpy as jnp 12 | 13 | from jeffnet.common import get_like_padding 14 | from jeffnet.utils.to_tuple import to_tuple 15 | 16 | PRNGKey = Any 17 | Shape = Tuple[int] 18 | Dtype = Any # this could be a real type? 19 | Array = Any 20 | 21 | default_kernel_init = initializers.kaiming_normal() 22 | 23 | 24 | def conv2d( 25 | features: int, 26 | kernel_size: int, 27 | stride: Optional[int] = None, 28 | padding: Union[str, Tuple[int, int]] = 0, 29 | dilation: Optional[int] = None, 30 | groups: int = 1, 31 | bias: bool = False, 32 | dtype: Dtype = jnp.float32, 33 | precision: Any = None, 34 | name: Optional[str] = None, 35 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init, 36 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros): 37 | 38 | stride = stride or 1 39 | dilation = dilation or 1 40 | if isinstance(padding, str): 41 | if padding == 'LIKE': 42 | padding = get_like_padding(kernel_size, stride, dilation) 43 | padding = to_tuple(padding, 2) 44 | padding = [padding, padding] 45 | else: 46 | padding = to_tuple(padding, 2) 47 | padding = [padding, padding] 48 | return nn.Conv( 49 | features=features, 50 | kernel_size=to_tuple(kernel_size, 2), 51 | strides=to_tuple(stride, 2), 52 | padding=padding, 53 | kernel_dilation=to_tuple(dilation, 2), 54 | feature_group_count=groups, 55 | use_bias=bias, 56 | dtype=dtype, 57 | precision=precision, 58 | name=name, 59 | kernel_init=kernel_init, 60 | bias_init=bias_init, 61 | ) 62 | 63 | 64 | def linear( 65 | features: int, 66 | bias: bool = True, 67 | dtype: Dtype = jnp.float32, 68 | name: str = None, 69 | kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init, 70 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros, 71 | ): 72 | return nn.Dense( 73 | features=features, 74 | use_bias=bias, 75 | dtype=dtype, 76 | name=name, 77 | kernel_init=kernel_init, 78 | bias_init=bias_init, 79 | ) 80 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/mixed_conv.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Sequence, Optional, Tuple, List, Union 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from .linear import conv2d 8 | 9 | ModuleDef = Any 10 | 11 | 12 | def _split_channels(num_feat, num_groups): 13 | split = [num_feat // num_groups for _ in range(num_groups)] 14 | split[0] += num_feat - sum(split) 15 | return split 16 | 17 | 18 | def _to_list(x): 19 | if isinstance(x, int): 20 | return [x] 21 | return x 22 | 23 | 24 | class MixedConv(nn.Module): 25 | """ Mixed Grouped Convolution 26 | Based on MDConv and GroupedConv in MixNet impl: 27 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 28 | """ 29 | features: int 30 | kernel_size: Union[List[int], int] = 3 31 | dilation: int = 1 32 | stride: int = 1 33 | padding: Union[str, Tuple[int, int]] = 0 34 | depthwise: bool = False 35 | bias: bool = False 36 | 37 | conv_layer: ModuleDef = conv2d 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | num_groups = len(_to_list(self.kernel_size)) 42 | # NOTE need to use np not jnp for calculating splits otherwise abstract value error 43 | in_splits = np.array(_split_channels(x.shape[-1], num_groups)).cumsum()[:-1] 44 | out_splits = _split_channels(self.features, num_groups) 45 | x_split = jnp.split(x, in_splits, axis=3) 46 | x_out = [self.conv_layer( 47 | feat, kernel_size=k, stride=self.stride, padding=self.padding, dilation=self.dilation, 48 | groups=feat if self.depthwise else 1, bias=self.bias, name=f'{idx}')(x_split[idx]) 49 | for idx, (k, feat) in enumerate(zip(self.kernel_size, out_splits))] 50 | x = jnp.concatenate(x_out, axis=3) 51 | return x 52 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/normalization.py: -------------------------------------------------------------------------------- 1 | """ BatchNorm Layer Wrapper 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from typing import Any, Callable, Tuple, Optional 6 | 7 | from jax import lax 8 | import jax.numpy as jnp 9 | import flax.linen as nn 10 | import flax.linen.initializers as initializers 11 | 12 | PRNGKey = Any 13 | Shape = Tuple[int] 14 | Dtype = Any # this could be a real type? 15 | Array = Any 16 | 17 | 18 | def _absolute_dims(rank, dims): 19 | return tuple([rank + dim if dim < 0 else dim for dim in dims]) 20 | 21 | 22 | class BatchNorm(nn.Module): 23 | """BatchNorm Module. 24 | 25 | NOTE: A BatchNorm layer similar to Flax ver, but with diff of squares for var cal for numerical 26 | comparisons. Also, removed cross-process reduction in this variation (for now). 27 | 28 | Attributes: 29 | axis: the feature or non-batch axis of the input. 30 | momentum: decay rate for the exponential moving average of the batch statistics. 31 | epsilon: a small float added to variance to avoid dividing by zero. 32 | dtype: the dtype of the computation (default: float32). 33 | bias: if True, bias (beta) is added. 34 | scale: if True, multiply by scale (gamma). 35 | When the next layer is linear (also e.g. nn.relu), this can be disabled 36 | since the scaling will be done by the next layer. 37 | bias_init: initializer for bias, by default, zero. 38 | scale_init: initializer for scale, by default, one. 39 | """ 40 | axis: int = -1 41 | momentum: float = 0.99 42 | epsilon: float = 1e-5 43 | dtype: Dtype = jnp.float32 44 | use_bias: bool = True 45 | use_scale: bool = True 46 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 47 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 48 | 49 | @nn.compact 50 | def __call__(self, x, training: bool): 51 | """Normalizes the input using batch statistics. 52 | Args: 53 | x: the input to be normalized. 54 | Returns: 55 | Normalized inputs (the same shape as inputs). 56 | """ 57 | x = jnp.asarray(x, jnp.float32) 58 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) 59 | axis = _absolute_dims(x.ndim, axis) 60 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) 61 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) 62 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) 63 | 64 | # we detect if we're in initialization via empty variable tree. 65 | initializing = not self.has_variable('batch_stats', 'mean') 66 | 67 | ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) 68 | ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) 69 | 70 | if not training: 71 | mean, var = ra_mean.value, ra_var.value 72 | else: 73 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False) 74 | var = jnp.mean((x - mean) ** 2, axis=reduction_axis, keepdims=False) 75 | if not initializing: 76 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean 77 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 78 | 79 | y = x - mean.reshape(feature_shape) 80 | mul = lax.rsqrt(var + self.epsilon) 81 | if self.use_scale: 82 | scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) 83 | mul = mul * scale 84 | y = y * mul 85 | if self.use_bias: 86 | bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) 87 | y = y + bias 88 | return jnp.asarray(y, self.dtype) 89 | 90 | 91 | class FlaxBatchNorm(nn.Module): 92 | """ FlaxBatchNorm Module. 93 | 94 | NOTE: A copy of the official Flax BN layer, w/ diff of squares variance and cross-process batch stats syncing. 95 | 96 | Attributes: 97 | axis: the feature or non-batch axis of the input. 98 | momentum: decay rate for the exponential moving average of the batch statistics. 99 | epsilon: a small float added to variance to avoid dividing by zero. 100 | dtype: the dtype of the computation (default: float32). 101 | bias: if True, bias (beta) is added. 102 | scale: if True, multiply by scale (gamma). 103 | When the next layer is linear (also e.g. nn.relu), this can be disabled 104 | since the scaling will be done by the next layer. 105 | bias_init: initializer for bias, by default, zero. 106 | scale_init: initializer for scale, by default, one. 107 | axis_name: the axis name used to combine batch statistics from multiple 108 | devices. See `jax.pmap` for a description of axis names (default: None). 109 | axis_index_groups: groups of axis indices within that named axis 110 | representing subsets of devices to reduce over (default: None). For 111 | example, `[[0, 1], [2, 3]]` would independently batch-normalize over 112 | the examples on the first two and last two devices. See `jax.lax.psum` for more details. 113 | """ 114 | axis: int = -1 115 | momentum: float = 0.99 116 | epsilon: float = 1e-5 117 | dtype: Dtype = jnp.float32 118 | use_bias: bool = True 119 | use_scale: bool = True 120 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 121 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 122 | axis_name: Optional[str] = None 123 | axis_index_groups: Any = None 124 | 125 | @nn.compact 126 | def __call__(self, x, training: bool): 127 | """Normalizes the input using batch statistics. 128 | Args: 129 | x: the input to be normalized. 130 | Returns: 131 | Normalized inputs (the same shape as inputs). 132 | """ 133 | x = jnp.asarray(x, jnp.float32) 134 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) 135 | axis = _absolute_dims(x.ndim, axis) 136 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) 137 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) 138 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) 139 | 140 | # we detect if we're in initialization via empty variable tree. 141 | initializing = not self.has_variable('batch_stats', 'mean') 142 | 143 | ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) 144 | ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) 145 | 146 | if not training: 147 | mean, var = ra_mean.value, ra_var.value 148 | else: 149 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False) 150 | mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) 151 | if self.axis_name is not None and not initializing: 152 | concatenated_mean = jnp.concatenate([mean, mean2]) 153 | mean, mean2 = jnp.split( 154 | lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) 155 | var = mean2 - lax.square(mean) 156 | 157 | if not initializing: 158 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean 159 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 160 | 161 | y = x - mean.reshape(feature_shape) 162 | mul = lax.rsqrt(var + self.epsilon) 163 | if self.use_scale: 164 | scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) 165 | mul = mul * scale 166 | y = y * mul 167 | if self.use_bias: 168 | bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) 169 | y = y + bias 170 | return jnp.asarray(y, self.dtype) 171 | 172 | 173 | class L1BatchNorm(nn.Module): 174 | """L1 BatchNorm Module. 175 | 176 | Attributes: 177 | axis: the feature or non-batch axis of the input. 178 | momentum: decay rate for the exponential moving average of the batch statistics. 179 | epsilon: a small float added to variance to avoid dividing by zero. 180 | dtype: the dtype of the computation (default: float32). 181 | bias: if True, bias (beta) is added. 182 | scale: if True, multiply by scale (gamma). 183 | When the next layer is linear (also e.g. nn.relu), this can be disabled 184 | since the scaling will be done by the next layer. 185 | bias_init: initializer for bias, by default, zero. 186 | scale_init: initializer for scale, by default, one. 187 | """ 188 | axis: int = -1 189 | momentum: float = 0.99 190 | epsilon: float = 1e-5 191 | dtype: Dtype = jnp.float32 192 | use_bias: bool = True 193 | use_scale: bool = True 194 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros 195 | scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones 196 | 197 | @nn.compact 198 | def __call__(self, x, training: bool): 199 | """Normalizes the input using batch statistics. 200 | Args: 201 | x: the input to be normalized. 202 | Returns: 203 | Normalized inputs (the same shape as inputs). 204 | """ 205 | x = jnp.asarray(x, self.dtype) 206 | axis = self.axis if isinstance(self.axis, tuple) else (self.axis,) 207 | axis = _absolute_dims(x.ndim, axis) 208 | feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape)) 209 | reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis) 210 | reduction_axis = tuple(i for i in range(x.ndim) if i not in axis) 211 | 212 | # we detect if we're in initialization via empty variable tree. 213 | initializing = not self.has_variable('batch_stats', 'mean') 214 | 215 | ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s, jnp.float32), reduced_feature_shape) 216 | ra_var = self.variable('batch_stats', 'var', lambda s: jnp.ones(s, jnp.float32), reduced_feature_shape) 217 | 218 | if not training: 219 | mean, var = ra_mean.value, ra_var.value 220 | else: 221 | mean = jnp.mean(x, axis=reduction_axis, keepdims=False) 222 | var = jnp.mean(lax.abs(x - mean), axis=reduction_axis, keepdims=False) * jnp.sqrt(jnp.pi / 2) 223 | if self.axis_name is not None and not initializing: 224 | concatenated_mean = jnp.concatenate([mean, var]) 225 | mean, var = jnp.split( 226 | lax.pmean(concatenated_mean, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups), 2) 227 | 228 | if not initializing: 229 | ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean 230 | ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var 231 | 232 | mean = jnp.asarray(mean, self.dtype) 233 | var = jnp.asarray(var, self.dtype) 234 | y = x - mean.reshape(feature_shape) 235 | mul = lax.reciprocal(var + self.epsilon) 236 | if self.use_scale: 237 | scale = self.param('scale', self.scale_init, reduced_feature_shape).reshape(feature_shape) 238 | scale = jnp.asarray(scale, self.dtype) 239 | mul = mul * scale 240 | y = y * mul 241 | if self.use_bias: 242 | bias = self.param('bias', self.bias_init, reduced_feature_shape).reshape(feature_shape) 243 | bias = jnp.asarray(bias, self.dtype) 244 | y = y + bias 245 | return jnp.asarray(y, self.dtype) 246 | 247 | 248 | def batchnorm2d( 249 | eps=1e-3, 250 | momentum=0.99, 251 | affine=True, 252 | dtype: Dtype = jnp.float32, 253 | name: Optional[str] = None, 254 | variant: str = '', 255 | bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros, 256 | weight_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones, 257 | ): 258 | layer = BatchNorm 259 | if variant == 'flax': 260 | layer = FlaxBatchNorm 261 | elif variant == 'l1': 262 | layer = L1BatchNorm 263 | 264 | return layer( 265 | momentum=momentum, 266 | epsilon=eps, 267 | use_bias=affine, 268 | use_scale=affine, 269 | dtype=dtype, 270 | name=name, 271 | bias_init=bias_init, 272 | scale_init=weight_init, 273 | ) 274 | -------------------------------------------------------------------------------- /jeffnet/linen/layers/stochastic.py: -------------------------------------------------------------------------------- 1 | """ Dropout, DropPath, DropBLock layers 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | from typing import Any 5 | 6 | from jax import lax 7 | from jax import random 8 | import jax.numpy as jnp 9 | 10 | import flax.linen as nn 11 | from flax.nn import make_rng 12 | 13 | 14 | PRNGKey = Any 15 | 16 | 17 | class Dropout(nn.Module): 18 | """ Dropout layer. 19 | Attributes: 20 | rate: the dropout probability. (_not_ the keep rate!) 21 | """ 22 | rate: float 23 | 24 | @nn.compact 25 | def __call__(self, x, training: bool, rng: PRNGKey = None): 26 | """Applies a random dropout mask to the input. 27 | Args: 28 | x: the inputs that should be randomly masked. 29 | training: if false the inputs are scaled by `1 / (1 - rate)` and 30 | masked, whereas if true, no mask is applied and the inputs are returned as is. 31 | rng: an optional `jax.random.PRNGKey`. By default `nn.make_rng()` will be used. 32 | Returns: 33 | The masked inputs reweighted to preserve mean. 34 | """ 35 | if self.rate == 0. or not training: 36 | return x 37 | keep_prob = 1. - self.rate 38 | if rng is None: 39 | rng = self.make_rng('dropout') 40 | mask = random.bernoulli(rng, p=keep_prob, shape=x.shape) 41 | return lax.select(mask, x / keep_prob, jnp.zeros_like(x)) 42 | 43 | 44 | def drop_path(x: jnp.array, drop_rate: float = 0., rng=None) -> jnp.array: 45 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 46 | 47 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 48 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 49 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 50 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 51 | 'survival rate' as the argument. 52 | 53 | """ 54 | if drop_rate == 0.: 55 | return x 56 | keep_prob = 1. - drop_rate 57 | if rng is None: 58 | rng = make_rng() 59 | mask = random.bernoulli(key=rng, p=keep_prob, shape=(x.shape[0], 1, 1, 1)) 60 | mask = jnp.broadcast_to(mask, x.shape) 61 | return lax.select(mask, x / keep_prob, jnp.zeros_like(x)) 62 | 63 | 64 | class DropPath(nn.Module): 65 | rate: float = 0. 66 | 67 | @nn.compact 68 | def __call__(self, x, training: bool, rng: PRNGKey = None): 69 | if not training or self.rate == 0.: 70 | return x 71 | if rng is None: 72 | rng = self.make_rng('dropout') 73 | return drop_path(x, self.rate, rng) 74 | -------------------------------------------------------------------------------- /jeffnet/objax/__init__.py: -------------------------------------------------------------------------------- 1 | from .efficientnet_objax import EfficientNet, create_model -------------------------------------------------------------------------------- /jeffnet/objax/blocks_objax.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet, MobileNetV3, etc Blocks for Objax 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman 4 | """ 5 | from typing import Union, Callable, Optional 6 | 7 | import objax.nn as nn 8 | import objax.functional as F 9 | from objax import Module, nn as nn, functional as F 10 | from objax.typing import JaxArray 11 | 12 | from jeffnet.common.block_utils import * 13 | from .layers import Conv2d, MixedConv, BatchNorm2d, drop_path, get_act_fn, Linear 14 | 15 | 16 | def create_conv(in_channels, out_channels, kernel_size, conv_layer=None, **kwargs): 17 | """ Select a convolution implementation based on arguments 18 | Creates and returns one of Conv, MixedConv, or CondConv (TODO) 19 | """ 20 | conv_layer = Conv2d if conv_layer is None else conv_layer 21 | if isinstance(kernel_size, list): 22 | assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently 23 | assert 'groups' not in kwargs # MixedConv groups are defined by kernel list 24 | # We're going to use only lists for defining the MixedConv2d kernel groups, 25 | # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. 26 | m = MixedConv(in_channels, out_channels, kernel_size, conv_layer=conv_layer, **kwargs) 27 | else: 28 | depthwise = kwargs.pop('depthwise', False) 29 | groups = in_channels if depthwise else kwargs.pop('groups', 1) 30 | # if 'num_experts' in kwargs and kwargs['num_experts'] > 0: 31 | # m = CondConv(in_channels, out_channels, kernel_size, groups=groups, conv_layer=conv_layer, **kwargs) 32 | # else: 33 | m = conv_layer(in_channels, out_channels, kernel_size, groups=groups, **kwargs) 34 | return m 35 | 36 | 37 | class SqueezeExcite(Module): 38 | def __init__(self, in_chs, se_ratio=0.25, block_chs=None, reduce_from_block=True, 39 | conv_layer=Conv2d, act_fn=F.relu, bound_act_fn=None, gate_fn=F.sigmoid, divisor=1): 40 | super(SqueezeExcite, self).__init__() 41 | base_features = block_chs if block_chs and reduce_from_block else in_chs 42 | reduced_chs = make_divisible(base_features * se_ratio, divisor) 43 | self.reduce = conv_layer(in_chs, reduced_chs, 1, bias=True) 44 | self.act_fn = bound_act_fn if bound_act_fn is not None else act_fn 45 | self.expand = conv_layer(reduced_chs, in_chs, 1, bias=True) 46 | self.gate_fn = gate_fn 47 | 48 | def __call__(self, x): 49 | x_se = x.mean((2, 3), keepdims=True) 50 | x_se = self.reduce(x_se) 51 | x_se = self.act_fn(x_se) 52 | x_se = self.expand(x_se) 53 | return x * self.gate_fn(x_se) 54 | 55 | 56 | class ConvBnAct(Module): 57 | def __init__(self, in_chs, out_chs, kernel_size, 58 | stride=1, dilation=1, pad_type='LIKE', conv_layer=Conv2d, norm_layer=BatchNorm2d, act_fn=F.relu): 59 | super(ConvBnAct, self).__init__() 60 | self.conv = conv_layer(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) 61 | self.bn = norm_layer(out_chs) 62 | self.act_fn = act_fn 63 | 64 | def __call__(self, x, training: bool): 65 | x = self.conv(x) 66 | x = self.bn(x, training=training) 67 | x = self.act_fn(x) 68 | return x 69 | 70 | 71 | class DepthwiseSeparable(Module): 72 | """ DepthwiseSeparable block 73 | Used for DS convs in MobileNet-V1 and in the place of IR blocks that have no expansion 74 | (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. 75 | """ 76 | def __init__(self, in_chs, out_chs, dw_kernel_size=3, 77 | stride=1, dilation=1, pad_type='LIKE', noskip=False, 78 | pw_kernel_size=1, pw_act=False, se_ratio=0., 79 | conv_layer=Conv2d, norm_layer=BatchNorm2d, se_layer=None, act_fn=F.relu, drop_path_rate=0.): 80 | super(DepthwiseSeparable, self).__init__() 81 | self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip 82 | self.has_pw_act = pw_act # activation after point-wise conv 83 | self.drop_path_rate = drop_path_rate 84 | 85 | self.conv_dw = create_conv( 86 | in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, 87 | padding=pad_type, depthwise=True, conv_layer=conv_layer) 88 | self.bn_dw = norm_layer(in_chs) 89 | self.act_fn = act_fn 90 | 91 | # Squeeze-and-excitation 92 | self.se = None 93 | if se_layer is not None and se_ratio > 0.: 94 | self.se = se_layer(in_chs, se_ratio=se_ratio, act_fn=act_fn) 95 | 96 | self.conv_pw = create_conv(in_chs, out_chs, pw_kernel_size, padding=pad_type, conv_layer=conv_layer) 97 | self.bn_pw = norm_layer(out_chs) 98 | 99 | def __call__(self, x, training: bool): 100 | shortcut = x 101 | 102 | x = self.conv_dw(x) 103 | x = self.bn_dw(x, training=training) 104 | x = self.act_fn(x) 105 | 106 | if self.se is not None: 107 | x = self.se(x) 108 | 109 | x = self.conv_pw(x) 110 | x = self.bn_pw(x, training=training) 111 | if self.has_pw_act: 112 | x = self.act_fn(x) 113 | 114 | if self.has_residual: 115 | if training: 116 | x = drop_path(x, drop_prob=self.drop_path_rate) 117 | x += shortcut 118 | return x 119 | 120 | 121 | class InvertedResidual(Module): 122 | """ Inverted residual block w/ optional SE and CondConv routing""" 123 | 124 | def __init__(self, in_chs, out_chs, dw_kernel_size=3, 125 | stride=1, dilation=1, pad_type='LIKE', noskip=False, 126 | exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0., 127 | conv_layer=Conv2d, norm_layer=BatchNorm2d, se_layer=None, act_fn=F.relu, drop_path_rate=0.): 128 | super(InvertedResidual, self).__init__() 129 | mid_chs = make_divisible(in_chs * exp_ratio) 130 | self.has_residual = (in_chs == out_chs and stride == 1) and not noskip 131 | self.drop_path_rate = drop_path_rate 132 | 133 | # Point-wise expansion 134 | self.conv_exp = create_conv(in_chs, mid_chs, exp_kernel_size, padding=pad_type, conv_layer=conv_layer) 135 | self.bn_exp = norm_layer(mid_chs) 136 | self.act_fn = act_fn 137 | 138 | # Depth-wise convolution 139 | self.conv_dw = create_conv( 140 | mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, 141 | padding=pad_type, depthwise=True, conv_layer=conv_layer) 142 | self.bn_dw = norm_layer(mid_chs) 143 | 144 | # Squeeze-and-excitation 145 | self.se = None 146 | if se_layer is not None and se_ratio > 0.: 147 | self.se = se_layer(mid_chs, block_chs=in_chs, se_ratio=se_ratio, act_fn=act_fn) 148 | 149 | # Point-wise linear projection 150 | self.conv_pwl = create_conv(mid_chs, out_chs, pw_kernel_size, padding=pad_type, conv_layer=conv_layer) 151 | self.bn_pwl = norm_layer(out_chs) 152 | 153 | def __call__(self, x, training: bool): 154 | shortcut = x 155 | 156 | # Point-wise expansion 157 | x = self.conv_exp(x) 158 | x = self.bn_exp(x, training=training) 159 | x = self.act_fn(x) 160 | 161 | # Depth-wise convolution 162 | x = self.conv_dw(x) 163 | x = self.bn_dw(x, training=training) 164 | x = self.act_fn(x) 165 | 166 | # Squeeze-and-excitation 167 | if self.se is not None: 168 | x = self.se(x) 169 | 170 | # Point-wise linear projection 171 | x = self.conv_pwl(x) 172 | x = self.bn_pwl(x, training=training) 173 | 174 | if self.has_residual: 175 | if training: 176 | x = drop_path(x, drop_prob=self.drop_path_rate) 177 | x += shortcut 178 | 179 | return x 180 | 181 | 182 | class EdgeResidual(Module): 183 | """ Residual block with expansion convolution followed by pointwise-linear w/ stride""" 184 | 185 | def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, 186 | stride=1, dilation=1, pad_type='LIKE', noskip=False, pw_kernel_size=1, 187 | se_ratio=0., conv_layer=Conv2d, norm_layer=BatchNorm2d, se_layer=None, act_fn=F.relu, 188 | drop_path_rate=0.): 189 | super(EdgeResidual, self).__init__() 190 | _in_chs = fake_in_chs if fake_in_chs > 0 else in_chs # mismatch in arch specs and actual in chs 191 | mid_chs = make_divisible(_in_chs * exp_ratio) 192 | self.has_residual = (in_chs == out_chs and stride == 1) and not noskip 193 | self.drop_path_rate = drop_path_rate 194 | 195 | # Expansion convolution 196 | self.conv_exp = create_conv(in_chs, mid_chs, exp_kernel_size, padding=pad_type, conv_layer=conv_layer) 197 | self.bn_exp = norm_layer(mid_chs) 198 | self.act_fn = act_fn 199 | 200 | # Squeeze-and-excitation 201 | self.se = None 202 | if se_layer is not None and se_ratio > 0.: 203 | self.se = se_layer(mid_chs, block_chs=in_chs, se_ratio=se_ratio, act_fn=act_fn) 204 | 205 | # Point-wise linear projection 206 | self.conv_pwl = create_conv( 207 | mid_chs, out_chs, pw_kernel_size, stride=stride, dilation=dilation, 208 | padding=pad_type, conv_layer=conv_layer) 209 | self.bn_pwl = norm_layer(out_chs) 210 | 211 | def __call__(self, x, training: bool): 212 | shortcut = x 213 | 214 | # Expansion convolution 215 | x = self.conv_exp(x) 216 | x = self.bn_exp(x, training=training) 217 | x = self.act_fn(x) 218 | 219 | # Squeeze-and-excitation 220 | if self.se is not None: 221 | x = self.se(x) 222 | 223 | # Point-wise linear projection 224 | x = self.conv_pwl(x) 225 | x = self.bn_pwl(x, training=training) 226 | 227 | if self.has_residual: 228 | if training: 229 | x = drop_path(x, drop_prob=self.drop_path_rate) 230 | x = x + shortcut 231 | 232 | return x 233 | 234 | 235 | class EfficientHead(Module): 236 | """ EfficientHead from MobileNetV3 """ 237 | def __init__(self, in_chs: int, num_features: int, num_classes: int = 1000, global_pool: str = 'avg', 238 | act_fn='relu', conv_layer=Conv2d, norm_layer=None): 239 | self.global_pool = global_pool # FIXME support diff pooling 240 | 241 | self.conv_pw = conv_layer(in_chs, num_features, 1, bias=True) 242 | self.act_fn = act_fn 243 | if num_classes > 0: 244 | self.classifier = Linear(num_features, num_classes, bias=True) 245 | else: 246 | self.classifier = None 247 | 248 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 249 | if self.global_pool == 'avg': 250 | x = x.mean((2, 3), keepdims=True) 251 | x = self.conv_pw(x).reshape(x.shape[0], -1) 252 | x = self.act_fn(x) 253 | x = self.classifier(x) 254 | return x 255 | 256 | 257 | class Head(Module): 258 | """ Standard Head from EfficientNet, MixNet, MNasNet, MobileNetV2, etc. """ 259 | def __init__(self, in_chs: int, num_features: int, num_classes: int = 1000, global_pool: str = 'avg', 260 | act_fn=F.relu, conv_layer=Conv2d, norm_layer=BatchNorm2d): 261 | self.global_pool = global_pool # FIXME support diff pooling 262 | 263 | self.conv_pw = conv_layer(in_chs, num_features, 1) 264 | self.bn = norm_layer(num_features) 265 | self.act_fn = act_fn 266 | if num_classes > 0: 267 | self.classifier = Linear(num_features, num_classes, bias=True) 268 | else: 269 | self.classifier = None 270 | 271 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 272 | x = self.conv_pw(x) 273 | x = self.bn(x, training=training) 274 | x = self.act_fn(x) 275 | if self.global_pool == 'avg': 276 | x = x.mean((2, 3)) 277 | if self.classifier is not None: 278 | x = self.classifier(x) 279 | return x 280 | 281 | 282 | class BlockFactory: 283 | 284 | @staticmethod 285 | def CondConv(stage_idx, block_idx, **block_args): 286 | assert False, "Not currently impl" 287 | 288 | @staticmethod 289 | def InvertedResidual(stage_idx, block_idx, **block_args): 290 | return InvertedResidual(**block_args) 291 | 292 | @staticmethod 293 | def DepthwiseSeparable(stage_idx, block_idx, **block_args): 294 | return DepthwiseSeparable(**block_args) 295 | 296 | @staticmethod 297 | def EdgeResidual(stage_idx, block_idx, **block_args): 298 | return EdgeResidual(**block_args) 299 | 300 | @staticmethod 301 | def ConvBnAct(stage_idx, block_idx, **block_args): 302 | block_args.pop('drop_path_rate', None) 303 | block_args.pop('se_layer', None) 304 | return ConvBnAct(**block_args) 305 | 306 | @staticmethod 307 | def get_act_fn(act_fn: Union[str, Callable]): 308 | return get_act_fn(act_fn) if isinstance(act_fn, str) else act_fn 309 | -------------------------------------------------------------------------------- /jeffnet/objax/efficientnet_objax.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet (Objax) Model and Factory 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from typing import Optional 6 | from functools import partial 7 | 8 | import objax.nn as nn 9 | import objax.functional as F 10 | from objax import Module 11 | from objax.typing import JaxArray 12 | 13 | from jeffnet.common import get_model_cfg, round_features, decode_arch_def, EfficientNetBuilder 14 | 15 | from .helpers import load_pretrained 16 | from .layers import Conv2d, BatchNorm2d, get_act_fn 17 | from .blocks_objax import ConvBnAct, SqueezeExcite, BlockFactory, Head, EfficientHead 18 | 19 | 20 | class EfficientNet(Module): 21 | """ EfficientNet (and other MBConvNets) 22 | * EfficientNet B0-B8, L2 23 | * EfficientNet-EdgeTPU 24 | * EfficientNet-Lite 25 | * MixNet S, M, L, XL 26 | * MobileNetV3 27 | * MobileNetV2 28 | * MnasNet A1, B1, and small 29 | * FBNet C 30 | * Single-Path NAS Pixel1 31 | """ 32 | 33 | def __init__(self, block_defs, 34 | num_classes: int = 1000, num_features: int = 1280, drop_rate: float = 0., global_pool: str = 'avg', 35 | feat_multiplier: float = 1.0, feat_divisor: int = 8, feat_min: Optional[int] = None, 36 | in_chs: int = 3, stem_size: int = 32, fix_stem: bool = False, output_stride: int = 32, 37 | efficient_head: bool = False, pad_type: str ='LIKE', conv_layer=Conv2d, norm_layer=BatchNorm2d, 38 | se_layer=SqueezeExcite, act_fn=F.relu, drop_path_rate: float = 0.): 39 | super(EfficientNet, self).__init__() 40 | 41 | self.num_classes = num_classes 42 | self.num_features = num_features 43 | self.drop_rate = drop_rate 44 | 45 | cba_kwargs = dict(conv_layer=conv_layer, norm_layer=norm_layer, act_fn=act_fn) 46 | if not fix_stem: 47 | stem_size = round_features(stem_size, feat_multiplier, feat_divisor, feat_min) 48 | self.stem = ConvBnAct(in_chs, stem_size, 3, stride=2, pad_type=pad_type, **cba_kwargs) 49 | 50 | # Middle stages (IR/ER/DS Blocks) 51 | builder = EfficientNetBuilder( 52 | stem_size, block_defs, BlockFactory(), 53 | feat_multiplier=feat_multiplier, feat_divisor=feat_divisor, feat_min=feat_min, 54 | output_stride=output_stride, pad_type=pad_type, se_layer=se_layer, **cba_kwargs, 55 | drop_path_rate=drop_path_rate) 56 | self.blocks = nn.Sequential([nn.Sequential(b) for b in builder()]) 57 | self.feature_info = builder.features 58 | head_chs = builder.in_chs 59 | 60 | # Head (1x1 conv + pooling + classifier) 61 | head_layer = EfficientHead if efficient_head else Head 62 | self.head = head_layer(head_chs, self.num_features, self.num_classes, global_pool=global_pool, **cba_kwargs) 63 | 64 | # how to init? 65 | 66 | def get_classifier(self): 67 | return self.head.classifier 68 | 69 | def forward_features(self, x: JaxArray, training: bool) -> JaxArray: 70 | x = self.stem(x, training=training) 71 | x = self.blocks(x, training=training) 72 | return x 73 | 74 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 75 | x = self.forward_features(x, training=training) 76 | x = self.head(x, training=training) 77 | return x 78 | 79 | 80 | def create_model(variant, pretrained=False, **kwargs): 81 | model_cfg = get_model_cfg(variant) 82 | model_args = model_cfg['arch_fn'](variant, **model_cfg['arch_cfg']) 83 | model_args.update(kwargs) 84 | 85 | # resolve some special layers and their arguments 86 | se_args = model_args.pop('se_cfg', {}) # not consumable by model 87 | if 'se_layer' not in model_args: 88 | if 'bound_act_fn' in se_args: 89 | se_args['bound_act_fn'] = get_act_fn(se_args['bound_act_fn']) 90 | if 'gate_fn' in se_args: 91 | se_args['gate_fn'] = get_act_fn(se_args['gate_fn']) 92 | model_args['se_layer'] = partial(SqueezeExcite, **se_args) 93 | 94 | bn_args = model_args.pop('bn_cfg') # not consumable by model 95 | if 'norm_layer' not in model_args: 96 | model_args['norm_layer'] = partial(BatchNorm2d, **bn_args) 97 | 98 | model_args['act_fn'] = get_act_fn(model_args.pop('act_fn', 'relu')) # convert str -> fn 99 | 100 | model = EfficientNet(**model_args) 101 | model.default_cfg = model_cfg['default_cfg'] 102 | 103 | if pretrained: 104 | load_pretrained(model, default_cfg=model.default_cfg) 105 | 106 | return model 107 | -------------------------------------------------------------------------------- /jeffnet/objax/helpers.py: -------------------------------------------------------------------------------- 1 | """Pretrained State Dict Helpers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from jeffnet.common import load_state_dict_from_url 6 | 7 | 8 | def load_pretrained(model, url='', default_cfg=None, filter_fn=None): 9 | if not url: 10 | assert default_cfg is not None and default_cfg['url'] 11 | url = default_cfg['url'] 12 | model_vars = model.vars() 13 | jax_state_dict = load_state_dict_from_url(url=url, transpose=False) 14 | if filter_fn is not None: 15 | jax_state_dict = filter_fn(jax_state_dict) 16 | # FIXME hack, assuming alignment, currently enforced by my layer customizations 17 | # TODO remap keys 18 | model_vars.assign(jax_state_dict.values()) 19 | -------------------------------------------------------------------------------- /jeffnet/objax/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import get_act_fn 2 | from .drop_path import drop_path 3 | from .linear import Conv2d, Linear 4 | from .mixed_conv import MixedConv 5 | from .normalization import BatchNorm2d 6 | -------------------------------------------------------------------------------- /jeffnet/objax/layers/activations.py: -------------------------------------------------------------------------------- 1 | """ Activation Factory 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import objax.functional as F 5 | import jax.nn as nn 6 | from functools import partial 7 | 8 | _ACT_FN = dict( 9 | relu=F.relu, 10 | relu6=nn.relu6, 11 | leaky_relu=F.leaky_relu, 12 | gelu=nn.gelu, 13 | elu=F.elu, 14 | softplus=F.softplus, 15 | silu=nn.silu, 16 | swish=nn.silu, 17 | sigmoid=F.sigmoid, 18 | tanh=F.tanh, 19 | hard_silu=nn.hard_silu, 20 | hard_swish=nn.hard_silu, 21 | hard_sigmoid=nn.hard_sigmoid, 22 | hard_tanh=nn.hard_tanh, 23 | ) 24 | 25 | 26 | def get_act_fn(name='relu', **kwargs): 27 | name = name.lower() 28 | assert name in _ACT_FN 29 | act_fn = _ACT_FN[name] 30 | if kwargs: 31 | act_fn = partial(act_fn, **kwargs) 32 | return act_fn 33 | 34 | -------------------------------------------------------------------------------- /jeffnet/objax/layers/drop_path.py: -------------------------------------------------------------------------------- 1 | """ Drop Path 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import jax.random as jr 5 | from objax import random 6 | from objax.typing import JaxArray 7 | 8 | 9 | def drop_path(x: JaxArray, drop_prob: float = 0., generator=random.DEFAULT_GENERATOR) -> JaxArray: 10 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 11 | 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | 18 | """ 19 | # FIXME not tested 20 | if drop_prob == 0.: 21 | return x 22 | keep_prob = 1 - drop_prob 23 | keep_shape = (x.shape[0], 1, 1, 1) 24 | keep_mask = keep_prob + jr.bernoulli(generator.key(), p=keep_prob, shape=keep_shape) 25 | output = (x / keep_prob) * keep_mask 26 | return output 27 | -------------------------------------------------------------------------------- /jeffnet/objax/layers/linear.py: -------------------------------------------------------------------------------- 1 | """ Linear and Conv Layers 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from typing import Callable, Iterable, Tuple, Optional, Union 6 | 7 | from jax import numpy as jnp, lax 8 | 9 | from objax import functional, random, util 10 | from objax.module import Module 11 | from objax.nn.init import kaiming_normal, xavier_normal 12 | from objax.typing import JaxArray 13 | from objax.variable import TrainVar 14 | 15 | from jeffnet.common.padding import get_like_padding 16 | 17 | 18 | class Conv2d(Module): 19 | """Applies a 2D convolution on a 4D-input batch of shape (N,C,H,W).""" 20 | 21 | def __init__(self, 22 | in_channels: int, 23 | out_channels: int, 24 | kernel_size: Union[Tuple[int, int], int], 25 | stride: Union[Tuple[int, int], int] = 1, 26 | padding: Union[str, Tuple[int, int], int] = 0, 27 | dilation: Union[Tuple[int, int], int] = 1, 28 | groups: int = 1, 29 | bias: bool = False, 30 | kernel_init: Callable = kaiming_normal, 31 | bias_init: Callable = jnp.zeros, 32 | ): 33 | """Creates a Conv2D module instance. 34 | 35 | Args: 36 | in_channels: number of channels of the input tensor. 37 | out_channels: number of channels of the output tensor. 38 | kernel_size: size of the convolution kernel, either tuple (height, width) or single number if they're the same. 39 | stride: convolution strides, either tuple (stride_y, stride_x) or single number if they're the same. 40 | dilation: spacing between kernel points (also known as astrous convolution), 41 | either tuple (dilation_y, dilation_x) or single number if they're the same. 42 | groups: number of input and output channels group. When groups > 1 convolution operation is applied 43 | individually for each group. nin and nout must both be divisible by groups. 44 | padding: padding of the input tensor, either Padding.SAME or Padding.VALID. 45 | bias: if True then convolution will have bias term. 46 | kernel_init: initializer for convolution kernel (a function that takes in a HWIO shape and returns a 4D matrix). 47 | """ 48 | super().__init__() 49 | assert in_channels % groups == 0, 'in_chs should be divisible by groups' 50 | assert out_channels % groups == 0, 'out_chs should be divisible by groups' 51 | kernel_size = util.to_tuple(kernel_size, 2) 52 | self.weight = TrainVar(kernel_init((out_channels, in_channels // groups, *kernel_size))) # OIHW 53 | self.bias = TrainVar(bias_init((out_channels,))) if bias else None 54 | self.strides = util.to_tuple(stride, 2) 55 | self.dilations = util.to_tuple(dilation, 2) 56 | if isinstance(padding, str): 57 | if padding == 'LIKE': 58 | padding = ( 59 | get_like_padding(kernel_size[0], self.strides[0], self.dilations[0]), 60 | get_like_padding(kernel_size[1], self.strides[1], self.dilations[1])) 61 | padding = [padding, padding] 62 | else: 63 | padding = util.to_tuple(padding, 2) 64 | padding = [padding, padding] 65 | self.padding = padding 66 | self.groups = groups 67 | 68 | def __call__(self, x: JaxArray) -> JaxArray: 69 | """Returns the results of applying the convolution to input x.""" 70 | y = lax.conv_general_dilated( 71 | x, self.weight.value, self.strides, self.padding, 72 | rhs_dilation=self.dilations, feature_group_count=self.groups, 73 | dimension_numbers=('NCHW', 'OIHW', 'NCHW')) 74 | if self.bias is not None: 75 | y += self.bias.value.reshape((1, -1, 1, 1)) 76 | return y 77 | 78 | 79 | class Linear(Module): 80 | """Applies a linear transformation on an input batch.""" 81 | 82 | def __init__( 83 | self, 84 | in_features: int, 85 | out_features: int, 86 | bias: bool = True, 87 | weight_init: Callable = xavier_normal, 88 | bias_init: Callable = jnp.zeros, 89 | ): 90 | """Creates a Linear module instance. 91 | 92 | Args: 93 | in_features: number of channels of the input tensor. 94 | out_features: number of channels of the output tensor. 95 | bias: if True then linear layer will have bias term. 96 | weight_init: weight initializer for linear layer (a function that takes in a IO shape and returns a 2D matrix). 97 | """ 98 | super().__init__() 99 | self.weight = TrainVar(weight_init((out_features, in_features))) 100 | self.bias = TrainVar(bias_init(out_features)) if bias else None 101 | 102 | def __call__(self, x: JaxArray) -> JaxArray: 103 | """Returns the results of applying the linear transformation to input x.""" 104 | y = jnp.dot(x, self.weight.value.transpose()) 105 | if self.bias is not None: 106 | y += self.bias.value 107 | return y -------------------------------------------------------------------------------- /jeffnet/objax/layers/mixed_conv.py: -------------------------------------------------------------------------------- 1 | """ Mixed Grouped Convolution 2 | 3 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 4 | """ 5 | from objax.module import ModuleList 6 | from jax import numpy as jnp 7 | 8 | from .linear import Conv2d 9 | 10 | 11 | def _split_channels(num_chan, num_groups): 12 | split = [num_chan // num_groups for _ in range(num_groups)] 13 | split[0] += num_chan - sum(split) 14 | return split 15 | 16 | 17 | class MixedConv(ModuleList): 18 | """ Mixed Grouped Convolution 19 | 20 | Based on MDConv and GroupedConv in MixNet impl: 21 | https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py 22 | """ 23 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding='', 24 | dilation=1, depthwise=False, conv_layer=None, **kwargs): 25 | super(MixedConv, self).__init__() 26 | conv_layer = Conv2d if conv_layer is None else conv_layer 27 | kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] 28 | num_groups = len(kernel_size) 29 | in_splits = _split_channels(in_channels, num_groups) 30 | out_splits = _split_channels(out_channels, num_groups) 31 | self.in_channels = sum(in_splits) 32 | self.out_channels = sum(out_splits) 33 | for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): 34 | conv_groups = out_ch if depthwise else 1 35 | self.append( 36 | conv_layer( 37 | in_ch, out_ch, k, stride=stride, 38 | padding=padding, dilation=dilation, groups=conv_groups, **kwargs) 39 | ) 40 | self.splits = jnp.array(in_splits).cumsum()[:-1] 41 | 42 | def __call__(self, x): 43 | x_split = jnp.split(x, self.splits, 1) 44 | x_out = [c(x_split[i]) for i, c in enumerate(self)] 45 | x = jnp.concatenate(x_out, axis=1) 46 | return x 47 | -------------------------------------------------------------------------------- /jeffnet/objax/layers/normalization.py: -------------------------------------------------------------------------------- 1 | """ Normalization Layer Defs 2 | """ 3 | from typing import Callable, Iterable, Tuple, Optional, Union 4 | 5 | from jax import numpy as jnp 6 | 7 | from objax import functional 8 | from objax.module import Module 9 | from objax.typing import JaxArray 10 | from objax.variable import TrainVar, StateVar 11 | 12 | 13 | class _BatchNorm(Module): 14 | """Applies a batch normalization on different ranks of an input tensor. 15 | 16 | The module follows the operation described in Algorithm 1 of 17 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 18 | `_. 19 | """ 20 | 21 | def __init__(self, num_features: int, redux: Iterable[int], momentum: float = 0.9, eps: float = 1e-5): 22 | """Creates a BatchNorm module instance. 23 | 24 | Args: 25 | dims: shape of the batch normalization state variables. 26 | redux: list of indices of reduction axes. Batch norm statistics are computed by averaging over these axes. 27 | momentum: value used to compute exponential moving average of batch statistics. 28 | eps: small value which is used for numerical stability. 29 | """ 30 | super().__init__() 31 | self.num_features = num_features 32 | self.momentum = momentum 33 | self.eps = eps 34 | self.redux = tuple(redux) 35 | self.weight = TrainVar(jnp.ones(num_features)) 36 | self.bias = TrainVar(jnp.zeros(num_features)) 37 | self.running_mean = StateVar(jnp.zeros(num_features)) 38 | self.running_var = StateVar(jnp.ones(num_features)) 39 | 40 | def __call__(self, x: JaxArray, training: bool) -> JaxArray: 41 | """Performs batch normalization of input tensor. 42 | 43 | Args: 44 | x: input tensor. 45 | training: if True compute batch normalization in training mode (accumulating batch statistics), 46 | otherwise compute in evaluation mode (using already accumulated batch statistics). 47 | 48 | Returns: 49 | Batch normalized tensor. 50 | """ 51 | shape = (1, -1, 1, 1) 52 | weight = self.weight.value.reshape(shape) 53 | bias = self.bias.value.reshape(shape) 54 | if training: 55 | mean = x.mean(self.redux, keepdims=True) 56 | var = (x ** 2).mean(self.redux, keepdims=True) - mean ** 2 57 | self.running_mean.value += (1 - self.momentum) * (mean.squeeze(axis=self.redux) - self.running_mean.value) 58 | self.running_var.value += (1 - self.momentum) * (var.squeeze(axis=self.redux) - self.running_var.value) 59 | else: 60 | mean, var = self.running_mean.value.reshape(shape), self.running_var.value.reshape(shape) 61 | 62 | y = weight * (x - mean) * functional.rsqrt(var + self.eps) + bias 63 | return y 64 | 65 | 66 | class BatchNorm1d(_BatchNorm): 67 | """Applies a 1D batch normalization on a 3D-input batch of shape (N,C,L). 68 | 69 | The module follows the operation described in Algorithm 1 of 70 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 71 | `_. 72 | """ 73 | 74 | def __init__(self, num_features: int, momentum: float = 0.9, eps: float = 1e-5): 75 | """Creates a BatchNorm1D module instance. 76 | 77 | Args: 78 | num_features: number of features in the input example. 79 | momentum: value used to compute exponential moving average of batch statistics. 80 | eps: small value which is used for numerical stability. 81 | """ 82 | super().__init__(num_features, (0, 2), momentum, eps) 83 | 84 | 85 | class BatchNorm2d(_BatchNorm): 86 | """Applies a 2D batch normalization on a 4D-input batch of shape (N,C,H,W). 87 | 88 | The module follows the operation described in Algorithm 1 of 89 | `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 90 | `_. 91 | """ 92 | 93 | def __init__(self, num_features: int, momentum: float = 0.9, eps: float = 1e-5): 94 | """Creates a BatchNorm2D module instance. 95 | 96 | Args: 97 | num_features: number of features in the input example. 98 | momentum: value used to compute exponential moving average of batch statistics. 99 | eps: small value which is used for numerical stability. 100 | """ 101 | super().__init__(num_features, (0, 2, 3), momentum, eps) -------------------------------------------------------------------------------- /jeffnet/utils/to_tuple.py: -------------------------------------------------------------------------------- 1 | from numbers import Number 2 | from typing import Callable, List, Union, Tuple, Iterable 3 | 4 | 5 | __all__ = ["to_tuple"] 6 | 7 | 8 | def to_tuple(v: Union[Tuple[Number, ...], Number, Iterable], n: int): 9 | """Converts input to tuple.""" 10 | if isinstance(v, tuple): 11 | return v 12 | elif isinstance(v, Number): 13 | return (v,) * n 14 | else: 15 | return tuple(v) 16 | -------------------------------------------------------------------------------- /pt_linen_validate.py: -------------------------------------------------------------------------------- 1 | """ ImageNet Validation Script 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import argparse 5 | import fnmatch 6 | import os 7 | import time 8 | 9 | import jax 10 | from timm.data import create_dataset, create_loader, resolve_data_config 11 | 12 | from jeffnet.common import get_model_cfg, list_models, correct_topk, AverageMeter 13 | from jeffnet.linen import create_model 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 16 | parser.add_argument('data', metavar='DIR', help='path to dataset') 17 | parser.add_argument('--model', '-m', metavar='MODEL', default='pt_efficientnet_b0', 18 | help='model architecture (default: pt_efficientnet_b0)') 19 | parser.add_argument('-b', '--batch-size', default=250, type=int, 20 | metavar='N', help='mini-batch size (default: 256)') 21 | parser.add_argument('--no-jit', action='store_true', default=False, 22 | help='Disable jit of model (for comparison).') 23 | 24 | 25 | def validate(args): 26 | rng = jax.random.PRNGKey(0) 27 | model, variables = create_model(args.model, pretrained=True, rng=rng) 28 | print(f'Created {args.model} model. Validating...') 29 | 30 | if args.no_jit: 31 | eval_step = lambda images, labels: eval_forward(model, variables, images, labels) 32 | else: 33 | eval_step = jax.jit(lambda images, labels: eval_forward(model, variables, images, labels)) 34 | 35 | dataset = create_dataset('imagenet', args.data) 36 | 37 | data_config = resolve_data_config(vars(args), model=model) 38 | loader = create_loader( 39 | dataset, 40 | input_size=data_config['input_size'], 41 | batch_size=args.batch_size, 42 | use_prefetcher=False, 43 | interpolation=data_config['interpolation'], 44 | mean=data_config['mean'], 45 | std=data_config['std'], 46 | num_workers=8, 47 | crop_pct=data_config['crop_pct']) 48 | 49 | batch_time = AverageMeter() 50 | correct_top1, correct_top5 = 0, 0 51 | total_examples = 0 52 | start_time = prev_time = time.time() 53 | for batch_index, (images, labels) in enumerate(loader): 54 | images = images.numpy().transpose(0, 2, 3, 1) 55 | labels = labels.numpy() 56 | 57 | top1_count, top5_count = eval_step(images, labels) 58 | correct_top1 += int(top1_count) 59 | correct_top5 += int(top5_count) 60 | total_examples += images.shape[0] 61 | 62 | batch_time.update(time.time() - prev_time) 63 | if batch_index % 20 == 0 and batch_index > 0: 64 | print( 65 | f'Test: [{batch_index:>4d}/{len(loader)}] ' 66 | f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' 67 | f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' 68 | f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') 69 | prev_time = time.time() 70 | 71 | acc_1 = 100 * correct_top1 / total_examples 72 | acc_5 = 100 * correct_top5 / total_examples 73 | print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' 74 | f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') 75 | return dict(top1=float(acc_1), top5=float(acc_5)) 76 | 77 | 78 | def eval_forward(model, variables, images, labels): 79 | logits = model.apply(variables, images, mutable=False, training=False) 80 | top1_count, top5_count = correct_topk(logits, labels, topk=(1, 5)) 81 | return top1_count, top5_count 82 | 83 | 84 | def main(): 85 | args = parser.parse_args() 86 | print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) 87 | print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) 88 | jax.config.enable_omnistaging() 89 | 90 | def _try_validate(args): 91 | res = None 92 | batch_size = args.batch_size 93 | while res is None: 94 | try: 95 | print(f'Setting validation batch size to {batch_size}') 96 | args.batch_size = batch_size 97 | res = validate(args) 98 | except RuntimeError as e: 99 | if batch_size <= 1: 100 | print("Validation failed with no ability to reduce batch size. Exiting.") 101 | raise e 102 | batch_size = max(batch_size // 2, 1) 103 | print("Validation failed, reducing batch size by 50%") 104 | return res 105 | 106 | if get_model_cfg(args.model) is not None: 107 | _try_validate(args) 108 | else: 109 | models = list_models(pretrained=True) 110 | if args.model != 'all': 111 | models = fnmatch.filter(models, args.model) 112 | if not models: 113 | print(f'ERROR: No models found to validate with pattern {args.model}.') 114 | exit(1) 115 | 116 | print('Validating:', ', '.join(models)) 117 | results = [] 118 | start_batch_size = args.batch_size 119 | for m in models: 120 | args.batch_size = start_batch_size # reset in case reduced for retry 121 | args.model = m 122 | res = _try_validate(args) 123 | res.update(dict(model=m)) 124 | results.append(res) 125 | print('Results:') 126 | for r in results: 127 | print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}") 128 | 129 | 130 | if __name__ == '__main__': 131 | main() 132 | -------------------------------------------------------------------------------- /pt_objax_validate.py: -------------------------------------------------------------------------------- 1 | """ ImageNet Validation Script 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import os 5 | import time 6 | import argparse 7 | import fnmatch 8 | 9 | import objax 10 | import jax 11 | 12 | from timm.data import create_dataset, create_loader, resolve_data_config 13 | from jeffnet.common import get_model_cfg, list_models, correct_topk, AverageMeter 14 | from jeffnet.objax import create_model 15 | 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 18 | parser.add_argument('data', metavar='DIR', help='path to dataset') 19 | parser.add_argument('--model', '-m', metavar='MODEL', default='pt_efficientnet_b0', 20 | help='model architecture (default: pt_efficientnet_b0)') 21 | parser.add_argument('-b', '--batch-size', default=250, type=int, 22 | metavar='N', help='mini-batch size (default: 256)') 23 | 24 | 25 | def validate(args): 26 | model = create_model(args.model, pretrained=True) 27 | print(f'Created {args.model} model. Validating...') 28 | 29 | eval_step = objax.Jit( 30 | lambda images, labels: eval_forward(model, images, labels), 31 | model.vars()) 32 | 33 | dataset = create_dataset('imagenet', args.data) 34 | 35 | data_config = resolve_data_config(vars(args), model=model) 36 | loader = create_loader( 37 | dataset, 38 | input_size=data_config['input_size'], 39 | batch_size=args.batch_size, 40 | use_prefetcher=False, 41 | interpolation=data_config['interpolation'], 42 | mean=data_config['mean'], 43 | std=data_config['std'], 44 | num_workers=8, 45 | crop_pct=data_config['crop_pct']) 46 | 47 | batch_time = AverageMeter() 48 | correct_top1, correct_top5 = 0, 0 49 | total_examples = 0 50 | start_time = prev_time = time.time() 51 | for batch_index, (images, labels) in enumerate(loader): 52 | images = images.numpy() 53 | labels = labels.numpy() 54 | 55 | top1_count, top5_count = eval_step(images, labels) 56 | correct_top1 += int(top1_count) 57 | correct_top5 += int(top5_count) 58 | total_examples += images.shape[0] 59 | 60 | batch_time.update(time.time() - prev_time) 61 | if batch_index % 20 == 0 and batch_index > 0: 62 | print( 63 | f'Test: [{batch_index:>4d}/{len(loader)}] ' 64 | f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' 65 | f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' 66 | f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') 67 | prev_time = time.time() 68 | 69 | acc_1 = 100 * correct_top1 / total_examples 70 | acc_5 = 100 * correct_top5 / total_examples 71 | print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' 72 | f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') 73 | return dict(top1=float(acc_1), top5=float(acc_5)) 74 | 75 | 76 | def eval_forward(model, images, labels): 77 | logits = model(images, training=False) 78 | top1_count, top5_count = correct_topk(logits, labels, topk=(1, 5)) 79 | return top1_count, top5_count 80 | 81 | 82 | def main(): 83 | args = parser.parse_args() 84 | print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) 85 | print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) 86 | 87 | def _try_validate(args): 88 | res = None 89 | batch_size = args.batch_size 90 | while res is None: 91 | try: 92 | print(f'Setting validation batch size to {batch_size}') 93 | args.batch_size = batch_size 94 | res = validate(args) 95 | except RuntimeError as e: 96 | if batch_size <= 1: 97 | print("Validation failed with no ability to reduce batch size. Exiting.") 98 | raise e 99 | batch_size = max(batch_size // 2, 1) 100 | print("Validation failed, reducing batch size by 50%") 101 | return res 102 | 103 | if get_model_cfg(args.model) is not None: 104 | _try_validate(args) 105 | else: 106 | models = list_models(pretrained=True) 107 | if args.model != 'all': 108 | models = fnmatch.filter(models, args.model) 109 | if not models: 110 | print(f'ERROR: No models found to validate with pattern {args.model}.') 111 | exit(1) 112 | 113 | print('Validating:', ', '.join(models)) 114 | results = [] 115 | start_batch_size = args.batch_size 116 | for m in models: 117 | args.batch_size = start_batch_size # reset in case reduced for retry 118 | args.model = m 119 | res = _try_validate(args) 120 | res.update(dict(model=m)) 121 | results.append(res) 122 | print('Results:') 123 | for r in results: 124 | print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}") 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flax 2 | optax>=0.0.6 3 | jax>=0.1.55 4 | jaxlib>=0.1.37 5 | numpy>=1.18.0 6 | -------------------------------------------------------------------------------- /tf_linen_validate.py: -------------------------------------------------------------------------------- 1 | """ ImageNet Validation Script 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import time 5 | import argparse 6 | import fnmatch 7 | 8 | import jax 9 | import flax 10 | import tensorflow_datasets as tfds 11 | 12 | import jeffnet.data.tf_input_pipeline as input_pipeline 13 | from jeffnet.common import correct_topk, AverageMeter, list_models, get_model_cfg 14 | from jeffnet.linen import create_model 15 | 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 18 | parser.add_argument('data', metavar='DIR', help='path to dataset') 19 | parser.add_argument('--model', '-m', metavar='MODEL', default='tf_efficientnet_b0', 20 | help='model architecture (default: tf_efficientnet_b0)') 21 | parser.add_argument('-b', '--batch-size', default=250, type=int, 22 | metavar='N', help='mini-batch size (default: 256)') 23 | parser.add_argument('--no-jit', action='store_true', default=False, 24 | help='Disable jit of model (for comparison).') 25 | parser.add_argument('--half-precision', action='store_true', default=False, 26 | help='Evaluate in half (mixed) precision') 27 | 28 | 29 | def validate(args): 30 | rng = jax.random.PRNGKey(0) 31 | platform = jax.local_devices()[0].platform 32 | if args.half_precision: 33 | if platform == 'tpu': 34 | model_dtype = jax.numpy.bfloat16 35 | else: 36 | model_dtype = jax.numpy.float16 37 | else: 38 | model_dtype = jax.numpy.float32 39 | 40 | model, variables = create_model(args.model, pretrained=True, dtype=model_dtype, rng=rng) 41 | print(f'Created {args.model} model. Validating...') 42 | 43 | if args.no_jit: 44 | eval_step = lambda images, labels: eval_forward(model.apply, variables, images, labels) 45 | else: 46 | eval_step = jax.jit(lambda images, labels: eval_forward(model.apply, variables, images, labels)) 47 | 48 | """Runs evaluation and returns top-1 accuracy.""" 49 | image_size = model.default_cfg['input_size'][-1] 50 | 51 | eval_iter, num_batches = create_eval_iter( 52 | args.data, args.batch_size, image_size, 53 | half_precision=args.half_precision, 54 | mean=tuple([x * 255 for x in model.default_cfg['mean']]), 55 | std=tuple([x * 255 for x in model.default_cfg['std']]), 56 | interpolation=model.default_cfg['interpolation'], 57 | ) 58 | 59 | batch_time = AverageMeter() 60 | correct_top1, correct_top5 = 0, 0 61 | total_examples = 0 62 | start_time = prev_time = time.time() 63 | for batch_index, batch in enumerate(eval_iter): 64 | images, labels = batch['image'], batch['label'] 65 | top1_count, top5_count = eval_step(images, labels) 66 | correct_top1 += int(top1_count) 67 | correct_top5 += int(top5_count) 68 | total_examples += images.shape[0] 69 | 70 | batch_time.update(time.time() - prev_time) 71 | if batch_index % 20 == 0 and batch_index > 0: 72 | print( 73 | f'Test: [{batch_index:>4d}/{num_batches}] ' 74 | f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' 75 | f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' 76 | f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') 77 | prev_time = time.time() 78 | 79 | acc_1 = 100 * correct_top1 / total_examples 80 | acc_5 = 100 * correct_top5 / total_examples 81 | print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' 82 | f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') 83 | return dict(top1=acc_1, top5=acc_5) 84 | 85 | 86 | def prepare_tf_data(xs): 87 | def _prepare(x): 88 | # Use _numpy() for zero-copy conversion between TF and NumPy. 89 | x = x._numpy() # pylint: disable=protected-access 90 | return x 91 | return jax.tree_map(_prepare, xs) 92 | 93 | 94 | def create_eval_iter(data_dir, batch_size, image_size, dataset_name='imagenet2012:5.0.0', half_precision=False, 95 | mean=None, std=None, interpolation='bicubic'): 96 | dataset_builder = tfds.builder(dataset_name, data_dir=data_dir) 97 | assert dataset_builder.info.splits['validation'].num_examples % batch_size == 0 98 | num_batches = dataset_builder.info.splits['validation'].num_examples // batch_size 99 | ds = input_pipeline.create_split( 100 | dataset_builder, batch_size, train=False, half_precision=half_precision, 101 | image_size=image_size, mean=mean, std=std, interpolation=interpolation, no_repeat=True) 102 | it = map(prepare_tf_data, ds) 103 | return it, num_batches 104 | 105 | 106 | def eval_forward(apply_fn, variables, images, labels): 107 | logits = apply_fn(variables, images, mutable=False, training=False) 108 | top1_count, top5_count = correct_topk(logits, labels, topk=(1, 5)) 109 | return top1_count, top5_count 110 | 111 | 112 | def main(): 113 | args = parser.parse_args() 114 | print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) 115 | print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) 116 | 117 | if get_model_cfg(args.model) is not None: 118 | validate(args) 119 | else: 120 | models = list_models(pretrained=True) 121 | if args.model != 'all': 122 | models = fnmatch.filter(models, args.model) 123 | if not models: 124 | print(f'ERROR: No models found to validate with pattern {args.model}.') 125 | exit(1) 126 | 127 | print('Validating: ', ', '.join(models)) 128 | results = [] 129 | for m in models: 130 | args.model = m 131 | res = validate(args) 132 | res.update(dict(model=m)) 133 | results.append(res) 134 | print('Results:') 135 | for r in results: 136 | print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}") 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /tf_objax_validate.py: -------------------------------------------------------------------------------- 1 | """ ImageNet Validation Script 2 | Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) 3 | """ 4 | import time 5 | import argparse 6 | import fnmatch 7 | 8 | import jax 9 | from absl import logging 10 | 11 | import objax 12 | import jeffnet.data.tf_imagenet_data as imagenet_data 13 | from jeffnet.common import correct_topk, AverageMeter, get_model_cfg, list_models 14 | from jeffnet.objax import create_model 15 | 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') 18 | parser.add_argument('data', metavar='DIR', help='path to dataset') 19 | parser.add_argument('--model', '-m', metavar='MODEL', default='tf_efficientnet_b0', 20 | help='model architecture (default: tf_efficientnet_b0)') 21 | parser.add_argument('-b', '--batch-size', default=250, type=int, 22 | metavar='N', help='mini-batch size (default: 256)') 23 | 24 | 25 | def validate(args): 26 | model = create_model(args.model, pretrained=True) 27 | print(f'Created {args.model} model. Validating...') 28 | 29 | eval_step = objax.Jit( 30 | lambda images, labels: eval_forward(model, images, labels), 31 | model.vars()) 32 | 33 | """Runs evaluation and returns top-1 accuracy.""" 34 | image_size = model.default_cfg['input_size'][-1] 35 | test_ds, num_batches = imagenet_data.load( 36 | imagenet_data.Split.TEST, 37 | is_training=False, 38 | image_size=image_size, 39 | batch_dims=[args.batch_size], 40 | chw=True, 41 | mean=tuple([x * 255 for x in model.default_cfg['mean']]), 42 | std=tuple([x * 255 for x in model.default_cfg['std']]), 43 | tfds_data_dir=args.data) 44 | 45 | batch_time = AverageMeter() 46 | correct_top1, correct_top5 = 0, 0 47 | total_examples = 0 48 | start_time = prev_time = time.time() 49 | for batch_index, batch in enumerate(test_ds): 50 | images, labels = batch['images'], batch['labels'] 51 | top1_count, top5_count = eval_step(images, labels) 52 | correct_top1 += int(top1_count) 53 | correct_top5 += int(top5_count) 54 | total_examples += images.shape[0] 55 | 56 | batch_time.update(time.time() - prev_time) 57 | if batch_index % 20 == 0 and batch_index > 0: 58 | print( 59 | f'Test: [{batch_index:>4d}/{num_batches}] ' 60 | f'Rate: {images.shape[0] / batch_time.val:>5.2f}/s ({images.shape[0] / batch_time.avg:>5.2f}/s) ' 61 | f'Acc@1: {100 * correct_top1 / total_examples:>7.3f} ' 62 | f'Acc@5: {100 * correct_top5 / total_examples:>7.3f}') 63 | prev_time = time.time() 64 | 65 | acc_1 = 100 * correct_top1 / total_examples 66 | acc_5 = 100 * correct_top5 / total_examples 67 | print(f'Validation complete. {total_examples / (prev_time - start_time):>5.2f} img/s. ' 68 | f'Acc@1 {acc_1:>7.3f}, Acc@5 {acc_5:>7.3f}') 69 | return dict(top1=float(acc_1), top5=float(acc_5)) 70 | 71 | 72 | def eval_forward(model, images, labels): 73 | logits = model(images, training=False) 74 | top1_count, top5_count = correct_topk(logits, labels, topk=(1, 5)) 75 | return top1_count, top5_count 76 | 77 | 78 | def main(): 79 | args = parser.parse_args() 80 | logging.set_verbosity(logging.ERROR) 81 | print('JAX host: %d / %d' % (jax.host_id(), jax.host_count())) 82 | print('JAX devices:\n%s' % '\n'.join(str(d) for d in jax.devices()), flush=True) 83 | 84 | if get_model_cfg(args.model) is not None: 85 | validate(args) 86 | else: 87 | models = list_models(pretrained=True) 88 | if args.model != 'all': 89 | models = fnmatch.filter(models, args.model) 90 | if not models: 91 | print(f'ERROR: No models found to validate with pattern ({args.model}).') 92 | exit(1) 93 | 94 | print('Validating:', ', '.join(models)) 95 | results = [] 96 | for m in models: 97 | args.model = m 98 | res = validate(args) 99 | res.update(dict(model=m)) 100 | results.append(res) 101 | print('Results:') 102 | for r in results: 103 | print(f"Model: {r['model']}, Top1: {r['top1']}, Top5: {r['top5']}") 104 | 105 | 106 | if __name__ == '__main__': 107 | main() 108 | -------------------------------------------------------------------------------- /train_configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2020 The Flax Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | """Default Hyperparameter configuration.""" 29 | 30 | import ml_collections 31 | 32 | 33 | def get_config(): 34 | """Get the default hyperparameter configuration.""" 35 | config = ml_collections.ConfigDict() 36 | 37 | # base output directory for experiments (checkpoints, summaries), './output' if not valid 38 | config.output_base_dir = '' 39 | 40 | config.data_dir = '/data/' # use --config.data_dir arg to set without modifying config file 41 | config.dataset = 'imagenet2012:5.0.0' 42 | config.num_classes = 1000 # FIXME not currently used 43 | 44 | config.model = 'tf_efficientnet_b0' 45 | config.image_size = 0 # set from model defaults if 0 46 | config.batch_size = 224 47 | config.eval_batch_size = 100 # set to config.bach_size if 0 48 | config.lr = 0.016 49 | config.label_smoothing = 0.1 50 | config.weight_decay = 1e-5 # l2 weight penalty added to loss 51 | config.ema_decay = .99997 52 | 53 | #config.opt = 'adamw' 54 | #config.opt = 'lars' 55 | #config.opt_eps = 1e-6 56 | #config.opt_beta1 = 0.9 57 | #config.opt_beta2 = 0.999 58 | #config.opt_weight_decay = 0.00001 # by default, weight decay not applied in opt, l2 penalty above is used 59 | 60 | config.opt = 'rmsproptf' 61 | config.opt_eps = .001 62 | config.opt_momentum = 0.9 63 | config.opt_decay = 0.9 64 | config.opt_weight_decay = 0. # by default, weight decay not applied in opt, l2 penalty above is used 65 | 66 | config.lr_schedule = 'step' 67 | config.lr_decay_rate = 0.97 68 | config.lr_decay_epochs = 2.4 69 | config.lr_warmup_epochs = 5. 70 | config.lr_minimum = 1e-6 71 | config.num_epochs = 450 72 | 73 | config.autoaugment = None # 'randaugment' 74 | config.randaug_magnitude = 10 75 | config.randaug_num_layers = 2 76 | config.cache = False 77 | config.half_precision = True 78 | 79 | config.drop_rate = 0.2 80 | config.drop_path_rate = 0.2 81 | 82 | # If num_train_steps==-1 then the number of training steps is calculated from 83 | # num_epochs using the entire dataset. Similarly for steps_per_eval. 84 | config.num_train_steps = -1 85 | config.steps_per_eval = -1 86 | return config 87 | -------------------------------------------------------------------------------- /train_configs/pt_efficientnet_b3-tpu_x8.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet-B3 for TPU v3-8 training 2 | """ 3 | 4 | from train_configs import default as default_lib 5 | 6 | 7 | def get_config(): 8 | config = default_lib.get_config() 9 | 10 | config.model = 'pt_efficientnet_b3' 11 | config.batch_size = 2048 12 | config.eval_batch_size = 1000 13 | config.ema_decay = .9999 14 | config.num_epochs = 550 15 | config.drop_rate = 0.3 16 | 17 | return config 18 | -------------------------------------------------------------------------------- /train_configs/tf_efficientnet_b0-gpu_24gb_x2.py: -------------------------------------------------------------------------------- 1 | """ EfficientNet-B0 for 2 x 24GB GPU training 2 | """ 3 | 4 | from train_configs import default as default_lib 5 | 6 | 7 | def get_config(): 8 | config = default_lib.get_config() 9 | 10 | config.batch_size = 500 11 | 12 | return config 13 | --------------------------------------------------------------------------------