├── .gitattributes ├── .github └── workflows │ ├── publish.yaml │ └── tests.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── docs ├── JAX FP8 matmul tutorial.ipynb ├── PyTorch FP8 matmul tutorial.ipynb ├── img │ └── fp-formats.webp └── operators.md ├── examples ├── cifar10 │ ├── cifar10_training.py │ ├── cifar10_training_with_optax.py │ └── dataset_cifar10.py ├── mnist │ ├── datasets.py │ ├── flax │ │ ├── README.md │ │ ├── configs │ │ │ ├── __init__.py │ │ │ └── default.py │ │ ├── main.py │ │ ├── requirements.txt │ │ └── train.py │ ├── mnist_classifier_from_scratch.py │ ├── mnist_classifier_from_scratch_fp8.py │ └── mnist_classifier_mlp_flax.py └── scalify-quickstart.ipynb ├── jax_scalify ├── __init__.py ├── core │ ├── __init__.py │ ├── datatype.py │ ├── debug.py │ ├── interpreters.py │ ├── pow2.py │ ├── typing.py │ └── utils.py ├── lax │ ├── __init__.py │ ├── base_scaling_primitives.py │ ├── scaled_ops_common.py │ └── scaled_ops_l2.py ├── ops │ ├── __init__.py │ ├── cast.py │ ├── debug.py │ ├── rescaling.py │ └── utils.py ├── quantization │ ├── __init__.py │ └── scale.py ├── tree │ ├── __init__.py │ └── tree_util.py └── utils │ ├── __init__.py │ └── hlo.py ├── pyproject.toml ├── setup.cfg ├── test-requirements.txt └── tests ├── core ├── test_datatype.py ├── test_interpreter.py ├── test_pow2.py └── test_utils.py ├── lax ├── test_base_scaling_primitives.py ├── test_numpy_integration.py ├── test_scaled_ops_common.py ├── test_scaled_ops_l2.py └── test_scipy_integration.py ├── ops ├── test_cast.py ├── test_debug.py └── test_rescaling.py ├── quantization └── test_scale.py ├── tree └── test_tree_util.py └── utils └── test_hlo.py /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/f9aa9af123ce3969cf533212a673d99ad6823dbe/.gitattributes -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | permissions: 8 | contents: read 9 | 10 | jobs: 11 | release-build: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.x" 20 | 21 | - name: Build release distributions 22 | run: | 23 | python -m pip install build 24 | python -m build 25 | 26 | - name: Upload distributions 27 | uses: actions/upload-artifact@v4 28 | with: 29 | name: release-dists 30 | path: dist/ 31 | 32 | pypi-publish: 33 | runs-on: ubuntu-latest 34 | 35 | needs: 36 | - release-build 37 | 38 | permissions: 39 | # IMPORTANT: this permission is mandatory for trusted publishing 40 | id-token: write 41 | 42 | # Dedicated environments with protections for publishing are strongly recommended. 43 | environment: 44 | name: pypi 45 | # OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status: 46 | url: https://pypi.org/p/jax-scalify 47 | 48 | steps: 49 | - name: Retrieve release distributions 50 | uses: actions/download-artifact@v4 51 | with: 52 | name: release-dists 53 | path: dist/ 54 | 55 | - name: Publish release distributions to PyPI 56 | uses: pypa/gh-action-pypi-publish@release/v1 57 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | 3 | env: 4 | GIT_MAIN_BRANCH: "main" 5 | 6 | # Controls when the workflow will run. 7 | on: 8 | push: 9 | branches: [ "main" ] 10 | pull_request: 11 | branches: [ "main" ] 12 | 13 | # Allows you to run this workflow manually from the Actions tab. 14 | workflow_dispatch: 15 | 16 | jobs: 17 | lint_and_typecheck: 18 | runs-on: ubuntu-latest 19 | timeout-minutes: 10 20 | steps: 21 | - name: Cancel previous 22 | uses: styfle/cancel-workflow-action@0.11.0 23 | with: 24 | access_token: ${{ github.token }} 25 | if: ${{github.ref != 'refs/head/main'}} 26 | - uses: actions/checkout@v3 27 | - name: Set up Python 3.10 28 | uses: actions/setup-python@v4 29 | with: 30 | python-version: "3.10" 31 | - uses: pre-commit/action@v3.0.0 32 | 33 | unit_tests: 34 | runs-on: ubuntu-latest 35 | timeout-minutes: 10 36 | steps: 37 | - name: Cancel previous 38 | uses: styfle/cancel-workflow-action@0.11.0 39 | with: 40 | access_token: ${{ github.token }} 41 | if: ${{github.ref != 'refs/head/main'}} 42 | - uses: actions/checkout@v3 43 | - name: Update pip 44 | id: pip-cache 45 | run: | 46 | python3 -m pip install --upgrade pip 47 | - name: Local install & test requirements 48 | run: | 49 | pip3 install -e ./ 50 | pip3 install -r ./test-requirements.txt 51 | # Run repository unit tests on latest JAX 52 | - name: Run unit tests JAX latest 53 | run: | 54 | pytest --tb=short -v --log-cli-level=INFO ./ 55 | - name: JAX 0.3.16 installation 56 | run: | 57 | pip3 install numpy==1.24.3 scipy==1.10.1 58 | pip3 install chex==0.1.6 jax==0.3.16 jaxlib==0.3.15 -f https://storage.googleapis.com/jax-releases/jax_releases.html 59 | # Run repository unit tests on JAX 0.3 60 | - name: Run unit tests JAX 0.3.16 61 | run: | 62 | pytest --tb=short -v --log-cli-level=INFO ./ 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.rendered.*.cpp 9 | *.gp 10 | *.so.lock 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib64/ 21 | parts/ 22 | reports/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | jax_scalify/version.py 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | .jupyter_ystore.db 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # IDEs 137 | .vscode 138 | 139 | # ML tensorboard 140 | *events.out* 141 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-merge-conflict 6 | - id: check-toml 7 | - id: check-yaml 8 | - id: debug-statements 9 | - id: end-of-file-fixer 10 | # Issue on Github action CI? 11 | # - id: no-commit-to-branch 12 | # args: [--branch, main] 13 | - id: requirements-txt-fixer 14 | - id: trailing-whitespace 15 | - repo: https://github.com/PyCQA/isort 16 | rev: 5.13.2 17 | hooks: 18 | - id: isort 19 | args: [--profile, black] 20 | - repo: https://github.com/asottile/pyupgrade 21 | rev: v3.16.0 22 | hooks: 23 | - id: pyupgrade 24 | args: [--py38-plus] 25 | - repo: https://github.com/PyCQA/flake8 26 | rev: 7.0.0 27 | hooks: 28 | - id: flake8 29 | args: ['--ignore=E501,E203,E731,W503'] 30 | - repo: https://github.com/psf/black 31 | rev: 24.4.2 32 | hooks: 33 | - id: black 34 | - repo: https://github.com/pre-commit/mirrors-mypy 35 | rev: v1.10.0 36 | hooks: 37 | - id: mypy 38 | additional_dependencies: [types-dataclasses, numpy] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JAX Scalify: end-to-end scaled arithmetic 2 | 3 | [![tests](https://github.com/graphcore-research/jax-scalify/actions/workflows/tests.yaml/badge.svg)](https://github.com/graphcore-research/jax-scalify/actions/workflows/tests-public.yaml) 4 | ![PyPI version](https://img.shields.io/pypi/v/jax-scalify) 5 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/graphcore-research/jax-scalify/blob/main/LICENSE) 6 | [![GitHub Repo stars](https://img.shields.io/github/stars/graphcore-research/jax-scalify)](https://github.com/graphcore-research/jax-scalify/stargazers) 7 | 8 | 9 | [**Installation**](#installation) 10 | | [**Quickstart**](#quickstart) 11 | | [**Documentation**](#documentation) 12 | 13 | **📣 Scalify** has been accepted to [**ICML 2024 workshop WANT**](https://openreview.net/forum?id=4IWCHWlb6K)! 📣 14 | 15 | **JAX Scalify** is a library implementing end-to-end scale propagation and scaled arithmetic, allowing easy training and inference of 16 | deep neural networks in low precision (BF16, FP16, FP8). 17 | 18 | Loss scaling, tensor scaling and block scaling have been widely used in the deep learning literature to unlock training and inference at lower precision. Most of these works focus on ad-hoc approaches around scaling of matrix multiplications (and sometimes reduction operations). `Scalify` is adopting a more systematic approach with end-to-end scale propagation, i.e. transforming the full computational graph into a `ScaledArray` graph where every operation has `ScaledArray` inputs and returns `ScaledArray`: 19 | 20 | ```python 21 | @dataclass 22 | class ScaledArray: 23 | # Main data component, in low precision. 24 | data: Array 25 | # Scale, usually scalar, in FP32 or E8M0. 26 | scale: Array 27 | 28 | def __array__(self) -> Array: 29 | # Tensor represented as a `ScaledArray`. 30 | return data * scale.astype(self.data.dtype) 31 | ``` 32 | 33 | The main benefits of the `scalify` approach are: 34 | 35 | * Agnostic to neural-net model definition; 36 | * Decoupling scaling from low-precision, reducing the computational overhead of dynamic rescaling; 37 | * FP8 matrix multiplications and reductions as simple as a cast; 38 | * Out-of-the-box support of FP16 (scaled) master weights and optimizer state; 39 | * Composable with JAX ecosystem: [Flax](https://github.com/google/flax), [Optax](https://github.com/google-deepmind/optax), ... 40 | 41 | ## Installation 42 | 43 | JAX Scalify can be directly installed from PyPi: 44 | ```bash 45 | pip install jax-scalify 46 | ``` 47 | Please follow [JAX documentation](https://github.com/google/jax/blob/main/README.md#installation) for a proper JAX installation on GPU/TPU. 48 | 49 | The latest version of JAX Scalify is available directly from Github: 50 | ```bash 51 | pip install git+https://github.com/graphcore-research/jax-scalify.git 52 | ``` 53 | 54 | ## Quickstart 55 | 56 | A typical JAX training loop just requires a couple of modifications to take advantage of `scalify`. More specifically: 57 | 58 | * Represent input and state as `ScaledArray` using the `as_scaled_array` method (or variations of it); 59 | * End-to-end scale propagation in `update` training method using `scalify` decorator; 60 | * (Optionally) add `dynamic_rescale` calls to improve low-precision accuracy and stability; 61 | 62 | 63 | The following (simplified) example presents how to `scalify` can be incorporated into a JAX training loop. 64 | ```python 65 | import jax_scalify as jsa 66 | 67 | # Scalify transform on FWD + BWD + optimizer. 68 | # Propagating scale in the computational graph. 69 | @jsa.scalify 70 | def update(state, data, labels): 71 | # Forward and backward pass on the NN model. 72 | loss, grads = 73 | jax.grad(model)(state, data, labels) 74 | # Optimizer applied on scaled state. 75 | state = optimizer.apply(state, grads) 76 | return loss, state 77 | 78 | # Model + optimizer state. 79 | state = (model.init(...), optimizer.init(...)) 80 | # Transform state to scaled array(s) 81 | sc_state = jsa.as_scaled_array(state) 82 | 83 | for (data, labels) in dataset: 84 | # If necessary (e.g. images), scale input data. 85 | data = jsa.as_scaled_array(data) 86 | # State update, with full scale propagation. 87 | sc_state = update(sc_state, data, labels) 88 | # Optional dynamic rescaling of state. 89 | sc_state = jsa.ops.dynamic_rescale_l2(sc_state) 90 | ``` 91 | As presented in the code above, the model state is represented as a JAX PyTree of `ScaledArray`, propagated end-to-end through the model (forward and backward passes) as well as the optimizer. 92 | 93 | 94 | A full collection of examples is available: 95 | * [Scalify quickstart notebook](./examples/scalify-quickstart.ipynb): basics of `ScaledArray` and `scalify` transform; 96 | * [MNIST FP16 training example](./examples/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`; 97 | * [MNIST FP8 training example](./examples/mnist/mnist_classifier_from_scratch_fp8.py): easy FP8 support in `scalify`; 98 | * [MNIST Flax example](./examples/mnist/mnist_classifier_mlp_flax.py): `scalify` Flax training, with Optax optimizer integration; 99 | 100 | ## Documentation 101 | 102 | * [**Scalify ICML 2024 workshop WANT paper**](https://openreview.net/forum?id=4IWCHWlb6K) 103 | * [Operators coverage in JAX `scalify`](docs/operators.md) 104 | 105 | ## Development 106 | 107 | For a local development setup, we recommend an interactive install: 108 | ```bash 109 | git clone git@github.com:graphcore-research/jax-scalify.git 110 | pip install -e ./ 111 | ``` 112 | 113 | Running `pre-commit` and `pytest` on the JAX Scalify repository: 114 | ```bash 115 | pip install pre-commit 116 | pre-commit run --all-files 117 | pytest -v ./tests 118 | ``` 119 | Python wheel can be built with the usual command `python -m build`. 120 | -------------------------------------------------------------------------------- /docs/PyTorch FP8 matmul tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7ae3e6c9-01d2-4a34-a4a8-88d36c7e9b3f", 6 | "metadata": {}, 7 | "source": [ 8 | "# PyTorch FP8 (fused) matmul tutorial" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 13, 14 | "id": "4c9500fc-648d-46d3-95ea-e74a0ee43fe6", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/plain": [ 20 | "(device(type='cuda', index=0), 'NVIDIA H100 PCIe')" 21 | ] 22 | }, 23 | "execution_count": 13, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "import numpy as np\n", 30 | "import torch\n", 31 | "\n", 32 | "# Local GPU device\n", 33 | "torch.device(0), torch.cuda.get_device_name(0)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "id": "bdbfe673-a8d8-4ba5-afb5-3c4f7eb5b0e7", 39 | "metadata": {}, 40 | "source": [ 41 | "### `_scaled_mm` FP8 matmul wrapper\n", 42 | "\n", 43 | "PyTorch `_scaled_mm` defintion: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/Blas.cpp#L1176C1-L1176C16\n", 44 | "\n", 45 | "`cublasLtMatmul` not supported `E5M2 @ E5M2` matmuls: https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp8#cublasltmatmul \n", 46 | "\n", 47 | "TorchAO is using `_scaled_mm` function for FP8 integration: https://github.com/pytorch/ao/blob/main/torchao/float8/float8_python_api.py" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 31, 53 | "id": "87bcf537-3c09-4241-8ab7-f5c2a55c3ed2", 54 | "metadata": {}, 55 | "outputs": [ 56 | { 57 | "ename": "RuntimeError", 58 | "evalue": "Multiplication of two Float8_e5m2 matrices is not supported", 59 | "output_type": "error", 60 | "traceback": [ 61 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 62 | "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", 63 | "Cell \u001b[0;32mIn[31], line 18\u001b[0m\n\u001b[1;32m 15\u001b[0m b_scale \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mones((), dtype\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mfloat32, device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;66;03m# FP8 matmul\u001b[39;00m\n\u001b[0;32m---> 18\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_scaled_mm\u001b[49m\u001b[43m(\u001b[49m\u001b[43ma_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mb_fp8\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[1;32m 19\u001b[0m \u001b[43m \u001b[49m\u001b[43mout_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfloat16\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 20\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_a\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43ma_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_b\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mb_scale\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_fast_accum\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mbias\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m 24\u001b[0m \u001b[43m \u001b[49m\u001b[43mscale_result\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n", 64 | "\u001b[0;31mRuntimeError\u001b[0m: Multiplication of two Float8_e5m2 matrices is not supported" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "M, N, K = 128, 64, 256\n", 70 | "\n", 71 | "a = torch.randn((M, K), dtype=torch.float16, device='cuda')\n", 72 | "# Transpose as cuBLASLt requires column major on `rhs`\n", 73 | "b = torch.randn((N, K), dtype=torch.float16, device='cuda').t()\n", 74 | "\n", 75 | "# FP8 inputs & scales\n", 76 | "# a_fp8 = a.to(torch.float8_e4m3fn)\n", 77 | "# b_fp8 = b.to(torch.float8_e4m3fn)\n", 78 | "\n", 79 | "a_fp8 = a.to(torch.float8_e5m2)\n", 80 | "b_fp8 = b.to(torch.float8_e5m2)\n", 81 | "\n", 82 | "a_scale = torch.ones((), dtype=torch.float32, device='cuda')\n", 83 | "b_scale = torch.ones((), dtype=torch.float32, device='cuda')\n", 84 | "\n", 85 | "# FP8 matmul\n", 86 | "out = torch._scaled_mm(a_fp8, b_fp8, \n", 87 | " out_dtype=torch.float16,\n", 88 | " scale_a=a_scale,\n", 89 | " scale_b=b_scale,\n", 90 | " use_fast_accum=True,\n", 91 | " bias=None,\n", 92 | " scale_result=None)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 28, 98 | "id": "50a320ec-769e-4dc8-b933-29610918d395", 99 | "metadata": {}, 100 | "outputs": [ 101 | { 102 | "data": { 103 | "text/plain": [ 104 | "(torch.Size([128, 64]), torch.float16)" 105 | ] 106 | }, 107 | "execution_count": 28, 108 | "metadata": {}, 109 | "output_type": "execute_result" 110 | } 111 | ], 112 | "source": [ 113 | "out.shape, out.dtype" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "6d3ef1a6-4322-4f87-901a-7e54185cd4f5", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [] 123 | } 124 | ], 125 | "metadata": { 126 | "kernelspec": { 127 | "display_name": "Python 3 (ipykernel)", 128 | "language": "python", 129 | "name": "python3" 130 | }, 131 | "language_info": { 132 | "codemirror_mode": { 133 | "name": "ipython", 134 | "version": 3 135 | }, 136 | "file_extension": ".py", 137 | "mimetype": "text/x-python", 138 | "name": "python", 139 | "nbconvert_exporter": "python", 140 | "pygments_lexer": "ipython3", 141 | "version": "3.10.12" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 5 146 | } 147 | -------------------------------------------------------------------------------- /docs/img/fp-formats.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/f9aa9af123ce3969cf533212a673d99ad6823dbe/docs/img/fp-formats.webp -------------------------------------------------------------------------------- /docs/operators.md: -------------------------------------------------------------------------------- 1 | # JAX Scaled Operators coverage 2 | 3 | Summary of JAX LAX operators supported in `scalify` graph transformation. 4 | 5 | ## [JAX LAX operations](https://jax.readthedocs.io/en/latest/jax.lax.html) 6 | 7 | | Operation | Supported | Remarks | 8 | | ---------------------- | ------------------ |-------- | 9 | | `abs` | :white_check_mark: | | 10 | | `add` | :white_check_mark: | | 11 | | `acos` | :x: | | 12 | | `approx_max_k` | :x: | | 13 | | `approx_min_k` | :x: | | 14 | | `argmax` | :white_check_mark: | | 15 | | `argmin` | :white_check_mark: | | 16 | | `asin` | :x: | | 17 | | `atan` | :x: | | 18 | | `atan2` | :x: | | 19 | | `batch_matmul` | :x: | | 20 | | `bessel_i0e` | :x: | | 21 | | `bessel_i1e` | :x: | | 22 | | `betainc` | :x: | | 23 | | `bitcast_convert_type` | :white_check_mark: | | 24 | | `bitwise_not` | :x: | | 25 | | `bitwise_and` | :x: | | 26 | | `bitwise_or` | :x: | | 27 | | `bitwise_xor` | :x: | | 28 | | `population_count` | :x: | | 29 | | `broadcast` | :white_check_mark: | | 30 | | `broadcast_in_dim` | :white_check_mark: | | 31 | | `cbrt` | :x: | | 32 | | `ceil` | :x: | | 33 | | `clamp` | :x: | | 34 | | `collapse` | :white_check_mark: | | 35 | | `complex` | :x: | | 36 | | `concatenate` | :white_check_mark: | | 37 | | `conj` | :x: | | 38 | | `conv` | :white_check_mark: | | 39 | | `convert_element_type` | :white_check_mark: | | 40 | | `conv_general_dilated` | :white_check_mark: | | 41 | | `conv_transpose` | :white_check_mark: | | 42 | | `cos` | :white_check_mark: | | 43 | | `cosh` | :x: | | 44 | | `cummax` | :x: | | 45 | | `cummin` | :x: | | 46 | | `cumprod` | :x: | | 47 | | `cumsum` | :x: | | 48 | | `digamma` | :x: | | 49 | | `div` | :white_check_mark: | | 50 | | `dot` | :white_check_mark: | | 51 | | `dot_general` | :white_check_mark: | Limited set of configurations. See below. | 52 | | `dynamic_slice` | :x: | | 53 | | `dynamic_update_slice` | :x: | | 54 | | `eq` | :white_check_mark: | | 55 | | `erf` | :x: | | 56 | | `erfc` | :x: | | 57 | | `erf_inv` | :x: | | 58 | | `exp` | :white_check_mark: | | 59 | | `expand_dims` | :white_check_mark: | | 60 | | `expm1` | :x: | | 61 | | `fft` | :x: | | 62 | | `floor` | :x: | | 63 | | `full` | :question: | | 64 | | `full_like` | :question: | | 65 | | `gather` | :x: | | 66 | | `ge` | :white_check_mark: | | 67 | | `gt` | :white_check_mark: | | 68 | | `igamma` | :x: | | 69 | | `igammac` | :x: | | 70 | | `imag` | :x: | | 71 | | `index_in_dim` | :x: | | 72 | | `index_take` | :x: | | 73 | | `iota` | :white_check_mark: | | 74 | | `is_finite` | :white_check_mark: | | 75 | | `le` | :white_check_mark: | | 76 | | `lt` | :white_check_mark: | | 77 | | `lgamma` | :x: | | 78 | | `log` | :white_check_mark: | | 79 | | `log1p` | :x: | | 80 | | `logistic` | :x: | | 81 | | `max` | :white_check_mark: | | 82 | | `min` | :white_check_mark: | | 83 | | `mul` | :white_check_mark: | | 84 | | `ne` | :white_check_mark: | | 85 | | `neg` | :white_check_mark: | | 86 | | `nextafter` | :x: | | 87 | | `pad` | :white_check_mark: | | 88 | | `polygamma` | :x: | | 89 | | `pow` | :x: | | 90 | | `real` | :x: | | 91 | | `reciprocal` | :x: | | 92 | | `reduce` | :white_check_mark: | | 93 | | `reduce_precision` | :white_check_mark: | | 94 | | `reduce_window` | :white_check_mark: | | 95 | | `reshape` | :white_check_mark: | | 96 | | `rem` | :x: | | 97 | | `rev` | :white_check_mark: | | 98 | | `round` | :x: | | 99 | | `rsqrt` | :x: | | 100 | | `scatter` | :x: | | 101 | | `scatter_add` | :x: | | 102 | | `scatter_max` | :x: | | 103 | | `scatter_min` | :x: | | 104 | | `scatter_mul` | :x: | | 105 | | `select` | :white_check_mark: | | 106 | | `shift_left` | :x: | | 107 | | `shift_right_arithmetic`| :x: | | 108 | | `shift_right_logical` | :x: | | 109 | | `slice` | :white_check_mark: | | 110 | | `slice_in_dim` | :white_check_mark: | | 111 | | `sign` | :x: | | 112 | | `sin` | :white_check_mark: | | 113 | | `sinh` | :x: | | 114 | | `sort` | :x: | | 115 | | `sort_key_val` | :x: | | 116 | | `sqrt` | :x: | | 117 | | `square` | :x: | | 118 | | `squeeze` | :white_check_mark: | | 119 | | `sub` | :white_check_mark: | | 120 | | `tan` | :x: | | 121 | | `tie_in` | :x: | Deprecated in JAX | 122 | | `top_k` | :x: | | 123 | | `transpose` | :white_check_mark: | | 124 | | `zeta` | :x: | | 125 | -------------------------------------------------------------------------------- /examples/cifar10/cifar10_training.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | """A basic CIFAR10 example using Numpy and JAX. 17 | 18 | CIFAR10 training using MLP network + raw SGD optimizer. 19 | """ 20 | import time 21 | 22 | import dataset_cifar10 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import numpy.random as npr 27 | from jax import grad, jit, lax 28 | 29 | import jax_scalify as jsa 30 | 31 | 32 | def logsumexp(a, axis=None, keepdims=False): 33 | dims = (axis,) 34 | amax = jnp.max(a, axis=dims, keepdims=keepdims) 35 | # FIXME: not proper scale propagation, introducing NaNs 36 | # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) 37 | amax = lax.stop_gradient(amax) 38 | out = lax.sub(a, amax) 39 | out = lax.exp(out) 40 | out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) 41 | return out 42 | 43 | 44 | def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): 45 | return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] 46 | 47 | 48 | def print_mean_std(name, v): 49 | data, scale = jsa.lax.get_data_scale(v) 50 | # Always use np.float32, to avoid floating errors in descaling + stats. 51 | v = jsa.asarray(data, dtype=np.float32) 52 | m, s = np.mean(v), np.std(v) 53 | # print(data) 54 | print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") 55 | 56 | 57 | def predict(params, inputs): 58 | activations = inputs 59 | for w, b in params[:-1]: 60 | # Matmul + relu 61 | outputs = jnp.dot(activations, w) + b 62 | activations = jnp.maximum(outputs, 0) 63 | 64 | final_w, final_b = params[-1] 65 | logits = jnp.dot(activations, final_w) + final_b 66 | 67 | # Dynamic rescaling of the gradient, as logits gradient not properly scaled. 68 | logits = jsa.ops.dynamic_rescale_l2_grad(logits) 69 | output = logits - logsumexp(logits, axis=1, keepdims=True) 70 | 71 | return output 72 | 73 | 74 | def loss(params, batch): 75 | inputs, targets = batch 76 | preds = predict(params, inputs) 77 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) 78 | 79 | 80 | def accuracy(params, batch): 81 | inputs, targets = batch 82 | target_class = jnp.argmax(targets, axis=1) 83 | predicted_class = jnp.argmax(predict(params, inputs), axis=1) 84 | return jnp.mean(predicted_class == target_class) 85 | 86 | 87 | if __name__ == "__main__": 88 | width = 2048 89 | lr = 1e-4 90 | use_scalify = True 91 | scalify = jsa.scalify if use_scalify else lambda f: f 92 | 93 | layer_sizes = [3072, width, width, 10] 94 | param_scale = 1.0 95 | 96 | step_size = lr 97 | num_epochs = 10 98 | batch_size = 128 99 | training_dtype = np.float16 100 | scale_dtype = np.float32 101 | 102 | train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar() 103 | num_train = train_images.shape[0] 104 | num_complete_batches, leftover = divmod(num_train, batch_size) 105 | num_batches = num_complete_batches + bool(leftover) 106 | 107 | def data_stream(): 108 | rng = npr.RandomState(0) 109 | while True: 110 | perm = rng.permutation(num_train) 111 | for i in range(num_batches): 112 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 113 | yield train_images[batch_idx], train_labels[batch_idx] 114 | 115 | batches = data_stream() 116 | params = init_random_params(param_scale, layer_sizes) 117 | # Transform parameters to `ScaledArray` and proper dtype. 118 | if use_scalify: 119 | params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) 120 | params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) 121 | 122 | @jit 123 | @scalify 124 | def update(params, batch): 125 | grads = grad(loss)(params, batch) 126 | return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] 127 | 128 | for epoch in range(num_epochs): 129 | start_time = time.time() 130 | for _ in range(num_batches): 131 | batch = next(batches) 132 | # Scaled micro-batch + training dtype cast. 133 | if use_scalify: 134 | batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) 135 | batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) 136 | 137 | with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): 138 | params = update(params, batch) 139 | 140 | epoch_time = time.time() - start_time 141 | 142 | # Evaluation in float32, for consistency. 143 | raw_params = jsa.asarray(params, dtype=np.float32) 144 | train_acc = accuracy(raw_params, (train_images, train_labels)) 145 | test_acc = accuracy(raw_params, (test_images, test_labels)) 146 | print(f"Epoch {epoch} in {epoch_time:0.2f} sec") 147 | print(f"Training set accuracy {train_acc:0.5f}") 148 | print(f"Test set accuracy {test_acc:0.5f}") 149 | -------------------------------------------------------------------------------- /examples/cifar10/cifar10_training_with_optax.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | """A basic CIFAR10 example using Numpy and JAX. 17 | """ 18 | 19 | 20 | import time 21 | 22 | import dataset_cifar10 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import numpy.random as npr 27 | import optax 28 | from jax import grad, jit, lax 29 | 30 | import jax_scalify as jsa 31 | 32 | 33 | def logsumexp(a, axis=None, keepdims=False): 34 | dims = (axis,) 35 | amax = jnp.max(a, axis=dims, keepdims=keepdims) 36 | # FIXME: not proper scale propagation, introducing NaNs 37 | # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) 38 | amax = lax.stop_gradient(amax) 39 | out = lax.sub(a, amax) 40 | out = lax.exp(out) 41 | out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) 42 | return out 43 | 44 | 45 | def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): 46 | return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] 47 | 48 | 49 | def print_mean_std(name, v): 50 | data, scale = jsa.lax.get_data_scale(v) 51 | # Always use np.float32, to avoid floating errors in descaling + stats. 52 | v = jsa.asarray(data, dtype=np.float32) 53 | m, s = np.mean(v), np.std(v) 54 | # print(data) 55 | print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / SCALE({scale:.4f})") 56 | 57 | 58 | def predict(params, inputs): 59 | activations = inputs 60 | for w, b in params[:-1]: 61 | # Matmul + relu 62 | outputs = jnp.dot(activations, w) + b 63 | activations = jnp.maximum(outputs, 0) 64 | 65 | final_w, final_b = params[-1] 66 | logits = jnp.dot(activations, final_w) + final_b 67 | # Dynamic rescaling of the gradient, as logits gradient not properly scaled. 68 | logits = jsa.ops.dynamic_rescale_l2_grad(logits) 69 | output = logits - logsumexp(logits, axis=1, keepdims=True) 70 | 71 | return output 72 | 73 | 74 | def loss(params, batch): 75 | inputs, targets = batch 76 | preds = predict(params, inputs) 77 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) 78 | 79 | 80 | def accuracy(params, batch): 81 | inputs, targets = batch 82 | target_class = jnp.argmax(targets, axis=1) 83 | predicted_class = jnp.argmax(predict(params, inputs), axis=1) 84 | return jnp.mean(predicted_class == target_class) 85 | 86 | 87 | if __name__ == "__main__": 88 | width = 256 89 | lr = 1e-3 90 | use_scalify = False 91 | training_dtype = np.float32 92 | scalify = jsa.scalify if use_scalify else lambda f: f 93 | 94 | layer_sizes = [3072, width, width, 10] 95 | param_scale = 1.0 96 | num_epochs = 10 97 | batch_size = 128 98 | scale_dtype = np.float32 99 | 100 | train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar() 101 | num_train = train_images.shape[0] 102 | num_complete_batches, leftover = divmod(num_train, batch_size) 103 | num_batches = num_complete_batches + bool(leftover) 104 | # num_batches = 2 105 | 106 | def data_stream(): 107 | rng = npr.RandomState(0) 108 | while True: 109 | perm = rng.permutation(num_train) 110 | for i in range(num_batches): 111 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 112 | yield train_images[batch_idx], train_labels[batch_idx] 113 | 114 | batches = data_stream() 115 | params = init_random_params(param_scale, layer_sizes) 116 | params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params) 117 | # Transform parameters to `ScaledArray` and proper dtype. 118 | optimizer = optax.adam(learning_rate=lr, eps=1e-5) 119 | opt_state = optimizer.init(params) 120 | 121 | if use_scalify: 122 | params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) 123 | 124 | params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) 125 | 126 | @jit 127 | @scalify 128 | def update(params, batch, opt_state): 129 | grads = grad(loss)(params, batch) 130 | updates, opt_state = optimizer.update(grads, opt_state) 131 | params = optax.apply_updates(params, updates) 132 | return params, opt_state 133 | 134 | for epoch in range(num_epochs): 135 | start_time = time.time() 136 | for _ in range(num_batches): 137 | batch = next(batches) 138 | # Scaled micro-batch + training dtype cast. 139 | if use_scalify: 140 | batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) 141 | batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) 142 | 143 | with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): 144 | params, opt_state = update(params, batch, opt_state) 145 | 146 | epoch_time = time.time() - start_time 147 | 148 | # Evaluation in float32, for consistency. 149 | raw_params = jsa.asarray(params, dtype=np.float32) 150 | train_acc = accuracy(raw_params, (train_images, train_labels)) 151 | test_acc = accuracy(raw_params, (test_images, test_labels)) 152 | print(f"Epoch {epoch} in {epoch_time:0.2f} sec") 153 | print(f"Training set accuracy {train_acc:0.5f}") 154 | print(f"Test set accuracy {test_acc:0.5f}") 155 | -------------------------------------------------------------------------------- /examples/cifar10/dataset_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | """Datasets used in examples.""" 17 | 18 | 19 | import array 20 | import gzip 21 | import os 22 | import pickle 23 | import struct 24 | import tarfile 25 | import urllib.request 26 | from os import path 27 | 28 | import numpy as np 29 | 30 | _DATA = "/tmp/jax_example_data/" 31 | 32 | 33 | def _download(url, filename): 34 | """Download a url to a file in the JAX data temp directory.""" 35 | if not path.exists(_DATA): 36 | os.makedirs(_DATA) 37 | out_file = path.join(_DATA, filename) 38 | if not path.isfile(out_file): 39 | urllib.request.urlretrieve(url, out_file) 40 | print(f"downloaded {url} to {_DATA}") 41 | 42 | 43 | def _partial_flatten(x): 44 | """Flatten all but the first dimension of an ndarray.""" 45 | return np.reshape(x, (x.shape[0], -1)) 46 | 47 | 48 | def _one_hot(x, k, dtype=np.float32): 49 | """Create a one-hot encoding of x of size k.""" 50 | return np.array(x[:, None] == np.arange(k), dtype) 51 | 52 | 53 | def _unzip(file): 54 | file = tarfile.open(file) 55 | file.extractall(_DATA) 56 | file.close() 57 | return 58 | 59 | 60 | def _unpickle(file): 61 | with open(file, "rb") as fo: 62 | dict = pickle.load(fo, encoding="bytes") 63 | return dict 64 | 65 | 66 | def mnist_raw(): 67 | """Download and parse the raw MNIST dataset.""" 68 | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 69 | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" 70 | 71 | def parse_labels(filename): 72 | with gzip.open(filename, "rb") as fh: 73 | _ = struct.unpack(">II", fh.read(8)) 74 | return np.array(array.array("B", fh.read()), dtype=np.uint8) 75 | 76 | def parse_images(filename): 77 | with gzip.open(filename, "rb") as fh: 78 | _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) 79 | return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols) 80 | 81 | for filename in [ 82 | "train-images-idx3-ubyte.gz", 83 | "train-labels-idx1-ubyte.gz", 84 | "t10k-images-idx3-ubyte.gz", 85 | "t10k-labels-idx1-ubyte.gz", 86 | ]: 87 | _download(base_url + filename, filename) 88 | 89 | train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) 90 | train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) 91 | test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) 92 | test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) 93 | 94 | return train_images, train_labels, test_images, test_labels 95 | 96 | 97 | def mnist(permute_train=False): 98 | """Download, parse and process MNIST data to unit scale and one-hot labels.""" 99 | train_images, train_labels, test_images, test_labels = mnist_raw() 100 | 101 | train_images = _partial_flatten(train_images) / np.float32(255.0) 102 | test_images = _partial_flatten(test_images) / np.float32(255.0) 103 | train_labels = _one_hot(train_labels, 10) 104 | test_labels = _one_hot(test_labels, 10) 105 | 106 | if permute_train: 107 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 108 | train_images = train_images[perm] 109 | train_labels = train_labels[perm] 110 | 111 | return train_images, train_labels, test_images, test_labels 112 | 113 | 114 | def cifar_raw(): 115 | """Download, unzip and parse the raw cifar dataset.""" 116 | 117 | filename = "cifar-10-python.tar.gz" 118 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 119 | _download(url, filename) 120 | _unzip(path.join(_DATA, filename)) 121 | 122 | data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"] 123 | data = [] 124 | labels = [] 125 | for batch in data_batches: 126 | tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch)) 127 | data.append(tmp_dict[b"data"]) 128 | labels.append(tmp_dict[b"labels"]) 129 | train_images = np.concatenate(data) 130 | train_labels = np.concatenate(labels) 131 | 132 | test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch")) 133 | test_images = test_dict[b"data"] 134 | test_labels = np.array(test_dict[b"labels"]) 135 | 136 | return train_images, train_labels, test_images, test_labels 137 | 138 | 139 | def cifar(permute_train=False): 140 | """Download, parse and process cifar data to unit scale and one-hot labels.""" 141 | 142 | train_images, train_labels, test_images, test_labels = cifar_raw() 143 | 144 | train_images = train_images / np.float32(255.0) 145 | test_images = test_images / np.float32(255.0) 146 | train_labels = _one_hot(train_labels, 10) 147 | test_labels = _one_hot(test_labels, 10) 148 | 149 | if permute_train: 150 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 151 | train_images = train_images[perm] 152 | train_labels = train_labels[perm] 153 | 154 | return train_images, train_labels, test_images, test_labels 155 | -------------------------------------------------------------------------------- /examples/mnist/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | 17 | """Datasets used in examples.""" 18 | 19 | 20 | import array 21 | import gzip 22 | import os 23 | import pickle 24 | import struct 25 | import tarfile 26 | import urllib.request 27 | from os import path 28 | 29 | import numpy as np 30 | 31 | _DATA = "/tmp/jax_example_data/" 32 | 33 | 34 | def _download(url, filename): 35 | """Download a url to a file in the JAX data temp directory.""" 36 | if not path.exists(_DATA): 37 | os.makedirs(_DATA) 38 | out_file = path.join(_DATA, filename) 39 | if not path.isfile(out_file): 40 | urllib.request.urlretrieve(url, out_file) 41 | print(f"downloaded {url} to {_DATA}") 42 | 43 | 44 | def _partial_flatten(x): 45 | """Flatten all but the first dimension of an ndarray.""" 46 | return np.reshape(x, (x.shape[0], -1)) 47 | 48 | 49 | def _one_hot(x, k, dtype=np.float32): 50 | """Create a one-hot encoding of x of size k.""" 51 | return np.array(x[:, None] == np.arange(k), dtype) 52 | 53 | 54 | def _unzip(file): 55 | file = tarfile.open(file) 56 | file.extractall(_DATA) 57 | file.close() 58 | return 59 | 60 | 61 | def _unpickle(file): 62 | with open(file, "rb") as fo: 63 | dict = pickle.load(fo, encoding="bytes") 64 | return dict 65 | 66 | 67 | def mnist_raw(): 68 | """Download and parse the raw MNIST dataset.""" 69 | # CVDF mirror of http://yann.lecun.com/exdb/mnist/ 70 | base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" 71 | 72 | def parse_labels(filename): 73 | with gzip.open(filename, "rb") as fh: 74 | _ = struct.unpack(">II", fh.read(8)) 75 | return np.array(array.array("B", fh.read()), dtype=np.uint8) 76 | 77 | def parse_images(filename): 78 | with gzip.open(filename, "rb") as fh: 79 | _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) 80 | return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols) 81 | 82 | for filename in [ 83 | "train-images-idx3-ubyte.gz", 84 | "train-labels-idx1-ubyte.gz", 85 | "t10k-images-idx3-ubyte.gz", 86 | "t10k-labels-idx1-ubyte.gz", 87 | ]: 88 | _download(base_url + filename, filename) 89 | 90 | train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) 91 | train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) 92 | test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) 93 | test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) 94 | 95 | return train_images, train_labels, test_images, test_labels 96 | 97 | 98 | def mnist(permute_train=False): 99 | """Download, parse and process MNIST data to unit scale and one-hot labels.""" 100 | train_images, train_labels, test_images, test_labels = mnist_raw() 101 | 102 | train_images = _partial_flatten(train_images) / np.float32(255.0) 103 | test_images = _partial_flatten(test_images) / np.float32(255.0) 104 | train_labels = _one_hot(train_labels, 10) 105 | test_labels = _one_hot(test_labels, 10) 106 | 107 | if permute_train: 108 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 109 | train_images = train_images[perm] 110 | train_labels = train_labels[perm] 111 | 112 | return train_images, train_labels, test_images, test_labels 113 | 114 | 115 | def cifar_raw(): 116 | """Download, unzip and parse the raw cifar dataset.""" 117 | 118 | filename = "cifar-10-python.tar.gz" 119 | url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 120 | _download(url, filename) 121 | _unzip(path.join(_DATA, filename)) 122 | 123 | data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"] 124 | data = [] 125 | labels = [] 126 | for batch in data_batches: 127 | tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch)) 128 | data.append(tmp_dict[b"data"]) 129 | labels.append(tmp_dict[b"labels"]) 130 | train_images = np.concatenate(data) 131 | train_labels = np.concatenate(labels) 132 | 133 | test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch")) 134 | test_images = test_dict[b"data"] 135 | test_labels = np.array(test_dict[b"labels"]) 136 | 137 | return train_images, train_labels, test_images, test_labels 138 | 139 | 140 | def cifar(permute_train=False): 141 | """Download, parse and process cifar data to unit scale and one-hot labels.""" 142 | 143 | train_images, train_labels, test_images, test_labels = cifar_raw() 144 | 145 | train_images = train_images / np.float32(255.0) 146 | test_images = test_images / np.float32(255.0) 147 | train_labels = _one_hot(train_labels, 10) 148 | test_labels = _one_hot(test_labels, 10) 149 | 150 | if permute_train: 151 | perm = np.random.RandomState(0).permutation(train_images.shape[0]) 152 | train_images = train_images[perm] 153 | train_labels = train_labels[perm] 154 | 155 | return train_images, train_labels, test_images, test_labels 156 | -------------------------------------------------------------------------------- /examples/mnist/flax/README.md: -------------------------------------------------------------------------------- 1 | ## MNIST classification 2 | 3 | Trains a simple convolutional network on the MNIST dataset. 4 | 5 | You can run this code and even modify it directly in Google Colab, no 6 | installation required: 7 | 8 | https://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb 9 | 10 | ### Requirements 11 | * TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary 12 | 13 | ### Example output 14 | 15 | | Name | Epochs | Walltime | Top-1 accuracy | Metrics | Workdir | 16 | | :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- | 17 | | default | 10 | 7.7m | 99.17% | [tfhub.dev] | [gs://flax_public/examples/mnist/default] | 18 | 19 | [tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0®exInput=default 20 | [gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default 21 | 22 | ``` 23 | I0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69 24 | I0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14 25 | ``` 26 | 27 | ### How to run 28 | 29 | `python main.py --workdir=/tmp/mnist --config=configs/default.py` 30 | 31 | #### Overriding Hyperparameter configurations 32 | 33 | MNIST example allows specifying a hyperparameter configuration by the means of 34 | setting `--config` flag. Configuration flag is defined using 35 | [config_flags](https://github.com/google/ml_collections/tree/master#config-flags). 36 | `config_flags` allows overriding configuration fields. This can be done as 37 | follows: 38 | 39 | ```shell 40 | python main.py \ 41 | --workdir=/tmp/mnist --config=configs/default.py \ 42 | --config.learning_rate=0.05 --config.num_epochs=5 43 | ``` 44 | -------------------------------------------------------------------------------- /examples/mnist/flax/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/graphcore-research/jax-scalify/f9aa9af123ce3969cf533212a673d99ad6823dbe/examples/mnist/flax/configs/__init__.py -------------------------------------------------------------------------------- /examples/mnist/flax/configs/default.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Default Hyperparameter configuration.""" 16 | 17 | import ml_collections 18 | 19 | 20 | def get_config(): 21 | """Get the default hyperparameter configuration.""" 22 | config = ml_collections.ConfigDict() 23 | 24 | config.learning_rate = 0.1 25 | config.momentum = 0.9 26 | config.batch_size = 128 27 | config.num_epochs = 10 28 | return config 29 | 30 | 31 | def metrics(): 32 | return [] 33 | -------------------------------------------------------------------------------- /examples/mnist/flax/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Main file for running the MNIST example. 16 | 17 | This file is intentionally kept short. The majority of logic is in libraries 18 | than can be easily tested and imported in Colab. 19 | """ 20 | 21 | import jax 22 | 23 | # import tensorflow as tf 24 | import train 25 | from absl import app, flags, logging 26 | from clu import platform 27 | from ml_collections import config_flags 28 | 29 | FLAGS = flags.FLAGS 30 | 31 | flags.DEFINE_string("workdir", None, "Directory to store model data.") 32 | config_flags.DEFINE_config_file( 33 | "config", 34 | None, 35 | "File path to the training hyperparameter configuration.", 36 | lock_config=True, 37 | ) 38 | 39 | 40 | def main(argv): 41 | if len(argv) > 1: 42 | raise app.UsageError("Too many command-line arguments.") 43 | 44 | # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make 45 | # it unavailable to JAX. 46 | # tf.config.experimental.set_visible_devices([], "GPU") 47 | 48 | logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) 49 | logging.info("JAX local devices: %r", jax.local_devices()) 50 | 51 | # Add a note so that we can tell which task is which JAX host. 52 | # (Depending on the platform task 0 is not guaranteed to be host 0) 53 | platform.work_unit().set_task_status( 54 | f"process_index: {jax.process_index()}, " f"process_count: {jax.process_count()}" 55 | ) 56 | platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY, FLAGS.workdir, "workdir") 57 | 58 | train.train_and_evaluate(FLAGS.config, FLAGS.workdir) 59 | 60 | 61 | if __name__ == "__main__": 62 | flags.mark_flags_as_required(["config", "workdir"]) 63 | app.run(main) 64 | -------------------------------------------------------------------------------- /examples/mnist/flax/requirements.txt: -------------------------------------------------------------------------------- 1 | clu 2 | flax 3 | ml-collections 4 | optax 5 | tensorflow-datasets 6 | -------------------------------------------------------------------------------- /examples/mnist/flax/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Flax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """MNIST example. 16 | 17 | Library file which executes the training and evaluation loop for MNIST. 18 | The data is loaded using tensorflow_datasets. 19 | """ 20 | 21 | # See issue #620. 22 | # pytype: disable=wrong-keyword-args 23 | 24 | import jax 25 | import jax.numpy as jnp 26 | import ml_collections 27 | import numpy as np 28 | import optax 29 | import tensorflow_datasets as tfds 30 | from absl import logging 31 | from flax import linen as nn # type:ignore 32 | 33 | # from flax.metrics import tensorboard 34 | from flax.training import train_state 35 | 36 | import jax_scalify as jsa 37 | 38 | 39 | class CNN(nn.Module): 40 | """A simple CNN model.""" 41 | 42 | @nn.compact 43 | def __call__(self, x): 44 | x = nn.Conv(features=32, kernel_size=(3, 3))(x) 45 | x = nn.relu(x) 46 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 47 | x = nn.Conv(features=64, kernel_size=(3, 3))(x) 48 | x = nn.relu(x) 49 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) 50 | x = x.reshape((x.shape[0], -1)) # flatten 51 | x = nn.Dense(features=256)(x) 52 | x = nn.relu(x) 53 | x = nn.Dense(features=10)(x) 54 | return x 55 | 56 | 57 | @jax.jit 58 | def apply_model(state, images, labels): 59 | """Computes gradients, loss and accuracy for a single batch.""" 60 | 61 | def loss_fn(params): 62 | logits = state.apply_fn({"params": params}, images) 63 | one_hot = jax.nn.one_hot(labels, 10) 64 | loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) 65 | return loss, logits 66 | 67 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 68 | (loss, logits), grads = grad_fn(state.params) 69 | accuracy = jnp.mean(jnp.argmax(logits, -1) == labels) 70 | return grads, loss, accuracy 71 | 72 | 73 | @jax.jit 74 | def update_model(state, grads): 75 | return state.apply_gradients(grads=grads) 76 | 77 | 78 | @jax.jit 79 | @jsa.scalify 80 | def apply_and_update_model(state, batch_images, batch_labels): 81 | # Jitting together forward + backward + update. 82 | grads, loss, accuracy = apply_model(state, batch_images, batch_labels) 83 | state = update_model(state, grads) 84 | return state, loss, accuracy 85 | 86 | 87 | def train_epoch(state, train_ds, batch_size, rng): 88 | """Train for a single epoch.""" 89 | train_ds_size = len(train_ds["image"]) 90 | steps_per_epoch = train_ds_size // batch_size 91 | 92 | perms = jax.random.permutation(rng, len(train_ds["image"])) 93 | perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch 94 | perms = perms.reshape((steps_per_epoch, batch_size)) 95 | 96 | epoch_loss = [] 97 | epoch_accuracy = [] 98 | 99 | for perm in perms: 100 | batch_images = train_ds["image"][perm, ...] 101 | batch_labels = train_ds["label"][perm, ...] 102 | # Transform batch to ScaledArray 103 | batch_images = jsa.as_scaled_array(batch_images) 104 | # Apply & update stages in scaled mode. 105 | state, loss, accuracy = apply_and_update_model(state, batch_images, batch_labels) 106 | 107 | epoch_loss.append(np.asarray(loss)) 108 | epoch_accuracy.append(np.asarray(accuracy)) 109 | 110 | train_loss = np.mean(epoch_loss) 111 | train_accuracy = np.mean(epoch_accuracy) 112 | return state, train_loss, train_accuracy 113 | 114 | 115 | def get_datasets(): 116 | """Load MNIST train and test datasets into memory.""" 117 | ds_builder = tfds.builder("mnist") 118 | ds_builder.download_and_prepare() 119 | train_ds = tfds.as_numpy(ds_builder.as_dataset(split="train", batch_size=-1)) 120 | test_ds = tfds.as_numpy(ds_builder.as_dataset(split="test", batch_size=-1)) 121 | train_ds["image"] = jnp.float32(train_ds["image"]) / 255.0 122 | test_ds["image"] = jnp.float32(test_ds["image"]) / 255.0 123 | return train_ds, test_ds 124 | 125 | 126 | def create_train_state(rng, config): 127 | """Creates initial `TrainState`.""" 128 | cnn = CNN() 129 | params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))["params"] 130 | tx = optax.sgd(config.learning_rate, config.momentum) 131 | return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx) 132 | 133 | 134 | def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train_state.TrainState: 135 | """Execute model training and evaluation loop. 136 | 137 | Args: 138 | config: Hyperparameter configuration for training and evaluation. 139 | workdir: Directory where the tensorboard summaries are written to. 140 | 141 | Returns: 142 | The train state (which includes the `.params`). 143 | """ 144 | train_ds, test_ds = get_datasets() 145 | rng = jax.random.key(0) 146 | 147 | # summary_writer = tensorboard.SummaryWriter(workdir) 148 | # summary_writer.hparams(dict(config)) 149 | 150 | rng, init_rng = jax.random.split(rng) 151 | init_rng = jax.random.PRNGKey(1) 152 | 153 | state = create_train_state(init_rng, config) 154 | # Convert model & optimizer states to `ScaledArray`` 155 | state = jsa.as_scaled_array(state) 156 | 157 | logging.info("Start Flax MNIST training...") 158 | 159 | for epoch in range(1, config.num_epochs + 1): 160 | rng, input_rng = jax.random.split(rng) 161 | state, train_loss, train_accuracy = train_epoch(state, train_ds, config.batch_size, input_rng) 162 | # NOTE: running evaluation on the plain normal arrays. 163 | _, test_loss, test_accuracy = apply_model(jsa.asarray(state), test_ds["image"], test_ds["label"]) 164 | 165 | logging.info( 166 | "epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f," 167 | " test_accuracy: %.2f" 168 | % ( 169 | epoch, 170 | train_loss, 171 | train_accuracy * 100, 172 | test_loss, 173 | test_accuracy * 100, 174 | ) 175 | ) 176 | 177 | # summary_writer.scalar("train_loss", train_loss, epoch) 178 | # summary_writer.scalar("train_accuracy", train_accuracy, epoch) 179 | # summary_writer.scalar("test_loss", test_loss, epoch) 180 | # summary_writer.scalar("test_accuracy", test_accuracy, epoch) 181 | 182 | # summary_writer.flush() 183 | return state 184 | -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_from_scratch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | """A basic MNIST example using Numpy and JAX. 17 | 18 | The primary aim here is simplicity and minimal dependencies. 19 | """ 20 | 21 | 22 | import time 23 | 24 | import datasets 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import numpy.random as npr 29 | from jax import grad, jit, lax 30 | 31 | import jax_scalify as jsa 32 | 33 | # from jax.scipy.special import logsumexp 34 | 35 | 36 | def logsumexp(a, axis=None, keepdims=False): 37 | dims = (axis,) 38 | amax = jnp.max(a, axis=dims, keepdims=keepdims) 39 | # FIXME: not proper scale propagation, introducing NaNs 40 | # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) 41 | amax = lax.stop_gradient(amax) 42 | out = lax.sub(a, amax) 43 | out = lax.exp(out) 44 | out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) 45 | return out 46 | 47 | 48 | def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): 49 | return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] 50 | 51 | 52 | def predict(params, inputs): 53 | activations = inputs 54 | for w, b in params[:-1]: 55 | # Matmul + relu 56 | outputs = jnp.dot(activations, w) + b 57 | activations = jax.nn.relu(outputs) 58 | 59 | final_w, final_b = params[-1] 60 | logits = jnp.dot(activations, final_w) + final_b 61 | # Dynamic rescaling of the gradient, as logits gradient not properly scaled. 62 | # logits = jsa.ops.dynamic_rescale_l2_grad(logits) 63 | logits = logits - logsumexp(logits, axis=1, keepdims=True) 64 | return logits 65 | 66 | 67 | def loss(params, batch): 68 | inputs, targets = batch 69 | preds = predict(params, inputs) 70 | targets = jsa.lax.rebalance(targets, np.float32(1 / 8)) 71 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) 72 | 73 | 74 | def accuracy(params, batch): 75 | inputs, targets = batch 76 | target_class = jnp.argmax(targets, axis=1) 77 | predicted_class = jnp.argmax(predict(params, inputs), axis=1) 78 | return jnp.mean(predicted_class == target_class) 79 | 80 | 81 | if __name__ == "__main__": 82 | layer_sizes = [784, 512, 512, 10] 83 | param_scale = 0.1 84 | step_size = 0.1 85 | num_epochs = 10 86 | batch_size = 128 87 | 88 | training_dtype = np.float16 89 | scale_dtype = np.float32 90 | 91 | train_images, train_labels, test_images, test_labels = datasets.mnist() 92 | num_train = train_images.shape[0] 93 | num_complete_batches, leftover = divmod(num_train, batch_size) 94 | num_batches = num_complete_batches + bool(leftover) 95 | 96 | def data_stream(): 97 | rng = npr.RandomState(0) 98 | while True: 99 | perm = rng.permutation(num_train) 100 | for i in range(num_batches): 101 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 102 | yield train_images[batch_idx], train_labels[batch_idx] 103 | 104 | batches = data_stream() 105 | params = init_random_params(param_scale, layer_sizes) 106 | # Transform parameters to `ScaledArray` and proper dtype. 107 | params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) 108 | params = jsa.tree.astype(params, training_dtype) 109 | 110 | @jit 111 | @jsa.scalify 112 | def update(params, batch): 113 | grads = grad(loss)(params, batch) 114 | return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] 115 | 116 | for epoch in range(num_epochs): 117 | start_time = time.time() 118 | for _ in range(num_batches): 119 | batch = next(batches) 120 | # Scaled micro-batch + training dtype cast. 121 | batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) 122 | batch = jsa.tree.astype(batch, training_dtype) 123 | 124 | with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): 125 | params = update(params, batch) 126 | 127 | epoch_time = time.time() - start_time 128 | 129 | # Evaluation in normal/unscaled float32, for consistency. 130 | raw_params = jsa.asarray(params, dtype=np.float32) 131 | train_acc = accuracy(raw_params, (train_images, train_labels)) 132 | test_acc = accuracy(raw_params, (test_images, test_labels)) 133 | print(f"Epoch {epoch} in {epoch_time:0.2f} sec") 134 | print(f"Training set accuracy {train_acc:0.5f}") 135 | print(f"Test set accuracy {test_acc:0.5f}") 136 | -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_from_scratch_fp8.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The JAX Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Modified by Graphcore Ltd 2024. 15 | 16 | """A basic MNIST example using Numpy and JAX. 17 | 18 | The primary aim here is simplicity and minimal dependencies. 19 | """ 20 | 21 | 22 | import time 23 | 24 | import datasets 25 | import jax.numpy as jnp 26 | import ml_dtypes 27 | import numpy as np 28 | import numpy.random as npr 29 | from jax import grad, jit, lax 30 | 31 | import jax_scalify as jsa 32 | 33 | # from functools import partial 34 | 35 | 36 | def print_mean_std(name, v): 37 | """Debugging method/tool for JAX Scalify.""" 38 | data, scale = jsa.lax.get_data_scale(v) 39 | # Always use np.float32, to avoid floating errors in descaling + stats. 40 | data = jsa.asarray(data, dtype=np.float32) 41 | m, s, min, max = np.mean(data), np.std(data), np.min(data), np.max(data) 42 | print(f"{name}: MEAN({m:.4f}) / STD({s:.4f}) / MIN({min:.4f}) / MAX({max:.4f}) / SCALE({scale:.4f})") 43 | 44 | 45 | def logsumexp(a, axis=None, keepdims=False): 46 | dims = (axis,) 47 | amax = jnp.max(a, axis=dims, keepdims=keepdims) 48 | # FIXME: not proper scale propagation, introducing NaNs 49 | # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) 50 | amax = lax.stop_gradient(amax) 51 | out = lax.sub(a, amax) 52 | out = lax.exp(out) 53 | out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) 54 | return out 55 | 56 | 57 | def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)): 58 | return [(scale * rng.randn(m, n), scale * rng.randn(n)) for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])] 59 | 60 | 61 | def predict(params, inputs, use_fp8=True): 62 | reduce_precision_on_forward = jsa.ops.reduce_precision_on_forward if use_fp8 else lambda x, d: x 63 | reduce_precision_on_backward = jsa.ops.reduce_precision_on_backward if use_fp8 else lambda x, d: x 64 | 65 | activations = inputs 66 | for w, b in params[:-1]: 67 | # Forward FP8 casting. 68 | w = reduce_precision_on_forward(w, ml_dtypes.float8_e4m3fn) 69 | activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn) 70 | # Matmul 71 | outputs = jnp.dot(activations, w) 72 | # Backward FP8 casting 73 | outputs = reduce_precision_on_backward(outputs, ml_dtypes.float8_e5m2) 74 | 75 | # Bias + relu 76 | outputs = outputs + b 77 | activations = jnp.maximum(outputs, 0) 78 | 79 | final_w, final_b = params[-1] 80 | # Forward FP8 casting. 81 | # final_w = jsa.ops.reduce_precision_on_forward(final_w, ml_dtypes.float8_e4m3fn) 82 | activations = reduce_precision_on_forward(activations, ml_dtypes.float8_e4m3fn) 83 | logits = jnp.dot(activations, final_w) 84 | # Backward FP8 casting 85 | logits = reduce_precision_on_backward(logits, ml_dtypes.float8_e5m2) 86 | 87 | logits = logits + final_b 88 | 89 | # Dynamic rescaling of the gradient, as logits gradient not properly scaled. 90 | logits = jsa.ops.dynamic_rescale_l2_grad(logits) 91 | logits = logits - logsumexp(logits, axis=1, keepdims=True) 92 | return logits 93 | 94 | 95 | def loss(params, batch): 96 | inputs, targets = batch 97 | preds = predict(params, inputs) 98 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) 99 | 100 | 101 | def accuracy(params, batch): 102 | inputs, targets = batch 103 | target_class = jnp.argmax(targets, axis=1) 104 | predicted_class = jnp.argmax(predict(params, inputs, use_fp8=False), axis=1) 105 | return jnp.mean(predicted_class == target_class) 106 | 107 | 108 | if __name__ == "__main__": 109 | layer_sizes = [784, 512, 512, 10] 110 | param_scale = 0.1 111 | step_size = 0.1 112 | num_epochs = 10 113 | batch_size = 128 114 | 115 | training_dtype = np.float16 116 | scale_dtype = np.float32 117 | 118 | train_images, train_labels, test_images, test_labels = datasets.mnist() 119 | num_train = train_images.shape[0] 120 | num_complete_batches, leftover = divmod(num_train, batch_size) 121 | num_batches = num_complete_batches + bool(leftover) 122 | 123 | def data_stream(): 124 | rng = npr.RandomState(0) 125 | while True: 126 | perm = rng.permutation(num_train) 127 | for i in range(num_batches): 128 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 129 | yield train_images[batch_idx], train_labels[batch_idx] 130 | 131 | batches = data_stream() 132 | params = init_random_params(param_scale, layer_sizes) 133 | # Transform parameters to `ScaledArray` and proper dtype. 134 | params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) 135 | params = jsa.tree.astype(params, training_dtype) 136 | 137 | @jit 138 | @jsa.scalify 139 | def update(params, batch): 140 | grads = grad(loss)(params, batch) 141 | return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)] 142 | 143 | for epoch in range(num_epochs): 144 | start_time = time.time() 145 | for _ in range(num_batches): 146 | batch = next(batches) 147 | # Scaled micro-batch + training dtype cast. 148 | batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) 149 | batch = jsa.tree.astype(batch, training_dtype) 150 | 151 | with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): 152 | params = update(params, batch) 153 | 154 | epoch_time = time.time() - start_time 155 | 156 | # Evaluation in float32, for consistency. 157 | raw_params = jsa.asarray(params, dtype=np.float32) 158 | train_acc = accuracy(raw_params, (train_images, train_labels)) 159 | test_acc = accuracy(raw_params, (test_images, test_labels)) 160 | print(f"Epoch {epoch} in {epoch_time:0.2f} sec") 161 | print(f"Training set accuracy {train_acc:0.5f}") 162 | print(f"Test set accuracy {test_acc:0.5f}") 163 | -------------------------------------------------------------------------------- /examples/mnist/mnist_classifier_mlp_flax.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | """A basic MNIST MLP training example using Flax and Optax. 3 | 4 | Similar to JAX MNIST from scratch, but using Flax and Optax libraries. 5 | 6 | This example aim is to show how Scalify can integrate with common 7 | NN libraries such as Flax and Optax. 8 | """ 9 | import time 10 | from functools import partial 11 | 12 | import datasets 13 | import jax 14 | import jax.numpy as jnp 15 | import numpy as np 16 | import optax 17 | from flax import linen as nn # type:ignore 18 | 19 | import jax_scalify as jsa 20 | 21 | # from jax.scipy.special import logsumexp 22 | 23 | 24 | def logsumexp(a, axis=None, keepdims=False): 25 | from jax import lax 26 | 27 | dims = (axis,) 28 | amax = jnp.max(a, axis=dims, keepdims=keepdims) 29 | # FIXME: not proper scale propagation, introducing NaNs 30 | # amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) 31 | amax = lax.stop_gradient(amax) 32 | out = lax.sub(a, amax) 33 | out = lax.exp(out) 34 | out = lax.add(lax.log(jnp.sum(out, axis=dims, keepdims=keepdims)), amax) 35 | return out 36 | 37 | 38 | class MLP(nn.Module): 39 | """A simple 3 layers MLP model.""" 40 | 41 | @nn.compact 42 | def __call__(self, x): 43 | x = nn.Dense(features=512, use_bias=True)(x) 44 | x = nn.relu(x) 45 | x = nn.Dense(features=512, use_bias=True)(x) 46 | x = nn.relu(x) 47 | x = nn.Dense(features=10, use_bias=True)(x) 48 | logprobs = x - logsumexp(x, axis=1, keepdims=True) 49 | return logprobs 50 | 51 | 52 | def loss(model, params, batch): 53 | inputs, targets = batch 54 | preds = model.apply(params, inputs) 55 | # targets = jsa.lax.rebalance(targets, np.float32(1 / 8)) 56 | return -jnp.mean(jnp.sum(preds * targets, axis=1)) 57 | 58 | 59 | def accuracy(model, params, batch): 60 | inputs, targets = batch 61 | target_class = jnp.argmax(targets, axis=1) 62 | preds = model.apply(params, inputs) 63 | predicted_class = jnp.argmax(preds, axis=1) 64 | return jnp.mean(predicted_class == target_class) 65 | 66 | 67 | def update(model, optimizer, model_state, opt_state, batch): 68 | grads = jax.grad(partial(loss, model))(model_state, batch) 69 | # Optimizer update (state & gradients). 70 | updates, opt_state = optimizer.update(grads, opt_state, model_state) 71 | model_state = optax.apply_updates(model_state, updates) 72 | return model_state, opt_state 73 | 74 | 75 | if __name__ == "__main__": 76 | step_size = 0.001 77 | num_epochs = 10 78 | batch_size = 128 79 | key = jax.random.PRNGKey(42) 80 | use_scalify: bool = True 81 | 82 | training_dtype = np.dtype(np.float16) 83 | optimizer_dtype = np.dtype(np.float16) 84 | scale_dtype = np.float32 85 | 86 | train_images, train_labels, test_images, test_labels = datasets.mnist() 87 | num_train = train_images.shape[0] 88 | num_complete_batches, leftover = divmod(num_train, batch_size) 89 | num_batches = num_complete_batches + bool(leftover) 90 | mnist_img_size = train_images.shape[-1] 91 | 92 | def data_stream(): 93 | rng = np.random.RandomState(0) 94 | while True: 95 | perm = rng.permutation(num_train) 96 | for i in range(num_batches): 97 | batch_idx = perm[i * batch_size : (i + 1) * batch_size] 98 | yield train_images[batch_idx], train_labels[batch_idx] 99 | 100 | # Build model & initialize model parameters. 101 | model = MLP() 102 | model_state = model.init(key, np.zeros((batch_size, mnist_img_size), dtype=training_dtype)) 103 | # Optimizer & optimizer state. 104 | # opt = optax.sgd(learning_rate=step_size) 105 | opt = optax.adam(learning_rate=step_size, eps=2**-16) 106 | opt_state = opt.init(model_state) 107 | # Freeze model, optimizer (with step size). 108 | update_fn = partial(update, model, opt) 109 | 110 | if use_scalify: 111 | # Transform parameters to `ScaledArray`. 112 | model_state = jsa.as_scaled_array(model_state, scale=scale_dtype(1.0)) 113 | opt_state = jsa.as_scaled_array(opt_state, scale=scale_dtype(0.0001)) 114 | # Scalify the update function as well. 115 | update_fn = jsa.scalify(update_fn) 116 | # Convert the model state (weights) & optimizer state to proper dtype. 117 | model_state = jsa.tree.astype(model_state, training_dtype) 118 | opt_state = jsa.tree.astype(opt_state, optimizer_dtype, floating_only=True) 119 | 120 | print(f"Using Scalify: {use_scalify}") 121 | print(f"Training data format: {training_dtype.name}") 122 | print(f"Optimizer data format: {optimizer_dtype.name}") 123 | print("") 124 | 125 | update_fn = jax.jit(update_fn) 126 | 127 | batches = data_stream() 128 | for epoch in range(num_epochs): 129 | start_time = time.time() 130 | 131 | for _ in range(num_batches): 132 | batch = next(batches) 133 | # Scaled micro-batch + training dtype cast. 134 | batch = jsa.tree.astype(batch, training_dtype) 135 | if use_scalify: 136 | batch = jsa.as_scaled_array(batch, scale=scale_dtype(1)) 137 | with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): 138 | model_state, opt_state = update_fn(model_state, opt_state, batch) 139 | 140 | epoch_time = time.time() - start_time 141 | 142 | # Evaluation in normal/unscaled float32, for consistency. 143 | unscaled_params = jsa.asarray(model_state, dtype=np.float32) 144 | train_acc = accuracy(model, unscaled_params, (train_images, train_labels)) 145 | test_acc = accuracy(model, unscaled_params, (test_images, test_labels)) 146 | print(f"Epoch {epoch} in {epoch_time:0.2f} sec") 147 | print(f"Training set accuracy {train_acc:0.5f}") 148 | print(f"Test set accuracy {test_acc:0.5f}") 149 | -------------------------------------------------------------------------------- /jax_scalify/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from . import core, lax, ops, tree, utils 3 | from .core import ( # noqa: F401 4 | Pow2RoundMode, 5 | ScaledArray, 6 | ScalifyConfig, 7 | as_scaled_array, 8 | asarray, 9 | debug_callback, 10 | scaled_array, 11 | scalify, 12 | ) 13 | from .version import __version__ 14 | -------------------------------------------------------------------------------- /jax_scalify/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from .datatype import ( # noqa: F401 3 | DTypeLike, 4 | ScaledArray, 5 | Shape, 6 | as_scaled_array, 7 | asarray, 8 | get_scale_dtype, 9 | is_scaled_leaf, 10 | is_static_anyscale, 11 | is_static_one_scalar, 12 | is_static_zero, 13 | make_scaled_scalar, 14 | scaled_array, 15 | ) 16 | from .debug import debug_callback # noqa: F401 17 | from .interpreters import ( # noqa: F401 18 | ScaledPrimitiveType, 19 | ScalifyConfig, 20 | find_registered_scaled_op, 21 | get_scalify_config, 22 | register_scaled_lax_op, 23 | register_scaled_op, 24 | scalify, 25 | ) 26 | from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401 27 | from .typing import Array, ArrayTypes, Sharding, get_numpy_api # noqa: F401 28 | from .utils import safe_div, safe_reciprocal # noqa: F401 29 | -------------------------------------------------------------------------------- /jax_scalify/core/datatype.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Any, Optional, Union 4 | 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | from chex import Shape 9 | from jax.core import ShapedArray 10 | from jax.tree_util import register_pytree_node_class 11 | from numpy.typing import ArrayLike, DTypeLike, NDArray 12 | 13 | from .pow2 import Pow2RoundMode, pow2_decompose 14 | from .typing import Array, ArrayTypes 15 | 16 | if TYPE_CHECKING: 17 | GenericArray = Union[Array, np.ndarray[Any, Any]] 18 | else: 19 | GenericArray = Union[Array, np.ndarray] 20 | 21 | 22 | @register_pytree_node_class 23 | @dataclass 24 | class ScaledArray: 25 | """ScaledArray: dataclass associating data and a scale. 26 | 27 | JAX Scaled Arithmetics provides a consistent JAX LAX implementation 28 | propagating scaling for low-precision arithmetics. 29 | 30 | Semantics: `ScaledArray` represents an array with the following values: 31 | self.data * self.scale 32 | where `self.scale` is always assumed to be broadcastable to `self.data`. 33 | 34 | Notes: 35 | 1. Current implementation only supports `scale` being a scalar. 36 | 2. `data` and `scale` can have different dtypes. `data` dtype is used as the 37 | reference dtype. Meaning a power of 2 `scale` is just a dtype `E8M0` for instance. 38 | 39 | Args: 40 | data: Un-scaled data array. 41 | scale: Scale array (scalar only supported at the moment). 42 | If `scale` is None, equivalent to a normal array. 43 | """ 44 | 45 | data: GenericArray 46 | scale: GenericArray 47 | 48 | def __post_init__(self): 49 | # Always have a Numpy array as `data`. 50 | if isinstance(self.data, np.number): 51 | object.__setattr__(self, "data", np.array(self.data)) 52 | # TODO/FIXME: support number as data? 53 | assert isinstance(self.data, (*ArrayTypes, np.ndarray)) 54 | assert isinstance(self.scale, (*ArrayTypes, np.ndarray, np.number)) 55 | # Only supporting scale scalar for now. 56 | assert self.scale.shape == () 57 | 58 | def tree_flatten(self): 59 | # See official JAX documentation on extending PyTrees. 60 | # Note: using explicit tree flatten instead of chex for MyPy compatibility. 61 | children = (self.data, self.scale) 62 | return (children, None) 63 | 64 | @classmethod 65 | def tree_unflatten(cls, aux_data, children): 66 | # See official JAX documentation on extending PyTrees. 67 | assert len(children) == 2 68 | return cls(children[0], children[1]) 69 | 70 | @property 71 | def dtype(self) -> DTypeLike: 72 | return self.data.dtype 73 | 74 | @property 75 | def shape(self) -> Shape: 76 | return self.data.shape 77 | 78 | @property 79 | def size(self) -> int: 80 | return self.data.size 81 | 82 | def to_array(self, dtype: DTypeLike = None) -> GenericArray: 83 | """Convert to the scaled array to a Numpy/JAX array. 84 | 85 | Args: 86 | dtype: Optional conversion dtype. `data.dtype` by default. 87 | """ 88 | dtype = self.data.dtype if dtype is None else dtype 89 | data = self.data.astype(dtype) 90 | scale = self.scale.astype(dtype) 91 | values = data * scale 92 | return values 93 | 94 | def __array__(self, dtype: DTypeLike = None) -> NDArray[Any]: 95 | """Numpy array interface support.""" 96 | return np.asarray(self.to_array(dtype)) 97 | 98 | @property 99 | def aval(self) -> ShapedArray: 100 | """Abstract value of the scaled array, i.e. shape and dtype.""" 101 | return ShapedArray(self.data.shape, self.data.dtype) 102 | 103 | def astype(self, dtype: DTypeLike) -> "ScaledArray": 104 | """Convert the ScaledArray to a dtype. 105 | NOTE: only impacting `data` field, not the `scale` tensor. 106 | """ 107 | return ScaledArray(self.data.astype(dtype), self.scale) 108 | 109 | 110 | def make_scaled_scalar(val: Array, scale_dtype: Optional[DTypeLike] = None) -> ScaledArray: 111 | """Make a scaled scalar (array), from a single value. 112 | 113 | The returned scalar will always be built such that: 114 | - data is scalar in [1, 2) 115 | - scale is a power-of-2 value. 116 | 117 | NOTE: data is chosen in [1, 2) instead of [0, 1) in order to 118 | keep any value representable in the same dtype, without overflowing. 119 | 120 | NOTE bis: only supporting floating point input. 121 | """ 122 | # FIXME: implicit conversion from float64 to float32??? 123 | if isinstance(val, float): 124 | val = np.float32(val) 125 | assert np.ndim(val) == 0 126 | assert np.issubdtype(val.dtype, np.floating) 127 | # Scale dtype to use. TODO: check the scale dtype is valid? 128 | scale_dtype = scale_dtype or val.dtype 129 | # Split mantissa and exponent in data and scale components. 130 | scale, mantissa = pow2_decompose(val, scale_dtype=scale_dtype, mode=Pow2RoundMode.DOWN) 131 | return ScaledArray(mantissa, scale) 132 | 133 | 134 | def is_scaled_leaf(val: Any) -> bool: 135 | """Is input a normal JAX PyTree leaf (i.e. `Array`) or `ScaledArray1. 136 | 137 | This function is useful for JAX PyTree handling with `jax.tree` methods where 138 | the user wants to keep the ScaledArray data structures (i.e. not flattened as a 139 | pair of arrays). 140 | 141 | See `jax_scalify.tree` for PyTree `jax.tree` methods compatible with `ScaledArray`. 142 | """ 143 | # TODO: check Numpy scalars as well? 144 | return np.isscalar(val) or isinstance(val, (Array, np.ndarray, ScaledArray)) 145 | 146 | 147 | def scaled_array_base(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray: 148 | """ScaledArray (helper) base factory method, similar to `(j)np.array`.""" 149 | data = npapi.asarray(data, dtype=dtype) 150 | scale = npapi.asarray(scale) 151 | return ScaledArray(data, scale) 152 | 153 | 154 | def scaled_array(data: ArrayLike, scale: ArrayLike, dtype: DTypeLike = None, npapi: Any = jnp) -> ScaledArray: 155 | """ScaledArray (helper) factory method, similar to `(j)np.array`. 156 | 157 | Args: 158 | data: Main data/values. 159 | scale: Scale tensor. 160 | dtype: Optional dtype to use for the data. 161 | npapi: Numpy API to use. 162 | Returns: 163 | Scaled array instance. 164 | """ 165 | return scaled_array_base(data, scale, dtype, npapi) 166 | 167 | 168 | def as_scaled_array_base( 169 | val: Any, scale: Optional[ArrayLike] = None, scale_dtype: Optional[DTypeLike] = None 170 | ) -> Union[Array, ScaledArray]: 171 | """ScaledArray (helper) base factory method, similar to `(j)np.array`. 172 | 173 | Args: 174 | val: Value to convert to scaled array. 175 | scale: Optional scale value. 176 | scale_dtype: Optional (default) scale dtype. 177 | """ 178 | if isinstance(val, ScaledArray): 179 | return val 180 | 181 | assert scale is None or scale_dtype is None 182 | # Simple case => when can ignore the scaling factor (i.e. 1 implicitely). 183 | is_static_one_scale: bool = scale is None or is_static_one_scalar(scale) # type:ignore 184 | # Trivial cases: bool, int, float. 185 | if is_static_one_scale and isinstance(val, (bool, int)): 186 | return val 187 | if is_static_one_scale and isinstance(val, float): 188 | return make_scaled_scalar(np.float32(val), scale_dtype) 189 | 190 | # Ignored dtypes by default: int and bool 191 | ignored_dtype = np.issubdtype(val.dtype, np.integer) or np.issubdtype(val.dtype, np.bool_) 192 | if ignored_dtype: 193 | return val 194 | # Floating point scalar 195 | if val.ndim == 0 and is_static_one_scale: 196 | return make_scaled_scalar(val, scale_dtype) 197 | 198 | scale_dtype = scale_dtype or val.dtype 199 | scale = np.array(1, dtype=scale_dtype) if scale is None else scale 200 | if isinstance(val, (np.ndarray, *ArrayTypes)): 201 | if is_static_one_scale: 202 | return ScaledArray(val, scale) 203 | else: 204 | return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore 205 | 206 | # TODO: fix bug when scale is not 1. 207 | raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") # type:ignore 208 | # return scaled_array_base(val, scale) 209 | 210 | 211 | def as_scaled_array(val: Any, scale: Optional[ArrayLike] = None) -> ScaledArray: 212 | """ScaledArray (helper) factory method, similar to `(j)np.array`. 213 | 214 | NOTE: by default, int and bool values/arrays will be returned unchanged, as 215 | in most cases, there is no value representing these as scaled arrays. 216 | 217 | Compatible with JAX PyTree. 218 | 219 | Args: 220 | val: Main data/values or existing ScaledArray. 221 | scale: Optional scale to use when (potentially) converting. 222 | Returns: 223 | Scaled array instance. 224 | """ 225 | return jax.tree_util.tree_map(lambda x: as_scaled_array_base(x, scale), val, is_leaf=is_scaled_leaf) 226 | 227 | 228 | def asarray_base(val: Any, dtype: DTypeLike = None) -> GenericArray: 229 | """Convert back to a common JAX/Numpy array, base function.""" 230 | if isinstance(val, ScaledArray): 231 | return val.to_array(dtype=dtype) 232 | elif isinstance(val, (*ArrayTypes, np.ndarray)): 233 | if dtype is None: 234 | return val 235 | return val.astype(dtype=dtype) 236 | # Convert to Numpy all other cases? 237 | return np.asarray(val, dtype=dtype) 238 | 239 | 240 | def asarray(val: Any, dtype: DTypeLike = None) -> GenericArray: 241 | """Convert back to a common JAX/Numpy array. 242 | 243 | Compatible with JAX PyTree. 244 | 245 | Args: 246 | dtype: Optional dtype of the final array. 247 | """ 248 | return jax.tree_util.tree_map(lambda x: asarray_base(x, dtype), val, is_leaf=is_scaled_leaf) 249 | 250 | 251 | def is_numpy_scalar_or_array(val): 252 | return isinstance(val, np.ndarray) or np.isscalar(val) 253 | 254 | 255 | def is_static_zero(val: Union[Array, ScaledArray]) -> Array: 256 | """Is a scaled array a static zero value (i.e. zero during JAX tracing as well)? 257 | 258 | Returns a boolean Numpy array of the shape of the input. 259 | """ 260 | if is_numpy_scalar_or_array(val): 261 | return np.equal(val, 0) 262 | if isinstance(val, ScaledArray): 263 | data_mask = ( 264 | np.equal(val.data, 0) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_) 265 | ) 266 | scale_mask = ( 267 | np.equal(val.scale, 0) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_) 268 | ) 269 | return np.logical_or(data_mask, scale_mask) 270 | # By default: can't decide. 271 | return np.zeros(val.shape, dtype=np.bool_) 272 | 273 | 274 | def is_static_anyscale(val: Union[Array, ScaledArray]) -> Array: 275 | """Is a scaled array a static anyscale values (i.e. 0/inf/-inf during JAX tracing as well)? 276 | 277 | Returns a boolean Numpy array of the shape of the input. 278 | """ 279 | 280 | def np_anyscale(arr): 281 | # Check if 0, np.inf or -np.inf 282 | absarr = np.abs(arr) 283 | return np.logical_or(np.equal(absarr, 0), np.equal(absarr, np.inf)) 284 | 285 | if is_numpy_scalar_or_array(val): 286 | return np_anyscale(val) 287 | if isinstance(val, ScaledArray): 288 | # TODO: deal with 0 * inf issue? 289 | data_mask = ( 290 | np_anyscale(val.data) if is_numpy_scalar_or_array(val.data) else np.zeros(val.data.shape, dtype=np.bool_) 291 | ) 292 | scale_mask = ( 293 | np_anyscale(val.scale) if is_numpy_scalar_or_array(val.scale) else np.zeros(val.scale.shape, dtype=np.bool_) 294 | ) 295 | return np.logical_or(data_mask, scale_mask) 296 | # By default: can't decide. 297 | return np.zeros(val.shape, dtype=np.bool_) 298 | 299 | 300 | def is_static_one_scalar(val: Array) -> Union[bool, np.bool_]: 301 | """Is a scaled array a static one scalar value (i.e. one during JAX tracing as well)?""" 302 | if isinstance(val, (int, float)): 303 | return val == 1 304 | elif is_numpy_scalar_or_array(val) and val.size == 1: 305 | return np.all(np.equal(val, 1)) 306 | elif isinstance(val, ScaledArray) and val.size == 1: 307 | return is_static_one_scalar(val.data) and is_static_one_scalar(val.scale) 308 | return False 309 | 310 | 311 | def get_scale_dtype(val: Any) -> DTypeLike: 312 | """Get the scale dtype. Compatible with arrays and scaled arrays.""" 313 | if isinstance(val, ScaledArray): 314 | return val.scale.dtype 315 | return val.dtype 316 | -------------------------------------------------------------------------------- /jax_scalify/core/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Any, Callable, Dict 3 | 4 | from jax import tree_util 5 | from jax._src.debugging import debug_callback as debug_callback_orig 6 | from jax._src.debugging import debug_callback_p 7 | 8 | from .interpreters import ScaledArray, register_scaled_op 9 | 10 | 11 | def get_debug_callback_effect(ordered: bool) -> Any: 12 | """Backward compatible effect factory method.""" 13 | try: 14 | from jax._src.debugging import debug_effect, ordered_debug_effect 15 | 16 | return ordered_debug_effect if ordered else debug_effect 17 | except ImportError: 18 | from jax._src.debugging import DebugEffect 19 | 20 | return DebugEffect.ORDERED_PRINT if ordered else DebugEffect.PRINT 21 | 22 | 23 | def debug_callback(callback: Callable[..., Any], *args: Any, ordered: bool = False, **kwargs: Any) -> None: 24 | # We need our custom version of `debug_callback` to deal with 25 | # changing JAX pytrees. 26 | # FIXME: probably patch `debug_callback` in JAX. 27 | flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) 28 | effect = get_debug_callback_effect(ordered) 29 | 30 | def _flat_callback(*flat_args): 31 | args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) 32 | callback(*args, **kwargs) 33 | return [] 34 | 35 | # Storing in original PyTree and callback function. 36 | # Allowing custom interpreters to retrieve and modify this information. 37 | _flat_callback.__callback_fn = callback # type:ignore 38 | _flat_callback.__callback_in_tree = in_tree # type:ignore 39 | debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) 40 | 41 | 42 | debug_callback.__doc__ = debug_callback_orig.__doc__ 43 | 44 | 45 | def scaled_debug_callback(*args: ScaledArray, **params: Dict[str, Any]) -> Any: 46 | """Scaled `debug_callback`: properly forwarding ScaledArrays 47 | to host callback. 48 | """ 49 | flat_callback_fn = params["callback"] 50 | if not hasattr(flat_callback_fn, "__callback_fn"): 51 | raise NotImplementedError("Please use `jsa.debug_callback` function instead of original JAX function.") 52 | callback_fn = flat_callback_fn.__callback_fn 53 | in_pytree = flat_callback_fn.__callback_in_tree # type:ignore 54 | # Re-build original input, with scaled arrays. 55 | scaled_args, scaled_kwargs = tree_util.tree_unflatten(in_pytree, args) 56 | # Re-build ordered boolean, in a backward compatible way. 57 | ordered = "ordered" in str(params["effect"]).lower() 58 | debug_callback(callback_fn, *scaled_args, ordered=ordered, **scaled_kwargs) 59 | return [] 60 | 61 | 62 | register_scaled_op(debug_callback_p, scaled_debug_callback) 63 | -------------------------------------------------------------------------------- /jax_scalify/core/pow2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import logging 3 | from enum import IntEnum 4 | from functools import partial 5 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 6 | 7 | import jax.numpy as jnp 8 | import ml_dtypes 9 | import numpy as np 10 | from jax import core 11 | from jax.interpreters import mlir 12 | from jax.interpreters.mlir import LoweringRuleContext, ir 13 | from numpy.typing import DTypeLike, NDArray 14 | 15 | from .typing import Array, get_numpy_api 16 | 17 | # Exponent bits masking. 18 | _exponent_bits_mask: Dict[Any, NDArray[Any]] = { 19 | np.dtype(jnp.bfloat16): np.packbits( 20 | np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8) 21 | ).view(np.int16), 22 | # Copy for ml_dtypes.bfloat16, distinct in older JAX versions. 23 | np.dtype(ml_dtypes.bfloat16): np.packbits( 24 | np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], dtype=np.uint8) 25 | ).view(np.int16), 26 | np.dtype(np.float16): np.packbits(np.array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], dtype=np.uint8)).view( 27 | np.int16 28 | ), 29 | np.dtype(np.float32): np.packbits( 30 | np.array( 31 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1], 32 | dtype=np.uint8, 33 | ) 34 | ).view(np.int32), 35 | np.dtype(np.float64): np.array(np.inf, np.float64).view(np.int64), 36 | } 37 | """Exponents bit masking: explicit bitmask to keep only exponent bits in floating point values. 38 | 39 | NOTE: normally should also correspond to `np.inf` value for FP16 and FP32. 40 | """ 41 | 42 | 43 | def dtype_exponent_mask(dtype: DTypeLike, sign_bit: bool = False) -> NDArray[Any]: 44 | """Get the exponent mask for a given Numpy/JAX dtype. 45 | 46 | Args: 47 | dtype: Numpy/JAX dtype. 48 | sign_bit: Include sign bit in the mask. 49 | Returns: 50 | Array mask as integer dtype. 51 | """ 52 | mask = _exponent_bits_mask[dtype] 53 | if sign_bit: 54 | # Negative value to add sign. 55 | intdtype = mask.dtype 56 | mask = (-mask.view(dtype)).view(intdtype) 57 | return mask 58 | return mask 59 | 60 | 61 | def pow2_decompose_round_down_impl(vin: Array, scale_dtype: DTypeLike) -> Array: 62 | """Pow-2 decompose with rounding down. 63 | 64 | Returns: 65 | (scale, vout) such that vin = scale * vout 66 | """ 67 | np_api = get_numpy_api(vin) 68 | # Perform all computations in FP32, to support FP16 submormals. 69 | # NOTE: `jnp.frexp` is buggy for subnormals. 70 | dtype = np.dtype(np.float32) 71 | minval = np.finfo(dtype).smallest_normal 72 | exponent_mask = dtype_exponent_mask(dtype) 73 | intdtype = exponent_mask.dtype 74 | val = vin.astype(dtype) 75 | # Masking mantissa bits, keeping only the exponents ones. 76 | scale_pow2 = np_api.bitwise_and(val.view(intdtype), exponent_mask).view(val.dtype).reshape(val.shape) 77 | # Get the mantissa in float32. Make sure we don't divide by zero, and handle nan/inf. 78 | normal_scale_val = np_api.logical_and(np_api.isfinite(scale_pow2), scale_pow2 != 0) 79 | scale_renorm = np_api.where(normal_scale_val, scale_pow2, minval) 80 | mantissa = val / scale_renorm 81 | return scale_pow2.astype(scale_dtype), mantissa.astype(vin.dtype) 82 | 83 | 84 | class Pow2RoundMode(IntEnum): 85 | """Power-of-two supported rounded mode.""" 86 | 87 | NONE = 0 88 | DOWN = 1 89 | UP = 2 90 | STOCHASTIC = 3 91 | 92 | 93 | pow2_decompose_p = core.Primitive("pow2_decompose") 94 | """`pow2_decompose` pow2 decompose JAX primitive. 95 | """ 96 | 97 | 98 | def pow2_decompose( 99 | vin: Array, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN 100 | ) -> Tuple[Array, Array]: 101 | """Power-2 decompose, i.e. vin = s * vout where s is a power-of 2 scaling. 102 | 103 | Args: 104 | vin: Input array. 105 | scale_dtype: Scale dtype to use. 106 | mode: Pow2 rounding. 107 | Returns: 108 | (scale, vout) such that vin = scale * vout 109 | """ 110 | scale_dtype = np.dtype(scale_dtype or vin.dtype) 111 | # A couple of checks on dtypes. 112 | assert np.issubdtype(vin.dtype, np.floating) 113 | assert np.issubdtype(scale_dtype, np.floating) 114 | if scale_dtype == np.float16: 115 | logging.warning("`pow2_decompose` does not support FP16 sub-normals when using FP16 scale dtype.") 116 | out = pow2_decompose_p.bind(vin, scale_dtype=scale_dtype, mode=mode) 117 | return out 118 | 119 | 120 | def pow2_decompose_eager_impl( 121 | vin: Array, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN 122 | ) -> Tuple[Array, Array]: 123 | """Eager mode implementation, on JAX/Numpy arrays.""" 124 | if mode == Pow2RoundMode.DOWN: 125 | return pow2_decompose_round_down_impl(vin, scale_dtype) 126 | raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.") 127 | 128 | 129 | def pow2_decompose_abstract_eval( 130 | vin: core.ShapedArray, scale_dtype: Optional[DTypeLike] = None, mode: Pow2RoundMode = Pow2RoundMode.DOWN 131 | ) -> Tuple[core.ShapedArray, core.ShapedArray]: 132 | scale_dtype = scale_dtype or vin.dtype 133 | sout = core.ShapedArray(vin.shape, dtype=scale_dtype) 134 | return (sout, vin) 135 | 136 | 137 | def pow2_decompose_mlir_lowering( 138 | ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any] 139 | ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: 140 | scale_dtype = params["scale_dtype"] 141 | mode = params["mode"] 142 | pow2_decompose_fn = partial(pow2_decompose_eager_impl, scale_dtype=scale_dtype, mode=mode) 143 | outputs = mlir.lower_fun(pow2_decompose_fn, multiple_results=True)(ctx, *args) 144 | return outputs 145 | 146 | 147 | # Register as standard JAX primitive 148 | pow2_decompose_p.multiple_results = True 149 | pow2_decompose_p.def_abstract_eval(pow2_decompose_abstract_eval) 150 | pow2_decompose_p.def_impl(pow2_decompose_eager_impl) 151 | # Default lowering on GPU, TPU, ... 152 | mlir.register_lowering(pow2_decompose_p, pow2_decompose_mlir_lowering) 153 | 154 | 155 | def pow2_round_down(val: Array) -> Array: 156 | """Round down to the closest power of 2.""" 157 | # Keep only the scale component of `pow2_decompose` 158 | pow2_val, _ = pow2_decompose(val, scale_dtype=val.dtype, mode=Pow2RoundMode.DOWN) 159 | return pow2_val 160 | 161 | 162 | def pow2_round_up(val: Array) -> Array: 163 | """Round up to the closest power of 2. 164 | NOTE: may overflow to inf. 165 | """ 166 | # FIXME: rounding when already a power of 2. 167 | # Should do additional masking to check that. 168 | pow2_val = pow2_round_down(val) * np.array(2, dtype=val.dtype) 169 | return pow2_val 170 | 171 | 172 | def pow2_round(val: Array, mode: Pow2RoundMode = Pow2RoundMode.DOWN) -> Array: 173 | """Power-of-two rounding.""" 174 | if mode == Pow2RoundMode.NONE: 175 | return val 176 | elif mode == Pow2RoundMode.DOWN: 177 | return pow2_round_down(val) 178 | elif mode == Pow2RoundMode.UP: 179 | return pow2_round_up(val) 180 | raise NotImplementedError(f"Unsupported power-of-2 rounding mode '{mode}'.") 181 | 182 | 183 | def get_mantissa(val: Array) -> Array: 184 | """Extract the mantissa of an array, masking the exponent. 185 | 186 | Similar to `numpy.frexp`, but with implicit bit to be consistent with 187 | `pow2_round_down`. 188 | """ 189 | _, mantissa = pow2_decompose(val, scale_dtype=val.dtype, mode=Pow2RoundMode.DOWN) 190 | return mantissa 191 | -------------------------------------------------------------------------------- /jax_scalify/core/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Any, Tuple 3 | 4 | # import chex 5 | import jax 6 | import jax.numpy as jnp 7 | import numpy as np 8 | 9 | # Type aliasing. To be compatible with JAX 0.3 as well. 10 | try: 11 | from jax import Array 12 | from jax.sharding import Sharding 13 | 14 | ArrayTypes: Tuple[Any, ...] = (Array,) 15 | except ImportError: 16 | from jaxlib.xla_extension import DeviceArray as Array 17 | 18 | Sharding = Any 19 | # Older version of JAX <0.4 20 | ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer) 21 | 22 | try: 23 | from jax.stages import ArgInfo 24 | 25 | # Additional ArgInfo in recent JAX versions. 26 | ArrayTypes = (*ArrayTypes, ArgInfo) 27 | except ImportError: 28 | pass 29 | 30 | 31 | def get_numpy_api(val: Any) -> Any: 32 | """Get the Numpy API corresponding to an array. 33 | 34 | Using the NumPy API whenever possible when tracing a JAX graph 35 | allows for simple constant folding optimization. 36 | 37 | JAX or classic Numpy supported. 38 | """ 39 | if isinstance(val, (np.ndarray, np.number)): 40 | return np 41 | if isinstance(val, ArrayTypes): 42 | return jnp 43 | raise NotImplementedError(f"Unsupported input type '{type(val)}'. No matching Numpy API.") 44 | -------------------------------------------------------------------------------- /jax_scalify/core/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Any 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | import numpy as np 7 | 8 | from .typing import Array 9 | 10 | 11 | def safe_div(lhs: Array, rhs: Array) -> Array: 12 | """Safe (scalar) div: if rhs is zero, returns zero.""" 13 | assert lhs.shape == () 14 | assert rhs.shape == () 15 | # assert lhs.dtype == rhs.dtype 16 | # Numpy inputs => direct computation. 17 | is_npy_inputs = isinstance(lhs, (np.number, np.ndarray)) and isinstance(rhs, (np.number, np.ndarray)) 18 | if is_npy_inputs: 19 | return np.divide(lhs, rhs, out=np.array(0, dtype=rhs.dtype), where=rhs != 0) 20 | # JAX general implementation. 21 | return jax.lax.select(rhs == 0, rhs, jnp.divide(lhs, rhs)) 22 | 23 | 24 | def safe_reciprocal(val: Array) -> Array: 25 | """Safe (scalar) reciprocal: if val is zero, returns zero.""" 26 | assert val.shape == () 27 | # Numpy inputs => direct computation. 28 | if isinstance(val, (np.number, np.ndarray)): 29 | return np.reciprocal(val, out=np.array(0, dtype=val.dtype), where=val != 0) 30 | # JAX general implementation. 31 | return jax.lax.select(val == 0, val, jax.lax.reciprocal(val)) 32 | 33 | 34 | def python_scalar_as_numpy(val: Any) -> Any: 35 | """Convert Python scalar to Numpy scalar, if possible. 36 | 37 | Using by default JAX 32 bits precision, instead of 64 bits. 38 | 39 | Returning unchanged value if not any (bool, int, float). 40 | """ 41 | if isinstance(val, bool): 42 | return np.bool_(val) 43 | elif isinstance(val, int): 44 | return np.int32(val) 45 | elif isinstance(val, float): 46 | return np.float32(val) 47 | return val 48 | -------------------------------------------------------------------------------- /jax_scalify/lax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from .base_scaling_primitives import ( # noqa: F401 3 | get_data_scale, 4 | get_data_scale_p, 5 | rebalance, 6 | set_scaling, 7 | set_scaling_p, 8 | stop_scaling, 9 | stop_scaling_p, 10 | ) 11 | from .scaled_ops_common import * # noqa: F401, F403 12 | from .scaled_ops_l2 import * # noqa: F401, F403 13 | -------------------------------------------------------------------------------- /jax_scalify/lax/base_scaling_primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import logging 3 | from typing import Any, Dict, Optional, Sequence, Union 4 | 5 | import numpy as np 6 | from jax import core 7 | from jax.interpreters import mlir 8 | from jax.interpreters.mlir import LoweringRuleContext, ir, ir_constant 9 | 10 | from jax_scalify.core import ( 11 | Array, 12 | DTypeLike, 13 | ScaledArray, 14 | ScaledPrimitiveType, 15 | asarray, 16 | get_scalify_config, 17 | is_static_one_scalar, 18 | register_scaled_op, 19 | safe_div, 20 | safe_reciprocal, 21 | ) 22 | 23 | set_scaling_p = core.Primitive("set_scaling_p") 24 | """`set_scaling` JAX primitive. 25 | 26 | In standard JAX, this is just an identity operation, ignoring the `scale` 27 | input, just returning unchanged the `data` component. 28 | 29 | In JAX Scalify mode, it will rebalance the data term to 30 | return a ScaledArray semantically equivalent. 31 | 32 | NOTE: there is specific corner case of passing zero to `set_scaling`. In this 33 | situation, the tensor is assumed to be zeroed by the user. 34 | """ 35 | 36 | 37 | def set_scaling(values: Array, scale: Array) -> Array: 38 | """`set_scaling` primitive call method.""" 39 | return set_scaling_p.bind(values, scale) 40 | 41 | 42 | def set_scaling_impl(values: Array, scale: Array) -> Array: 43 | """Set scaling general implementation. 44 | 45 | Need to work on the following inputs combinations: 46 | - (Array, Array) -> Array 47 | - (ScaledArray, Array) -> ScaledArray 48 | - (ScaledArray, ScaledArray) -> ScaledArray 49 | with Numpy or JAX arrays (and keep the proper type for output). 50 | """ 51 | assert scale.shape == () 52 | if isinstance(values, ScaledArray): 53 | # Automatic promotion should ensure we always get a scaled scalar here! 54 | scale_value = asarray(scale) 55 | # Rebalancing data tensor using the new scale. 56 | data = values.data * safe_div(values.scale, scale_value).astype(values.dtype) 57 | return ScaledArray(data, scale_value) 58 | # No scaled array => no-op. 59 | return values 60 | 61 | 62 | def set_scaling_abstract_eval(values: core.ShapedArray, scale: core.ShapedArray) -> core.ShapedArray: 63 | return values 64 | 65 | 66 | def set_scaling_mlir_lowering( 67 | ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] 68 | ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: 69 | # Just forwarding `values` term, ignoring the `scale`. 70 | return (args[0],) 71 | 72 | 73 | def scaled_set_scaling(values: ScaledArray, scale: ScaledArray) -> ScaledArray: 74 | """Scaled `set_scaling` implementation: rebalancing the data using the new scale value.""" 75 | # Trivial case of scale == 1 76 | if is_static_one_scalar(scale): 77 | if isinstance(values, ScaledArray): 78 | return values 79 | return ScaledArray(values, scale) 80 | assert scale.shape == () 81 | # Automatic promotion should ensure we always get a scaled scalar here! 82 | scale_value = asarray(scale) 83 | if not isinstance(values, ScaledArray): 84 | # Simple case, with no pre-existing scale. 85 | return ScaledArray(values * safe_reciprocal(scale_value.astype(values.dtype)), scale_value) 86 | # Rebalancing data tensor using the new scale. 87 | data = values.data * safe_div(values.scale, scale_value).astype(values.dtype) 88 | return ScaledArray(data, scale_value) 89 | 90 | 91 | # Register as standard JAX primitive 92 | set_scaling_p.multiple_results = False 93 | set_scaling_p.def_abstract_eval(set_scaling_abstract_eval) 94 | set_scaling_p.def_impl(set_scaling_impl) 95 | mlir.register_lowering(set_scaling_p, set_scaling_mlir_lowering) 96 | # Register "scaled" translation. 97 | register_scaled_op(set_scaling_p, scaled_set_scaling, ScaledPrimitiveType.ALWAYS_SCALE) 98 | 99 | 100 | stop_scaling_p = core.Primitive("stop_scaling_p") 101 | """`stop_scaling` JAX primitive. 102 | 103 | In standard JAX, this is just an identity operation (with optional casting). 104 | 105 | In JAX Scalify mode, it will return the value tensor, with optional casting. 106 | 107 | Similar in principle to `jax.lax.stop_gradient` 108 | """ 109 | 110 | 111 | def stop_scaling(values: Array, dtype: Optional[DTypeLike] = None) -> Array: 112 | """`stop_scaling` primitive call method.""" 113 | return stop_scaling_p.bind(values, dtype=dtype) 114 | 115 | 116 | def stop_scaling_impl(values: Array, dtype: Optional[DTypeLike]) -> Array: 117 | if isinstance(values, ScaledArray): 118 | return values.to_array(dtype=dtype) 119 | if dtype is not None: 120 | values = values.astype(dtype) 121 | return values 122 | 123 | 124 | def stop_scaling_abstract_eval(values: core.ShapedArray, dtype: Optional[DTypeLike]) -> core.ShapedArray: 125 | return values.update(dtype=dtype) 126 | 127 | 128 | def stop_scaling_mlir_lowering( 129 | ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]], **params: Dict[str, Any] 130 | ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: 131 | dtype = params.get("dtype", None) 132 | if dtype is not None: 133 | # TODO: caching of the MLIR lowered function? 134 | stop_scaling_mlir_fn = mlir.lower_fun(lambda x: x.astype(dtype), multiple_results=False) 135 | return stop_scaling_mlir_fn(ctx, *args) 136 | # By default: forward tensor. 137 | return (args[0],) 138 | 139 | 140 | def scaled_stop_scaling(values: ScaledArray, dtype: Optional[DTypeLike] = None) -> Array: 141 | """Scaled `stop_scaling` implementation: returning tensor values (with optional cast).""" 142 | assert isinstance(values, ScaledArray) 143 | # TODO/FIXME: how to handle not scaled input? 144 | return values.to_array(dtype=dtype) 145 | 146 | 147 | # Register as standard JAX primitive 148 | stop_scaling_p.multiple_results = False 149 | stop_scaling_p.def_abstract_eval(stop_scaling_abstract_eval) 150 | stop_scaling_p.def_impl(stop_scaling_impl) 151 | mlir.register_lowering(stop_scaling_p, stop_scaling_mlir_lowering) 152 | # Register "scaled" translation. 153 | register_scaled_op(stop_scaling_p, scaled_stop_scaling) 154 | 155 | 156 | get_data_scale_p = core.Primitive("get_data_scale_p") 157 | """`get_data_scale` unbundling JAX primitive: return a tuple of data and scale 158 | arrays. 159 | 160 | In standard JAX, this is just an operation returning the input array and a constant scalar(1). 161 | 162 | In JAX Scalify mode, it will return the pair of data and scale tensors 163 | from a ScaledArray. 164 | """ 165 | 166 | 167 | def get_scale_dtype() -> Optional[DTypeLike]: 168 | """Get the scale dtype, if set in the Scalify config.""" 169 | return get_scalify_config().scale_dtype 170 | 171 | 172 | def get_data_scale(values: Array) -> Array: 173 | """`get_data_scale` primitive call method.""" 174 | return get_data_scale_p.bind(values) 175 | 176 | 177 | def get_data_scale_impl(values: Array) -> Array: 178 | if isinstance(values, ScaledArray): 179 | return (values.data, values.scale) 180 | # Use array dtype for scale by default. 181 | scale_dtype = get_scale_dtype() or values.dtype 182 | scale = np.ones((), dtype=scale_dtype) 183 | return values, scale 184 | 185 | 186 | def get_data_scale_abstract_eval(values: core.ShapedArray) -> core.ShapedArray: 187 | if isinstance(values, ScaledArray): 188 | return (values.data, values.scale) 189 | # Use array dtype for scale by default. 190 | scale_dtype = get_scale_dtype() or values.dtype 191 | return values, core.ShapedArray((), dtype=scale_dtype) 192 | 193 | 194 | def get_data_scale_mlir_lowering( 195 | ctx: LoweringRuleContext, *args: Union[ir.Value, Sequence[ir.Value]] 196 | ) -> Sequence[Union[ir.Value, Sequence[ir.Value]]]: 197 | # Just forwarding `values` term, adding a constant scalar scale(1). 198 | assert len(args) == 1 199 | assert len(ctx.avals_in) == 1 200 | assert len(ctx.avals_out) == 2 201 | # Scale dtype "decided" during initial JAX tracing. 202 | scale_dtype = ctx.avals_out[1].dtype 203 | scale = ir_constant(np.ones((), dtype=scale_dtype)) 204 | return (args[0], scale) 205 | 206 | 207 | def scaled_get_data_scale(values: ScaledArray) -> Array: 208 | """Scaled `get_data_scale` implementation: return scale tensor.""" 209 | scale_dtype = get_scale_dtype() 210 | # Mis-match may potentially create issues (i.e. not equivalent scale dtype after scalify tracer)! 211 | if scale_dtype != values.scale.dtype: 212 | logging.warning( 213 | f"Scalify config scale dtype not matching ScaledArray scale dtype: '{values.scale.dtype}' vs '{scale_dtype}'. Scalify graph transformation may fail because of that." 214 | ) 215 | return values.data, values.scale 216 | 217 | 218 | # Register as standard JAX primitive 219 | get_data_scale_p.multiple_results = True 220 | get_data_scale_p.def_abstract_eval(get_data_scale_abstract_eval) 221 | get_data_scale_p.def_impl(get_data_scale_impl) 222 | mlir.register_lowering(get_data_scale_p, get_data_scale_mlir_lowering) 223 | # Register "scaled" translation. 224 | register_scaled_op(get_data_scale_p, scaled_get_data_scale) 225 | 226 | 227 | def rebalance(values: ScaledArray, rebalance_scale: Array) -> ScaledArray: 228 | """Rebalance a ScaledArray with a scale component, i.e. 229 | respectively divide/multiply the data/scale by the rebalance scale. 230 | 231 | NOTE: no-op in normal JAX mode. 232 | 233 | Args: 234 | values: Array/ScaledArray to rebalance. 235 | rebalance_scale: Rebalancing scale value. 236 | Returns: 237 | Array/ScaledArray rebalanced. 238 | """ 239 | _, scale = get_data_scale(values) 240 | assert scale.dtype == rebalance_scale.dtype 241 | out_scale = scale * rebalance_scale 242 | return set_scaling(values, out_scale) 243 | -------------------------------------------------------------------------------- /jax_scalify/lax/scaled_ops_l2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from typing import Any, Dict, Optional, Sequence, Tuple 3 | 4 | import jax.numpy as jnp 5 | import numpy as np 6 | from jax import lax 7 | from jax._src.ad_util import add_any_p 8 | 9 | from jax_scalify import core 10 | from jax_scalify.core import DTypeLike, ScaledArray, get_scalify_config, pow2_round, register_scaled_op, safe_div 11 | 12 | from .scaled_ops_common import check_scalar_scales, promote_scale_types 13 | 14 | 15 | def scaled_add_sub(A: ScaledArray, B: ScaledArray, binary_op: Any) -> ScaledArray: 16 | """Scaled add/sub generic implementation.""" 17 | # TODO: understand when promotion is really required? 18 | # A, B = as_scaled_array((A, B)) # type:ignore 19 | check_scalar_scales(A, B) 20 | A, B = promote_scale_types(A, B) 21 | assert np.issubdtype(A.scale.dtype, np.floating) 22 | # Pow2 rounding for unit scaling "rule". 23 | pow2_rounding_mode = get_scalify_config().rounding_mode 24 | # TODO: what happens to `sqrt` for non-floating scale? 25 | # More stable than direct L2 norm, to avoid scale overflow. 26 | ABscale_max = lax.max(A.scale, B.scale) 27 | ABscale_min = lax.min(A.scale, B.scale) 28 | ABscale_ratio = safe_div(ABscale_min, ABscale_max) 29 | output_scale = ABscale_max * lax.sqrt(1 + ABscale_ratio * ABscale_ratio) 30 | # Transform back to power-of-2 31 | output_scale = pow2_round(output_scale, pow2_rounding_mode) 32 | # Output dtype => promotion of A and B dtypes. 33 | outdtype = jnp.promote_types(A.dtype, B.dtype) 34 | Arescale = safe_div(A.scale, output_scale).astype(outdtype) 35 | Brescale = safe_div(B.scale, output_scale).astype(outdtype) 36 | # check correct type output if mismatch between data and scale precision 37 | output_data = binary_op(Arescale * A.data, Brescale * B.data) 38 | return ScaledArray(output_data, output_scale) 39 | 40 | 41 | @core.register_scaled_lax_op 42 | def scaled_add(A: ScaledArray, B: ScaledArray) -> ScaledArray: 43 | return scaled_add_sub(A, B, lax.add) 44 | 45 | 46 | # TODO: understand difference between `add` and `add_anys` 47 | register_scaled_op(add_any_p, scaled_add) 48 | 49 | 50 | @core.register_scaled_lax_op 51 | def scaled_sub(A: ScaledArray, B: ScaledArray) -> ScaledArray: 52 | return scaled_add_sub(A, B, lax.sub) 53 | 54 | 55 | @core.register_scaled_lax_op 56 | def scaled_dot_general( 57 | lhs: ScaledArray, 58 | rhs: ScaledArray, 59 | dimension_numbers: Tuple[Tuple[Sequence[int], Sequence[int]], Tuple[Sequence[int], Sequence[int]]], 60 | precision: Any = None, 61 | preferred_element_type: Optional[DTypeLike] = None, 62 | ) -> ScaledArray: 63 | # Checks on `dot_general` arguments. Only supporting a subset right now. 64 | ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims)) = dimension_numbers 65 | assert len(lhs_batch_dims) == 0 66 | assert len(rhs_batch_dims) == 0 67 | assert len(lhs_contracting_dims) == 1 68 | assert len(rhs_contracting_dims) == 1 69 | 70 | # Pow2 rounding for unit scaling "rule". 71 | pow2_rounding_mode = get_scalify_config().rounding_mode 72 | contracting_dim_size = lhs.shape[lhs_contracting_dims[0]] 73 | # "unit scaling" rule, based on the contracting axis. 74 | outscale_dtype = jnp.promote_types(lhs.scale.dtype, rhs.scale.dtype) 75 | contracting_rescale = np.sqrt(contracting_dim_size).astype(outscale_dtype) 76 | contracting_rescale = pow2_round(contracting_rescale, pow2_rounding_mode) 77 | # Keeping power of 2 scale. 78 | output_scale = lhs.scale * rhs.scale * contracting_rescale.astype(outscale_dtype) 79 | # NOTE: need to be a bit careful about scale promotion? 80 | output_data = lax.dot_general( 81 | lhs.data, 82 | rhs.data, 83 | dimension_numbers=dimension_numbers, 84 | precision=precision, 85 | preferred_element_type=preferred_element_type, 86 | ) 87 | output_data = output_data / contracting_rescale.astype(output_data.dtype) 88 | return ScaledArray(output_data, output_scale) 89 | 90 | 91 | @core.register_scaled_lax_op 92 | def scaled_conv_general_dilated(lhs: ScaledArray, rhs: ScaledArray, **params: Dict[str, Any]) -> ScaledArray: 93 | assert isinstance(lhs, ScaledArray) 94 | assert isinstance(rhs, ScaledArray) 95 | data = lax.conv_general_dilated_p.bind(lhs.data, rhs.data, **params) 96 | # FIXME: should we change scaling if e.g. window > 3? 97 | return ScaledArray(data, lhs.scale * rhs.scale) 98 | 99 | 100 | @core.register_scaled_lax_op 101 | def scaled_reduce_sum(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: 102 | assert isinstance(val, ScaledArray) 103 | shape = val.shape 104 | scale_dtype = val.scale.dtype 105 | axes_size = np.array([shape[idx] for idx in axes]) 106 | # Pow2 rounding for unit scaling "rule". 107 | pow2_rounding_mode = get_scalify_config().rounding_mode 108 | # Rescale data component following reduction axes & round to power of 2 value. 109 | axes_rescale = np.sqrt(np.prod(axes_size)).astype(scale_dtype) 110 | axes_rescale = pow2_round(axes_rescale, pow2_rounding_mode) 111 | data = lax.reduce_sum_p.bind(val.data, axes=axes) / axes_rescale.astype(val.data.dtype) 112 | outscale = val.scale * axes_rescale.astype(scale_dtype) 113 | return ScaledArray(data, outscale) 114 | 115 | 116 | @core.register_scaled_lax_op 117 | def scaled_reduce_prod(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: 118 | assert isinstance(val, ScaledArray) 119 | shape = val.shape 120 | data = lax.reduce_prod_p.bind(val.data, axes=axes) 121 | axes_size = np.prod(np.array([shape[idx] for idx in axes])) 122 | # Stable for power of 2. 123 | scale = lax.integer_pow(val.scale, axes_size) 124 | return ScaledArray(data, scale) 125 | 126 | 127 | @core.register_scaled_lax_op 128 | def scaled_reduce_max(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: 129 | assert isinstance(val, ScaledArray) 130 | data = lax.reduce_max_p.bind(val.data, axes=axes) 131 | # unchanged scaling. 132 | return ScaledArray(data, val.scale) 133 | 134 | 135 | @core.register_scaled_lax_op 136 | def scaled_reduce_min(val: ScaledArray, axes: Tuple[int]) -> ScaledArray: 137 | assert isinstance(val, ScaledArray) 138 | data = lax.reduce_min_p.bind(val.data, axes=axes) 139 | # unchanged scaling. 140 | return ScaledArray(data, val.scale) 141 | 142 | 143 | @core.register_scaled_lax_op 144 | def scaled_reduce_window_sum( 145 | val: ScaledArray, 146 | window_dimensions: Any, 147 | window_strides: Any, 148 | padding: Any, 149 | base_dilation: Any, 150 | window_dilation: Any, 151 | ) -> ScaledArray: 152 | assert isinstance(val, ScaledArray) 153 | data = lax.reduce_window_sum_p.bind( 154 | val.data, 155 | window_dimensions=window_dimensions, 156 | window_strides=window_strides, 157 | padding=padding, 158 | base_dilation=base_dilation, 159 | window_dilation=window_dilation, 160 | ) 161 | # FIXME: should we change scaling if e.g. window > 3? 162 | return ScaledArray(data, val.scale) 163 | 164 | 165 | @core.register_scaled_lax_op 166 | def scaled_reduce_window_min( 167 | val: ScaledArray, 168 | window_dimensions: Any, 169 | window_strides: Any, 170 | padding: Any, 171 | base_dilation: Any, 172 | window_dilation: Any, 173 | ) -> ScaledArray: 174 | assert isinstance(val, ScaledArray) 175 | data = lax.reduce_window_min_p.bind( 176 | val.data, 177 | window_dimensions=window_dimensions, 178 | window_strides=window_strides, 179 | padding=padding, 180 | base_dilation=base_dilation, 181 | window_dilation=window_dilation, 182 | ) 183 | # unchanged scaling. 184 | return ScaledArray(data, val.scale) 185 | 186 | 187 | @core.register_scaled_lax_op 188 | def scaled_reduce_window_max( 189 | val: ScaledArray, 190 | window_dimensions: Any, 191 | window_strides: Any, 192 | padding: Any, 193 | base_dilation: Any, 194 | window_dilation: Any, 195 | ) -> ScaledArray: 196 | assert isinstance(val, ScaledArray) 197 | data = lax.reduce_window_max_p.bind( 198 | val.data, 199 | window_dimensions=window_dimensions, 200 | window_strides=window_strides, 201 | padding=padding, 202 | base_dilation=base_dilation, 203 | window_dilation=window_dilation, 204 | ) 205 | # unchanged scaling. 206 | return ScaledArray(data, val.scale) 207 | -------------------------------------------------------------------------------- /jax_scalify/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from .cast import ( # noqa: F401 3 | cast_on_backward, 4 | cast_on_forward, 5 | reduce_precision_on_backward, 6 | reduce_precision_on_forward, 7 | ) 8 | from .debug import debug_callback, debug_callback_grad, debug_print, debug_print_grad # noqa: F401 9 | from .rescaling import ( # noqa: F401 10 | dynamic_rescale_l1, 11 | dynamic_rescale_l1_grad, 12 | dynamic_rescale_l2, 13 | dynamic_rescale_l2_grad, 14 | dynamic_rescale_max, 15 | dynamic_rescale_max_grad, 16 | ) 17 | -------------------------------------------------------------------------------- /jax_scalify/ops/cast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import jax 5 | import ml_dtypes 6 | 7 | from jax_scalify.core import Array, DTypeLike 8 | 9 | from .utils import map_on_backward, map_on_forward 10 | 11 | 12 | def reduce_precision_dtype_base(arr: Array, dtype: DTypeLike) -> Array: 13 | """`Fake` cast to an ML dtype (e.g. FP8), using JAX LAX `reduce_precision` operator.""" 14 | info = ml_dtypes.finfo(dtype) 15 | return jax.lax.reduce_precision(arr, exponent_bits=info.nexp, mantissa_bits=info.nmant) 16 | 17 | 18 | def reduce_precision_on_forward(arr: Array, dtype: DTypeLike) -> Array: 19 | """`Fake` cast to an ML dtype, on the forward pass (no-op on backward pass).""" 20 | return partial(map_on_forward, lambda v: reduce_precision_dtype_base(v, dtype))(arr) 21 | 22 | 23 | def reduce_precision_on_backward(arr: Array, dtype: DTypeLike) -> Array: 24 | """`Fake` cast to an ML dtype on the backward pass (no-op on forward pass).""" 25 | return partial(map_on_backward, lambda v: reduce_precision_dtype_base(v, dtype))(arr) 26 | 27 | 28 | def cast_on_forward(arr: Array, dtype: DTypeLike) -> Array: 29 | """Cast input array only on the forward pass (no-op on the backward pass). 30 | 31 | Useful for implementation `DenseGeneral` FP8 matmuls. 32 | """ 33 | return partial(map_on_forward, lambda v: jax.lax.convert_element_type(v, dtype))(arr) 34 | 35 | 36 | def cast_on_backward(arr: Array, dtype: DTypeLike) -> Array: 37 | """Cast input array only on the backward pass (no-op on the forward pass). 38 | 39 | Useful for implementation `DenseGeneral` FP8 matmuls. 40 | """ 41 | return partial(map_on_backward, lambda v: jax.lax.convert_element_type(v, dtype))(arr) 42 | -------------------------------------------------------------------------------- /jax_scalify/ops/debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | from typing import Sequence 4 | 5 | import jax 6 | 7 | from jax_scalify.core import Array, debug_callback 8 | 9 | 10 | @partial(jax.custom_vjp, nondiff_argnums=(0,)) 11 | def debug_callback_grad(f, *args): 12 | """Custom callback, called on gradients.""" 13 | return args 14 | 15 | 16 | def debug_callback_grad_fwd(f, *args): 17 | return args, None 18 | 19 | 20 | def debug_callback_grad_bwd(f, _, args_grad): 21 | debug_callback(f, *args_grad) 22 | return args_grad 23 | 24 | 25 | debug_callback_grad.defvjp(debug_callback_grad_fwd, debug_callback_grad_bwd) 26 | 27 | 28 | def debug_print(fmt: str, *args: Array) -> Sequence[Array]: 29 | """Debug print of a collection of tensors.""" 30 | debug_callback(lambda *args: print(fmt.format(*args)), *args) 31 | return args 32 | 33 | 34 | def debug_print_grad(fmt: str, *args: Array) -> Sequence[Array]: 35 | """Debug print of gradients of a collection of tensors.""" 36 | return debug_callback_grad(lambda *args: print(fmt.format(*args)), *args) 37 | -------------------------------------------------------------------------------- /jax_scalify/ops/rescaling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import jax 5 | import numpy as np 6 | 7 | from jax_scalify.core import ScaledArray, pow2_round, pow2_round_down 8 | from jax_scalify.lax import get_data_scale, rebalance 9 | 10 | from .utils import map_on_backward, map_on_forward 11 | 12 | 13 | def dynamic_rescale_max_base(arr: ScaledArray) -> ScaledArray: 14 | """Dynamic rescaling of a ScaledArray, using abs-max.""" 15 | # Similarly to ML norms => need some epsilon for training stability! 16 | eps = pow2_round_down(np.float32(1e-4)) 17 | 18 | data, scale = get_data_scale(arr) 19 | data_sq = jax.lax.abs(data) 20 | axes = tuple(range(data.ndim)) 21 | # Get MAX norm + pow2 rounding. 22 | norm = jax.lax.reduce_max_p.bind(data_sq, axes=axes) 23 | norm = jax.lax.max(pow2_round(norm).astype(scale.dtype), eps.astype(scale.dtype)) 24 | # Rebalancing based on norm. 25 | return rebalance(arr, norm) 26 | 27 | 28 | def dynamic_rescale_l1_base(arr: ScaledArray) -> ScaledArray: 29 | """Dynamic rescaling of a ScaledArray, using L1 norm. 30 | 31 | NOTE: by default, computing L1 norm in FP32. 32 | """ 33 | # Similarly to ML norms => need some epsilon for training stability! 34 | norm_dtype = np.float32 35 | eps = pow2_round_down(norm_dtype(1e-4)) 36 | 37 | data, scale = get_data_scale(arr) 38 | data_sq = jax.lax.abs(data.astype(np.float32)) 39 | axes = tuple(range(data.ndim)) 40 | # Get L1 norm + pow2 rounding. 41 | norm = jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size 42 | norm = jax.lax.max(pow2_round(norm), eps).astype(scale.dtype) 43 | # Rebalancing based on norm. 44 | return rebalance(arr, norm) 45 | 46 | 47 | def dynamic_rescale_l2_base(arr: ScaledArray) -> ScaledArray: 48 | """Dynamic rescaling of a ScaledArray, using L2 norm. 49 | 50 | NOTE: by default, computing L2 norm in FP32. 51 | """ 52 | # Similarly to ML norms => need some epsilon for training stability! 53 | norm_dtype = np.float32 54 | eps = pow2_round_down(norm_dtype(1e-4)) 55 | 56 | data, scale = get_data_scale(arr) 57 | data_sq = jax.lax.integer_pow(data.astype(norm_dtype), 2) 58 | axes = tuple(range(data.ndim)) 59 | # Get L2 norm + pow2 rounding. 60 | norm = jax.lax.sqrt(jax.lax.reduce_sum_p.bind(data_sq, axes=axes) / data.size) 61 | # Make sure we don't "underflow" too much on the norm. 62 | norm = jax.lax.max(pow2_round(norm), eps).astype(scale.dtype) 63 | # Rebalancing based on norm. 64 | return rebalance(arr, norm) 65 | 66 | 67 | # Dynamic rescale on fwd arrays. 68 | dynamic_rescale_max = partial(map_on_forward, dynamic_rescale_max_base) 69 | dynamic_rescale_l1 = partial(map_on_forward, dynamic_rescale_l1_base) 70 | dynamic_rescale_l2 = partial(map_on_forward, dynamic_rescale_l2_base) 71 | 72 | # Dynamic rescale on gradients. 73 | dynamic_rescale_max_grad = partial(map_on_backward, dynamic_rescale_max_base) 74 | dynamic_rescale_l1_grad = partial(map_on_backward, dynamic_rescale_l1_base) 75 | dynamic_rescale_l2_grad = partial(map_on_backward, dynamic_rescale_l2_base) 76 | 77 | 78 | # Backward compatibility. DEPRECATED, WILL BE REMOVED! 79 | fn_fwd_identity_bwd = map_on_forward 80 | fn_bwd_identity_fwd = map_on_backward 81 | -------------------------------------------------------------------------------- /jax_scalify/ops/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import jax 5 | 6 | 7 | @partial(jax.custom_vjp, nondiff_argnums=(0,)) 8 | def map_on_forward(f, arg): 9 | """Map a function on a forward pass only. No-op/identity on backward pass.""" 10 | return f(arg) 11 | 12 | 13 | def map_on_forward_fwd(f, arg): 14 | return arg, None 15 | 16 | 17 | def map_on_forward_bwd(f, _, grad): 18 | return (grad,) 19 | 20 | 21 | map_on_forward.defvjp(map_on_forward_fwd, map_on_forward_bwd) 22 | 23 | 24 | @partial(jax.custom_vjp, nondiff_argnums=(0,)) 25 | def map_on_backward(f, arg): 26 | """Map a function on the gradient/backward pass. No-op/identity on forward.""" 27 | return arg 28 | 29 | 30 | def map_on_backward_fwd(f, arg): 31 | return arg, None 32 | 33 | 34 | def map_on_backward_bwd(f, _, grad): 35 | return (f(grad),) 36 | 37 | 38 | map_on_backward.defvjp(map_on_backward_fwd, map_on_backward_bwd) 39 | -------------------------------------------------------------------------------- /jax_scalify/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | from .scale import as_e8m0 # noqa: F401 3 | -------------------------------------------------------------------------------- /jax_scalify/quantization/scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | import jax.numpy as jnp 3 | import ml_dtypes 4 | import numpy as np 5 | 6 | from jax_scalify.core import Array, DTypeLike, get_numpy_api 7 | from jax_scalify.core.pow2 import dtype_exponent_mask 8 | 9 | 10 | def pow2_truncate(arr: Array) -> Array: 11 | """Convert an Array to a power of 2, using mantissa truncation. 12 | 13 | NOTE: all sub-normals values are flushed to zero. 14 | """ 15 | np_api = get_numpy_api(arr) 16 | # Masking mantissa & sign-bit, keeping only exponent values. 17 | exponent_mask = dtype_exponent_mask(arr.dtype, sign_bit=True) 18 | intdtype = exponent_mask.dtype 19 | # Masking mantissa bits, keeping only the exponents ones. 20 | arr_pow2 = np_api.bitwise_and(arr.view(intdtype), exponent_mask).view(arr.dtype).reshape(arr.shape) 21 | return arr_pow2 22 | 23 | 24 | def as_e8m0(arr: Array) -> Array: 25 | """Convert an Array to e8m0 format (i.e. power of two values). 26 | 27 | This function is only implementing a truncation + saturation variant, in line with 28 | the MX OCP format. 29 | 30 | Args: 31 | arr: Input array (FP16, FP32 or BF16). 32 | Returns: 33 | E8M0 array (as uint8). 34 | """ 35 | np_api = get_numpy_api(arr) 36 | # assert len(arr.shape) < 2 37 | assert arr.dtype in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} 38 | # Saturation => negative values saturating to min value (i.e. zero bits) in E8M0. 39 | arr = np_api.maximum(arr, np.array(0, arr.dtype)) 40 | arr = pow2_truncate(arr) 41 | 42 | # Bit masking to extract the exponent as uint8 array. 43 | arr_u8 = arr.view(np.uint8).reshape((*arr.shape, -1)) 44 | arr_e8m0 = np_api.bitwise_or(np_api.left_shift(arr_u8[..., -1], 1), np_api.right_shift(arr_u8[..., -2], 7)) 45 | return arr_e8m0 46 | 47 | 48 | def from_e8m0(arr: Array, dtype: DTypeLike) -> Array: 49 | """Convert an Array of e8m0 values (i.e. power of two values) to a given dtype. 50 | 51 | Args: 52 | arr: E8M0 array (assuming uint8 storage dtype). 53 | dtype: Output dtype. FP32 or BF16 supported. 54 | Returns: 55 | Converted output. 56 | """ 57 | np_api = get_numpy_api(arr) 58 | assert arr.dtype == np.uint8 59 | assert np.dtype(dtype) in {np.dtype(jnp.bfloat16), np.dtype(ml_dtypes.bfloat16), np.dtype(jnp.float32)} 60 | # Avoid issues with 7 mantissa bits in BF16. 61 | # TODO: more efficient implementation! 62 | arr = np_api.exp2(arr.astype(np.float32) - 127) 63 | return arr.astype(dtype) 64 | -------------------------------------------------------------------------------- /jax_scalify/tree/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | from .tree_util import all, astype, flatten, leaves, map, structure, unflatten # noqa: F401 3 | -------------------------------------------------------------------------------- /jax_scalify/tree/tree_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | from typing import Any, Callable 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | from jax import tree_util 7 | 8 | from jax_scalify.core import DTypeLike, is_scaled_leaf 9 | 10 | Leaf = Any 11 | 12 | 13 | def astype(tree: Any, dtype: DTypeLike, floating_only: bool = False) -> Any: 14 | """Map `astype` method to all pytree leaves, `Array` or `ScaledArray`. 15 | 16 | Args: 17 | tree: the pytree to cast. 18 | dtype: Dtype to cast to. 19 | floating_only: Only convert leaves with floating datatype. 20 | 21 | Returns: 22 | A new PyTree with the same structure, with casting to new dtype. 23 | """ 24 | if floating_only: 25 | # Convert only leaves with floating dtype. 26 | cast_fn = lambda v: v.astype(dtype) if jnp.issubdtype(v.dtype, jnp.floating) else v 27 | return tree_util.tree_map(cast_fn, tree, is_leaf=is_scaled_leaf) 28 | return tree_util.tree_map(lambda v: v.astype(dtype), tree, is_leaf=is_scaled_leaf) 29 | 30 | 31 | def all(tree: Any) -> bool: 32 | """Call all() over the leaves of a tree, `Array` or `ScaledArray` 33 | 34 | Args: 35 | tree: the pytree to evaluate 36 | Returns: 37 | result: boolean True or False 38 | """ 39 | return all(jax.tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf)) 40 | 41 | 42 | def flatten(tree: Any) -> tuple[list[Leaf], tree_util.PyTreeDef]: 43 | """Flattens a pytree, with `Array` or `ScaledArray` leaves. 44 | 45 | The flattening order (i.e. the order of elements in the output list) 46 | is deterministic, corresponding to a left-to-right depth-first tree 47 | traversal. 48 | 49 | Args: 50 | tree: a pytree to flatten. 51 | 52 | Returns: 53 | A pair where the first element is a list of leaf values and the second 54 | element is a treedef representing the structure of the flattened tree. 55 | 56 | See Also: 57 | - :func:`jax_scalify.tree.leaves` 58 | - :func:`jax_scalify.tree.structure` 59 | - :func:`jax_scalify.tree.unflatten` 60 | """ 61 | return tree_util.tree_flatten(tree, is_leaf=is_scaled_leaf) 62 | 63 | 64 | def leaves( 65 | tree: Any, 66 | ) -> list[Leaf]: 67 | """Gets the leaves (`Array` or `ScaledArray`) of a pytree. 68 | 69 | Args: 70 | tree: the pytree for which to get the leaves 71 | 72 | Returns: 73 | leaves: a list of tree leaves. 74 | 75 | See Also: 76 | - :func:`jax_scalify.tree.flatten` 77 | - :func:`jax_scalify.tree.structure` 78 | - :func:`jax_scalify.tree.unflatten` 79 | """ 80 | return tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf) 81 | 82 | 83 | def map(f: Callable[..., Any], tree: Any, *rest: Any) -> Any: 84 | """Maps a multi-input function over pytree args to produce a new pytree. 85 | 86 | Args: 87 | f: function that takes ``1 + len(rest)`` arguments, to be applied at the 88 | corresponding leaves of the pytrees. 89 | tree: a pytree to be mapped over, with each leaf providing the first 90 | positional argument to ``f``. 91 | rest: a tuple of pytrees, each of which has the same structure as ``tree`` 92 | or has ``tree`` as a prefix. 93 | 94 | Returns: 95 | A new pytree with the same structure as ``tree`` but with the value at each 96 | leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding 97 | leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in 98 | ``rest``. 99 | 100 | See Also: 101 | - :func:`jax_scalify.tree.leaves` 102 | - :func:`jax_scalify.tree.reduce` 103 | """ 104 | return tree_util.tree_map(f, tree, *rest, is_leaf=is_scaled_leaf) 105 | 106 | 107 | def structure(tree: Any) -> tree_util.PyTreeDef: 108 | """Gets the treedef for a pytree, with `Array` or `ScaledArray` leaves. 109 | 110 | Args: 111 | tree: the pytree for which to get the leaves 112 | 113 | Returns: 114 | pytreedef: a PyTreeDef representing the structure of the tree. 115 | 116 | See Also: 117 | - :func:`jax_scalify.tree.flatten` 118 | - :func:`jax_scalify.tree.leaves` 119 | - :func:`jax_scalify.tree.unflatten` 120 | """ 121 | return tree_util.tree_structure(tree, is_leaf=is_scaled_leaf) 122 | 123 | 124 | # Alias of JAX tree unflatten. 125 | unflatten = jax.tree_util.tree_unflatten 126 | -------------------------------------------------------------------------------- /jax_scalify/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | from .hlo import parse_hlo_module, print_hlo_module # noqa: F401 3 | -------------------------------------------------------------------------------- /jax_scalify/utils/hlo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | import json 3 | import textwrap 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, List 6 | 7 | from jax.stages import Compiled, Lowered 8 | 9 | 10 | @dataclass 11 | class HloOperationInfo: 12 | """HLO module operation (raw) info. 13 | 14 | Parse from raw `as_text` compiled HloModule. 15 | 16 | Args: 17 | cmd: Raw HLO operation (function + inputs/outputs). 18 | metadata: JAX metadata (line, ...) 19 | backend_config: Optional backend config dictionary. 20 | """ 21 | 22 | cmd: str 23 | indent: int = 0 24 | metadata: str | None = None 25 | backend_config: Dict[Any, Any] | None = None 26 | 27 | def as_text(self, metadata: bool = False, backend_cfg: bool = False, indent: int = 2) -> str: 28 | """Convert to raw text, with formatting issues.""" 29 | indent_txt = " " * (indent * self.indent) 30 | line = indent_txt + self.cmd 31 | if backend_cfg and self.backend_config: 32 | # A bit hacky text formating of backend config! 33 | backend_cfg_raw = json.dumps(self.backend_config, indent=indent) 34 | backend_cfg_raw = "backend_cfg: " + backend_cfg_raw 35 | backend_cfg_raw = textwrap.indent(backend_cfg_raw, indent_txt + " " * indent) 36 | line += "\n" + backend_cfg_raw 37 | return line 38 | 39 | 40 | def parse_hlo_operation_raw_line(raw_line: str) -> HloOperationInfo: 41 | """Very crude and ugly parsing of an Hlo operation raw line! 42 | 43 | Returns: 44 | Parsed Hlo operation line. 45 | """ 46 | metadata: str | None = None 47 | backend_cfg = None 48 | 49 | # Parse "metadata={...}" block. 50 | metadata_prefix = ", metadata={" 51 | lidx = raw_line.find(metadata_prefix) 52 | if lidx >= 0: 53 | ridx = raw_line[lidx:].find("}") + lidx 54 | metadata = raw_line[lidx : ridx + 1] 55 | raw_line = raw_line.replace(metadata, "") 56 | metadata = metadata[2:] 57 | 58 | # Parse "backend_config={...}" block. 59 | backend_cfg_prefix = ", backend_config=" 60 | lidx = raw_line.find(backend_cfg_prefix) 61 | if lidx >= 0: 62 | backend_cfg_str = raw_line[lidx + len(backend_cfg_prefix) :] 63 | # TODO: deal with exception raised. 64 | backend_cfg = json.loads(backend_cfg_str) 65 | raw_line = raw_line[:lidx] 66 | 67 | # Clean the raw line. 68 | raw_line = raw_line.rstrip() 69 | size = len(raw_line) 70 | raw_line = raw_line.lstrip() 71 | indent = (size - len(raw_line)) // 2 72 | return HloOperationInfo(raw_line, indent, metadata, backend_cfg) 73 | 74 | 75 | def parse_hlo_module(module: Lowered | Compiled) -> List[HloOperationInfo]: 76 | """Parse an Hlo module, to be human-readable! 77 | 78 | Note: `m.hlo_modules()[0].computations()[0].render_html()` 79 | is also generating a nice HTML output! 80 | 81 | Args: 82 | module: HLO module or JAX stages compiled instance. 83 | Returns: 84 | List of HLO operation info. 85 | """ 86 | assert isinstance(module, (Lowered, Compiled)) 87 | if isinstance(module, Lowered): 88 | module = module.compile() 89 | module_raw_txt = module.as_text() 90 | module_lines = module_raw_txt.split("\n") 91 | ops = [parse_hlo_operation_raw_line(line) for line in module_lines] 92 | return ops 93 | 94 | 95 | def print_hlo_module( 96 | module: Lowered | Compiled, metadata: bool = False, backend_cfg: bool = False, indent: int = 2 97 | ) -> None: 98 | """Human-readable Hlo module printing. 99 | 100 | Args: 101 | module: AOT Lowered or Compiled JAX module. 102 | metadata: Print op metadata as well. 103 | backend_cfg: Print op backend config as well. 104 | """ 105 | cmds = parse_hlo_module(module) 106 | for c in cmds: 107 | print(c.as_text(metadata=metadata, backend_cfg=backend_cfg, indent=indent)) 108 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Configuration inspired by official pypa example: 2 | # https://github.com/pypa/sampleproject/blob/main/pyproject.toml 3 | 4 | [build-system] 5 | requires = ["setuptools", "setuptools-scm"] 6 | build-backend = "setuptools.build_meta" 7 | 8 | [project] 9 | name = "jax_scalify" 10 | description="JAX Scalify: end-to-end scaled arithmetic." 11 | readme = "README.md" 12 | authors = [ 13 | { name = "Paul Balanca", email = "paulb@graphcore.ai" }, 14 | ] 15 | requires-python = ">=3.10" 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Developers", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Programming Language :: Python :: 3.10", 22 | "Programming Language :: Python :: 3.11", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | ] 25 | dependencies = [ 26 | "chex>=0.1.6", 27 | "jax>=0.3.16", 28 | "jaxlib>=0.3.15", 29 | "ml_dtypes", 30 | "numpy>=1.22.4" 31 | ] 32 | dynamic = ["version"] 33 | 34 | [project.urls] 35 | "Homepage" = "https://github.com/graphcore-research/jax-scalify/#readme" 36 | "Bug Reports" = "https://github.com/graphcore-research/jax-scalify/issues" 37 | "Source" = "https://github.com/graphcore-research/jax-scalify/" 38 | 39 | [project.optional-dependencies] 40 | dev = ["check-manifest"] 41 | test = ["pytest"] 42 | 43 | # Relying on the default setuptools. 44 | # In case of an issue, can use the following options 45 | # [tool.setuptools] 46 | # packages = ["jax_scalify", "jax_scalify.core", "jax_scalify.lax", "jax_scalify.ops", "jax_scalify.tree"] 47 | # [tool.setuptools.packages] 48 | # find = {namespaces = false} 49 | 50 | [tool.setuptools.dynamic] 51 | version = {attr = "jax_scalify.version.__version__"} 52 | 53 | [tool.setuptools_scm] 54 | version_file = "jax_scalify/version.py" 55 | 56 | [tool.pytest.ini_options] 57 | minversion = "6.0" 58 | addopts = ["-ra", "--showlocals", "--strict-config", "-p no:hypothesispytest"] 59 | xfail_strict = true 60 | filterwarnings = [ 61 | "error", 62 | "ignore:(ast.Str|Attribute s|ast.NameConstant|ast.Num) is deprecated:DeprecationWarning:_pytest", # Python 3.12 63 | ] 64 | testpaths = ["tests"] 65 | 66 | [tool.black] 67 | line-length = 120 68 | target-version = ['py38', 'py39', 'py310'] 69 | 70 | [tool.isort] 71 | line_length = 120 72 | known_first_party = "jax_scalify" 73 | 74 | [tool.mypy] 75 | python_version = "3.10" 76 | plugins = ["numpy.typing.mypy_plugin"] 77 | # Config heavily inspired by Pydantic! 78 | show_error_codes = true 79 | # strict_optional = true 80 | warn_redundant_casts = true 81 | warn_unused_ignores = true 82 | warn_unused_configs = true 83 | check_untyped_defs = true 84 | disallow_any_generics = true 85 | no_implicit_optional = false 86 | disallow_incomplete_defs = true 87 | # disallow_untyped_decorators = true 88 | # disallow_untyped_calls = true 89 | # # disallow_subclassing_any = true 90 | # # for strict mypy: (this is the tricky one :-)) 91 | # disallow_untyped_defs = true 92 | exclude = ['examples'] 93 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | max-complexity = 20 4 | min_python_version = 3.8 5 | ignore = F401 6 | per-file-ignores = 7 | jax_scalify/__init__.py: F401 8 | -------------------------------------------------------------------------------- /test-requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | -------------------------------------------------------------------------------- /tests/core/test_pow2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import chex 5 | import jax.numpy as jnp 6 | import numpy as np 7 | import numpy.testing as npt 8 | from absl.testing import parameterized 9 | 10 | from jax_scalify.core import Pow2RoundMode, pow2_decompose, pow2_round_down, pow2_round_up 11 | from jax_scalify.core.pow2 import _exponent_bits_mask, get_mantissa 12 | 13 | 14 | class Pow2DecomposePrimitveTests(chex.TestCase): 15 | @parameterized.parameters( 16 | {"dtype": np.float16}, 17 | {"dtype": np.float32}, 18 | ) 19 | def test__exponent_bitmask__inf_value(self, dtype): 20 | val = _exponent_bits_mask[np.dtype(dtype)].view(dtype) 21 | expected_val = dtype(np.inf) 22 | npt.assert_equal(val, expected_val) 23 | 24 | @parameterized.product( 25 | val_exp=[ 26 | (0, 0), 27 | (1, 1), 28 | (2.1, 2), 29 | (0.3, 0.25), 30 | (0.51, 0.5), 31 | (65500, 32768), 32 | # Test float16 sub-normals. 33 | (np.finfo(np.float16).smallest_normal, np.finfo(np.float16).smallest_normal), 34 | (np.finfo(np.float16).smallest_subnormal, np.finfo(np.float16).smallest_subnormal), 35 | (np.float16(3.123283386230469e-05), 3.0517578e-05), 36 | ], 37 | dtype=[np.float16, np.float32], 38 | scale_dtype=[np.float16, np.float32], 39 | ) 40 | def test__pow2_decompose_round_down__numpy_implementation__proper_result(self, val_exp, dtype, scale_dtype): 41 | scale_dtype = np.float32 42 | vin, exp_scale = dtype(val_exp[0]), scale_dtype(val_exp[1]) 43 | scale, vout = pow2_decompose(vin, scale_dtype, Pow2RoundMode.DOWN) 44 | 45 | assert isinstance(scale, (np.ndarray, np.number)) 46 | assert isinstance(vout, (np.ndarray, np.number)) 47 | assert scale.dtype == scale_dtype 48 | assert vout.dtype == vin.dtype 49 | # Always accurate when casting up to scale dtype. 50 | npt.assert_equal(scale * vout.astype(scale_dtype), vin.astype(scale_dtype)) 51 | npt.assert_equal(scale, exp_scale) 52 | 53 | @chex.variants(with_jit=True, without_jit=True) 54 | @parameterized.product( 55 | val_exp=[ 56 | (0, 0), 57 | (1, 1), 58 | (2.1, 2), 59 | (0.3, 0.25), 60 | (0.51, 0.5), 61 | (65500, 32768), 62 | # Test float16 sub-normals. 63 | (np.finfo(np.float16).smallest_normal, np.finfo(np.float16).smallest_normal), 64 | (np.finfo(np.float16).smallest_subnormal, np.finfo(np.float16).smallest_subnormal), 65 | (np.float16(3.123283386230469e-05), 3.0517578e-05), 66 | # Test float32 sub-normals: known bug! 67 | # (np.finfo(np.float32).smallest_normal, np.finfo(np.float32).smallest_normal), 68 | # (np.finfo(np.float32).smallest_subnormal, np.finfo(np.float32).smallest_subnormal), 69 | ], 70 | dtype=[np.float16, np.float32], 71 | scale_dtype=[np.float16, np.float32], 72 | ) 73 | def test__pow2_decompose_round_down__jax_numpy__proper_result(self, val_exp, dtype, scale_dtype): 74 | vin, exp_scale = dtype(val_exp[0]), scale_dtype(val_exp[1]) 75 | vin = jnp.array(vin) 76 | scale, vout = self.variant(lambda v: pow2_decompose(v, scale_dtype, Pow2RoundMode.DOWN))(vin) 77 | 78 | assert isinstance(scale, jnp.ndarray) 79 | assert isinstance(vout, jnp.ndarray) 80 | assert scale.dtype == scale_dtype 81 | assert vout.dtype == vin.dtype 82 | # Always accurate when casting up to scale dtype. 83 | npt.assert_equal(np.asarray(scale), exp_scale) 84 | npt.assert_equal(scale * np.array(vout, scale_dtype), np.asarray(vin, scale_dtype)) 85 | 86 | @chex.variants(with_jit=True, without_jit=True) 87 | @parameterized.product( 88 | val_exp=[ 89 | (+np.inf, np.inf, +np.inf), 90 | (-np.inf, np.inf, -np.inf), 91 | (np.nan, np.inf, np.nan), # FIXME? scale == np.inf? 92 | ], 93 | dtype=[np.float16, np.float32], 94 | scale_dtype=[np.float16, np.float32], 95 | ) 96 | def test__pow2_decompose_round_down__special_values(self, val_exp, dtype, scale_dtype): 97 | vin, exp_scale, exp_vout = dtype(val_exp[0]), scale_dtype(val_exp[1]), dtype(val_exp[2]) 98 | scale, vout = self.variant(partial(pow2_decompose, scale_dtype=scale_dtype, mode=Pow2RoundMode.DOWN))(vin) 99 | npt.assert_equal(np.ravel(scale)[0], exp_scale) 100 | npt.assert_equal(np.ravel(vout)[0], exp_vout) 101 | 102 | @parameterized.product( 103 | val_exp=[(0, 0), (1, 1), (2.1, 2), (0.3, 0.25), (0.51, 0.5), (65500, 32768)], 104 | dtype=[np.float16, np.float32, np.float64], 105 | ) 106 | def test__pow2_round_down__proper_rounding__multi_dtypes(self, val_exp, dtype): 107 | val, exp = dtype(val_exp[0]), dtype(val_exp[1]) 108 | pow2_val = pow2_round_down(val) 109 | assert pow2_val.dtype == val.dtype 110 | assert pow2_val.shape == () 111 | assert type(pow2_val) in {type(val), np.ndarray} 112 | npt.assert_equal(pow2_val, exp) 113 | 114 | @parameterized.product( 115 | val_exp=[(2.1, 4), (0.3, 0.5), (0.51, 1), (17000, 32768)], 116 | dtype=[np.float16], 117 | ) 118 | def test__pow2_round_up__proper_rounding__multi_dtypes(self, val_exp, dtype): 119 | val, exp = dtype(val_exp[0]), dtype(val_exp[1]) 120 | pow2_val = pow2_round_up(val) 121 | assert pow2_val.dtype == val.dtype 122 | assert type(pow2_val) in {type(val), np.ndarray} 123 | npt.assert_equal(pow2_val, exp) 124 | 125 | @parameterized.product( 126 | val_mant=[(1, 1), (2.1, 1.05), (0, 0), (0.51, 1.02), (65504, 1.9990234375)], 127 | dtype=[np.float16, np.float32], # FIXME: float64 support in pure Numpy 128 | ) 129 | def test__get_mantissa__proper_value__multi_dtypes(self, val_mant, dtype): 130 | val, mant = dtype(val_mant[0]), dtype(val_mant[1]) 131 | val_mant = get_mantissa(val) 132 | assert val_mant.dtype == val.dtype 133 | assert val_mant.shape == () 134 | assert type(val_mant) in {type(val), np.ndarray} 135 | npt.assert_equal(val_mant, mant) 136 | # Should be consistent with `pow2_round_down`. bitwise, not approximation. 137 | npt.assert_equal(mant * pow2_round_down(val), val) 138 | -------------------------------------------------------------------------------- /tests/core/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import numpy.testing as npt 6 | from absl.testing import parameterized 7 | 8 | from jax_scalify.core.utils import Array, python_scalar_as_numpy, safe_div, safe_reciprocal 9 | 10 | 11 | class SafeDivOpTests(chex.TestCase): 12 | @parameterized.parameters( 13 | {"lhs": np.float16(0), "rhs": np.float16(0)}, 14 | {"lhs": np.float32(0), "rhs": np.float32(0)}, 15 | {"lhs": np.float16(2), "rhs": np.float16(0)}, 16 | {"lhs": np.float32(4), "rhs": np.float32(0)}, 17 | ) 18 | def test__safe_div__zero_div__numpy_inputs(self, lhs, rhs): 19 | out = safe_div(lhs, rhs) 20 | assert isinstance(out, (np.number, np.ndarray)) 21 | assert out.dtype == lhs.dtype 22 | npt.assert_equal(out, 0) 23 | 24 | @parameterized.parameters( 25 | {"lhs": np.float16(0), "rhs": jnp.float16(0)}, 26 | {"lhs": jnp.float32(0), "rhs": np.float32(0)}, 27 | {"lhs": jnp.float16(2), "rhs": np.float16(0)}, 28 | {"lhs": np.float32(4), "rhs": jnp.float32(0)}, 29 | ) 30 | def test__safe_div__zero_div__jax_inputs(self, lhs, rhs): 31 | out = safe_div(lhs, rhs) 32 | assert isinstance(out, Array) 33 | assert out.dtype == lhs.dtype 34 | npt.assert_almost_equal(out, 0) 35 | 36 | @parameterized.parameters( 37 | {"val": np.float16(0)}, 38 | {"val": jnp.float16(0)}, 39 | ) 40 | def test__safe_reciprocal__zero_div(self, val): 41 | out = safe_reciprocal(val) 42 | assert out.dtype == val.dtype 43 | npt.assert_almost_equal(out, 0) 44 | 45 | 46 | def test__python_scalar_as_numpy__proper_convertion(): 47 | npt.assert_equal(python_scalar_as_numpy(False), np.bool_(False)) 48 | npt.assert_equal(python_scalar_as_numpy(4), np.int32(4)) 49 | npt.assert_equal(python_scalar_as_numpy(3.2), np.float32(3.2)) 50 | -------------------------------------------------------------------------------- /tests/lax/test_base_scaling_primitives.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import jax.numpy as jnp 4 | import numpy as np 5 | import numpy.testing as npt 6 | from absl.testing import parameterized 7 | from numpy.typing import NDArray 8 | 9 | from jax_scalify.core import Array, ScaledArray, ScalifyConfig, scaled_array, scalify 10 | from jax_scalify.lax.base_scaling_primitives import ( 11 | get_data_scale, 12 | rebalance, 13 | scaled_set_scaling, 14 | set_scaling, 15 | stop_scaling, 16 | ) 17 | 18 | 19 | class SetScalingPrimitiveTests(chex.TestCase): 20 | @parameterized.parameters( 21 | {"npapi": np}, 22 | {"npapi": jnp}, 23 | ) 24 | def test__set_scaling_primitive__scaled_array__eager_mode(self, npapi): 25 | values = scaled_array([-1.0, 2.0], 2.0, dtype=np.float16, npapi=npapi) 26 | scale = npapi.float16(4) 27 | output = set_scaling(values, scale) 28 | assert isinstance(output, ScaledArray) 29 | # No implicit promotion to JAX array when not necessary. 30 | assert isinstance(output.data, npapi.ndarray) 31 | assert isinstance(output.scale, (npapi.number, npapi.ndarray)) 32 | npt.assert_equal(output.scale, npapi.float16(4)) 33 | npt.assert_array_equal(output, values) 34 | 35 | @chex.variants(with_jit=True, without_jit=True) 36 | @parameterized.parameters( 37 | {"arr": np.array([-1.0, 2.0], dtype=np.float32)}, 38 | {"arr": scaled_array([-1.0, 2.0], 1.0, dtype=np.float16)}, 39 | {"arr": scaled_array([-1.0, 2.0], 0.0, dtype=np.float32)}, 40 | ) 41 | def test__set_scaling_primitive__zero_scaling(self, arr): 42 | def fn(arr, scale): 43 | return set_scaling(arr, scale) 44 | 45 | scale = np.array(0, dtype=arr.dtype) 46 | out = self.variant(scalify(fn))(arr, scale) 47 | assert isinstance(out, ScaledArray) 48 | npt.assert_array_almost_equal(out.scale, 0) 49 | npt.assert_array_almost_equal(out.data, 0) 50 | 51 | @chex.variants(with_jit=True, without_jit=True) 52 | def test__set_scaling_primitive__proper_result_without_scalify(self): 53 | def fn(arr, scale): 54 | return set_scaling(arr, scale) 55 | 56 | fn = self.variant(fn) 57 | arr = jnp.array([2, 3], dtype=np.float32) 58 | scale = jnp.array(4, dtype=np.float32) 59 | out = fn(arr, scale) 60 | # Scale input ignored => forward same array. 61 | assert isinstance(out, jnp.ndarray) 62 | npt.assert_array_equal(out, arr) 63 | 64 | @chex.variants(with_jit=True, without_jit=False) 65 | @parameterized.parameters( 66 | # Testing different combination of scaled/unscaled inputs. 67 | {"arr": np.array([-1.0, 2.0], dtype=np.float32), "scale": np.array(4.0, dtype=np.float32)}, 68 | {"arr": np.array([-1.0, 2.0], dtype=np.float16), "scale": np.array(4.0, dtype=np.float32)}, 69 | {"arr": scaled_array([-1.0, 2.0], 1.0, dtype=np.float16), "scale": np.array(4.0, dtype=np.float32)}, 70 | {"arr": scaled_array([-1.0, 2.0], 2.0, dtype=np.float32), "scale": scaled_array(1.0, 4.0, dtype=np.float32)}, 71 | {"arr": scaled_array([-1.0, 2.0], 2.0, dtype=np.float16), "scale": scaled_array(1.0, 4.0, dtype=np.float32)}, 72 | ) 73 | def test__set_scaling_primitive__proper_result_with_scalify(self, arr, scale): 74 | def fn(arr, scale): 75 | return set_scaling(arr, scale) 76 | 77 | fn = self.variant(scalify(fn)) 78 | out = fn(arr, scale) 79 | # Unchanged output tensor, with proper dtype. 80 | assert isinstance(out, ScaledArray) 81 | assert out.dtype == arr.dtype 82 | npt.assert_array_equal(out.scale, scale) 83 | npt.assert_array_equal(out, arr) 84 | 85 | @parameterized.parameters( 86 | {"scale": 1}, 87 | {"scale": np.int32(1)}, 88 | {"scale": np.float32(1)}, 89 | ) 90 | def test__scaled_set_scaling__unchanged_scaled_array(self, scale): 91 | val = scaled_array([-1.0, 2.0], 2.0, dtype=np.float16) 92 | assert scaled_set_scaling(val, scale) is val 93 | 94 | @parameterized.parameters( 95 | {"scale": np.int32(1)}, 96 | {"scale": np.float32(1)}, 97 | ) 98 | def test__scaled_set_scaling__unchanged_data_scaled_array(self, scale): 99 | val: NDArray[np.float16] = np.array([-1.0, 2.0], dtype=np.float16) 100 | out = scaled_set_scaling(val, scale) # type:ignore 101 | assert isinstance(out, ScaledArray) 102 | assert out.data is val 103 | 104 | 105 | class StopScalingPrimitiveTests(chex.TestCase): 106 | @parameterized.parameters( 107 | {"npapi": np}, 108 | {"npapi": jnp}, 109 | ) 110 | def test__stop_scaling_primitive__scaled_array__eager_mode(self, npapi): 111 | values = scaled_array([-1.0, 2.0], 2.0, dtype=np.float16, npapi=npapi) 112 | output = stop_scaling(values) 113 | assert isinstance(output, npapi.ndarray) 114 | npt.assert_array_equal(output, values) 115 | 116 | @chex.variants(with_jit=True, without_jit=True) 117 | def test__stop_scaling_primitive__proper_result_without_scalify(self): 118 | def fn(arr): 119 | # Testing both variants. 120 | return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) 121 | 122 | arr = jnp.array([2, 3], dtype=np.float32) 123 | out0, out1 = self.variant(fn)(arr) 124 | assert out0.dtype == arr.dtype 125 | assert out1.dtype == np.float16 126 | npt.assert_array_equal(out0, arr) 127 | npt.assert_array_almost_equal(out1, arr) 128 | 129 | @chex.variants(with_jit=True, without_jit=True) 130 | def test__stop_scaling_primitive__proper_result_with_scalify(self): 131 | def fn(arr): 132 | # Testing both variants. 133 | return stop_scaling(arr), stop_scaling(arr, dtype=np.float16) 134 | 135 | fn = self.variant(scalify(fn)) 136 | arr = scaled_array([-1.0, 2.0], 3.0, dtype=np.float32) 137 | out0, out1 = fn(arr) 138 | assert isinstance(out0, Array) 139 | assert isinstance(out1, Array) 140 | assert out0.dtype == arr.dtype 141 | assert out1.dtype == np.float16 142 | npt.assert_array_equal(out0, arr) 143 | npt.assert_array_almost_equal(out1, arr) 144 | 145 | 146 | class GetDataScalePrimitiveTests(chex.TestCase): 147 | @chex.variants(with_jit=True, without_jit=True) 148 | def test__get_data_scale_primitive__proper_result_without_scalify(self): 149 | def fn(arr): 150 | # Set a default scale dtype. 151 | with ScalifyConfig(scale_dtype=np.float32): 152 | return get_data_scale(arr) 153 | 154 | fn = self.variant(fn) 155 | arr = jnp.array([2, 3], dtype=np.float16) 156 | data, scale = fn(arr) 157 | assert data.dtype == np.float16 158 | assert scale.dtype == np.float32 159 | npt.assert_array_equal(data, arr) 160 | npt.assert_equal(scale, np.array(1, np.float32)) 161 | 162 | @chex.variants(with_jit=True, without_jit=True) 163 | def test__get_data_scale_primitive__proper_result_with_scalify(self): 164 | def fn(arr): 165 | return get_data_scale(arr) 166 | 167 | fn = self.variant(scalify(fn)) 168 | arr = scaled_array([2, 3], np.float16(4), dtype=np.float16) 169 | data, scale = fn(arr) 170 | npt.assert_array_equal(data, arr.data) 171 | npt.assert_equal(scale, arr.scale) 172 | 173 | def test__get_data_scale_primitive__numpy_input(self): 174 | arr = scaled_array([2, 3], 4, dtype=np.float16) 175 | # ScaledArray input. 176 | data, scale = get_data_scale(arr) 177 | npt.assert_array_equal(data, arr.data) 178 | npt.assert_array_equal(scale, arr.scale) 179 | # Normal numpy array input. 180 | data, scale = get_data_scale(np.asarray(arr)) 181 | npt.assert_array_equal(data, arr) 182 | npt.assert_almost_equal(scale, 1) 183 | 184 | 185 | class RebalancingOpsTests(chex.TestCase): 186 | @parameterized.parameters( 187 | {"npapi": np}, 188 | {"npapi": jnp}, 189 | ) 190 | def test__rebalance_op__normal_array__eagermode(self, npapi): 191 | values = npapi.array([-1, 2], dtype=np.float16) 192 | output = rebalance(values, np.float16(2)) 193 | # Same Python array object forwarded. 194 | assert output is values 195 | 196 | def test__rebalance_op__scaled_array__eagermode(self): 197 | arr_in = scaled_array([2, 3], np.float16(4), dtype=np.float16) 198 | scale_in = np.float16(0.5) 199 | arr_out = rebalance(arr_in, scale_in) 200 | assert isinstance(arr_out, ScaledArray) 201 | npt.assert_array_equal(arr_out.scale, arr_in.scale * scale_in) 202 | npt.assert_array_equal(arr_out, arr_in) 203 | -------------------------------------------------------------------------------- /tests/lax/test_numpy_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import jax 4 | import jax.numpy as jnp 5 | import numpy as np 6 | 7 | from jax_scalify.core import ScaledArray, scaled_array, scalify 8 | 9 | 10 | class ScaledJaxNumpyFunctions(chex.TestCase): 11 | def setUp(self): 12 | super().setUp() 13 | # Use random state for reproducibility! 14 | self.rs = np.random.RandomState(42) 15 | 16 | @chex.variants(with_jit=True, without_jit=False) 17 | def test__numpy_mean__proper_gradient_scale_propagation(self): 18 | def mean_fn(x): 19 | # Taking the square to "force" ScaledArray gradient. 20 | # Numpy mean constant rescaling creating trouble on backward pass! 21 | return jax.grad(lambda v: jnp.mean(v * v))(x) 22 | 23 | # size = 8 * 16 24 | input_scaled = scaled_array(self.rs.rand(8, 16).astype(np.float32), np.float32(1)) 25 | output_grad_scaled = self.variant(scalify(mean_fn))(input_scaled) 26 | 27 | assert isinstance(output_grad_scaled, ScaledArray) 28 | # Proper scale propagation on the backward pass (rough interval) 29 | assert np.std(output_grad_scaled.data) >= 0.25 30 | assert np.std(output_grad_scaled.data) <= 1.0 31 | # "small" scale. 32 | assert output_grad_scaled.scale <= 0.01 33 | -------------------------------------------------------------------------------- /tests/lax/test_scaled_ops_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import numpy as np 4 | import numpy.testing as npt 5 | from absl.testing import parameterized 6 | from jax import lax 7 | 8 | from jax_scalify.core import Array, ScaledArray, debug_callback, find_registered_scaled_op, scaled_array, scalify 9 | from jax_scalify.lax import ( 10 | scaled_broadcast_in_dim, 11 | scaled_concatenate, 12 | scaled_convert_element_type, 13 | scaled_is_finite, 14 | scaled_pad, 15 | scaled_reduce_precision, 16 | scaled_reshape, 17 | scaled_rev, 18 | scaled_select_n, 19 | scaled_sign, 20 | scaled_slice, 21 | scaled_transpose, 22 | ) 23 | 24 | 25 | class ScaledTranslationPrimitivesTests(chex.TestCase): 26 | def setUp(self): 27 | super().setUp() 28 | # Use random state for reproducibility! 29 | self.rs = np.random.RandomState(42) 30 | 31 | @chex.variants(with_jit=True, without_jit=True) 32 | def test__scaled_debug_callback__proper_forwarding(self): 33 | host_values = [] 34 | 35 | def callback(*args): 36 | for v in args: 37 | host_values.append(v) 38 | 39 | def fn(a): 40 | # NOTE: multiplying by a power of 2 to simplify test. 41 | debug_callback(callback, a, a * 4) 42 | return a 43 | 44 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float16) 45 | fn = self.variant(scalify(fn)) 46 | fn(x) 47 | 48 | assert len(host_values) == 2 49 | for sv in host_values: 50 | assert isinstance(sv, ScaledArray) 51 | npt.assert_array_equal(sv.data, x.data) 52 | npt.assert_array_equal(host_values[0].scale, x.scale) 53 | npt.assert_array_equal(host_values[1].scale, x.scale * 4) 54 | 55 | def test__scaled_broadcast_in_dim__proper_scaling(self): 56 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 57 | z = scaled_broadcast_in_dim(x, shape=(5, 1), broadcast_dimensions=(0,)) 58 | assert isinstance(z, ScaledArray) 59 | npt.assert_array_equal(z.scale, x.scale) 60 | npt.assert_array_almost_equal(z.data, x.data.reshape((5, 1))) 61 | 62 | def test__scaled_reshape__proper_scaling(self): 63 | x = scaled_array(self.rs.rand(8), 2, dtype=np.float32) 64 | z = scaled_reshape(x, new_sizes=(4, 2), dimensions=None) 65 | assert isinstance(z, ScaledArray) 66 | npt.assert_array_equal(z.scale, x.scale) 67 | npt.assert_array_almost_equal(z.data, x.data.reshape((4, 2))) 68 | 69 | def test__scaled_concatenate__proper_scaling(self): 70 | x = scaled_array(self.rs.rand(2, 3), 0.5, dtype=np.float16) 71 | y = scaled_array(self.rs.rand(5, 3), 2, dtype=np.float16) 72 | z = scaled_concatenate([x, y], dimension=0) 73 | assert isinstance(z, ScaledArray) 74 | assert z.dtype == x.dtype 75 | npt.assert_array_equal(z.scale, y.scale) 76 | npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0)) 77 | 78 | def test__scaled_concatenate__zero_input_scales(self): 79 | x = scaled_array(self.rs.rand(2, 3), 0.0, dtype=np.float16) 80 | y = scaled_array(self.rs.rand(5, 3), 0.0, dtype=np.float16) 81 | z = scaled_concatenate([x, y], dimension=0) 82 | assert isinstance(z, ScaledArray) 83 | assert z.dtype == x.dtype 84 | npt.assert_array_equal(z.scale, 0) 85 | npt.assert_array_almost_equal(z, lax.concatenate([np.asarray(x), np.asarray(y)], dimension=0)) 86 | 87 | def test__scaled_convert_element_type__proper_scaling(self): 88 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 89 | z = scaled_convert_element_type(x, new_dtype=np.float16) 90 | assert isinstance(z, ScaledArray) 91 | npt.assert_array_equal(z.scale, x.scale) 92 | npt.assert_array_almost_equal(z.data, x.data.astype(z.dtype)) 93 | 94 | def test__scaled_transpose__proper_scaling(self): 95 | x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) 96 | z = scaled_transpose(x, (1, 0)) 97 | assert isinstance(z, ScaledArray) 98 | assert z.scale == x.scale 99 | npt.assert_array_almost_equal(z.data, x.data.T) 100 | 101 | def test__scaled_sign__proper_scaling(self): 102 | x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float32) 103 | z = scaled_sign(x) 104 | assert isinstance(z, ScaledArray) 105 | assert z.scale == 1 106 | npt.assert_array_equal(z.data, lax.sign(x.data)) 107 | 108 | def test__scaled_rev__proper_scaling(self): 109 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 110 | z = scaled_rev(x, dimensions=(0,)) 111 | assert isinstance(z, ScaledArray) 112 | assert z.scale == x.scale 113 | npt.assert_array_almost_equal(z.data, x.data[::-1]) 114 | 115 | def test__scaled_pad__proper_scaling(self): 116 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 117 | z = scaled_pad(x, 0.0, [(1, 2, 0)]) 118 | assert isinstance(z, ScaledArray) 119 | assert z.scale == x.scale 120 | npt.assert_array_almost_equal(z.data, lax.pad(x.data, 0.0, [(1, 2, 0)])) 121 | 122 | def test__scaled_reduce_precision__proper_result(self): 123 | x = scaled_array(self.rs.rand(3, 5), 2, dtype=np.float16) 124 | # Reduction to pseudo FP8 format. 125 | z = scaled_reduce_precision(x, exponent_bits=4, mantissa_bits=3) 126 | assert isinstance(z, ScaledArray) 127 | assert z.dtype == x.dtype 128 | assert z.scale == x.scale 129 | npt.assert_array_almost_equal(z.data, lax.reduce_precision(x.data, exponent_bits=4, mantissa_bits=3)) 130 | 131 | def test__scaled_slice__proper_scaling(self): 132 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 133 | z = scaled_slice(x, (1,), (4,), (2,)) 134 | assert isinstance(z, ScaledArray) 135 | assert z.scale == x.scale 136 | npt.assert_array_almost_equal(z.data, x.data[1:4:2]) 137 | 138 | @parameterized.parameters({"prim": lax.argmax_p}, {"prim": lax.argmin_p}) 139 | def test__scaled_argminmax__proper_scaling(self, prim): 140 | x = scaled_array(self.rs.rand(5), 2, dtype=np.float32) 141 | expected_out = prim.bind(x.to_array(), axes=(0,), index_dtype=np.int32) 142 | scaled_translation, _ = find_registered_scaled_op(prim) 143 | out = scaled_translation(x, axes=(0,), index_dtype=np.int32) 144 | assert isinstance(out, Array) 145 | npt.assert_array_equal(out, expected_out) 146 | 147 | 148 | class ScaledTranslationBooleanPrimitivesTests(chex.TestCase): 149 | def setUp(self): 150 | super().setUp() 151 | # Use random state for reproducibility! 152 | self.rs = np.random.RandomState(42) 153 | 154 | @parameterized.parameters( 155 | {"val": scaled_array([2, 3], 2.0, dtype=np.float32), "expected_out": [True, True]}, 156 | # Supporting `int` scale as well. 157 | {"val": scaled_array([2, np.inf], 2, dtype=np.float32), "expected_out": [True, False]}, 158 | {"val": scaled_array([2, 3], np.nan, dtype=np.float32), "expected_out": [False, False]}, 159 | {"val": scaled_array([np.nan, 3], 3.0, dtype=np.float32), "expected_out": [False, True]}, 160 | ) 161 | def test__scaled_is_finite__proper_result(self, val, expected_out): 162 | out = scaled_is_finite(val) 163 | assert isinstance(out, Array) 164 | assert out.dtype == np.bool_ 165 | npt.assert_array_equal(out, expected_out) 166 | 167 | @parameterized.parameters( 168 | {"bool_prim": lax.eq_p}, 169 | {"bool_prim": lax.ne_p}, 170 | {"bool_prim": lax.lt_p}, 171 | {"bool_prim": lax.le_p}, 172 | {"bool_prim": lax.gt_p}, 173 | {"bool_prim": lax.ge_p}, 174 | ) 175 | def test__scaled_boolean_binary_op__proper_result(self, bool_prim): 176 | lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) 177 | rhs = scaled_array(self.rs.rand(5), 3.0, dtype=np.float32) 178 | scaled_bool_op, _ = find_registered_scaled_op(bool_prim) 179 | out0 = scaled_bool_op(lhs, rhs) 180 | out1 = scaled_bool_op(lhs, lhs) 181 | assert isinstance(out0, Array) 182 | assert out0.dtype == np.bool_ 183 | npt.assert_array_equal(out0, bool_prim.bind(lhs.to_array(), rhs.to_array())) 184 | npt.assert_array_equal(out1, bool_prim.bind(lhs.to_array(), lhs.to_array())) 185 | 186 | def test__scaled_select_n__proper_result(self): 187 | mask = self.rs.rand(5) > 0.5 188 | lhs = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) 189 | rhs = scaled_array(self.rs.rand(5), 4.0, dtype=np.float32) 190 | out = scaled_select_n(mask, lhs, rhs) 191 | assert isinstance(out, ScaledArray) 192 | assert out.dtype == np.float32 193 | # Max scale used. 194 | npt.assert_almost_equal(out.scale, 4) 195 | npt.assert_array_equal(out, np.where(mask, rhs, lhs)) 196 | 197 | @parameterized.parameters( 198 | {"scale": 0.25}, 199 | {"scale": 8.0}, 200 | ) 201 | def test__scaled_select__relu_grad_example(self, scale): 202 | @scalify 203 | def relu_grad(g): 204 | return lax.select(g > 0, g, lax.full_like(g, 0)) 205 | 206 | # Gradient with some scale. 207 | gin = scaled_array([1.0, 0.5], np.float32(scale), dtype=np.float32) 208 | gout = relu_grad(gin) 209 | # Same scale should be propagated to gradient output. 210 | assert isinstance(gout, ScaledArray) 211 | npt.assert_array_equal(gout.scale, gin.scale) 212 | -------------------------------------------------------------------------------- /tests/lax/test_scaled_ops_l2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import numpy as np 4 | import numpy.testing as npt 5 | from absl.testing import parameterized 6 | from jax import lax 7 | 8 | from jax_scalify.core import ScaledArray, find_registered_scaled_op, scaled_array 9 | from jax_scalify.lax import scaled_div, scaled_dot_general, scaled_mul, scaled_reduce_window_sum 10 | 11 | 12 | class ScaledTranslationDotPrimitivesTests(chex.TestCase): 13 | def setUp(self): 14 | super().setUp() 15 | # Use random state for reproducibility! 16 | self.rs = np.random.RandomState(42) 17 | 18 | @parameterized.parameters( 19 | {"ldtype": np.float32, "rdtype": np.float32}, 20 | # {"ldtype": np.float32, "rdtype": np.float16}, # Not supported in JAX 0.3.x 21 | # {"ldtype": np.float16, "rdtype": np.float32}, 22 | {"ldtype": np.float16, "rdtype": np.float16}, 23 | ) 24 | def test__scaled_dot_general__proper_scaling(self, ldtype, rdtype): 25 | # Reduction dimension: 5 => sqrt(5) ~ 2 26 | lhs = scaled_array(self.rs.rand(3, 5), 2.0, dtype=ldtype) 27 | rhs = scaled_array(self.rs.rand(5, 2), 4.0, dtype=rdtype) 28 | 29 | dimension_numbers = (((1,), (0,)), ((), ())) 30 | out = scaled_dot_general(lhs, rhs, dimension_numbers) 31 | expected_out = lax.dot_general(np.asarray(lhs), np.asarray(rhs), dimension_numbers) 32 | 33 | assert isinstance(out, ScaledArray) 34 | assert out.dtype == expected_out.dtype 35 | assert out.scale.dtype == np.float32 # TODO: more test coverage. 36 | npt.assert_almost_equal(out.scale, lhs.scale * rhs.scale * 2) 37 | npt.assert_array_almost_equal(out, expected_out, decimal=2) 38 | 39 | 40 | class ScaledTranslationUnaryOpsTests(chex.TestCase): 41 | def setUp(self): 42 | super().setUp() 43 | # Use random state for reproducibility! 44 | self.rs = np.random.RandomState(42) 45 | 46 | @chex.variants(with_jit=True, without_jit=True) 47 | @parameterized.parameters( 48 | {"prim": lax.exp_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME! 49 | {"prim": lax.log_p, "dtype": np.float16, "expected_scale": 1.0}, # FIXME! 50 | {"prim": lax.neg_p, "dtype": np.float16, "expected_scale": 2.0}, 51 | {"prim": lax.abs_p, "dtype": np.float16, "expected_scale": 2.0}, 52 | {"prim": lax.cos_p, "dtype": np.float16, "expected_scale": 1.0}, 53 | {"prim": lax.sin_p, "dtype": np.float16, "expected_scale": 1.0}, 54 | ) 55 | def test__scaled_unary_op__proper_result_and_scaling(self, prim, dtype, expected_scale): 56 | scaled_op, _ = find_registered_scaled_op(prim) 57 | val = scaled_array(self.rs.rand(3, 5), 2.0, dtype=dtype) 58 | out = self.variant(scaled_op)(val) 59 | expected_output = prim.bind(np.asarray(val)) 60 | assert isinstance(out, ScaledArray) 61 | assert out.dtype == val.dtype 62 | assert out.scale.dtype == val.scale.dtype 63 | npt.assert_almost_equal(out.scale, expected_scale) 64 | # FIXME: higher precision for `log`? 65 | npt.assert_array_almost_equal(out, expected_output, decimal=3) 66 | 67 | def test__scaled_exp__large_scale_zero_values(self): 68 | scaled_op, _ = find_registered_scaled_op(lax.exp_p) 69 | # Scaled array, with values < 0 and scale overflowing in float16. 70 | val = scaled_array(np.array([0, -1, -2, -32768], np.float16), np.float32(32768 * 16)) 71 | out = scaled_op(val) 72 | # Zero value should not be a NaN! 73 | npt.assert_array_almost_equal(out, [1, 0, 0, 0], decimal=2) 74 | 75 | def test__scaled_log__zero_large_values_large_scale(self): 76 | scaled_op, _ = find_registered_scaled_op(lax.log_p) 77 | # 0 + large values => proper log values, without NaN/overflow. 78 | val = scaled_array(np.array([0, 1], np.float16), np.float32(32768 * 16)) 79 | out = scaled_op(val) 80 | # No NaN value + not overflowing! 81 | npt.assert_array_almost_equal(out, lax.log(val.to_array(np.float32)), decimal=2) 82 | 83 | 84 | class ScaledTranslationBinaryOpsTests(chex.TestCase): 85 | def setUp(self): 86 | super().setUp() 87 | # Use random state for reproducibility! 88 | self.rs = np.random.RandomState(42) 89 | 90 | @chex.variants(with_jit=True, without_jit=True) 91 | @parameterized.product( 92 | prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.div_p, lax.min_p, lax.max_p], 93 | dtype=[np.float16, np.float32], 94 | sdtype=[np.float16, np.float32], 95 | ) 96 | def test__scaled_binary_op__proper_result_and_promotion(self, prim, dtype, sdtype): 97 | scaled_op, _ = find_registered_scaled_op(prim) 98 | # NOTE: direct construction to avoid weirdity between NumPy array and scalar! 99 | x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(8.0)) 100 | y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(2.0)) 101 | # Ensure scale factor has the right dtype. 102 | assert x.scale.dtype == sdtype 103 | assert y.scale.dtype == sdtype 104 | 105 | z = self.variant(scaled_op)(x, y) 106 | expected_z = prim.bind(np.asarray(x), np.asarray(y)) 107 | 108 | assert z.dtype == x.dtype 109 | assert z.scale.dtype == sdtype 110 | npt.assert_array_almost_equal(z, expected_z, decimal=4) 111 | 112 | @chex.variants(with_jit=True, without_jit=True) 113 | @parameterized.product( 114 | prim=[lax.add_p, lax.sub_p, lax.mul_p, lax.min_p, lax.max_p], 115 | dtype=[np.float16, np.float32], 116 | sdtype=[np.float16, np.float32], 117 | ) 118 | def test__scaled_binary_op__proper_zero_scale_handling(self, prim, dtype, sdtype): 119 | scaled_op, _ = find_registered_scaled_op(prim) 120 | # NOTE: direct construction to avoid weirdity between NumPy array and scalar! 121 | x = ScaledArray(np.array([-1.0, 2.0], dtype), sdtype(0.0)) 122 | y = ScaledArray(np.array([1.5, 4.5], dtype), sdtype(0.0)) 123 | # Ensure scale factor has the right dtype. 124 | assert x.scale.dtype == sdtype 125 | assert y.scale.dtype == sdtype 126 | 127 | z = self.variant(scaled_op)(x, y) 128 | expected_z = prim.bind(np.asarray(x), np.asarray(y)) 129 | 130 | assert z.dtype == x.dtype 131 | assert z.scale.dtype == sdtype 132 | npt.assert_array_almost_equal(z, expected_z, decimal=4) 133 | 134 | @parameterized.parameters( 135 | {"prim": lax.add_p}, 136 | {"prim": lax.sub_p}, 137 | ) 138 | def test__scaled_addsub__proper_scaling(self, prim): 139 | scaled_op, _ = find_registered_scaled_op(prim) 140 | x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32) 141 | y = scaled_array([1.5, 4.5], 2.0, dtype=np.float32) 142 | z = scaled_op(x, y) 143 | assert isinstance(z, ScaledArray) 144 | assert z.dtype == x.dtype 145 | # Round down to power-of-2 146 | npt.assert_almost_equal(z.scale, 4) 147 | 148 | @parameterized.parameters( 149 | {"prim": lax.add_p}, 150 | {"prim": lax.sub_p}, 151 | ) 152 | def test__scaled_addsub__not_overflowing_scale(self, prim): 153 | scaled_op, _ = find_registered_scaled_op(prim) 154 | x = scaled_array([-1.0, 2.0], np.float16(2.0), dtype=np.float16) 155 | y = scaled_array([1.5, 4.0], np.float16(1024.0), dtype=np.float16) 156 | z = scaled_op(x, y) 157 | assert z.scale.dtype == np.float16 158 | assert np.isfinite(z.scale) 159 | npt.assert_array_almost_equal(z, prim.bind(np.asarray(x, np.float32), np.asarray(y, np.float32)), decimal=6) 160 | 161 | @parameterized.product( 162 | prim=[lax.min_p, lax.max_p], 163 | ) 164 | def test__scaled_minmax__static_zero_scale_propagation(self, prim): 165 | scaled_op, _ = find_registered_scaled_op(prim) 166 | x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32) 167 | y = scaled_array([1.5, 4.5], 0.0, dtype=np.float32) 168 | z = scaled_op(x, y) 169 | assert isinstance(z, ScaledArray) 170 | assert z.dtype == x.dtype 171 | # Keep the lhs scale. 172 | npt.assert_almost_equal(z.scale, 4.0) 173 | 174 | @parameterized.product( 175 | prim=[lax.min_p, lax.max_p], 176 | ) 177 | def test__scaled_minmax__static_inf_scale_propagation(self, prim): 178 | scaled_op, _ = find_registered_scaled_op(prim) 179 | x = scaled_array([-1.0, 2.0], 4.0, dtype=np.float32, npapi=np) 180 | y = scaled_array([-np.inf, np.inf], np.inf, dtype=np.float32, npapi=np) 181 | z = scaled_op(x, y) 182 | assert isinstance(z, ScaledArray) 183 | assert z.dtype == x.dtype 184 | # Keep the lhs scale. 185 | npt.assert_almost_equal(z.scale, 4.0) 186 | 187 | def test__scaled_mul__proper_scaling(self): 188 | x = scaled_array([-2.0, 2.0], 3, dtype=np.float32) 189 | y = scaled_array([1.5, 1.5], 2, dtype=np.float32) 190 | z = scaled_mul(x, y) 191 | assert isinstance(z, ScaledArray) 192 | assert z.scale == 6 193 | npt.assert_array_almost_equal(z, np.asarray(x) * np.asarray(y)) 194 | 195 | def test__scaled_div__proper_scaling(self): 196 | x = scaled_array([-2.0, 2.0], 3.0, dtype=np.float32) 197 | y = scaled_array([1.5, 1.5], 2.0, dtype=np.float32) 198 | z = scaled_div(x, y) 199 | assert isinstance(z, ScaledArray) 200 | assert z.scale == 1.5 201 | npt.assert_array_almost_equal(z, np.asarray(x) / np.asarray(y)) 202 | 203 | 204 | class ScaledTranslationReducePrimitivesTests(chex.TestCase): 205 | def setUp(self): 206 | super().setUp() 207 | # Use random state for reproducibility! 208 | self.rs = np.random.RandomState(42) 209 | 210 | @parameterized.parameters( 211 | {"reduce_prim": lax.reduce_sum_p, "expected_scale": 2 * 2}, 212 | {"reduce_prim": lax.reduce_prod_p, "expected_scale": 2**5}, 213 | {"reduce_prim": lax.reduce_min_p, "expected_scale": 2}, 214 | {"reduce_prim": lax.reduce_max_p, "expected_scale": 2}, 215 | ) 216 | def test__scaled_reduce__single_axis__proper_scaling(self, reduce_prim, expected_scale): 217 | axes = (0,) 218 | # NOTE: float16 useful for checking dtype promotion! 219 | val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float16) 220 | scaled_reduce_op, _ = find_registered_scaled_op(reduce_prim) 221 | out = scaled_reduce_op(val, axes=axes) 222 | 223 | assert isinstance(out, ScaledArray) 224 | assert out.shape == () 225 | assert out.dtype == val.dtype 226 | npt.assert_almost_equal(out.scale, expected_scale) 227 | npt.assert_array_almost_equal(out, reduce_prim.bind(np.asarray(val), axes=axes)) 228 | 229 | def test__scaled_reduce_window_sum__proper_result(self): 230 | val = scaled_array(self.rs.rand(5), 2.0, dtype=np.float32) 231 | out = scaled_reduce_window_sum( 232 | val, 233 | window_dimensions=(3,), 234 | window_strides=(1,), 235 | padding=((1, 0),), 236 | base_dilation=(1,), 237 | window_dilation=(1,), 238 | ) 239 | assert isinstance(out, ScaledArray) 240 | assert out.shape == (4,) 241 | assert out.dtype == val.dtype 242 | npt.assert_almost_equal(out.scale, val.scale) 243 | -------------------------------------------------------------------------------- /tests/lax/test_scipy_integration.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import numpy as np 4 | import numpy.testing as npt 5 | from absl.testing import parameterized 6 | from jax import lax 7 | 8 | from jax_scalify.core import scaled_array, scalify 9 | 10 | 11 | class ScaledScipyHighLevelMethodsTests(chex.TestCase): 12 | def setUp(self): 13 | super().setUp() 14 | # Use random state for reproducibility! 15 | self.rs = np.random.RandomState(42) 16 | 17 | def test__lax_full_like__zero_scale(self): 18 | def fn(a): 19 | return lax.full_like(a, 0) 20 | 21 | a = scaled_array(np.random.rand(3, 5).astype(np.float32), np.float32(1)) 22 | scalify(fn)(a) 23 | # FIMXE/TODO: what should be the expected result? 24 | 25 | @chex.variants(with_jit=False, without_jit=True) 26 | @parameterized.parameters( 27 | {"dtype": np.float32}, 28 | {"dtype": np.float16}, 29 | ) 30 | def test__scipy_logsumexp__accurate_scaled_op(self, dtype): 31 | from jax.scipy.special import logsumexp 32 | 33 | input_scaled = scaled_array(self.rs.rand(10), 4.0, dtype=dtype) 34 | # JAX `logsumexp` Jaxpr is a non-trivial graph! 35 | out_scaled = self.variant(scalify(logsumexp))(input_scaled) 36 | out_expected = logsumexp(np.asarray(input_scaled)) 37 | assert out_scaled.dtype == out_expected.dtype 38 | # Proper accuracy + keep the same scale. 39 | npt.assert_array_equal(out_scaled.scale, input_scaled.scale) 40 | npt.assert_array_almost_equal(out_scaled, out_expected, decimal=5) 41 | -------------------------------------------------------------------------------- /tests/ops/test_cast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | from functools import partial 3 | 4 | import chex 5 | import jax 6 | import jax.numpy as jnp 7 | import ml_dtypes 8 | import numpy as np 9 | import numpy.testing as npt 10 | from absl.testing import parameterized 11 | from numpy.typing import NDArray 12 | 13 | from jax_scalify.core import scaled_array, scalify 14 | from jax_scalify.ops import cast_on_backward, cast_on_forward, reduce_precision_on_forward 15 | 16 | 17 | class ReducePrecisionDtypeTests(chex.TestCase): 18 | @parameterized.parameters( 19 | {"ml_dtype": ml_dtypes.float8_e4m3fn}, 20 | {"ml_dtype": ml_dtypes.float8_e5m2}, 21 | ) 22 | def test__reduce_precision_on_forward__consistent_rounding_down(self, ml_dtype): 23 | # Values potentially "problematic" in FP8. 24 | values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) 25 | out = reduce_precision_on_forward(values, dtype=ml_dtype) 26 | expected_out = values.astype(ml_dtype) 27 | assert out.dtype == values.dtype 28 | npt.assert_array_equal(out, expected_out) 29 | 30 | @parameterized.parameters( 31 | {"ml_dtype": ml_dtypes.float8_e4m3fn}, 32 | {"ml_dtype": ml_dtypes.float8_e5m2}, 33 | ) 34 | def test__reduce_precision_on_forward__scalify_compatiblity(self, ml_dtype): 35 | values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) 36 | arr = scaled_array(values, np.float32(1)) 37 | out = scalify(partial(reduce_precision_on_forward, dtype=ml_dtype))(arr) 38 | 39 | npt.assert_array_equal(out.scale, arr.scale) 40 | npt.assert_array_equal(out, np.asarray(arr.data).astype(ml_dtype)) 41 | 42 | 43 | class CastOnForwardBackwardTests(chex.TestCase): 44 | @chex.variants(with_jit=True, without_jit=True) 45 | @parameterized.parameters( 46 | {"dtype": jnp.float16}, 47 | # TODO: uncomment when JAX 0.4+ used 48 | # {"dtype": jnp.float8_e4m3fn}, 49 | # {"dtype": jnp.float8_e5m2}, 50 | ) 51 | def test__cast_on_forward_backward__proper_results(self, dtype): 52 | # Values potentially "problematic" in FP8. 53 | values: NDArray[np.float16] = np.array([17, -17, 8, 1, 9, 11, 18], np.float16) 54 | out_on_fwd = self.variant(partial(cast_on_forward, dtype=dtype))(values) 55 | out_on_bwd = self.variant(partial(cast_on_backward, dtype=dtype))(values) 56 | 57 | assert out_on_fwd.dtype == dtype 58 | assert out_on_bwd.dtype == values.dtype 59 | npt.assert_array_equal(out_on_fwd, jax.lax.convert_element_type(values, dtype)) 60 | npt.assert_array_equal(out_on_bwd, values) 61 | 62 | @parameterized.parameters( 63 | {"dtype": jnp.float16}, 64 | # TODO: uncomment when JAX 0.4+ used 65 | # {"dtype": jnp.float8_e4m3fn}, 66 | # {"dtype": jnp.float8_e5m2}, 67 | ) 68 | def test__cast_on_backward__grad__proper_results(self, dtype): 69 | def fn(val, with_cast): 70 | if with_cast: 71 | val = cast_on_backward(val, dtype=dtype) 72 | val = val * val 73 | return jax.lax.reduce_sum_p.bind(val, axes=(0,)) 74 | 75 | # Values potentially "problematic" in FP8. 76 | values: NDArray[np.float32] = np.array([17, -17, 8, 1, 9, 11, 18], np.float32) 77 | # Backward pass => gradient. 78 | grads = jax.grad(partial(fn, with_cast=True))(values) 79 | grads_ref = jax.grad(partial(fn, with_cast=False))(values) 80 | 81 | assert grads.dtype == dtype 82 | assert grads_ref.dtype == values.dtype 83 | npt.assert_array_equal(grads, jax.lax.convert_element_type(grads_ref, dtype)) 84 | -------------------------------------------------------------------------------- /tests/ops/test_debug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import numpy as np 3 | 4 | from jax_scalify.core import scaled_array, scalify 5 | from jax_scalify.ops import debug_print 6 | 7 | 8 | def test__debug_print__scaled_arrays(capfd): 9 | fmt = "INPUTS: {} + {}" 10 | 11 | def debug_print_fn(x): 12 | debug_print(fmt, x, x) 13 | 14 | input_scaled = scaled_array([2, 3], 2.0, dtype=np.float32) 15 | scalify(debug_print_fn)(input_scaled) 16 | # Check captured stdout and stderr! 17 | captured = capfd.readouterr() 18 | assert len(captured.err) == 0 19 | assert captured.out.strip() == fmt.format(input_scaled, input_scaled) 20 | -------------------------------------------------------------------------------- /tests/ops/test_rescaling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import numpy as np 4 | import numpy.testing as npt 5 | from absl.testing import parameterized 6 | 7 | from jax_scalify.core import ScaledArray, scaled_array 8 | from jax_scalify.ops import dynamic_rescale_l1, dynamic_rescale_l2, dynamic_rescale_max 9 | 10 | 11 | class DynamicRescaleOpsTests(chex.TestCase): 12 | def test__dynamic_rescale_max__proper_max_rescale_pow2_rounding(self): 13 | arr_in = scaled_array([2, -3], np.float16(4), dtype=np.float16) 14 | arr_out = dynamic_rescale_max(arr_in) 15 | 16 | assert isinstance(arr_out, ScaledArray) 17 | assert arr_out.dtype == arr_in.dtype 18 | npt.assert_array_equal(arr_out.scale, np.float16(8)) 19 | npt.assert_array_equal(arr_out, arr_in) 20 | 21 | def test__dynamic_rescale_l1__proper_l1_rescale_pow2_rounding(self): 22 | # L1 norm = 2 23 | arr_in = scaled_array([1, -6], np.float16(4), dtype=np.float16) 24 | arr_out = dynamic_rescale_l1(arr_in) 25 | 26 | assert isinstance(arr_out, ScaledArray) 27 | assert arr_out.dtype == arr_in.dtype 28 | npt.assert_array_equal(arr_out.scale, np.float16(8)) 29 | npt.assert_array_equal(arr_out, arr_in) 30 | 31 | def test__dynamic_rescale_l2__proper_max_rescale_pow2_rounding(self): 32 | # L2 norm = 8.945 33 | arr_in = scaled_array([4, -8], np.float16(4), dtype=np.float16) 34 | arr_out = dynamic_rescale_l2(arr_in) 35 | 36 | assert isinstance(arr_out, ScaledArray) 37 | assert arr_out.dtype == arr_in.dtype 38 | npt.assert_array_equal(arr_out.scale, np.float16(16)) 39 | npt.assert_array_equal(arr_out, arr_in) 40 | 41 | @parameterized.parameters( 42 | {"dynamic_rescale_fn": dynamic_rescale_max}, 43 | {"dynamic_rescale_fn": dynamic_rescale_l1}, 44 | {"dynamic_rescale_fn": dynamic_rescale_l2}, 45 | ) 46 | def test__dynamic_rescale__epsilon_norm_value(self, dynamic_rescale_fn): 47 | arr_in = scaled_array([0, 0], np.float32(1), dtype=np.float16) 48 | arr_out = dynamic_rescale_fn(arr_in) 49 | # Rough bounds on the epsilon value. 50 | assert arr_out.scale > 0.0 51 | assert arr_out.scale < 0.001 52 | -------------------------------------------------------------------------------- /tests/quantization/test_scale.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import ml_dtypes 4 | import numpy as np 5 | import numpy.testing as npt 6 | from absl.testing import parameterized 7 | 8 | from jax_scalify.quantization.scale import as_e8m0, from_e8m0, pow2_truncate 9 | 10 | 11 | class QuantizationScaleTests(chex.TestCase): 12 | @parameterized.parameters( 13 | {"dtype": np.float16}, 14 | {"dtype": np.float32}, 15 | {"dtype": ml_dtypes.bfloat16}, 16 | ) 17 | def test__pow2_truncate__proper_result(self, dtype): 18 | vin = np.array([-2, 0, 2, 1, 9, 15]).astype(dtype) 19 | vout = pow2_truncate(vin) 20 | assert vout.dtype == vin.dtype 21 | npt.assert_array_equal(vout, [-2.0, 0.0, 2.0, 1.0, 8.0, 8.0]) 22 | 23 | @parameterized.parameters( 24 | # {"dtype": np.float16}, 25 | {"dtype": np.float32}, 26 | {"dtype": ml_dtypes.bfloat16}, 27 | ) 28 | def test__as_e8m0__positive_values(self, dtype): 29 | vin = np.array([0.6, 2, 1, 9, 15, 127]).astype(dtype).reshape((-1, 2)) 30 | vout = as_e8m0(vin) 31 | assert vout.dtype == np.uint8 32 | assert vout.shape == vin.shape 33 | npt.assert_array_equal(vout, np.log2(pow2_truncate(vin)) + 127) 34 | 35 | @parameterized.parameters( 36 | # {"dtype": np.float16}, 37 | {"dtype": np.float32}, 38 | {"dtype": ml_dtypes.bfloat16}, 39 | ) 40 | def test__as_e8m0__negative_values(self, dtype): 41 | vin = np.array([-0.1, -3, 0, 2**-127]).astype(dtype) 42 | vout = as_e8m0(vin) 43 | assert vout.dtype == np.uint8 44 | # NOTE: uint8(0) is the smallest positive scale in E8M0. 45 | npt.assert_array_equal(vout, np.uint8(0)) 46 | 47 | @parameterized.parameters( 48 | # {"dtype": np.float16}, 49 | {"dtype": np.float32}, 50 | {"dtype": ml_dtypes.bfloat16}, 51 | ) 52 | def test__from_e8m0(self, dtype): 53 | vin = np.array([2**-127, 0.25, 1, 2, 8, 2**127.0]).astype(dtype).reshape((-1, 2)) 54 | vin_e8m0 = as_e8m0(vin) 55 | vout = from_e8m0(vin_e8m0, dtype) 56 | assert vin.dtype == vout.dtype 57 | assert vout.shape == vin.shape 58 | npt.assert_array_equal(vout, vin) 59 | -------------------------------------------------------------------------------- /tests/tree/test_tree_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import numpy as np 4 | 5 | import jax_scalify as jsa 6 | 7 | 8 | class ScalifyTreeUtilTests(chex.TestCase): 9 | def test__tree_flatten__proper_result(self): 10 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 11 | outputs, _ = jsa.tree.flatten(values) 12 | assert len(outputs) == 2 13 | assert outputs[0] == 2 14 | assert isinstance(outputs[1], jsa.ScaledArray) 15 | assert np.asarray(outputs[1]) == 1.5 16 | 17 | def test__tree_leaves__proper_result(self): 18 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 19 | outputs = jsa.tree.leaves(values) 20 | assert len(outputs) == 2 21 | assert outputs[0] == 2 22 | assert isinstance(outputs[1], jsa.ScaledArray) 23 | assert np.asarray(outputs[1]) == 1.5 24 | 25 | def test__tree_structure__proper_result(self): 26 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 27 | pytree = jsa.tree.structure(values) 28 | assert pytree == jsa.tree.flatten(values)[1] 29 | 30 | def test__tree_unflatten__proper_result(self): 31 | values_in = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 32 | outputs, pytree = jsa.tree.flatten(values_in) 33 | values_out = jsa.tree.unflatten(pytree, outputs) 34 | assert values_out == values_in 35 | 36 | def test__tree_map__proper_result(self): 37 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 38 | outputs = jsa.tree.map(lambda v: v.dtype, values) 39 | assert outputs == {"a": np.int32, "b": np.float32} 40 | 41 | def test__tree_astype__all_leaves_casting(self): 42 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 43 | outputs = jsa.tree.astype(values, dtype=np.float16) 44 | dtypes = jsa.tree.map(lambda v: v.dtype, outputs) 45 | assert dtypes == {"a": np.float16, "b": np.float16} 46 | 47 | def test__tree_astype__only_float_casting(self): 48 | values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} 49 | outputs = jsa.tree.astype(values, dtype=np.float16, floating_only=True) 50 | dtypes = jsa.tree.map(lambda v: v.dtype, outputs) 51 | assert dtypes == {"a": np.int32, "b": np.float16} 52 | -------------------------------------------------------------------------------- /tests/utils/test_hlo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Graphcore Ltd. All rights reserved. 2 | import chex 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | import jax_scalify as jsa 7 | 8 | 9 | class ScalifyHloUtilsTests(chex.TestCase): 10 | def test__hlo_util__parse_hlo_module(self): 11 | def matmul_fn(a, b): 12 | return jax.lax.dot(a, b.T) 13 | 14 | a = jax.core.ShapedArray((16, 48), jnp.float16) 15 | b = jax.core.ShapedArray((32, 48), jnp.float16) 16 | 17 | matmul_lowered = jax.jit(matmul_fn).lower(a, b) 18 | matmul_compiled = matmul_lowered.compile() 19 | 20 | ops = jsa.utils.parse_hlo_module(matmul_compiled) 21 | assert len(ops) >= 6 22 | # TODO: other tests??? 23 | # jsa.utils.print_hlo_module(matmul_compiled, backend_cfg=True, indent=2) 24 | --------------------------------------------------------------------------------