├── .github └── workflows │ ├── ci.yml │ └── pypi.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── jmp ├── BUILD ├── __init__.py └── _src │ ├── BUILD │ ├── __init__.py │ ├── build_defs.bzl │ ├── loss_scale.py │ ├── loss_scale_test.py │ ├── policy.py │ └── policy_test.py ├── requirements-jax.txt ├── requirements-test.txt ├── requirements.txt └── setup.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: pytest 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | test-ubuntu: 7 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 8 | runs-on: "${{ matrix.os }}" 9 | 10 | strategy: 11 | matrix: 12 | python-version: [3.8, 3.9, '3.10', '3.11'] 13 | os: [ubuntu-latest] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v1 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install -r requirements.txt 25 | pip install -r requirements-jax.txt 26 | pip install -r requirements-test.txt 27 | pip install . 28 | - name: Test with pytest 29 | run: | 30 | pip install pytest 31 | pytest jmp 32 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | *.ipynb 3 | *.pyc 4 | .DS_Store 5 | .ipynb_checkpoints/ 6 | .mypy_cache/ 7 | .pytype/ 8 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a 8 | couple of legal hurdles. 9 | 10 | Please fill out either the individual or corporate Contributor License Agreement 11 | (CLA). 12 | 13 | * If you are an individual writing original source code and you're sure you 14 | own the intellectual property, then you'll need to sign an [individual 15 | CLA](http://code.google.com/legal/individual-cla-v1.0.html). 16 | * If you work for a company that wants to allow you to contribute your work, 17 | then you'll need to sign a [corporate 18 | CLA](http://code.google.com/legal/corporate-cla-v1.0.html). 19 | 20 | Follow either of the two links above to access the appropriate CLA and 21 | instructions for how to sign and return it. Once we receive it, we'll be able to 22 | accept your pull requests. 23 | 24 | ***NOTE***: Only original source code from you and other people that have signed 25 | the CLA can be accepted into the main repository. 26 | 27 | ### Contributing code 28 | 29 | If you have improvements to JMP, send us your pull requests! For those just 30 | getting started, Github has a 31 | [howto](https://help.github.com/articles/using-pull-requests/). 32 | 33 | If you want to contribute but you're not sure where to start, take a look at the 34 | [issues with the "contributions welcome" 35 | label](https://github.com/deepmind/jmp/labels/stat%3Acontributions%20welcome). 36 | These are issues that we believe are particularly well suited for outside 37 | contributions, often because we probably won't get to them right now. If you 38 | decide to start on an issue, leave a comment so that other people know that 39 | you're working on it. If you want to help out, but not alone, use the issue 40 | comment thread to coordinate. 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements.txt 4 | include requirements-jax.txt 5 | include requirements-test.txt 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixed precision training in [JAX] 2 | 3 | ![Test status](https://github.com/deepmind/jmp/workflows/pytest/badge.svg) 4 | ![PyPI version](https://img.shields.io/pypi/v/jmp) 5 | 6 | [**Installation**](#installation) 7 | | [**Examples**](#examples) 8 | | [**Policies**](#policies) 9 | | [**Loss scaling**](#loss-scaling) 10 | | [**Citing JMP**](#citing-jmp) 11 | | [**References**](#references) 12 | 13 | Mixed precision training [[0]] is a technique that mixes the use of full and 14 | half precision floating point numbers during training to reduce the memory 15 | bandwidth requirements and improve the computational efficiency of a given 16 | model. 17 | 18 | This library implements support for mixed precision training in [JAX] by providing 19 | two key abstractions (mixed precision "policies" and loss scaling). Neural 20 | network libraries (such as [Haiku]) can integrate with `jmp` and provide 21 | "Automatic Mixed Precision (AMP)" support (automating or simplifying applying 22 | policies to modules). 23 | 24 | All code examples below assume the following: 25 | 26 | ```python 27 | import jax 28 | import jax.numpy as jnp 29 | import jmp 30 | 31 | half = jnp.float16 # On TPU this should be jnp.bfloat16. 32 | full = jnp.float32 33 | ``` 34 | 35 | ## Installation 36 | 37 | JMP is written in pure Python, but depends on C++ code via JAX and NumPy. 38 | 39 | Because JAX installation is different depending on your CUDA version, JMP does 40 | not list JAX as a dependency in `requirements.txt`. 41 | 42 | First, follow [these instructions](https://github.com/google/jax#installation) 43 | to install JAX with the relevant accelerator support. 44 | 45 | Then, install JMP using pip: 46 | 47 | ```bash 48 | $ pip install git+https://github.com/deepmind/jmp 49 | ``` 50 | 51 | ## Examples 52 | 53 | You can find a 54 | [fully worked JMP example in Haiku](https://github.com/deepmind/dm-haiku/tree/master/examples/imagenet) 55 | which shows how to use mixed f32/f16 precision to halve training time on GPU and 56 | mixed f32/bf16 to reduce training time on TPU by a third. 57 | 58 | ## Policies 59 | 60 | A mixed precision policy encapsulates the configuration in a mixed precision 61 | experiment. 62 | 63 | ```python 64 | # Our policy specifies that we will store parameters in full precision but will 65 | # compute and return output in half precision. 66 | my_policy = jmp.Policy(compute_dtype=half, 67 | param_dtype=full, 68 | output_dtype=half) 69 | ``` 70 | 71 | The policy object can be used to cast pytrees: 72 | 73 | ```python 74 | def layer(params, x): 75 | params, x = my_policy.cast_to_compute((params, x)) 76 | w, b = params 77 | y = x @ w + b 78 | return my_policy.cast_to_output(y) 79 | 80 | params = {"w": jnp.ones([], dtype=my_policy.param_dtype)} 81 | y = layer(params, x) 82 | assert y.dtype == half 83 | ``` 84 | 85 | You can replace the output type of a given policy: 86 | 87 | ```python 88 | my_policy = my_policy.with_output_dtype(full) 89 | ``` 90 | 91 | You can also define a policy via a string, which may be useful for specifying a 92 | policy as a command-line argument or as a hyperparameter to your experiment: 93 | 94 | ```python 95 | my_policy = jmp.get_policy("params=float32,compute=float16,output=float32") 96 | float16 = jmp.get_policy("float16") # Everything in f16. 97 | half = jmp.get_policy("half") # Everything in half (f16 or bf16). 98 | ``` 99 | 100 | ## Loss scaling 101 | 102 | When training with reduced precision, consider whether gradients will need to be 103 | shifted into the representable range of the format that you are using. This is 104 | particularly important when training with `float16` and less important for 105 | `bfloat16`. See the NVIDIA mixed precision user guide [[1]] for more details. 106 | 107 | The easiest way to shift gradients is with loss scaling, which scales your loss 108 | and gradients by `S` and `1/S` respectively. 109 | 110 | ```python 111 | def my_loss_fn(params, loss_scale: jmp.LossScale, ...): 112 | loss = ... 113 | # You should apply regularization etc before scaling. 114 | loss = loss_scale.scale(loss) 115 | return loss 116 | 117 | def train_step(params, loss_scale: jmp.LossScale, ...): 118 | grads = jax.grad(my_loss_fn)(...) 119 | grads = loss_scale.unscale(grads) 120 | # You should put gradient clipping etc after unscaling. 121 | params = apply_optimizer(params, grads) 122 | return params 123 | 124 | loss_scale = jmp.StaticLossScale(2 ** 15) 125 | for _ in range(num_steps): 126 | params = train_step(params, loss_scale, ...) 127 | ``` 128 | 129 | The appropriate value for `S` depends on your model, loss, batch size and 130 | potentially other factors. You can determine this with trial and error. As a 131 | rule of thumb you want the largest value of `S` that does not introduce overflow 132 | during backprop. NVIDIA [[1]] recommend computing statistics about the gradients 133 | of your model (in full precision) and picking `S` such that its product with the 134 | maximum norm of your gradients is below `65,504`. 135 | 136 | We provide a dynamic loss scale, which adjusts the loss scale periodically 137 | during training to find the largest value for `S` that produces finite 138 | gradients. This is more convenient and robust compared with picking a static 139 | loss scale, but has a small performance impact (between 1 and 5%). 140 | 141 | ```python 142 | def my_loss_fn(params, loss_scale: jmp.LossScale, ...): 143 | loss = ... 144 | # You should apply regularization etc before scaling. 145 | loss = loss_scale.scale(loss) 146 | return loss 147 | 148 | def train_step(params, loss_scale: jmp.LossScale, ...): 149 | grads = jax.grad(my_loss_fn)(...) 150 | grads = loss_scale.unscale(grads) 151 | # You should put gradient clipping etc after unscaling. 152 | 153 | # You definitely want to skip non-finite updates with the dynamic loss scale, 154 | # but you might also want to consider skipping them when using a static loss 155 | # scale if you experience NaN's when training. 156 | skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale) 157 | 158 | if skip_nonfinite_updates: 159 | grads_finite = jmp.all_finite(grads) 160 | # Adjust our loss scale depending on whether gradients were finite. The 161 | # loss scale will be periodically increased if gradients remain finite and 162 | # will be decreased if not. 163 | loss_scale = loss_scale.adjust(grads_finite) 164 | # Only apply our optimizer if grads are finite, if any element of any 165 | # gradient is non-finite the whole update is discarded. 166 | params = jmp.select_tree(grads_finite, apply_optimizer(params, grads), params) 167 | else: 168 | # With static or no loss scaling just apply our optimizer. 169 | params = apply_optimizer(params, grads) 170 | 171 | # Since our loss scale is dynamic we need to return the new value from 172 | # each step. All loss scales are `PyTree`s. 173 | return params, loss_scale 174 | 175 | loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15)) 176 | for _ in range(num_steps): 177 | params, loss_scale = train_step(params, loss_scale, ...) 178 | ``` 179 | 180 | In general using a static loss scale should offer the best speed, but we have 181 | optimized dynamic loss scaling to make it competitive. We recommend you start 182 | with dynamic loss scaling and move to static loss scaling if performance is an 183 | issue. 184 | 185 | We finally offer a no-op loss scale which you can use as a drop in replacement. 186 | It does nothing (apart from implement the `jmp.LossScale` API): 187 | 188 | ```python 189 | loss_scale = jmp.NoOpLossScale() 190 | assert loss is loss_scale.scale(loss) 191 | assert grads is loss_scale.unscale(grads) 192 | assert loss_scale is loss_scale.adjust(grads_finite) 193 | assert loss_scale.loss_scale == 1 194 | ``` 195 | 196 | ## Citing JMP 197 | 198 | This repository is part of the [DeepMind JAX Ecosystem](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research), 199 | to cite JMP please use the [DeepMind JAX Ecosystem citation](https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt). 200 | 201 | ## References 202 | 203 | [[0]] Paulius Micikevicius, Sharan Narang, Jonah Alben, Gregory Diamos, Erich 204 | Elsen, David Garcia, Boris Ginsburg, Michael Houston, Oleksii Kuchaiev, Ganesh 205 | Venkatesh, Hao Wu: "Mixed Precision Training", 2017; arXiv:1710.03740 206 | https://arxiv.org/abs/1710.03740. 207 | 208 | [[1]] "Training With Mixed Precision :: NVIDIA Deep Learning Performance 209 | Documentation". Docs.Nvidia.Com, 2020, 210 | https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/. 211 | 212 | [0]: https://arxiv.org/abs/1710.03740 213 | [1]: https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html 214 | [Haiku]: https://github.com/deepmind/dm-haiku 215 | [JAX]: https://github.com/google/jax 216 | -------------------------------------------------------------------------------- /jmp/BUILD: -------------------------------------------------------------------------------- 1 | load("//third_party/bazel_rules/rules_python/python:py_library.bzl", "py_library") 2 | load("//tools/build_defs/license:license.bzl", "license") 3 | 4 | # Description: JMP is a JAX Mixed Precision library. 5 | package( 6 | default_applicable_licenses = [":license"], 7 | default_visibility = ["//visibility:private"], 8 | ) 9 | 10 | license( 11 | name = "license", 12 | package_name = "jmp", 13 | ) 14 | 15 | licenses(["notice"]) 16 | 17 | exports_files(["LICENSE"]) 18 | 19 | py_library( 20 | name = "jmp", 21 | srcs = ["__init__.py"], 22 | visibility = ["//visibility:public"], 23 | deps = [ 24 | "//jmp/_src:loss_scale", 25 | "//jmp/_src:policy", 26 | ], 27 | ) 28 | -------------------------------------------------------------------------------- /jmp/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """JMP is a Mixed Precision library for JAX.""" 16 | 17 | from jmp._src.loss_scale import all_finite 18 | from jmp._src.loss_scale import DynamicLossScale 19 | from jmp._src.loss_scale import LossScale 20 | from jmp._src.loss_scale import NoOpLossScale 21 | from jmp._src.loss_scale import select_tree 22 | from jmp._src.loss_scale import StaticLossScale 23 | from jmp._src.policy import cast_to_full 24 | from jmp._src.policy import cast_to_half 25 | from jmp._src.policy import get_policy 26 | from jmp._src.policy import half_dtype 27 | from jmp._src.policy import Policy 28 | 29 | __version__ = "0.0.5.dev" 30 | 31 | __all__ = ( 32 | "all_finite", 33 | "DynamicLossScale", 34 | "LossScale", 35 | "NoOpLossScale", 36 | "select_tree", 37 | "StaticLossScale", 38 | "cast_to_full", 39 | "cast_to_half", 40 | "get_policy", 41 | "half_dtype", 42 | "Policy", 43 | ) 44 | 45 | # _________________________________________ 46 | # / Please don't use symbols in `_src` they \ 47 | # \ are not part of the JMP public API. / 48 | # ----------------------------------------- 49 | # \ ^__^ 50 | # \ (oo)\_______ 51 | # (__)\ )\/\ 52 | # ||----w | 53 | # || || 54 | # 55 | try: 56 | del _src # pylint: disable=undefined-variable 57 | except NameError: 58 | pass 59 | -------------------------------------------------------------------------------- /jmp/_src/BUILD: -------------------------------------------------------------------------------- 1 | load("//jmp/_src:build_defs.bzl", "jmp_py_library", "jmp_py_test") 2 | 3 | package( 4 | default_applicable_licenses = ["//jmp:license"], 5 | default_visibility = ["//jmp:__subpackages__"], 6 | ) 7 | 8 | licenses(["notice"]) 9 | 10 | jmp_py_library( 11 | name = "loss_scale", 12 | srcs = ["loss_scale.py"], 13 | deps = [ 14 | # pip: jax 15 | # pip: numpy 16 | ], 17 | ) 18 | 19 | jmp_py_test( 20 | name = "loss_scale_test", 21 | srcs = ["loss_scale_test.py"], 22 | deps = [ 23 | ":loss_scale", 24 | # pip: absl/testing:absltest 25 | # pip: absl/testing:parameterized 26 | # pip: jax 27 | # pip: numpy 28 | ], 29 | ) 30 | 31 | jmp_py_library( 32 | name = "policy", 33 | srcs = ["policy.py"], 34 | deps = [ 35 | # pip: jax 36 | # pip: numpy 37 | ], 38 | ) 39 | 40 | jmp_py_test( 41 | name = "policy_test", 42 | srcs = ["policy_test.py"], 43 | deps = [ 44 | ":policy", 45 | # pip: absl/testing:absltest 46 | # pip: absl/testing:parameterized 47 | # pip: jax 48 | # pip: numpy 49 | ], 50 | ) 51 | -------------------------------------------------------------------------------- /jmp/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /jmp/_src/build_defs.bzl: -------------------------------------------------------------------------------- 1 | """Build rules for JMP.""" 2 | 3 | # NOTE: Internally we swap these out for macros testing on various 4 | # HW platforms. 5 | jmp_py_binary = native.py_binary 6 | jmp_py_test = native.py_test 7 | jmp_py_library = native.py_library 8 | -------------------------------------------------------------------------------- /jmp/_src/loss_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for loss scaling.""" 16 | 17 | import dataclasses 18 | import functools 19 | from typing import Tuple, TypeVar, Union 20 | import warnings 21 | 22 | import jax 23 | from jax import tree_util 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | # from deepmind.internal import usage_logging 28 | 29 | T = TypeVar("T") 30 | 31 | 32 | def register_empty_pytree(cls): 33 | tree_util.register_pytree_node(cls, lambda x: ((), x), lambda x, _: x) 34 | 35 | 36 | @dataclasses.dataclass(frozen=True) 37 | class NoOpLossScale: 38 | """No-op loss scale does nothing.""" 39 | 40 | @property 41 | def loss_scale(self): 42 | return 1 43 | 44 | def scale(self, tree: T) -> T: 45 | # usage_logging.log_event(usage_logging.Event.JMP, "NoOpLossScale") 46 | return tree 47 | 48 | def unscale(self, tree: T) -> T: 49 | return tree 50 | 51 | def adjust(self, grads_finite: jnp.ndarray): 52 | del grads_finite 53 | return self 54 | 55 | 56 | @dataclasses.dataclass(frozen=True) 57 | class StaticLossScale: 58 | """Scales and unscales by a fixed constant.""" 59 | 60 | loss_scale: jnp.ndarray 61 | 62 | def scale(self, tree: T) -> T: 63 | # usage_logging.log_event(usage_logging.Event.JMP, "StaticLossScale") 64 | return jax.tree_util.tree_map(lambda x: x * self.loss_scale, tree) 65 | 66 | def unscale(self, tree: T) -> T: 67 | inv_loss_scale = 1 / self.loss_scale 68 | return jax.tree_util.tree_map(lambda x: x * inv_loss_scale, tree) 69 | 70 | def adjust(self, grads_finite: jnp.ndarray): 71 | del grads_finite 72 | return self 73 | 74 | _Data = Tuple[jnp.ndarray, ...] 75 | _Meta = Tuple[int, int] 76 | 77 | 78 | @dataclasses.dataclass(frozen=True) 79 | class DynamicLossScale: 80 | """Dynamic loss scale. 81 | 82 | Dynamic loss scaling tries to determine the largest loss scale value that 83 | will keep gradients finite. It does this by increasing the loss scale every 84 | `period` steps by `factor` if the grads remain finite, otherwise it reduces 85 | the loss scale by `1 / factor` and resets the counter. 86 | 87 | loss_scale = 2 ** 15 88 | counter = 0 89 | period = 2000 90 | factor = 2 91 | 92 | for step in range(num_steps): 93 | loss *= loss_scale 94 | grads /= loss_scale 95 | grads_finite = all_finite(grads) 96 | 97 | if grads_finite: 98 | counter += 1 99 | if counter == period: 100 | counter = 0 101 | loss_scale = first_finite(loss_scale * factor, loss_scale) 102 | else: 103 | counter = 0 104 | loss_scale = max(1, loss_scale / factor) 105 | 106 | Typical usage of this class will be something like: 107 | 108 | >>> loss_scale = jmp.DynamicLossScale(jnp.asarray(2. ** 15)) 109 | >>> for _ in range(num_steps): 110 | ... # compute loss 111 | ... loss = loss_scale.scale(loss) 112 | ... # compute grads 113 | ... grads = loss_scale.unscale(grads) 114 | ... grads_finite = jmp.all_finite(grads) 115 | ... loss_scale = loss_scale.adjust(grads_finite) 116 | ... # conditionally update params using grads 117 | """ 118 | loss_scale: jnp.ndarray 119 | counter: jnp.ndarray = dataclasses.field( 120 | default_factory=lambda: np.zeros([], np.int32)) 121 | period: int = 2000 122 | factor: int = 2 123 | min_loss_scale: jnp.ndarray = dataclasses.field( 124 | default_factory=lambda: np.ones([], np.float32)) 125 | 126 | def __post_init__(self) -> None: 127 | warn_if_not_floating(self.loss_scale, "loss_scale") 128 | warn_if_not_floating(self.min_loss_scale, "min_loss_scale") 129 | 130 | def scale(self, tree: T) -> T: 131 | # usage_logging.log_event(usage_logging.Event.JMP, "DynamicLossScale") 132 | return jax.tree_util.tree_map(lambda x: x * self.loss_scale, tree) 133 | 134 | def unscale(self, tree: T) -> T: 135 | inv_loss_scale = 1 / self.loss_scale 136 | return jax.tree_util.tree_map(lambda x: x * inv_loss_scale, tree) 137 | 138 | def tree_flatten(self) -> Tuple[_Data, _Meta]: 139 | data = (self.loss_scale, self.counter) 140 | meta = (self.period, self.factor) 141 | return data, meta 142 | 143 | @classmethod 144 | def tree_unflatten(cls, meta: _Meta, data: _Data) -> "DynamicLossScale": 145 | loss_scale, counter = data 146 | period, factor = meta 147 | return cls(loss_scale, counter, period, factor) 148 | 149 | def adjust(self, grads_finite: jnp.ndarray) -> "DynamicLossScale": 150 | """Returns the next state dependent on whether grads are finite.""" 151 | assert grads_finite.ndim == 0, "Expected boolean scalar" 152 | 153 | first_finite = lambda a, b: jax.lax.select(jnp.isfinite(a).all(), a, b) 154 | loss_scale = jax.lax.select( 155 | grads_finite, 156 | 157 | # When grads are finite increase loss scale periodically. 158 | jax.lax.select( 159 | self.counter == (self.period - 1), 160 | first_finite(self.loss_scale * self.factor, 161 | self.loss_scale), 162 | self.loss_scale), 163 | 164 | # If grads are non finite reduce loss scale. 165 | jnp.maximum(self.min_loss_scale, self.loss_scale / self.factor)) 166 | 167 | counter = ((self.counter + 1) % self.period) * grads_finite 168 | 169 | return DynamicLossScale( 170 | loss_scale=loss_scale, 171 | counter=counter, 172 | period=self.period, 173 | factor=self.factor, 174 | min_loss_scale=self.min_loss_scale) 175 | 176 | 177 | register_empty_pytree(NoOpLossScale) 178 | register_empty_pytree(StaticLossScale) 179 | tree_util.register_pytree_node_class(DynamicLossScale) 180 | 181 | LossScale = Union[NoOpLossScale, StaticLossScale, DynamicLossScale] 182 | 183 | 184 | def all_finite(tree) -> jnp.ndarray: 185 | """Returns a scalar ndarray indicating whether the input arrays are finite.""" 186 | leaves = jax.tree_util.tree_leaves(tree) 187 | if not leaves: 188 | return jnp.array(True) 189 | else: 190 | leaves = map(jnp.isfinite, leaves) 191 | leaves = map(jnp.all, leaves) 192 | return jnp.stack(list(leaves)).all() 193 | 194 | 195 | def select_tree(pred: jnp.ndarray, a: T, b: T) -> T: 196 | """Selects a pytree based on the given predicate.""" 197 | assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar" 198 | return jax.tree_util.tree_map(functools.partial(jax.lax.select, pred), a, b) 199 | 200 | 201 | def warn_if_not_floating(x: Union[jnp.ndarray, object], var_name: str) -> None: 202 | """Produces a warning if the given array does not have a floating type. 203 | 204 | This function handles an edgecase where Jax passes in an `object()` to 205 | determine the structure of user defined pytrees during compilation. They 206 | recommend explicitly checking if the array in question has the type `object`. 207 | 208 | From the Jax documentation: "The __init__ and __new__ methods of custom 209 | PyTree classes should generally avoid doing any array conversion or other 210 | input validation, or else anticipate and handle these special cases." 211 | 212 | See: 213 | https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization 214 | 215 | Args: 216 | x: Any object. 217 | var_name: A useful name to put in error messages. 218 | """ 219 | if type(x) is object: # pylint: disable=unidiomatic-typecheck 220 | return 221 | x_dtype = jax.eval_shape(lambda: x).dtype 222 | if not jnp.issubdtype(x_dtype, jnp.floating): 223 | warnings.warn(f"Expected floating type for {var_name}, got {x_dtype}") 224 | -------------------------------------------------------------------------------- /jmp/_src/loss_scale_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for jmp._src.loss_scale.""" 16 | 17 | import warnings 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import jax 22 | import jax.numpy as jnp 23 | from jmp._src import loss_scale as jmp 24 | import numpy as np 25 | 26 | 27 | class LossScaleTest(parameterized.TestCase): 28 | 29 | def test_no_op_loss_scale(self): 30 | loss_scale = jmp.NoOpLossScale() 31 | tree = {"a": jnp.ones([])} 32 | self.assertIs(loss_scale.scale(tree), tree) 33 | self.assertIs(loss_scale.unscale(tree), tree) 34 | 35 | @parameterized.named_parameters( 36 | ("StaticLossScale(2)", jmp.StaticLossScale, 2), 37 | ("StaticLossScale(3)", jmp.StaticLossScale, 3), 38 | ("StaticLossScale(4)", jmp.StaticLossScale, 4), 39 | ("DynamicLossScale(2)", jmp.DynamicLossScale, 2.), 40 | ("DynamicLossScale(3)", jmp.DynamicLossScale, 3.), 41 | ("DynamicLossScale(4)", jmp.DynamicLossScale, 4.), 42 | ) 43 | def test_static_loss_scale(self, cls, scale): 44 | loss_scale = cls(scale) 45 | tree = {"a": jnp.array(1.)} 46 | scaled_tree = {"a": jnp.array(1. * scale)} 47 | self.assertEqual(loss_scale.scale(tree), scaled_tree) 48 | self.assertEqual(loss_scale.unscale(scaled_tree), tree) 49 | 50 | @parameterized.named_parameters( 51 | ("NoOpLossScale", jmp.NoOpLossScale), 52 | ("StaticLossScale", lambda: jmp.StaticLossScale(0)), # pytype: disable=wrong-arg-types # jax-ndarray 53 | ) 54 | def test_static_empty_trees(self, create): 55 | loss_scale = create() 56 | self.assertEmpty(jax.tree_util.tree_leaves(loss_scale)) 57 | 58 | def test_dynamic_loss_scale_no_warnings(self): 59 | with warnings.catch_warnings(record=True) as logged_warnings: 60 | jmp.DynamicLossScale(2. ** 15) # pytype: disable=wrong-arg-types # jax-ndarray 61 | self.assertEmpty(logged_warnings) 62 | 63 | def test_dynamic_loss_scale_tree(self): 64 | scale = jnp.ones([]) 65 | counter = jnp.zeros([], jnp.int32) 66 | period = 2000 67 | factor = 2 68 | loss_scale = jmp.DynamicLossScale(scale, counter, period, factor) 69 | self.assertEqual(jax.tree_util.tree_leaves(loss_scale), [scale, counter]) 70 | self.assertEqual(jax.tree_util.tree_map(lambda x: x, loss_scale), 71 | loss_scale) 72 | 73 | @parameterized.parameters((20, 2), (30, 3)) 74 | def test_dynamic_loss_scale_adjust_increases_on_finite(self, period, factor): 75 | grads_finite = jnp.bool_(True) 76 | loss_scale = jmp.DynamicLossScale(jnp.float32(10), jnp.int32(0), 77 | period, factor) 78 | for i in range(1, period): 79 | loss_scale = loss_scale.adjust(grads_finite) 80 | self.assertEqual(loss_scale.loss_scale, 10) 81 | self.assertEqual(loss_scale.counter, i) 82 | self.assertEqual(loss_scale.period, period) 83 | self.assertEqual(loss_scale.factor, factor) 84 | 85 | # Loss scale should wrap. 86 | loss_scale = loss_scale.adjust(grads_finite) 87 | self.assertEqual(loss_scale.loss_scale, 10 * factor) 88 | self.assertEqual(loss_scale.counter, 0) 89 | self.assertEqual(loss_scale.period, period) 90 | self.assertEqual(loss_scale.factor, factor) 91 | 92 | @parameterized.parameters((20, 2), (30, 3)) 93 | def test_dynamic_loss_scale_adjust_reduce_on_non_finite(self, period, factor): 94 | grads_finite = jnp.bool_(False) 95 | init = np.float32(10) 96 | loss_scale = jmp.DynamicLossScale(jnp.asarray(init), jnp.int32(0), period, 97 | factor) 98 | self.assertLess(init / (factor ** 100), 1, msg="should cover max(1, S)") 99 | for i in range(100): 100 | loss_scale = loss_scale.adjust(grads_finite) 101 | np.testing.assert_allclose(loss_scale.loss_scale, 102 | max(1, init / (factor ** (i + 1))), 103 | rtol=1e-5) 104 | self.assertEqual(loss_scale.counter, 0) 105 | self.assertEqual(loss_scale.period, period) 106 | self.assertEqual(loss_scale.factor, factor) 107 | 108 | @parameterized.parameters((20, 2, .3125), (30, 3, .37), (5., 2., 0.)) 109 | def test_dynamic_loss_scale_explicit_min_loss_scale(self, period, factor, 110 | min_loss_scale): 111 | grads_finite = jnp.bool_(False) 112 | init = np.float32(10) 113 | loss_scale = jmp.DynamicLossScale( 114 | jnp.asarray(init), jnp.int32(0), period, factor, 115 | jnp.asarray(min_loss_scale)) 116 | self.assertLess(init / (factor**100), 1, msg="should cover max(1, S)") 117 | for i in range(100): 118 | loss_scale = loss_scale.adjust(grads_finite) 119 | np.testing.assert_allclose( 120 | loss_scale.loss_scale, 121 | max(min_loss_scale, init / (factor**(i + 1))), 122 | rtol=1e-5) 123 | self.assertEqual(loss_scale.counter, 0) 124 | self.assertEqual(loss_scale.period, period) 125 | self.assertEqual(loss_scale.factor, factor) 126 | 127 | def test_dynamic_loss_scale_adjust_requires_scalar_input(self): 128 | pass 129 | 130 | def test_dynamic_loss_scale_raises_type_error_on_int_loss_scale(self): 131 | expected_message = "Expected floating type for loss_scale" 132 | with self.assertWarnsRegex(Warning, expected_message): 133 | jmp.DynamicLossScale(jnp.asarray(1, dtype=jnp.int32)) 134 | 135 | def test_dynamic_loss_scale_raises_type_error_on_int_min_loss_scale(self): 136 | expected_message = "Expected floating type for min_loss_scale" 137 | with self.assertWarnsRegex(Warning, expected_message): 138 | jmp.DynamicLossScale(jnp.asarray(1, dtype=jnp.float32), 139 | min_loss_scale=jnp.asarray(1, dtype=jnp.int32)) 140 | 141 | @parameterized.parameters(jnp.inf, jnp.nan) 142 | def test_all_finite(self, non_finite): 143 | self.assertTrue(jmp.all_finite(None)) 144 | self.assertTrue(jmp.all_finite({})) 145 | self.assertFalse(jmp.all_finite({"a": jnp.array(non_finite)})) 146 | self.assertFalse(jmp.all_finite({"a": jnp.ones([]), 147 | "b": jnp.array(non_finite)})) 148 | self.assertFalse(jmp.all_finite({"a": jnp.array(non_finite), 149 | "b": jnp.ones([])})) 150 | self.assertTrue(jmp.all_finite({"a": jnp.ones([]), "b": jnp.ones([])})) 151 | 152 | def test_select_tree(self): 153 | a = {"a": jnp.ones([]), "b": jnp.zeros([])} 154 | b = {"a": jnp.zeros([]), "b": jnp.ones([])} 155 | self.assertIsNone(jmp.select_tree(jnp.bool_(True), None, None)) 156 | self.assertIsNone(jmp.select_tree(jnp.bool_(False), None, None)) 157 | self.assertEqual(jmp.select_tree(jnp.bool_(True), a, b), a) 158 | self.assertEqual(jmp.select_tree(jnp.bool_(False), a, b), b) 159 | 160 | def test_select_tree_rejects_non_scalar(self): 161 | with self.assertRaisesRegex(AssertionError, "expected boolean scalar"): 162 | jmp.select_tree(jnp.ones([1]), None, None) 163 | 164 | if __name__ == "__main__": 165 | absltest.main() 166 | -------------------------------------------------------------------------------- /jmp/_src/policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for mixed precision in JAX.""" 16 | 17 | import dataclasses 18 | from typing import TypeVar 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | T = TypeVar("T") 25 | 26 | 27 | def _cast_floating_to(tree: T, dtype: jnp.dtype) -> T: 28 | def conditional_cast(x): 29 | if (isinstance(x, (np.ndarray, jnp.ndarray)) and 30 | jnp.issubdtype(x.dtype, jnp.floating)): 31 | x = x.astype(dtype) 32 | return x 33 | return jax.tree_util.tree_map(conditional_cast, tree) 34 | 35 | 36 | @dataclasses.dataclass(frozen=True) 37 | class Policy: 38 | """Encapsulates casting for inputs, outputs and parameters.""" 39 | param_dtype: jnp.dtype 40 | compute_dtype: jnp.dtype 41 | output_dtype: jnp.dtype 42 | 43 | def cast_to_param(self, x): 44 | """Converts floating point values to the param dtype.""" 45 | return _cast_floating_to(x, self.param_dtype) 46 | 47 | def cast_to_compute(self, x): 48 | """Converts floating point values to the compute dtype.""" 49 | return _cast_floating_to(x, self.compute_dtype) 50 | 51 | def cast_to_output(self, x): 52 | """Converts floating point values to the output dtype.""" 53 | return _cast_floating_to(x, self.output_dtype) 54 | 55 | def with_output_dtype(self, output_dtype: jnp.dtype) -> "Policy": 56 | return dataclasses.replace(self, output_dtype=output_dtype) 57 | 58 | def __str__(self): 59 | return "p={},c={},o={}".format(dtype_to_names[self.param_dtype][0], 60 | dtype_to_names[self.compute_dtype][0], 61 | dtype_to_names[self.output_dtype][0]) 62 | 63 | 64 | def get_policy(policy_name: str) -> Policy: 65 | """Returns a mixed precision policy parsed from a string.""" 66 | # Loose grammar supporting: 67 | # - "c=f16" (params full, compute+output in f16), 68 | # - "p=f16,c=f16" (params, compute and output in f16). 69 | # - "p=f16,c=bf16" (params in f16, compute in bf16, output in bf16) 70 | # For values that are not specified params defaults to f32, compute follows 71 | # params and output follows compute (e.g. 'c=f16' -> 'p=f32,c=f16,o=f16'). 72 | param_dtype = jnp.float32 73 | compute_dtype = output_dtype = None 74 | if "=" in policy_name: 75 | for part in policy_name.split(","): 76 | key, value = part.split("=", 2) 77 | value = parse_dtype(value) 78 | if key == "p" or key == "params": 79 | param_dtype = value 80 | elif key == "c" or key == "compute": 81 | compute_dtype = value 82 | elif key == "o" or key == "output": 83 | output_dtype = value 84 | else: 85 | raise ValueError(f"Unknown key '{key}' in '{policy_name}' should be " 86 | "'params', 'compute' or 'output'.") 87 | if compute_dtype is None: 88 | compute_dtype = param_dtype 89 | if output_dtype is None: 90 | output_dtype = compute_dtype 91 | else: 92 | # Assume policy name is a dtype (e.g. 'f32' or 'half') that all components 93 | # of the policy should contain. 94 | param_dtype = compute_dtype = output_dtype = parse_dtype(policy_name) 95 | 96 | return Policy(param_dtype=param_dtype, compute_dtype=compute_dtype, 97 | output_dtype=output_dtype) 98 | 99 | 100 | def cast_to_full(tree: T) -> T: 101 | """Ensures floating point leaves of the given tree are f32.""" 102 | return _cast_floating_to(tree, jnp.float32) 103 | 104 | 105 | def cast_to_half(tree: T) -> T: 106 | """Ensures floating point leaves of the given tree are half precision.""" 107 | return _cast_floating_to(tree, half_dtype()) 108 | 109 | 110 | def half_dtype() -> jnp.dtype: 111 | """Returns the half precision dtype for the current backend.""" 112 | device0 = jax.local_devices()[0] 113 | on_tpu = device0.platform == "tpu" 114 | return jnp.bfloat16 if on_tpu else jnp.float16 115 | 116 | 117 | dtype_to_names = { 118 | jnp.bfloat16: ("bf16", "bfloat16"), 119 | jnp.float16: ("f16", "float16"), 120 | jnp.float32: ("full", "f32", "float32"), 121 | jnp.float64: ("f64", "float64"), 122 | } 123 | 124 | name_to_dtype = {name: dtype for dtype, names in dtype_to_names.items() # pylint: disable=g-complex-comprehension 125 | for name in names} 126 | 127 | 128 | def parse_dtype(value: str) -> jnp.dtype: 129 | """Parses a string representing a dtype into a dtype object.""" 130 | if value == "half": 131 | return half_dtype() 132 | 133 | try: 134 | return name_to_dtype[value] 135 | except KeyError as e: 136 | raise ValueError( 137 | f"Unknown dtype '{value}' must be full,half,float16,bfloat16 or a " 138 | "contraction thereof (e.g. 'f' for 'full', 'bf16' for 'bfloat16')" 139 | ) from e 140 | -------------------------------------------------------------------------------- /jmp/_src/policy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for jmp._src.policy.""" 16 | 17 | import itertools as it 18 | import unittest 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | import jax 23 | import jax.numpy as jnp 24 | from jmp._src import policy as jmp 25 | import numpy as np 26 | 27 | HALF_DTYPES = (np.float16, jnp.float16, jnp.bfloat16) 28 | FULL_DTYPES = (np.float32, jnp.float32) 29 | DTYPES = HALF_DTYPES + FULL_DTYPES 30 | NUMPYS = (np, jnp) 31 | 32 | 33 | def get_dtype_name(dtype): 34 | names = { 35 | np.float16: "float16", 36 | jnp.bfloat16: "bfloat16", 37 | np.float32: "float32" 38 | } 39 | return names[dtype] 40 | 41 | 42 | def current_platform(): 43 | return jax.local_devices()[0].platform 44 | 45 | 46 | def skip_if_unsupported(dtype): 47 | platform = current_platform() 48 | if ((platform == "gpu" and dtype == jnp.bfloat16) or 49 | (platform == "tpu" and dtype in (np.float16, jnp.float16))): 50 | raise unittest.SkipTest( 51 | f"{get_dtype_name(dtype)} not supported on {platform}") 52 | 53 | 54 | class PolicyTest(parameterized.TestCase): 55 | 56 | def assert_dtypes_equal(self, tree_a, tree_b): 57 | jax.tree_util.tree_map(lambda a, b: self.assertEqual(a.dtype, b.dtype), 58 | tree_a, tree_b) 59 | 60 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 61 | def test_policy_cast_to_param(self, dtype, np_): 62 | skip_if_unsupported(dtype) 63 | policy = jmp.Policy(dtype, dtype, dtype) 64 | self.assertEqual(policy.param_dtype, dtype) 65 | tree = {"a": np_.ones([])} 66 | self.assert_dtypes_equal(policy.cast_to_param(tree), 67 | {"a": np_.ones([], dtype)}) 68 | 69 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 70 | def test_policy_cast_to_compute(self, dtype, np_): 71 | skip_if_unsupported(dtype) 72 | policy = jmp.Policy(dtype, dtype, dtype) 73 | self.assertEqual(policy.compute_dtype, dtype) 74 | tree = {"a": np_.ones([])} 75 | self.assert_dtypes_equal(policy.cast_to_compute(tree), 76 | {"a": np_.ones([], dtype)}) 77 | 78 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 79 | def test_policy_cast_to_output(self, dtype, np_): 80 | skip_if_unsupported(dtype) 81 | policy = jmp.Policy(dtype, dtype, dtype) 82 | self.assertEqual(policy.output_dtype, dtype) 83 | tree = {"a": np_.ones([])} 84 | self.assert_dtypes_equal(policy.cast_to_output(tree), 85 | {"a": np_.ones([], dtype)}) 86 | 87 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 88 | def test_policy_with_output_dtype(self, dtype, np_): 89 | policy = jmp.Policy(np_.float32, np_.float32, np_.float32) 90 | policy = policy.with_output_dtype(dtype) 91 | self.assertEqual(policy.output_dtype, dtype) 92 | 93 | @parameterized.parameters(("float16", np.float16), 94 | ("float32", np.float32), 95 | ("bfloat16", jnp.bfloat16)) 96 | def test_get_policy(self, dtype_name, dtype): 97 | policy = jmp.get_policy(dtype_name) 98 | self.assertEqual(policy.param_dtype, dtype) 99 | self.assertEqual(policy.compute_dtype, dtype) 100 | self.assertEqual(policy.output_dtype, dtype) 101 | 102 | def test_get_policy_almost_dtype(self): 103 | with self.assertRaisesRegex(ValueError, "Unknown dtype"): 104 | jmp.get_policy("compute_float16") 105 | 106 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 107 | def test_get_policy_mixed(self, dtype, np_): 108 | full = np_.float32 109 | policy = jmp.get_policy(f"c={get_dtype_name(dtype)}") 110 | self.assertEqual(policy.param_dtype, full) 111 | self.assertEqual(policy.compute_dtype, dtype) 112 | self.assertEqual(policy.output_dtype, dtype) 113 | 114 | @parameterized.parameters(*it.product(DTYPES, NUMPYS)) 115 | def test_get_policy_compute(self, dtype, np_): 116 | full = np_.float32 117 | policy = jmp.get_policy(f"c={get_dtype_name(dtype)},o=full") 118 | self.assertEqual(policy.param_dtype, full) 119 | self.assertEqual(policy.compute_dtype, dtype) 120 | self.assertEqual(policy.output_dtype, full) 121 | 122 | def test_half_dtype(self): 123 | if current_platform() == "tpu": 124 | self.assertEqual(jmp.half_dtype(), jnp.bfloat16) 125 | else: 126 | self.assertEqual(jmp.half_dtype(), jnp.float16) 127 | 128 | def test_cast_to_full(self): 129 | half_tree = dict(o=object(), 130 | h=jnp.ones([], dtype=jmp.half_dtype()), 131 | f=jnp.ones([]), 132 | i=jnp.ones([], dtype=jnp.int16)) 133 | full_tree = dict(o=half_tree["o"], 134 | h=half_tree["h"].astype(jnp.float32), 135 | f=half_tree["f"], 136 | i=half_tree["i"]) 137 | self.assertEqual(jmp.cast_to_full(half_tree), full_tree) 138 | 139 | def test_cast_to_half(self): 140 | dtype = jmp.half_dtype() 141 | half_tree = dict(o=object(), 142 | h=jnp.ones([], dtype=dtype), 143 | f=jnp.ones([]), 144 | i=jnp.ones([], dtype=jnp.int16)) 145 | full_tree = dict(o=half_tree["o"], 146 | h=half_tree["h"], 147 | f=half_tree["f"].astype(dtype), 148 | i=half_tree["i"]) 149 | self.assertEqual(jmp.cast_to_half(full_tree), half_tree) 150 | 151 | @parameterized.parameters(*it.product(DTYPES)) 152 | def test_str(self, dtype): 153 | policy = jmp.Policy(dtype, dtype, dtype) 154 | policy_str = str(policy) 155 | for str_piece in policy_str.split(","): 156 | dtype_str = str_piece.split("=")[1] 157 | self.assertEqual(dtype_str, jmp.dtype_to_names[dtype][0]) 158 | 159 | if __name__ == "__main__": 160 | absltest.main() 161 | -------------------------------------------------------------------------------- /requirements-jax.txt: -------------------------------------------------------------------------------- 1 | jax>=0.2.20 2 | jaxlib>=0.1.71 3 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest>=6.2.1 2 | absl-py>=1.4.0 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.19.5 2 | dataclasses>=0.7 ; python_version<'3.7' 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup for pip package.""" 16 | 17 | from setuptools import find_namespace_packages 18 | from setuptools import setup 19 | 20 | 21 | def _get_version(): 22 | with open('jmp/__init__.py') as fp: 23 | for line in fp: 24 | if line.startswith('__version__'): 25 | g = {} 26 | exec(line, g) # pylint: disable=exec-used 27 | return g['__version__'] 28 | raise ValueError('`__version__` not defined in `jmp/__init__.py`') 29 | 30 | 31 | def _parse_requirements(requirements_txt_path): 32 | with open(requirements_txt_path) as fp: 33 | return fp.read().splitlines() 34 | 35 | 36 | _VERSION = _get_version() 37 | 38 | EXTRA_PACKAGES = { 39 | 'jax': _parse_requirements('requirements-jax.txt'), 40 | } 41 | 42 | setup( 43 | name='jmp', 44 | version=_VERSION, 45 | url='https://github.com/deepmind/jmp', 46 | license='Apache 2.0', 47 | author='DeepMind', 48 | description='JMP is a Mixed Precision library for JAX.', 49 | long_description=open('README.md').read(), 50 | long_description_content_type='text/markdown', 51 | author_email='jmp-dev-os@google.com', 52 | # Contained modules and scripts. 53 | packages=find_namespace_packages(exclude=['*_test.py']), 54 | install_requires=_parse_requirements('requirements.txt'), 55 | extras_require=EXTRA_PACKAGES, 56 | tests_require=_parse_requirements('requirements-test.txt'), 57 | requires_python='>=3.8', 58 | include_package_data=True, 59 | zip_safe=False, 60 | # PyPI package information. 61 | classifiers=[ 62 | 'Development Status :: 4 - Beta', 63 | 'Intended Audience :: Developers', 64 | 'Intended Audience :: Education', 65 | 'Intended Audience :: Science/Research', 66 | 'License :: OSI Approved :: Apache Software License', 67 | 'Programming Language :: Python :: 3', 68 | 'Programming Language :: Python :: 3.8', 69 | 'Programming Language :: Python :: 3.9', 70 | 'Topic :: Scientific/Engineering :: Mathematics', 71 | 'Topic :: Software Development :: Libraries :: Python Modules', 72 | 'Topic :: Software Development :: Libraries', 73 | ], 74 | ) 75 | --------------------------------------------------------------------------------